Source code for catalyst.dl.meters.msemeter

import math

import torch

from . import meter


[docs]class MSEMeter(meter.Meter): def __init__(self, root=False): super(MSEMeter, self).__init__() self.reset() self.root = root
[docs] def reset(self): self.n = 0 self.sesum = 0.0
[docs] def add(self, output, target): if not torch.is_tensor(output) and not torch.is_tensor(target): output = torch.from_numpy(output) target = torch.from_numpy(target) self.n += output.numel() self.sesum += torch.sum((output - target) ** 2)
[docs] def value(self): mse = self.sesum / max(1, self.n) return math.sqrt(mse) if self.root else mse