Source code for catalyst.callbacks.metrics.perplexity
from functools import partial
from torch import nn
from catalyst.callbacks.metric import BatchMetricCallback
def _perplexity_metric(outputs, targets, criterion):
cross_entropy = criterion(outputs, targets).detach()
perplexity = 2 ** cross_entropy
return perplexity
[docs]class PerplexityCallback(BatchMetricCallback):
"""
Perplexity is a very popular metric in NLP
especially in Language Modeling task.
It is 2^cross_entropy.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "perplexity",
ignore_index: int = None,
**kwargs,
):
"""
Args:
input_key: input key to use for perplexity calculation,
target tokens
output_key: output key to use for perplexity calculation,
logits of the predicted tokens
ignore_index: index to ignore, usually pad_index
"""
self.ignore_index = ignore_index or nn.CrossEntropyLoss().ignore_index
self.cross_entropy_loss = nn.CrossEntropyLoss(
ignore_index=self.ignore_index
)
metric_fn = partial(
_perplexity_metric, criterion=self.cross_entropy_loss
)
super().__init__(
metric_fn=metric_fn,
input_key=input_key,
output_key=output_key,
prefix=prefix,
**kwargs,
)
# backward compatibility
PerplexityMetricCallback = PerplexityCallback
__all__ = ["PerplexityCallback", "PerplexityMetricCallback"]