Source code for catalyst.dl.meters.classerrormeter
import numbers
import numpy as np
import torch
from . import meter
[docs]class ClassErrorMeter(meter.Meter):
def __init__(self, topk=[1], accuracy=False):
super(ClassErrorMeter, self).__init__()
self.topk = np.sort(topk)
self.accuracy = accuracy
self.reset()
[docs] def reset(self):
self.sum = {v: 0 for v in self.topk}
self.n = 0
[docs] def add(self, output, target):
if torch.is_tensor(output):
output = output.cpu().squeeze().numpy()
if torch.is_tensor(target):
target = np.atleast_1d(target.cpu().squeeze().numpy())
elif isinstance(target, numbers.Number):
target = np.asarray([target])
if np.ndim(output) == 1:
output = output[np.newaxis]
else:
assert np.ndim(output) == 2, \
"wrong output size (1D or 2D expected)"
assert np.ndim(target) == 1, \
"target and output do not match"
assert target.shape[0] == output.shape[0], \
"target and output do not match"
topk = self.topk
maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64
no = output.shape[0]
pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy()
correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1)
for k in topk:
self.sum[k] += no - correct[:, 0:k].sum()
self.n += no
[docs] def value(self, k=-1):
if k != -1:
assert k in self.sum.keys(), \
"invalid k (this k was not provided at construction time)"
if self.accuracy:
return (1. - float(self.sum[k]) / self.n) * 100.0
else:
return float(self.sum[k]) / self.n * 100.0
else:
return [self.value(k_) for k_ in self.topk]