from typing import Dict, Iterable, List, Optional
import torch
from catalyst.metrics._metric import AccumulationMetric
from catalyst.metrics.functional._cmc_score import cmc_score, masked_cmc_score
from catalyst.utils.distributed import get_rank
[docs]class CMCMetric(AccumulationMetric):
"""Cumulative Matching Characteristics
Args:
embeddings_key: key of embedding tensor in batch
labels_key: key of label tensor in batch
is_query_key: key of query flag tensor in batch
topk_args: list of k, specifies which cmc@k should be calculated
compute_on_call: if True, allows compute metric's value on call
prefix: metric prefix
suffix: metric suffix
Examples:
>>> from collections import OrderedDict
>>> from torch.optim import Adam
>>> from torch.utils.data import DataLoader
>>> from catalyst.contrib import nn
>>> from catalyst.contrib.datasets import MnistMLDataset, MnistQGDataset
>>> from catalyst.data import BalanceBatchSampler, HardTripletsSampler
>>> from catalyst.dl import ControlFlowCallback, LoaderMetricCallback, SupervisedRunner
>>> from catalyst.metrics import CMCMetric
>>>
>>> dataset_root = "."
>>>
>>> # download dataset for train and val, create loaders
>>> dataset_train = MnistMLDataset(root=dataset_root, download=True, transform=None)
>>> sampler = BalanceBatchSampler(labels=dataset_train.get_labels(), p=5, k=10)
>>> train_loader = DataLoader(
>>> dataset=dataset_train, sampler=sampler, batch_size=sampler.batch_size
>>> )
>>> dataset_valid = MnistQGDataset(root=dataset_root, transform=None, gallery_fraq=0.2)
>>> valid_loader = DataLoader(dataset=dataset_valid, batch_size=1024)
>>>
>>> # model, optimizer, criterion
>>> model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 100))
>>> optimizer = Adam(model.parameters())
>>> sampler_inbatch = HardTripletsSampler(norm_required=False)
>>> criterion = nn.TripletMarginLossWithSampler(
>>> margin=0.5, sampler_inbatch=sampler_inbatch
>>> )
>>>
>>> # batch data processing
>>> class CustomRunner(SupervisedRunner):
>>> def handle_batch(self, batch):
>>> if self.is_train_loader:
>>> images, targets = batch["features"].float(), batch["targets"].long()
>>> features = model(images)
>>> self.batch = {
>>> "embeddings": features,
>>> "targets": targets,
>>> }
>>> else:
>>> images, targets, is_query = (
>>> batch["features"].float(),
>>> batch["targets"].long(),
>>> batch["is_query"].bool(),
>>> )
>>> features = model(images)
>>> self.batch = {
>>> "embeddings": features,
>>> "targets": targets,
>>> "is_query": is_query,
>>> }
>>>
>>> # training
>>> runner = CustomRunner(input_key="features", output_key="embeddings")
>>> runner.train(
>>> model=model,
>>> criterion=criterion,
>>> optimizer=optimizer,
>>> callbacks=OrderedDict(
>>> {
>>> "cmc": ControlFlowCallback(
>>> LoaderMetricCallback(
>>> CMCMetric(
>>> embeddings_key="embeddings",
>>> labels_key="targets",
>>> is_query_key="is_query",
>>> topk_args=(1, 3)
>>> ),
>>> input_key=["embeddings", "is_query"],
>>> target_key=["targets"]
>>> ),
>>> loaders="valid",
>>> ),
>>> }
>>> ),
>>> loaders=OrderedDict({"train": train_loader, "valid": valid_loader}),
>>> valid_loader="valid",
>>> valid_metric="cmc01",
>>> minimize_valid_metric=False,
>>> logdir="./logs",
>>> verbose=True,
>>> num_epochs=3,
>>> )
"""
def __init__(
self,
embeddings_key: str,
labels_key: str,
is_query_key: str,
topk_args: Iterable[int] = None,
compute_on_call: bool = True,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
) -> None:
"""Init CMCMetric"""
super().__init__(
compute_on_call=compute_on_call,
prefix=prefix,
suffix=suffix,
accumulative_fields=[embeddings_key, labels_key, is_query_key],
)
self.embeddings_key = embeddings_key
self.labels_key = labels_key
self.is_query_key = is_query_key
self.topk_args = topk_args or (1,)
self.metric_name = f"{self.prefix}cmc{self.suffix}"
def reset(self, num_batches: int, num_samples: int) -> None:
"""
Reset metrics fields
Args:
num_batches: expected number of batches
num_samples: expected number of samples to accumulate
"""
super().reset(num_batches, num_samples)
assert get_rank() < 0, "No DDP support implemented yet"
def compute(self) -> List[float]:
"""
Compute cmc@k metrics with all the accumulated data for all k.
Returns:
list of metrics values
"""
query_mask = (self.storage[self.is_query_key] == 1).to(torch.bool)
embeddings = self.storage[self.embeddings_key].float()
labels = self.storage[self.labels_key]
query_embeddings = embeddings[query_mask]
query_labels = labels[query_mask]
gallery_embeddings = embeddings[~query_mask]
gallery_labels = labels[~query_mask]
conformity_matrix = (gallery_labels == query_labels.reshape(-1, 1)).to(torch.bool)
metrics = []
for k in self.topk_args:
value = cmc_score(
query_embeddings=query_embeddings,
gallery_embeddings=gallery_embeddings,
conformity_matrix=conformity_matrix,
topk=k,
)
metrics.append(value)
return metrics
def compute_key_value(self) -> Dict[str, float]:
"""
Compute cmc@k metrics with all the accumulated data for all k.
Returns:
metrics values in key-value format
"""
values = self.compute()
kv_metrics = {
f"{self.metric_name}{k:02d}": value for k, value in zip(self.topk_args, values)
}
return kv_metrics
[docs]class ReidCMCMetric(AccumulationMetric):
"""Cumulative Matching Characteristics for Reid case
Args:
embeddings_key: key of embedding tensor in batch
pids_key: key of pids tensor in batch
cids_key: key of cids tensor in batch
is_query_key: key of query flag tensor in batch
topk_args: list of k, specifies which cmc@k should be calculated
compute_on_call: if True, allows compute metric's value on call
prefix: metric prefix
suffix: metric suffix
"""
def __init__(
self,
embeddings_key: str,
pids_key: str,
cids_key: str,
is_query_key: str,
topk_args: Iterable[int] = None,
compute_on_call: bool = True,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
) -> None:
"""Init CMCMetric"""
super().__init__(
compute_on_call=compute_on_call,
prefix=prefix,
suffix=suffix,
accumulative_fields=[embeddings_key, pids_key, cids_key, is_query_key],
)
self.embeddings_key = embeddings_key
self.pids_key = pids_key
self.cids_key = cids_key
self.is_query_key = is_query_key
self.topk_args = topk_args or (1,)
self.metric_name = f"{self.prefix}cmc{self.suffix}"
def reset(self, num_batches: int, num_samples: int) -> None:
"""
Reset metrics fields
Args:
num_batches: expected number of batches
num_samples: expected number of samples to accumulate
"""
super().reset(num_batches, num_samples)
assert get_rank() < 0, "No DDP support implemented yet"
def compute(self) -> List[float]:
"""
Compute cmc@k metrics with all the accumulated data for all k.
Returns:
list of metrics values
Raises:
ValueError: if there are samples in query that have no relevant samples in gallery
"""
query_mask = (self.storage[self.is_query_key] == 1).to(torch.bool)
embeddings = self.storage[self.embeddings_key].float()
pids = self.storage[self.pids_key]
cids = self.storage[self.cids_key]
query_embeddings = embeddings[query_mask]
query_pids = pids[query_mask]
query_cids = cids[query_mask]
gallery_embeddings = embeddings[~query_mask]
gallery_pids = pids[~query_mask]
gallery_cids = cids[~query_mask]
pid_conformity_matrix = (gallery_pids == query_pids.reshape(-1, 1)).bool()
cid_conformity_matrix = (gallery_cids == query_cids.reshape(-1, 1)).bool()
# Now we are going to generate a mask that should show if
# a sample from gallery can be used during model scoring on the query
# sample.
# There is only one case when the label shouldn't be used for:
# if query sample is a photo of the person pid_i taken from camera
# cam_j and the gallery sample is a photo of the same person pid_i
# from the same camera cam_j. All other cases are available.
available_samples = ~(pid_conformity_matrix * cid_conformity_matrix).bool()
if (available_samples.max(dim=1).values == 0).any():
raise ValueError("There is a sample in query that has no relevant samples in gallery.")
metrics = []
for k in self.topk_args:
value = masked_cmc_score(
query_embeddings=query_embeddings,
gallery_embeddings=gallery_embeddings,
conformity_matrix=pid_conformity_matrix,
available_samples=available_samples,
topk=k,
)
metrics.append(value)
return metrics
def compute_key_value(self) -> Dict[str, float]:
"""
Compute cmc@k metrics with all the accumulated data for all k.
Returns:
metrics values in key-value format
"""
values = self.compute()
kv_metrics = {
f"{self.metric_name}{k:02d}": value for k, value in zip(self.topk_args, values)
}
return kv_metrics
__all__ = ["CMCMetric", "ReidCMCMetric"]