from typing import Callable, List, Type  # isort:skip
import functools
import numpy as np
from catalyst.utils import get_one_hot, imread, mimread
[docs]class ReaderSpec:
    """Reader abstraction for all Readers. Applies a function
    to an element of your data.
    For example to a row from csv, or to an image, etc.
    All inherited classes have to implement `__call__`.
    """
[docs]    def __init__(self, input_key: str, output_key: str):
        """
        Args:
            input_key (str): input key to use from annotation dict
            output_key (str): output key to use to store the result
        """
        self.input_key = input_key
        self.output_key = output_key 
[docs]    def __call__(self, row):
        """Reads a row from your annotations dict and
        transfer it to data, needed by your network
        for example open image by path, or read string and tokenize it.
        Args:
            row: elem in your dataset.
        Returns:
            Data object used for your neural network
        """
        raise NotImplementedError(
            "You cannot apply a transformation using `BaseReader`"
        )  
[docs]class ImageReader(ReaderSpec):
    """
    Image reader abstraction. Reads images from a `csv` dataset.
    """
[docs]    def __init__(
        self,
        input_key: str,
        output_key: str,
        datapath: str = None,
        grayscale: bool = False
    ):
        """
        Args:
            input_key (str): key to use from annotation dict
            output_key (str): key to use to store the result
            datapath (str): path to images dataset
                (so your can use relative paths in annotations)
            grayscale (bool): flag if you need to work only
                with grayscale images
        """
        super().__init__(input_key, output_key)
        self.datapath = datapath
        self.grayscale = grayscale 
[docs]    def __call__(self, row):
        """Reads a row from your annotations dict with filename and
        transfer it to an image
        Args:
            row: elem in your dataset.
        Returns:
            np.ndarray: Image
        """
        image_name = str(row[self.input_key])
        img = imread(
            image_name, rootpath=self.datapath, grayscale=self.grayscale
        )
        result = {self.output_key: img}
        return result  
class MaskReader(ReaderSpec):
    """
    Mask reader abstraction. Reads masks from a `csv` dataset.
    """
    def __init__(self, input_key: str, output_key: str, datapath: str = None):
        """
        Args:
            input_key (str): key to use from annotation dict
            output_key (str): key to use to store the result
            datapath (str): path to images dataset
                (so your can use relative paths in annotations)
        """
        super().__init__(input_key, output_key)
        self.datapath = datapath
    def __call__(self, row):
        """Reads a row from your annotations dict with filename and
        transfer it to a mask
        Args:
            row: elem in your dataset.
        Returns:
            np.ndarray: Mask
        """
        mask_name = str(row[self.input_key])
        mask = mimread(mask_name, rootpath=self.datapath, clip_range=(0, 1))
        result = {self.output_key: mask}
        return result
[docs]class ScalarReader(ReaderSpec):
    """
    Numeric data reader abstraction.
    Reads a single float, int, str or other from data
    """
[docs]    def __init__(
        self,
        input_key: str,
        output_key: str,
        dtype: Type = np.float32,
        default_value: float = None,
        one_hot_classes: int = None,
        smoothing: float = None,
    ):
        """
        Args:
            input_key (str): input key to use from annotation dict
            output_key (str): output key to use to store the result
            dtype (type): datatype of scalar values to use
            default_value: default value to use if something goes wrong
            one_hot_classes (int): number of one-hot classes
            smoothing (float, optional): if specified applies label smoothing
                to one_hot classes
        """
        super().__init__(input_key, output_key)
        self.dtype = dtype
        self.default_value = default_value
        self.one_hot_classes = one_hot_classes
        self.smoothing = smoothing
        if self.one_hot_classes is not None and self.smoothing is not None:
            assert 0.0 < smoothing < 1.0, \
                
f"If smoothing is specified it must be in (0; 1), " \
                
f"got {smoothing}" 
[docs]    def __call__(self, row):
        """Reads a row from your annotations dict with filename and
        transfer it to a single value
        Args:
            row: elem in your dataset.
        Returns:
            dtype: Scalar value
        """
        scalar = self.dtype(row.get(self.input_key, self.default_value))
        if self.one_hot_classes is not None:
            scalar = get_one_hot(
                scalar, self.one_hot_classes, smoothing=self.smoothing
            )
        result = {self.output_key: scalar}
        return result  
[docs]class LambdaReader(ReaderSpec):
    """
    Reader abstraction with an lambda encoder.
    Can read an elem from dataset and apply `encode_fn` function to it
    """
[docs]    def __init__(
        self,
        input_key: str,
        output_key: str,
        encode_fn: Callable = lambda x: x,
        **kwargs
    ):
        """
        Args:
            input_key (str): input key to use from annotation dict
            output_key (str): output key to use to store the result
            encode_fn (callable): encode function to use to prepare your data
                (for example convert chars/words/tokens to indices, etc)
            kwargs: kwargs for encode function
        """
        super().__init__(input_key, output_key)
        self.encode_fn = functools.partial(encode_fn, **kwargs) 
[docs]    def __call__(self, row):
        """Reads a row from your annotations dict
        and applies `encode_fn` function
        Args:
            row: elem in your dataset.
        Returns:
            Value after applying `encode_fn` function
        """
        elem = row[self.input_key]
        elem = self.encode_fn(elem)
        result = {self.output_key: elem}
        return result  
[docs]class ReaderCompose(object):
    """
    Abstraction to compose several readers into one open function.
    """
[docs]    def __init__(self, readers: List[ReaderSpec], mixins: [] = None):
        """
        Args:
            readers (List[ReaderSpec]): list of reader to compose
            mixins: list of mixins to use
        """
        self.readers = readers
        self.mixins = mixins or [] 
[docs]    def __call__(self, row):
        """Reads a row from your annotations dict
        and applies all readers and mixins
        Args:
            row: elem in your dataset.
        Returns:
            Value after applying all readers and mixins
        """
        result = {}
        for fn in self.readers:
            result = {**result, **fn(row)}
        for fn in self.mixins:
            result = {**result, **fn(result)}
        return result  
__all__ = [
    "ReaderSpec", "ImageReader", "ScalarReader", "LambdaReader",
    "ReaderCompose"
]