Source code for

from typing import List, Mapping
import logging

import torch
from import Dataset
from transformers import AutoTokenizer

[docs]class TextClassificationDataset(Dataset): """Wrapper around Torch Dataset to perform text classification."""
[docs] def __init__( self, texts: List[str], labels: List[str] = None, label_dict: Mapping[str, int] = None, max_seq_length: int = 512, model_name: str = "distilbert-base-uncased", ): """ Args: texts (List[str]): a list with texts to classify or to train the classifier on labels List[str]: a list with classification labels (optional) label_dict (dict): a dictionary mapping class names to class ids, to be passed to the validation data (optional) max_seq_length (int): maximal sequence length in tokens, texts will be stripped to this length model_name (str): transformer model name, needed to perform appropriate tokenization """ self.texts = texts self.labels = labels self.label_dict = label_dict self.max_seq_length = max_seq_length if self.label_dict is None and labels is not None: # {'class1': 0, 'class2': 1, 'class3': 2, ...} # using this instead of `sklearn.preprocessing.LabelEncoder` # no easily handle unknown target values self.label_dict = dict( zip(sorted(set(labels)), range(len(set(labels)))) ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) # suppresses tokenizer warnings logging.getLogger("transformers.tokenization_utils").setLevel( logging.FATAL ) # special tokens for transformers # in the simplest case a [CLS] token is added in the beginning # and [SEP] token is added in the end of a piece of text # [CLS] <indexes text tokens> [SEP] .. <[PAD]> self.sep_vid = self.tokenizer.vocab["[SEP]"] self.cls_vid = self.tokenizer.vocab["[CLS]"] self.pad_vid = self.tokenizer.vocab["[PAD]"]
def __len__(self) -> int: """ Returns: int: length of the dataset """ return len(self.texts) def __getitem__(self, index) -> Mapping[str, torch.Tensor]: """Gets element of the dataset. Args: index (int): index of the element in the dataset Returns: Single element by index """ # encoding the text x = self.texts[index] x_encoded = self.tokenizer.encode( x, add_special_tokens=True, max_length=self.max_seq_length, return_tensors="pt", ).squeeze(0) # padding short texts true_seq_length = x_encoded.size(0) pad_size = self.max_seq_length - true_seq_length pad_ids = torch.Tensor([self.pad_vid] * pad_size).long() x_tensor =, pad_ids)) # dealing with attention masks - there's a 1 for each input token and # if the sequence is shorter that `max_seq_length` then the rest is # padded with zeroes. Attention mask will be passed to the model in # order to compute attention scores only with input data # ignoring padding mask = torch.ones_like(x_encoded, dtype=torch.int8) mask_pad = torch.zeros_like(pad_ids, dtype=torch.int8) mask =, mask_pad)) output_dict = {"features": x_tensor, "attention_mask": mask} # encoding target if self.labels is not None: y = self.labels[index] y_encoded = ( torch.Tensor([self.label_dict.get(y, -1)]).long().squeeze(0) ) output_dict["targets"] = y_encoded return output_dict
__all__ = ["TextClassificationDataset"]