Source code for catalyst.dl.utils.text

from typing import Dict, List, Union  # isort:skip
import string

import numpy as np

import torch

from catalyst.contrib.nn.modules import LamaPooling


[docs]def tokenize_text( text: str, tokenizer, # HuggingFace tokenizer, ex: BertTokenizer max_length: int, strip: bool = True, lowercase: bool = True, remove_punctuation: bool = True, ) -> Dict[str, np.array]: """Tokenizes givin text Args: text (str): text to tokenize tokenizer: Tokenizer instance from HuggingFace max_length (int): maximum length of tokens strip (bool): if true strips text before tokenizing lowercase (bool): if true makes text lowercase before tokenizing remove_punctuation (bool): if true removes ``string.punctuation`` from text before tokenizing """ if strip: text = text.strip() if lowercase: text = text.lower() if remove_punctuation: text.replace(string.punctuation, "") inputs = tokenizer.encode_plus( text, "", add_special_tokens=True, max_length=max_length ) input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] attention_mask = [1] * len(input_ids) padding_length = max_length - len(input_ids) input_ids = input_ids + ([0] * padding_length) attention_mask = attention_mask + ([0] * padding_length) token_type_ids = token_type_ids + ([0] * padding_length) return { "input_ids": np.array(input_ids, dtype=np.int64), "token_type_ids": np.array(token_type_ids, dtype=np.int64), "attention_mask": np.array(attention_mask, dtype=np.int64), }
[docs]def process_bert_output( bert_output, hidden_size: int, output_hidden_states: bool = False, pooling_groups: List[str] = None, mask: torch.Tensor = None, level: Union[int, str] = None ): """Processed the output""" # @TODO: make this functional pooling = LamaPooling( groups=pooling_groups, in_features=hidden_size ) if pooling_groups is not None else None def _process_features(features): if pooling is not None: features = pooling(features, mask=mask) return features if isinstance(level, str): assert level in ("pooling", "class") if level == "pooling": return _process_features(bert_output[0]) else: return bert_output[1] elif isinstance(level, int): return _process_features(bert_output[2][level]) output = { "pooling": _process_features(bert_output[0]), "class": bert_output[1], } if output_hidden_states: for i, feature_ in enumerate(bert_output[2]): output[i] = _process_features(feature_) return output