Source code for catalyst.contrib.criterion.center

import torch
from torch.autograd import Variable
from torch.autograd.function import Function
import torch.nn as nn


[docs]class CenterLoss(nn.Module): def __init__(self, num_classes, feature_dim): super().__init__() self.centers = nn.Parameter(torch.randn(num_classes, feature_dim)) self.loss_fn = CenterLossFunc.apply self.feature_dim = feature_dim
[docs] def forward(self, feature, label): batch_size = feature.size(0) feature = feature.view(batch_size, 1, 1, -1).squeeze() # To check the dim of centers and features if feature.size(1) != self.feature_dim: raise ValueError( "Center's dim: {0} " "should be equal to input feature's dim: {1}".format( self.feature_dim, feature.size(1) ) ) return self.loss_fn(feature, label, self.centers)
[docs]class CenterLossFunc(Function):
[docs] @staticmethod def forward(ctx, feature, label, centers): ctx.save_for_backward(feature, label, centers) centers_batch = centers.index_select(0, label.long()) return (feature - centers_batch).pow(2).sum(1).sum(0) / 2.0
[docs] @staticmethod def backward(ctx, grad_output): feature, label, centers = ctx.saved_tensors centers_batch = centers.index_select(0, label.long()) diff = centers_batch - feature # init every iteration counts = centers.new(centers.size(0)).fill_(1) ones = centers.new(label.size(0)).fill_(1) grad_centers = centers.new(centers.size()).fill_(0) counts = counts.scatter_add_(0, label.long(), ones) # print counts, grad_centers grad_centers.scatter_add_( 0, label.unsqueeze(1).expand(feature.size()).long(), diff ) grad_centers = grad_centers / counts.view(-1, 1) return Variable(-grad_output.data * diff), None, Variable(grad_centers)