from typing import Iterable, Union

from import tqdm

import torch
from import Dataset
import transformers
from transformers import AutoTokenizer, PreTrainedTokenizer

[docs]class LanguageModelingDataset(Dataset): """ Dataset for (masked) language model task. Can sort sequnces for efficient padding. """
[docs] def __init__( self, texts: Iterable[str], tokenizer: Union[str, PreTrainedTokenizer], max_seq_length: int = None, sort: bool = True, lazy: bool = False, ): """ Args: texts: Iterable object with text tokenizer (str or tokenizer): pre trained huggingface tokenizer or model name max_seq_length: max sequence length to tokenize sort: If True then sort all sequences by length for efficient padding lazy: If True then tokenize and encode sequence in __getitem__ method else will tokenize in __init__ also if set to true sorting is unavialible """ if sort and lazy: raise Exception( "lazy is set to True so we can't sort" " sequences by length.\n" "You should set sort=False and lazy=True" " if you want to encode text in __get_item__ function" ) if isinstance(tokenizer, str): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) elif isinstance( tokenizer, transformers.tokenization_utils.PreTrainedTokenizer ): self.tokenizer = tokenizer else: raise TypeError( "tokenizer argument should be a model name" + " or huggingface PreTrainedTokenizer" ) self.max_seq_length = max_seq_length self.lazy = lazy if lazy: self.texts = texts if not lazy: pbar = tqdm(texts, desc="tokenizing texts") self.encoded = [ self.tokenizer.encode(text, max_length=max_seq_length) for text in pbar ] if sort: self.encoded.sort(key=len) self.length = len(texts) self._getitem_fn = ( self._getitem_lazy if lazy else self._getitem_encoded )
def __len__(self): """Return length of dataloader""" return self.length def _getitem_encoded(self, idx) -> torch.Tensor: return torch.tensor(self.encoded[idx]) def _getitem_lazy(self, idx) -> torch.Tensor: encoded = self.tokenizer.encode( self.texts[idx], max_length=self.max_seq_length ) return torch.tensor(encoded) def __getitem__(self, idx): """Return tokenized and encoded sequence""" return self._getitem_fn(idx)
__all__ = ["LanguageModelingDataset"]