from typing import Any, Dict, List, Optional, Union # isort:skip
import logging
import torch
from catalyst import utils
from catalyst.core import _State, Callback, CallbackOrder
logger = logging.getLogger(__name__)
def _add_loss_to_state(
loss_key: Optional[str],
state: _State,
loss: torch.Tensor
):
if loss_key is None:
if state.loss is not None:
if isinstance(state.loss, list):
state.loss.append(loss)
else:
state.loss = [state.loss, loss]
else:
state.loss = loss
else:
if state.loss is not None:
assert isinstance(state.loss, dict)
state.loss[loss_key] = loss
else:
state.loss = {loss_key: loss}
[docs]class CriterionCallback(Callback):
"""
Callback for that measures loss with specified criterion.
"""
[docs] def __init__(
self,
input_key: Union[str, List[str], Dict[str, str]] = "targets",
output_key: Union[str, List[str], Dict[str, str]] = "logits",
prefix: str = "loss",
criterion_key: str = None,
multiplier: float = 1.0
):
"""
Args:
input_key (Union[str, List[str], Dict[str, str]]): key/list/dict
of keys that takes values from the input dictionary
If None, the whole input will be passed to the criterion.
output_key (Union[str, List[str], Dict[str, str]]): key/list/dict
of keys that takes values from the input dictionary
If None, the whole output will be passed to the criterion.
prefix (str): prefix for metrics and output key for loss
in ``state.loss`` dictionary
criterion_key (str): A key to take a criterion in case
there are several of them and they are in a dictionary format.
multiplier (float): scale factor for the output loss.
"""
super().__init__(CallbackOrder.Criterion)
self.input_key = input_key
self.output_key = output_key
self.prefix = prefix
self.criterion_key = criterion_key
self.multiplier = multiplier
self._get_input = utils.get_dictkey_auto_fn(self.input_key)
self._get_output = utils.get_dictkey_auto_fn(self.output_key)
kv_types = (dict, tuple, list, type(None))
# @TODO: fix to only KV usage
if hasattr(self, "_compute_loss"):
pass # overridden in descendants
elif isinstance(self.input_key, str) \
and isinstance(self.output_key, str):
self._compute_loss = self._compute_loss_value
elif isinstance(self.input_key, kv_types) \
and isinstance(self.output_key, kv_types):
self._compute_loss = self._compute_loss_key_value
else:
raise NotImplementedError()
def _compute_loss_value(self, state: _State, criterion):
output = self._get_output(state.output, self.output_key)
input = self._get_input(state.input, self.input_key)
loss = criterion(output, input)
return loss
def _compute_loss_key_value(self, state: _State, criterion):
output = self._get_output(state.output, self.output_key)
input = self._get_input(state.input, self.input_key)
loss = criterion(**output, **input)
return loss
[docs] def on_stage_start(self, state: _State):
"""
Checks that the current stage has correct criterion
"""
assert state.criterion is not None
[docs] def on_batch_end(self, state: _State):
"""
Computes the loss and add it to the metrics
"""
criterion = state.get_key(
key="criterion", inner_key=self.criterion_key
)
loss = self._compute_loss(state, criterion) * self.multiplier
state.metric_manager.add_batch_value(
metrics_dict={
self.prefix: loss.item(),
}
)
_add_loss_to_state(self.prefix, state, loss)
[docs]class CriterionOutputOnlyCallback(CriterionCallback):
"""
Callback for that measures loss with specified criterion.
Based on model output only.
@TODO: merge logic with CriterionCallback.
"""
[docs] def __init__(
self,
output_key: Union[Dict[str, str], List[str]],
**kwargs
):
"""
Args:
output_key (Union[List[str]], Dict[str, str]): dict or list of keys
that takes values from the output dictionary
If None, the whole output will be passed to the criterion.
**kwargs: CriterionCallback init parameters
"""
super().__init__(
input_key=None,
output_key=output_key,
**kwargs
)
def _compute_loss_value(self, state: _State, criterion):
output = self._get_output(state.output, self.output_key)
loss = criterion(output)
return loss
def _compute_loss_key_value(self, state: _State, criterion):
output = self._get_output(state.output, self.output_key)
loss = criterion(**output)
return loss
[docs]class CriterionAggregatorCallback(Callback):
"""
This callback allows you to aggregate the values of the loss
(with different aggregation strategies)
and put the value back into ``state.loss``.
"""
[docs] def __init__(
self,
prefix: str,
loss_keys: Union[str, List[str], Dict[str, float]] = None,
loss_aggregate_fn: str = "sum",
multiplier: float = 1.0
) -> None:
"""
Args:
prefix (str): new key for aggregated loss.
loss_keys (Union[str, List[str], Dict[str, float]]): If not empty,
it aggregates only the values from the loss by these keys.
for ``weighted_sum`` aggregation it must be a Dict[str, float].
loss_aggregate_fn (str): function for aggregation.
Must be either ``sum``, ``mean`` or ``weighted_sum``.
multiplier (float): scale factor for the aggregated loss.
"""
super().__init__(CallbackOrder.Criterion + 1)
if prefix is None or not isinstance(prefix, str):
raise ValueError("prefix must be str")
self.prefix = prefix
if isinstance(loss_keys, str):
loss_keys = [loss_keys]
self.loss_keys = loss_keys
self.multiplier = multiplier
if loss_keys in ("sum", "mean"):
if loss_keys is not None and not isinstance(loss_keys, list):
raise ValueError(
"For `sum` or `mean` mode the loss_keys must be "
"None or list or str (not dict)"
)
elif loss_keys in ("weighted_sum", "weighted_mean"):
if loss_keys is None or not isinstance(loss_keys, dict):
raise ValueError(
"For `weighted_sum` or `weighted_mean` mode "
"the loss_keys must be specified "
"and must be a dict"
)
if loss_aggregate_fn in ("sum", "weighted_sum", "weighted_mean"):
self.loss_fn = lambda x: torch.sum(torch.stack(x)) * multiplier
if loss_aggregate_fn == "weighted_mean":
weights_sum = sum(loss_keys.items())
self.loss_keys = {
key: weight / weights_sum
for key, weight in loss_keys.items()
}
elif loss_aggregate_fn == "mean":
self.loss_fn = lambda x: torch.mean(torch.stack(x)) * multiplier
else:
raise ValueError(
"loss_aggregate_fn must be `sum`, `mean` "
"or `weighted_sum` or `weighted_mean`"
)
self.loss_aggregate_name = loss_aggregate_fn
def _preprocess_loss(self, loss: Any) -> List[torch.Tensor]:
if isinstance(loss, list):
if self.loss_keys is not None:
logger.warning(
f"Trying to get {self.loss_keys} keys from the losses, "
"but the loss is a list. All values will be aggregated."
)
result = loss
elif isinstance(loss, dict):
if self.loss_keys is not None:
if self.loss_aggregate_name == "weighted_sum":
result = [
loss[key] * value
for key, value in self.loss_keys.items()
]
else:
result = [loss[key] for key in self.loss_keys]
else:
result = list(loss.values())
else:
result = [loss]
return result
[docs] def on_batch_end(self, state: _State) -> None:
"""
Computes the loss and add it to the metrics
"""
loss = state.get_key(key="loss")
loss = self._preprocess_loss(loss)
loss = self.loss_fn(loss)
state.metric_manager.add_batch_value(
metrics_dict={
self.prefix: loss.item(),
}
)
_add_loss_to_state(self.prefix, state, loss)
__all__ = [
"CriterionCallback",
"CriterionOutputOnlyCallback",
"CriterionAggregatorCallback"
]