Shortcuts

Source code for catalyst.metrics._cmc_score

from typing import Dict, Iterable, List, Optional

import torch

from catalyst.metrics._metric import AccumulationMetric
from catalyst.metrics.functional._cmc_score import cmc_score, masked_cmc_score
from catalyst.utils.distributed import get_rank


[docs]class CMCMetric(AccumulationMetric): """Cumulative Matching Characteristics Args: embeddings_key: key of embedding tensor in batch labels_key: key of label tensor in batch is_query_key: key of query flag tensor in batch topk_args: list of k, specifies which cmc@k should be calculated compute_on_call: if True, allows compute metric's value on call prefix: metric prefix suffix: metric suffix Examples: >>> from collections import OrderedDict >>> from torch.optim import Adam >>> from torch.utils.data import DataLoader >>> from catalyst.contrib import nn >>> from catalyst.contrib.datasets import MnistMLDataset, MnistQGDataset >>> from catalyst.data import BalanceBatchSampler, HardTripletsSampler >>> from catalyst.dl import ControlFlowCallback, LoaderMetricCallback, SupervisedRunner >>> from catalyst.metrics import CMCMetric >>> >>> dataset_root = "." >>> >>> # download dataset for train and val, create loaders >>> dataset_train = MnistMLDataset(root=dataset_root, download=True, transform=None) >>> sampler = BalanceBatchSampler(labels=dataset_train.get_labels(), p=5, k=10) >>> train_loader = DataLoader( >>> dataset=dataset_train, sampler=sampler, batch_size=sampler.batch_size >>> ) >>> dataset_valid = MnistQGDataset(root=dataset_root, transform=None, gallery_fraq=0.2) >>> valid_loader = DataLoader(dataset=dataset_valid, batch_size=1024) >>> >>> # model, optimizer, criterion >>> model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 100)) >>> optimizer = Adam(model.parameters()) >>> sampler_inbatch = HardTripletsSampler(norm_required=False) >>> criterion = nn.TripletMarginLossWithSampler( >>> margin=0.5, sampler_inbatch=sampler_inbatch >>> ) >>> >>> # batch data processing >>> class CustomRunner(SupervisedRunner): >>> def handle_batch(self, batch): >>> if self.is_train_loader: >>> images, targets = batch["features"].float(), batch["targets"].long() >>> features = model(images) >>> self.batch = { >>> "embeddings": features, >>> "targets": targets, >>> } >>> else: >>> images, targets, is_query = ( >>> batch["features"].float(), >>> batch["targets"].long(), >>> batch["is_query"].bool(), >>> ) >>> features = model(images) >>> self.batch = { >>> "embeddings": features, >>> "targets": targets, >>> "is_query": is_query, >>> } >>> >>> # training >>> runner = CustomRunner(input_key="features", output_key="embeddings") >>> runner.train( >>> model=model, >>> criterion=criterion, >>> optimizer=optimizer, >>> callbacks=OrderedDict( >>> { >>> "cmc": ControlFlowCallback( >>> LoaderMetricCallback( >>> CMCMetric( >>> embeddings_key="embeddings", >>> labels_key="targets", >>> is_query_key="is_query", >>> topk_args=(1, 3) >>> ), >>> input_key=["embeddings", "is_query"], >>> target_key=["targets"] >>> ), >>> loaders="valid", >>> ), >>> } >>> ), >>> loaders=OrderedDict({"train": train_loader, "valid": valid_loader}), >>> valid_loader="valid", >>> valid_metric="cmc01", >>> minimize_valid_metric=False, >>> logdir="./logs", >>> verbose=True, >>> num_epochs=3, >>> ) """ def __init__( self, embeddings_key: str, labels_key: str, is_query_key: str, topk_args: Iterable[int] = None, compute_on_call: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, ) -> None: """Init CMCMetric""" super().__init__( compute_on_call=compute_on_call, prefix=prefix, suffix=suffix, accumulative_fields=[embeddings_key, labels_key, is_query_key], ) self.embeddings_key = embeddings_key self.labels_key = labels_key self.is_query_key = is_query_key self.topk_args = topk_args or (1,) self.metric_name = f"{self.prefix}cmc{self.suffix}" def reset(self, num_batches: int, num_samples: int) -> None: """ Reset metrics fields Args: num_batches: expected number of batches num_samples: expected number of samples to accumulate """ super().reset(num_batches, num_samples) assert get_rank() < 0, "No DDP support implemented yet" def compute(self) -> List[float]: """ Compute cmc@k metrics with all the accumulated data for all k. Returns: list of metrics values """ query_mask = (self.storage[self.is_query_key] == 1).to(torch.bool) embeddings = self.storage[self.embeddings_key].float() labels = self.storage[self.labels_key] query_embeddings = embeddings[query_mask] query_labels = labels[query_mask] gallery_embeddings = embeddings[~query_mask] gallery_labels = labels[~query_mask] conformity_matrix = (gallery_labels == query_labels.reshape(-1, 1)).to(torch.bool) metrics = [] for k in self.topk_args: value = cmc_score( query_embeddings=query_embeddings, gallery_embeddings=gallery_embeddings, conformity_matrix=conformity_matrix, topk=k, ) metrics.append(value) return metrics def compute_key_value(self) -> Dict[str, float]: """ Compute cmc@k metrics with all the accumulated data for all k. Returns: metrics values in key-value format """ values = self.compute() kv_metrics = { f"{self.metric_name}{k:02d}": value for k, value in zip(self.topk_args, values) } return kv_metrics
[docs]class ReidCMCMetric(AccumulationMetric): """Cumulative Matching Characteristics for Reid case Args: embeddings_key: key of embedding tensor in batch pids_key: key of pids tensor in batch cids_key: key of cids tensor in batch is_query_key: key of query flag tensor in batch topk_args: list of k, specifies which cmc@k should be calculated compute_on_call: if True, allows compute metric's value on call prefix: metric prefix suffix: metric suffix """ def __init__( self, embeddings_key: str, pids_key: str, cids_key: str, is_query_key: str, topk_args: Iterable[int] = None, compute_on_call: bool = True, prefix: Optional[str] = None, suffix: Optional[str] = None, ) -> None: """Init CMCMetric""" super().__init__( compute_on_call=compute_on_call, prefix=prefix, suffix=suffix, accumulative_fields=[embeddings_key, pids_key, cids_key, is_query_key], ) self.embeddings_key = embeddings_key self.pids_key = pids_key self.cids_key = cids_key self.is_query_key = is_query_key self.topk_args = topk_args or (1,) self.metric_name = f"{self.prefix}cmc{self.suffix}" def reset(self, num_batches: int, num_samples: int) -> None: """ Reset metrics fields Args: num_batches: expected number of batches num_samples: expected number of samples to accumulate """ super().reset(num_batches, num_samples) assert get_rank() < 0, "No DDP support implemented yet" def compute(self) -> List[float]: """ Compute cmc@k metrics with all the accumulated data for all k. Returns: list of metrics values Raises: ValueError: if there are samples in query that have no relevant samples in gallery """ query_mask = (self.storage[self.is_query_key] == 1).to(torch.bool) embeddings = self.storage[self.embeddings_key].float() pids = self.storage[self.pids_key] cids = self.storage[self.cids_key] query_embeddings = embeddings[query_mask] query_pids = pids[query_mask] query_cids = cids[query_mask] gallery_embeddings = embeddings[~query_mask] gallery_pids = pids[~query_mask] gallery_cids = cids[~query_mask] pid_conformity_matrix = (gallery_pids == query_pids.reshape(-1, 1)).bool() cid_conformity_matrix = (gallery_cids == query_cids.reshape(-1, 1)).bool() # Now we are going to generate a mask that should show if # a sample from gallery can be used during model scoring on the query # sample. # There is only one case when the label shouldn't be used for: # if query sample is a photo of the person pid_i taken from camera # cam_j and the gallery sample is a photo of the same person pid_i # from the same camera cam_j. All other cases are available. available_samples = ~(pid_conformity_matrix * cid_conformity_matrix).bool() if (available_samples.max(dim=1).values == 0).any(): raise ValueError("There is a sample in query that has no relevant samples in gallery.") metrics = [] for k in self.topk_args: value = masked_cmc_score( query_embeddings=query_embeddings, gallery_embeddings=gallery_embeddings, conformity_matrix=pid_conformity_matrix, available_samples=available_samples, topk=k, ) metrics.append(value) return metrics def compute_key_value(self) -> Dict[str, float]: """ Compute cmc@k metrics with all the accumulated data for all k. Returns: metrics values in key-value format """ values = self.compute() kv_metrics = { f"{self.metric_name}{k:02d}": value for k, value in zip(self.topk_args, values) } return kv_metrics
__all__ = ["CMCMetric", "ReidCMCMetric"]