Shortcuts

Source code for catalyst.contrib.data.reader

from typing import Callable, List, Optional, Type
import functools

import numpy as np

from catalyst.utils.numpy import get_one_hot


[docs]class IReader: """Reader abstraction for all Readers. Applies a function to an element of your data. For example to a row from csv, or to an image, etc. All inherited classes have to implement `__call__`. """
[docs] def __init__(self, input_key: str, output_key: str): """ Args: input_key: input key to use from annotation dict output_key: output key to use to store the result, default: ``input_key`` """ self.input_key = input_key self.output_key = output_key
def __call__(self, element): """ Reads a row from your annotations dict and transfer it to data, needed by your network for example open image by path, or read string and tokenize it. Args: element: elem in your dataset Returns: Data object used for your neural network """ raise NotImplementedError("You cannot apply a transformation using `BaseReader`")
[docs]class ScalarReader(IReader): """ Numeric data reader abstraction. Reads a single float, int, str or other from data """
[docs] def __init__( self, input_key: str, output_key: Optional[str] = None, dtype: Type = np.float32, default_value: float = None, one_hot_classes: int = None, smoothing: float = None, ): """ Args: input_key: input key to use from annotation dict output_key: output key to use to store the result, default: ``input_key`` dtype: datatype of scalar values to use default_value: default value to use if something goes wrong one_hot_classes: number of one-hot classes smoothing (float, optional): if specified applies label smoothing to one_hot classes """ super().__init__(input_key, output_key or input_key) self.dtype = dtype self.default_value = default_value self.one_hot_classes = one_hot_classes self.smoothing = smoothing if self.one_hot_classes is not None and self.smoothing is not None: assert 0.0 < smoothing < 1.0, ( "If smoothing is specified it must be in (0; 1), " + f"got {smoothing}" )
def __call__(self, element): """ Reads a row from your annotations dict and transfer it to a single value Args: element: elem in your dataset Returns: dtype: Scalar value """ scalar = self.dtype(element.get(self.input_key, self.default_value)) if self.one_hot_classes is not None: scalar = get_one_hot(scalar, self.one_hot_classes, smoothing=self.smoothing) output = {self.output_key: scalar} return output
[docs]class LambdaReader(IReader): """ Reader abstraction with an lambda encoders. Can read an elem from dataset and apply `encode_fn` function to it. """
[docs] def __init__( self, input_key: str, output_key: Optional[str] = None, lambda_fn: Optional[Callable] = None, **kwargs, ): """ Args: input_key: input key to use from annotation dict output_key: output key to use to store the result lambda_fn: encode function to use to prepare your data (for example convert chars/words/tokens to indices, etc) kwargs: kwargs for encode function """ super().__init__(input_key, output_key) lambda_fn = lambda_fn or (lambda x: x) self.lambda_fn = functools.partial(lambda_fn, **kwargs)
def __call__(self, element): """ Reads a row from your annotations dict and applies `encode_fn` function. Args: element: elem in your dataset. Returns: Value after applying `lambda_fn` function """ if self.input_key is not None: element = element[self.input_key] output = self.lambda_fn(element) if self.output_key is not None: output = {self.output_key: output} return output
[docs]class ReaderCompose(object): """Abstraction to compose several readers into one open function."""
[docs] def __init__(self, transforms: List[IReader]): """ Args: transforms: list of reader to compose mixins: list of mixins to use """ self.transforms = transforms
def __call__(self, element): """ Reads a row from your annotations dict and applies all readers and mixins Args: element: elem in your dataset. Returns: Value after applying all readers and mixins """ result = {} for transform_fn in self.transforms: result = {**result, **transform_fn(element)} return result
__all__ = [ "IReader", "ScalarReader", "LambdaReader", "ReaderCompose", ]