Source code for catalyst.tools.meters.msemeter
"""
MSE and RMSE meters.
"""
import math
import torch
from catalyst.tools.meters import meter
[docs]class MSEMeter(meter.Meter):
"""
This meter can handle MSE and RMSE.
Root calculation can be toggled(not calculated by default).
"""
[docs] def __init__(self, root: bool = False):
"""
Args:
root: Toggle between calculation of
RMSE (True) and MSE (False)
"""
super(MSEMeter, self).__init__()
self.reset()
self.root = root
[docs] def reset(self) -> None:
"""Reset meter number of elements and squared error sum."""
self.n = 0
self.sesum = 0.0
[docs] def add(self, output: torch.Tensor, target: torch.Tensor) -> None:
"""Update squared error stored sum and number of elements.
Args:
output: Model output tensor or numpy array
target: Target tensor or numpy array
"""
if not torch.is_tensor(output) and not torch.is_tensor(target):
output = torch.from_numpy(output)
target = torch.from_numpy(target)
self.n += output.numel()
self.sesum += torch.sum((output - target) ** 2)
[docs] def value(self) -> float:
"""Calculate MSE and return RMSE or MSE.
Returns:
float: Root of MSE if `self.root` is True else MSE
"""
mse = self.sesum / max(1, self.n)
return math.sqrt(mse) if self.root else mse
__all__ = ["MSEMeter"]