Source code for catalyst.data.reader

from typing import Callable, List, Type, Tuple, Union  # isort:skip
import functools

import numpy as np

from catalyst.utils import get_one_hot, imread, mimread


[docs]class ReaderSpec: """Reader abstraction for all Readers. Applies a function to an element of your data. For example to a row from csv, or to an image, etc. All inherited classes have to implement `__call__`. """
[docs] def __init__(self, input_key: str, output_key: str): """ Args: input_key (str): input key to use from annotation dict output_key (str): output key to use to store the result """ self.input_key = input_key self.output_key = output_key
[docs] def __call__(self, element): """Reads a row from your annotations dict and transfer it to data, needed by your network for example open image by path, or read string and tokenize it. Args: element: elem in your dataset. Returns: Data object used for your neural network """ raise NotImplementedError( "You cannot apply a transformation using `BaseReader`" )
[docs]class ImageReader(ReaderSpec): """ Image reader abstraction. Reads images from a `csv` dataset. """
[docs] def __init__( self, input_key: str, output_key: str, rootpath: str = None, grayscale: bool = False ): """ Args: input_key (str): key to use from annotation dict output_key (str): key to use to store the result rootpath (str): path to images dataset root directory (so your can use relative paths in annotations) grayscale (bool): flag if you need to work only with grayscale images """ super().__init__(input_key, output_key) self.rootpath = rootpath self.grayscale = grayscale
[docs] def __call__(self, element): """Reads a row from your annotations dict with filename and transfer it to an image Args: element: elem in your dataset. Returns: np.ndarray: Image """ image_name = str(element[self.input_key]) img = imread( image_name, rootpath=self.rootpath, grayscale=self.grayscale ) output = {self.output_key: img} return output
class MaskReader(ReaderSpec): """ Mask reader abstraction. Reads masks from a `csv` dataset. """ def __init__( self, input_key: str, output_key: str, rootpath: str = None, clip_range: Tuple[Union[int, float], Union[int, float]] = (0, 1) ): """ Args: input_key (str): key to use from annotation dict output_key (str): key to use to store the result rootpath (str): path to images dataset root directory (so your can use relative paths in annotations) clip_range (Tuple[int, int]): lower and upper interval edges, image values outside the interval are clipped to the interval edges """ super().__init__(input_key, output_key) self.rootpath = rootpath self.clip = clip_range def __call__(self, element): """Reads a row from your annotations dict with filename and transfer it to a mask Args: element: elem in your dataset. Returns: np.ndarray: Mask """ mask_name = str(element[self.input_key]) mask = mimread(mask_name, rootpath=self.rootpath, clip_range=self.clip) output = {self.output_key: mask} return output
[docs]class ScalarReader(ReaderSpec): """ Numeric data reader abstraction. Reads a single float, int, str or other from data """
[docs] def __init__( self, input_key: str, output_key: str, dtype: Type = np.float32, default_value: float = None, one_hot_classes: int = None, smoothing: float = None, ): """ Args: input_key (str): input key to use from annotation dict output_key (str): output key to use to store the result dtype (type): datatype of scalar values to use default_value: default value to use if something goes wrong one_hot_classes (int): number of one-hot classes smoothing (float, optional): if specified applies label smoothing to one_hot classes """ super().__init__(input_key, output_key) self.dtype = dtype self.default_value = default_value self.one_hot_classes = one_hot_classes self.smoothing = smoothing if self.one_hot_classes is not None and self.smoothing is not None: assert 0.0 < smoothing < 1.0, \ f"If smoothing is specified it must be in (0; 1), " \ f"got {smoothing}"
[docs] def __call__(self, element): """ Reads a row from your annotations dict and transfer it to a single value Args: element: elem in your dataset. Returns: dtype: Scalar value """ scalar = self.dtype(element.get(self.input_key, self.default_value)) if self.one_hot_classes is not None: scalar = get_one_hot( scalar, self.one_hot_classes, smoothing=self.smoothing ) output = {self.output_key: scalar} return output
[docs]class LambdaReader(ReaderSpec): """ Reader abstraction with an lambda encoders. Can read an elem from dataset and apply `encode_fn` function to it """
[docs] def __init__( self, input_key: str, output_key: str, lambda_fn: Callable = lambda x: x, **kwargs ): """ Args: input_key (str): input key to use from annotation dict output_key (str): output key to use to store the result lambda_fn (callable): encode function to use to prepare your data (for example convert chars/words/tokens to indices, etc) kwargs: kwargs for encode function """ super().__init__(input_key, output_key) self.lambda_fn = functools.partial(lambda_fn, **kwargs)
[docs] def __call__(self, element): """Reads a row from your annotations dict and applies `encode_fn` function Args: element: elem in your dataset. Returns: Value after applying `lambda_fn` function """ if self.input_key is not None: element = element[self.input_key] output = self.lambda_fn(element) if self.output_key is not None: output = {self.output_key: output} return output
[docs]class ReaderCompose(object): """ Abstraction to compose several readers into one open function. """
[docs] def __init__(self, readers: List[ReaderSpec], mixins: [] = None): """ Args: readers (List[ReaderSpec]): list of reader to compose mixins: list of mixins to use """ self.readers = readers self.mixins = mixins or []
[docs] def __call__(self, element): """Reads a row from your annotations dict and applies all readers and mixins Args: element: elem in your dataset. Returns: Value after applying all readers and mixins """ result = {} for fn in self.readers: result = {**result, **fn(element)} for fn in self.mixins: result = {**result, **fn(result)} return result
__all__ = [ "ReaderSpec", "ImageReader", "ScalarReader", "LambdaReader", "ReaderCompose" ]