Source code for catalyst.utils.meters.msemeter
import math
import torch
from . import meter
[docs]class MSEMeter(meter.Meter):
    def __init__(self, root=False):
        super(MSEMeter, self).__init__()
        self.reset()
        self.root = root
[docs]    def reset(self):
        self.n = 0
        self.sesum = 0.0 
[docs]    def add(self, output, target):
        if not torch.is_tensor(output) and not torch.is_tensor(target):
            output = torch.from_numpy(output)
            target = torch.from_numpy(target)
        self.n += output.numel()
        self.sesum += torch.sum((output - target)**2) 
[docs]    def value(self):
        mse = self.sesum / max(1, self.n)
        return math.sqrt(mse) if self.root else mse