Shortcuts

Source code for catalyst.dl.utils.trace

from typing import (  # isort:skip
    Any,
    Callable,
    Dict,
    List,
    Union,
)
import inspect
from pathlib import Path

from torch import nn
from torch.jit import load, save, ScriptModule, trace

from catalyst.core.runner import IRunner
from catalyst.dl.experiment.config import ConfigExperiment
from catalyst.tools.typing import Device, Model
from catalyst.utils import (
    any2device,
    assert_fp16_available,
    get_fn_argsnames,
    get_native_batch_from_loaders,
    get_nn_from_ddp_module,
    get_requires_grad,
    import_experiment_and_runner,
    load_checkpoint,
    load_config,
    pack_checkpoint,
    set_requires_grad,
    unpack_checkpoint,
)


def _get_input_argnames(
    fn: Callable[..., Any], exclude: List[str] = None
) -> List[str]:
    """
    Function to get input argument names of function.

    Args:
        fn (Callable[..., Any]): Function to get argument names from
        exclude (List[str]): List of string of names to exclude

    Returns:
        (List[str]): List of input argument names
    """
    argspec = inspect.getfullargspec(fn)
    assert (
        argspec.varargs is None and argspec.varkw is None
    ), "not supported by PyTorch"

    return get_fn_argsnames(fn, exclude=exclude)


class _ForwardOverrideModel(nn.Module):
    """Model that calls specified method instead of forward.

    (Workaround, single method tracing is not supported)
    """

    def __init__(self, model, method_name):
        super().__init__()
        self.model = model
        self.method_name = method_name

    def forward(self, *args, **kwargs):
        return getattr(self.model, self.method_name)(*args, **kwargs)


class _TracingModelWrapper(nn.Module):
    """Wrapper that traces model with batch instead of calling it.

    (Workaround, to use native model batch handler)
    """

    def __init__(self, model, method_name):
        super().__init__()
        self.model = model
        self.method_name = method_name
        self.tracing_result: ScriptModule

    def __call__(self, *args, **kwargs):
        method_model = _ForwardOverrideModel(self.model, self.method_name)

        try:
            assert len(args) == 0, "only KV support implemented"

            fn = getattr(self.model, self.method_name)
            method_argnames = _get_input_argnames(fn=fn, exclude=["self"])
            method_input = tuple(kwargs[name] for name in method_argnames)

            self.tracing_result = trace(method_model, method_input)
        except Exception:
            # for backward compatibility
            self.tracing_result = trace(method_model, *args, **kwargs)
        output = self.model.forward(*args, **kwargs)

        return output


[docs]def trace_model( model: Model, predict_fn: Callable, batch=None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """Traces model using runner and batch. Args: model: Model to trace predict_fn: Function to run prediction with the model provided, takes model, inputs parameters batch: Batch to trace the model method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): Apex FP16 init level, optional device (str): Torch device predict_params (dict): additional parameters for model forward Returns: (ScriptModule): Traced model """ if batch is None or predict_fn is None: raise ValueError("Both batch and predict_fn must be specified.") if mode not in ["train", "eval"]: raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'") predict_params = predict_params or {} tracer = _TracingModelWrapper(model, method_name) if opt_level is not None: assert_fp16_available() # If traced in AMP we need to initialize the model before calling # the jit # https://github.com/NVIDIA/apex/issues/303#issuecomment-493142950 from apex import amp model = model.to(device) model = amp.initialize(model, optimizers=None, opt_level=opt_level) getattr(model, mode)() set_requires_grad(model, requires_grad=requires_grad) predict_fn(tracer, batch, **predict_params) return tracer.tracing_result
[docs]def trace_model_from_checkpoint( logdir: Path, method_name: str, checkpoint_name: str, stage: str = None, loader: Union[str, int] = None, mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ): """ Traces model using created experiment and runner. Args: logdir (Union[str, Path]): Path to Catalyst logdir with model checkpoint_name (str): Name of model checkpoint to use stage (str): experiment's stage name loader (Union[str, int]): experiment's loader name or its index method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: the traced model """ config_path = logdir / "configs" / "_config.json" checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" print("Load config") config: Dict[str, dict] = load_config(config_path) runner_params = config.get("runner_params", {}) or {} # Get expdir name config_expdir = Path(config["args"]["expdir"]) # We will use copy of expdir from logs for reproducibility expdir = Path(logdir) / "code" / config_expdir.name print("Import experiment and runner from logdir") ExperimentType, RunnerType = import_experiment_and_runner(expdir) experiment: ConfigExperiment = ExperimentType(config) print(f"Load model state from checkpoints/{checkpoint_name}.pth") if stage is None: stage = list(experiment.stages)[0] model = experiment.get_model(stage) checkpoint = load_checkpoint(checkpoint_path) unpack_checkpoint(checkpoint, model=model) runner: RunnerType = RunnerType(**runner_params) runner.model, runner.device = model, device if loader is None: loader = 0 batch = get_native_batch_from_loaders( loaders=experiment.get_loaders(stage), loader=loader ) # function to run prediction on batch def predict_fn(model, inputs, **kwargs): _model = runner.model runner.model = model result = runner.predict_batch(inputs, **kwargs) runner.model = _model return result print("Tracing") traced_model = trace_model( model=model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) print("Done") return traced_model
[docs]def trace_model_from_runner( runner: IRunner, checkpoint_name: str = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ) -> ScriptModule: """ Traces model using created experiment and runner. Args: runner (Runner): Current runner. checkpoint_name (str): Name of model checkpoint to use, if None traces current model from runner method_name (str): Model's method name that will be used as entrypoint during tracing mode (str): Mode for model to trace (``train`` or ``eval``) requires_grad (bool): Flag to use grads opt_level (str): AMP FP16 init level device (str): Torch device Returns: (ScriptModule): Traced model """ logdir = runner.logdir model = get_nn_from_ddp_module(runner.model) if checkpoint_name is not None: dumped_checkpoint = pack_checkpoint(model=model) checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" checkpoint = load_checkpoint(filepath=checkpoint_path) unpack_checkpoint(checkpoint=checkpoint, model=model) # getting input names of args for method since we don't have Runner # and we don't know input_key to preprocess batch for method call fn = getattr(model, method_name) method_argnames = _get_input_argnames(fn=fn, exclude=["self"]) batch = {} for name in method_argnames: # TODO: We don't know input_keys without runner assert name in runner.input, ( "Input batch should contain the same keys as input argument " "names of `forward` function to be traced correctly" ) batch[name] = runner.input[name] batch = any2device(batch, device) # Dumping previous runner of the model, we will need it to restore _device, _is_training, _requires_grad = ( runner.device, model.training, get_requires_grad(model), ) model.to(device) # Function to run prediction on batch def predict_fn(model: Model, inputs, **kwargs): return model(**inputs, **kwargs) traced_model = trace_model( model=model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) if checkpoint_name is not None: unpack_checkpoint(checkpoint=dumped_checkpoint, model=model) # Restore previous runner of the model getattr(model, "train" if _is_training else "eval")() set_requires_grad(model, _requires_grad) model.to(_device) return traced_model
[docs]def get_trace_name( method_name: str, mode: str = "eval", requires_grad: bool = False, opt_level: str = None, additional_string: str = None, ) -> str: """Creates a file name for the traced model. Args: method_name (str): model's method name mode (str): ``train`` or ``eval`` requires_grad (bool): flag if model was traced with gradients opt_level (str): opt_level if model was traced in FP16 additional_string (str): any additional information Returns: file_name (str): Filename for traced model to be saved. """ file_name = f"traced" if additional_string is not None: file_name += f"-{additional_string}" file_name += f"-{method_name}" if mode == "train": file_name += "-in_train" if requires_grad: file_name += f"-with_grad" if opt_level is not None: file_name += f"-opt_{opt_level}" file_name += ".pth" return file_name
[docs]def save_traced_model( model: ScriptModule, logdir: Union[str, Path] = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, out_dir: Union[str, Path] = None, out_model: Union[str, Path] = None, checkpoint_name: str = None, ): """Saves traced model. Args: model (ScriptModule): Traced model logdir (Union[str, Path]): Path to experiment method_name (str): Name of the method was traced mode (str): Model's mode - `train` or `eval` requires_grad (bool): Whether model was traced with require_grad or not opt_level (str): Apex FP16 init level used during tracing out_dir (Union[str, Path]): Directory to save model to (overrides logdir) out_model (Union[str, Path]): Path to save model to (overrides logdir & out_dir) checkpoint_name (str): Checkpoint name used to restore the model """ if out_model is None: file_name = get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, additional_string=checkpoint_name, ) output: Path = out_dir if output is None: if logdir is None: raise ValueError( "One of `logdir`, `out_dir` or `out_model` " "should be specified" ) output: Path = Path(logdir) / "trace" output.mkdir(exist_ok=True, parents=True) out_model = str(output / file_name) else: out_model = str(out_model) save(model, out_model)
[docs]def load_traced_model( model_path: Union[str, Path], device: Device = "cpu", opt_level: str = None, ) -> ScriptModule: """Loads a traced model. Args: model_path: Path to traced model device (str): Torch device opt_level (str): Apex FP16 init level, optional Returns: (ScriptModule): Traced model """ # jit.load dont work with pathlib.Path model_path = str(model_path) if opt_level is not None: device = "cuda" model = load(model_path, map_location=device) if opt_level is not None: assert_fp16_available() from apex import amp model = amp.initialize(model, optimizers=None, opt_level=opt_level) return model
__all__ = [ "trace_model", "trace_model_from_checkpoint", "trace_model_from_runner", "get_trace_name", "save_traced_model", "load_traced_model", ]