Source code for catalyst.data.nlp.dataset.language_modeling
from typing import Iterable, Union
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset
import transformers
from transformers import AutoTokenizer, PreTrainedTokenizer
[docs]class LanguageModelingDataset(Dataset):
    """
    Dataset for (masked) language model task.
    Can sort sequnces for efficient padding.
    """
[docs]    def __init__(
        self,
        texts: Iterable[str],
        tokenizer: Union[str, PreTrainedTokenizer],
        max_seq_length: int = None,
        sort: bool = True,
        lazy: bool = False,
    ):
        """
        Args:
            texts (Iterable): Iterable object with text
            tokenizer (str or tokenizer): pre trained
                huggingface tokenizer or model name
            max_seq_length (int): max sequence length to tokenize
            sort (bool): If True then sort all sequences by length
                for efficient padding
            lazy (bool): If True then tokenize and encode sequence
                in __getitem__ method
                else will tokenize in __init__ also
                if set to true sorting is unavialible
        """
        if sort and lazy:
            raise Exception(
                "lazy is set to True so we can't sort"
                " sequences by length.\n"
                "You should set sort=False and lazy=True"
                " if you want to encode text in __get_item__ function"
            )
        if isinstance(tokenizer, str):
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        elif isinstance(
            tokenizer, transformers.tokenization_utils.PreTrainedTokenizer
        ):
            self.tokenizer = tokenizer
        else:
            raise TypeError(
                "tokenizer argument should be a model name"
                + " or huggingface PreTrainedTokenizer"
            )
        self.max_seq_length = max_seq_length
        self.lazy = lazy
        if lazy:
            self.texts = texts
        if not lazy:
            pbar = tqdm(texts, desc="tokenizing texts")
            self.encoded = [
                self.tokenizer.encode(text, max_length=max_seq_length)
                for text in pbar
            ]
            if sort:
                self.encoded.sort(key=len)
        self.length = len(texts)
        self._getitem_fn = (
            self._getitem_lazy if lazy else self._getitem_encoded
        ) 
    def __len__(self):
        """Return length of dataloader"""
        return self.length
    def _getitem_encoded(self, idx) -> torch.Tensor:
        return torch.tensor(self.encoded[idx])
    def _getitem_lazy(self, idx) -> torch.Tensor:
        encoded = self.tokenizer.encode(
            self.texts[idx], max_length=self.max_seq_length
        )
        return torch.tensor(encoded)
    def __getitem__(self, idx):
        """Return tokenized and encoded sequence"""
        return self._getitem_fn(idx) 
__all__ = ["LanguageModelingDataset"]