Shortcuts

Source code for catalyst.data.dataset.metric_learning

from typing import Dict, List
from abc import ABC, abstractmethod

import torch
from torch.utils.data import Dataset


[docs]class MetricLearningTrainDataset(Dataset, ABC): """ Base class for datasets adapted for metric learning train stage. """
[docs] @abstractmethod def get_labels(self) -> List[int]: """ Dataset for metric learning must provide label of each sample for forming positive and negative pairs during the training based on these labels. Raises: NotImplementedError: You should implement it # noqa: DAR402 """ raise NotImplementedError()
[docs]class QueryGalleryDataset(Dataset, ABC): """ QueryGallleryDataset for CMCScoreCallback """
[docs] @abstractmethod def __getitem__(self, item) -> Dict[str, torch.Tensor]: """ Dataset for query/gallery split should return dict with `feature`, `targets` and `is_query` key. Value by key `is_query` should be boolean and indicate whether current object is in query or in gallery. Raises: NotImplementedError: You should implement it # noqa: DAR402 """ raise NotImplementedError()
@property @abstractmethod def query_size(self) -> int: """ Query/Gallery dataset should have property query size. Returns: query size # noqa: DAR202 Raises: NotImplementedError: You should implement it # noqa: DAR402 """ raise NotImplementedError() @property @abstractmethod def gallery_size(self) -> int: """ Query/Gallery dataset should have property gallery size. Returns: gallery size # noqa: DAR202 Raises: NotImplementedError: You should implement it # noqa: DAR402 """ raise NotImplementedError()
__all__ = ["MetricLearningTrainDataset", "QueryGalleryDataset"]