Source code for catalyst.contrib.dl.callbacks.perplexity_metric
from torch import nn
from catalyst.core.callbacks import MetricCallback
[docs]class PerplexityMetricCallback(MetricCallback):
"""
Perplexity is a very popular metric in NLP
especially in Language Modeling task.
It is 2^cross_entropy.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "perplexity",
ignore_index: int = None,
):
"""
Args:
input_key (str): input key to use for perplexity calculation,
target tokens
output_key (str): output key to use for perplexity calculation,
logits of the predicted tokens
ignore_index (int): index to ignore, usually pad_index
"""
self.ignore_index = ignore_index or nn.CrossEntropyLoss().ignore_index
self.cross_entropy_loss = nn.CrossEntropyLoss(
ignore_index=self.ignore_index
)
super().__init__(
metric_fn=self.metric_fn,
input_key=input_key,
output_key=output_key,
prefix=prefix,
)
[docs] def metric_fn(self, outputs, targets):
"""Calculate perplexity"""
cross_entropy = (
self.cross_entropy_loss(outputs, targets).detach().cpu()
)
perplexity = 2 ** cross_entropy
return perplexity.item()