from typing import Iterable, TYPE_CHECKING
from accelerate.state import DistributedType
from catalyst.callbacks.metric import LoaderMetricCallback
from catalyst.metrics._cmc_score import CMCMetric, ReidCMCMetric
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
[docs]class CMCScoreCallback(LoaderMetricCallback):
"""
Cumulative Matching Characteristics callback.
This callback was designed to count
cumulative matching characteristics.
If current object is from query your dataset
should output `True` in `is_query_key`
and false if current object is from gallery.
You can see `QueryGalleryDataset` in
`catalyst.contrib.datasets.metric_learning` for more information.
On batch end callback accumulate all embeddings
Args:
embeddings_key: embeddings key in output dict
labels_key: labels key in output dict
is_query_key: bool key True if current object is from query
topk: specifies which cmc@K to log
prefix: metric prefix
suffix: metric suffix
.. note::
You should use it with `ControlFlowCallback`
and add all query/gallery sets to loaders.
Loaders should contain "is_query" and "label" key.
Examples:
.. code-block:: python
import os
from torch.optim import Adam
from torch.utils.data import DataLoader
from catalyst import data, dl
from catalyst.contrib import data, datasets, models, nn
# 1. train and valid loaders
transforms = data.Compose([
data.ImageToTensor(), data.NormalizeImage((0.1307,), (0.3081,))
])
train_dataset = datasets.MnistMLDataset(
root=os.getcwd(), download=True, transform=transforms
)
sampler = data.BatchBalanceClassSampler(
labels=train_dataset.get_labels(), num_classes=5, num_samples=10
)
train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler)
valid_dataset = datasets.MnistQGDataset(
root=os.getcwd(), transform=transforms, gallery_fraq=0.2
)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024)
# 2. model and optimizer
model = models.MnistSimpleNet(out_features=16)
optimizer = Adam(model.parameters(), lr=0.001)
# 3. criterion with triplets sampling
sampler_inbatch = data.HardTripletsSampler(norm_required=False)
criterion = nn.TripletMarginLossWithSampler(
margin=0.5, sampler_inbatch=sampler_inbatch
)
# 4. training with catalyst Runner
class CustomRunner(dl.SupervisedRunner):
def handle_batch(self, batch) -> None:
if self.is_train_loader:
images, targets = batch["features"].float(), batch["targets"].long()
features = self.model(images)
self.batch = {"embeddings": features, "targets": targets}
else:
images, targets, is_query = (
batch["features"].float(),
batch["targets"].long(),
batch["is_query"].bool()
)
features = self.model(images)
self.batch = {
"embeddings": features, "targets": targets, "is_query": is_query
}
callbacks = [
dl.ControlFlowCallback(
dl.CriterionCallback(
input_key="embeddings", target_key="targets", metric_key="loss"
),
loaders="train",
),
dl.ControlFlowCallback(
dl.CMCScoreCallback(
embeddings_key="embeddings",
labels_key="targets",
is_query_key="is_query",
topk=[1],
),
loaders="valid",
),
dl.PeriodicLoaderCallback(
valid_loader_key="valid",
valid_metric_key="cmc01",
minimize=False,
valid=2
),
]
runner = CustomRunner(input_key="features", output_key="embeddings")
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
callbacks=callbacks,
loaders={"train": train_loader, "valid": valid_loader},
verbose=False,
logdir="./logs",
valid_loader="valid",
valid_metric="cmc01",
minimize_valid_metric=False,
num_epochs=10,
)
.. note::
Metric names depending on input parameters:
- ``topk = (1,) or None`` ---> ``"cmc01"``
- ``topk = (1, 3)`` ---> ``"cmc01"``, ``"cmc03"``
- ``topk = (1, 3, 5)`` ---> ``"cmc01"``, ``"cmc03"``, ``"cmc05"``
You can find them in ``runner.batch_metrics``, ``runner.loader_metrics`` or
``runner.epoch_metrics``.
.. note::
Please follow the `minimal examples`_ sections for more use cases.
.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505
"""
def __init__(
self,
embeddings_key: str,
labels_key: str,
is_query_key: str,
topk: Iterable[int] = None,
prefix: str = None,
suffix: str = None,
):
"""Init."""
super().__init__(
metric=CMCMetric(
embeddings_key=embeddings_key,
labels_key=labels_key,
is_query_key=is_query_key,
topk=topk,
prefix=prefix,
suffix=suffix,
),
input_key=[embeddings_key, is_query_key],
target_key=[labels_key],
)
def on_experiment_start(self, runner: "IRunner") -> None:
"""Event handler."""
assert runner.engine.distributed_type not in (
DistributedType.MULTI_GPU,
DistributedType.TPU,
), "CMCScoreCallback could not work within ddp training"
return super().on_experiment_start(runner)
[docs]class ReidCMCScoreCallback(LoaderMetricCallback):
"""
Cumulative Matching Characteristics callback for reID case.
More information about cmc-based callbacks in CMCScoreCallback's docs.
Args:
embeddings_key: embeddings key in output dict
pids_key: pids key in output dict
cids_key: cids key in output dict
is_query_key: bool key True if current object is from query
topk: specifies which cmc@K to log.
[1] - cmc@1
[1, 3] - cmc@1 and cmc@3
[1, 3, 5] - cmc@1, cmc@3 and cmc@5
prefix: metric prefix
suffix: metric suffix
"""
def __init__(
self,
embeddings_key: str,
pids_key: str,
cids_key: str,
is_query_key: str,
topk: Iterable[int] = None,
prefix: str = None,
suffix: str = None,
):
"""Init."""
super().__init__(
metric=ReidCMCMetric(
embeddings_key=embeddings_key,
pids_key=pids_key,
cids_key=cids_key,
is_query_key=is_query_key,
topk=topk,
prefix=prefix,
suffix=suffix,
),
input_key=[embeddings_key, is_query_key],
target_key=[pids_key, cids_key],
)
def on_experiment_start(self, runner: "IRunner") -> None:
"""Event handler."""
assert runner.engine.distributed_type not in (
DistributedType.MULTI_GPU,
DistributedType.TPU,
), "ReidCMCScoreCallback could not work within ddp training"
return super().on_experiment_start(runner)
__all__ = ["CMCScoreCallback", "ReidCMCScoreCallback"]