Shortcuts

Source code for catalyst.runners.self_supervised

from typing import Any, Mapping
import collections

from torch import nn

from catalyst.core.runner import IRunner


[docs]class ISelfSupervisedRunner(IRunner): """IRunner for experiments with contrastive model. Args: input_key: key in ``runner.batch`` dict mapping for model input target_key: key in ``runner.batch`` dict mapping for target loss_key: key for ``runner.batch_metrics`` to store criterion loss output augemention_prefix: key for ``runner.batch`` to sample augumentions projection_prefix: key for ``runner.batch`` to store model projection embedding_prefix: key for `runner.batch`` to store model embeddings Abstraction, please check out implementations for more details: - :py:mod:`catalyst.runners.contrastive.ContrastiveRunner` .. note:: ISelfSupervisedRunner contains only the logic with batch handling. ISelfSupervisedRunner logic pseudocode: .. code-block:: python batch = {"aug1": tensor, "aug2": tensor, ...} _, proj1 = model(batch["aug1"]) _, proj2 = model(batch["aug2"]) loss = criterion(proj1, proj2) batch_metrics["loss_key"] = loss Examples: .. code-block:: python # 1. loader and transforms transforms = Compose( [ ToTensor(), Normalize((0.1307,), (0.3081,)), torchvision.transforms.RandomCrop((28, 28)), torchvision.transforms.RandomVerticalFlip(), torchvision.transforms.RandomHorizontalFlip(), ] ) mnist = MNIST("./logdir", train=True, download=True, transform=None) contrastive_mnist = ContrastiveDataset(mnist, transforms=transforms) train_loader = torch.utils.data.DataLoader(contrastive_mnist, batch_size=BATCH_SIZE) # 2. model and optimizer encoder = MnistSimpleNet(out_features=16) projection_head = nn.Sequential( nn.Linear(16, 16, bias=False), nn.ReLU(inplace=True), nn.Linear(16, 16, bias=True) ) class ContrastiveModel(torch.nn.Module): def __init__(self, model, encoder): super(ContrastiveModel, self).__init__() self.model = model self.encoder = encoder def forward(self, x): emb = self.encoder(x) projection = self.model(emb) return emb, projection model = ContrastiveModel(model=projection_head, encoder=encoder) optimizer = Adam(model.parameters(), lr=LR) # 3. criterion with triplets sampling criterion = NTXentLoss(tau=0.1) callbacks = [ dl.ControlFlowCallback( dl.CriterionCallback( input_key="projection_left", target_key="projection_right", metric_key="loss" ), loaders="train", ), dl.SklearnModelCallback( feature_key="embedding_left", target_key="target", train_loader="train", valid_loaders="valid", model_fn=RandomForestClassifier, predict_method="predict_proba", predict_key="sklearn_predict", random_state=RANDOM_STATE, n_estimators=10, ), dl.ControlFlowCallback( dl.AccuracyCallback( target_key="target", input_key="sklearn_predict", topk_args=(1, 3) ), loaders="valid", ), ] runner = dl.ContrastiveRunner() logdir = "./logdir" runner.train( model=model, engine=engine or dl.DeviceEngine(device), criterion=criterion, optimizer=optimizer, callbacks=callbacks, loaders={"train": train_loader, "valid": train_loader}, verbose=True, logdir=logdir, valid_loader="train", valid_metric="loss", minimize_valid_metric=True, num_epochs=10, ) .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples """
[docs] def __init__( self, input_key: str = "features", target_key: str = "target", loss_key: str = "loss", augemention_prefix: str = "augment", projection_prefix: str = "projection", embedding_prefix: str = "embedding", ): """Init.""" IRunner.__init__(self) self._target_key = target_key self._loss_key = loss_key self._projection_prefix = projection_prefix self._augemention_prefix = augemention_prefix self._embedding_prefix = embedding_prefix self._input_key = input_key
def _process_batch(self, batch): if isinstance(batch, (tuple, list)): assert len(batch) in [3, 4] if len(batch) == 4: batch = { self._input_key: batch[0], f"{self._augemention_prefix}_left": batch[1], f"{self._augemention_prefix}_right": batch[2], self._target_key: batch[3], } elif len(batch) == 3: batch = { self._input_key: batch[0], f"{self._augemention_prefix}_left": batch[1], f"{self._augemention_prefix}_right": batch[2], } return batch def on_stage_start(self, runner: "IRunner"): """on_stage_start event handler.""" super().on_stage_start(runner) self.is_kv_model = False if isinstance(self.model, (collections.Mapping, nn.ModuleDict)): self.is_kv_model = True def _process_input(self, batch: Mapping[str, Any], **kwargs): if self.is_kv_model: encoders = [(encoder_name, self.model[encoder_name]) for encoder_name in self.model] else: encoders = [("", self.model)] for (encoder_name, encoder) in encoders: embedding1, projection1 = encoder(batch[f"{self._augemention_prefix}_left"], **kwargs) embedding2, projection2 = encoder(batch[f"{self._augemention_prefix}_right"], **kwargs) origin_embeddings, projection_origin = encoder(batch[self._input_key], **kwargs) prefix = f"{encoder_name}_" if encoder_name else "" batch = { **batch, f"{prefix}{self._projection_prefix}_left": projection1, f"{prefix}{self._projection_prefix}_right": projection2, f"{prefix}{self._projection_prefix}_origin": projection_origin, f"{prefix}{self._embedding_prefix}_left": embedding1, f"{prefix}{self._embedding_prefix}_right": embedding2, f"{prefix}{self._embedding_prefix}_origin": origin_embeddings, } return batch def on_batch_start(self, runner: "IRunner"): """Event handler.""" self.batch = self._process_batch(self.batch) super().on_batch_start(runner)
[docs] def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Forward method for your Runner. Should not be called directly outside of runner. If your model has specific interface, override this method to use it Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoaders. **kwargs: additional parameters to pass to the model Returns: dict with model output batch """ return self._process_input(batch, **kwargs)
[docs] def handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer stage during Experiment run. Args: batch: dictionary with data batches from DataLoader. """ self.batch = {**batch, **self.forward(batch)}
__all__ = ["ISelfSupervisedRunner"]