Source code for catalyst.dl.meters.msemeter
import math
import torch
from . import meter
[docs]class MSEMeter(meter.Meter):
def __init__(self, root=False):
super(MSEMeter, self).__init__()
self.reset()
self.root = root
[docs] def reset(self):
self.n = 0
self.sesum = 0.0
[docs] def add(self, output, target):
if not torch.is_tensor(output) and not torch.is_tensor(target):
output = torch.from_numpy(output)
target = torch.from_numpy(target)
self.n += output.numel()
self.sesum += torch.sum((output - target) ** 2)
[docs] def value(self):
mse = self.sesum / max(1, self.n)
return math.sqrt(mse) if self.root else mse