Source code for catalyst.contrib.criterion.center
import torch
from torch.autograd import Variable
from torch.autograd.function import Function
import torch.nn as nn
[docs]class CenterLoss(nn.Module):
def __init__(self, num_classes, feature_dim):
super().__init__()
self.centers = nn.Parameter(torch.randn(num_classes, feature_dim))
self.loss_fn = CenterLossFunc.apply
self.feature_dim = feature_dim
[docs] def forward(self, feature, label):
batch_size = feature.size(0)
feature = feature.view(batch_size, 1, 1, -1).squeeze()
# To check the dim of centers and features
if feature.size(1) != self.feature_dim:
raise ValueError(
"Center's dim: {0} "
"should be equal to input feature's dim: {1}".format(
self.feature_dim, feature.size(1)
)
)
return self.loss_fn(feature, label, self.centers)
[docs]class CenterLossFunc(Function):
[docs] @staticmethod
def forward(ctx, feature, label, centers):
ctx.save_for_backward(feature, label, centers)
centers_batch = centers.index_select(0, label.long())
return (feature - centers_batch).pow(2).sum(1).sum(0) / 2.0
[docs] @staticmethod
def backward(ctx, grad_output):
feature, label, centers = ctx.saved_tensors
centers_batch = centers.index_select(0, label.long())
diff = centers_batch - feature
# init every iteration
counts = centers.new(centers.size(0)).fill_(1)
ones = centers.new(label.size(0)).fill_(1)
grad_centers = centers.new(centers.size()).fill_(0)
counts = counts.scatter_add_(0, label.long(), ones)
# print counts, grad_centers
grad_centers.scatter_add_(
0,
label.unsqueeze(1).expand(feature.size()).long(), diff
)
grad_centers = grad_centers / counts.view(-1, 1)
return Variable(-grad_output.data * diff), None, Variable(grad_centers)