Shortcuts

Source code for catalyst.contrib.nn.modules.curricularface

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class CurricularFace(nn.Module): """Implementation of `CurricularFace: Adaptive Curriculum Learning\ Loss for Deep Face Recognition`_. .. _CurricularFace\: Adaptive Curriculum Learning\ Loss for Deep Face Recognition: https://arxiv.org/abs/2004.00288 Official `pytorch implementation`_. .. _pytorch implementation: https://github.com/HuangYG123/CurricularFace Args: in_features: size of each input sample. out_features: size of each output sample. s: norm of input feature. Default: ``64.0``. m: margin. Default: ``0.5``. Shape: - Input: :math:`(batch, H_{in})` where :math:`H_{in} = in\_features`. - Output: :math:`(batch, H_{out})` where :math:`H_{out} = out\_features`. Example: >>> layer = CurricularFace(5, 10, s=1.31, m=0.5) >>> loss_fn = nn.CrosEntropyLoss() >>> embedding = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(10) >>> output = layer(embedding, target) >>> loss = loss_fn(output, target) >>> loss.backward() """ # noqa: RST215 def __init__( # noqa: D107 self, in_features: int, out_features: int, s: float = 64.0, m: float = 0.5, ): super(CurricularFace, self).__init__() self.in_features = in_features self.out_features = out_features self.m = m self.s = s self.cos_m = math.cos(m) self.sin_m = math.sin(m) self.threshold = math.cos(math.pi - m) self.mm = math.sin(math.pi - m) * m self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) self.register_buffer("t", torch.zeros(1)) nn.init.normal_(self.weight, std=0.01) def __repr__(self) -> str: # noqa: D105 rep = ( "CurricularFace(" f"in_features={self.in_features}," f"out_features={self.out_features}," f"m={self.m},s={self.s}" ")" ) return rep
[docs] def forward(self, input: torch.Tensor, label: torch.LongTensor = None) -> torch.Tensor: """ Args: input: input features, expected shapes ``BxF`` where ``B`` is batch dimension and ``F`` is an input feature dimension. label: target classes, expected shapes ``B`` where ``B`` is batch dimension. If `None` then will be returned projection on centroids. Default is `None`. Returns: tensor (logits) with shapes ``BxC`` where ``C`` is a number of classes. """ cos_theta = torch.mm(F.normalize(input), F.normalize(self.weight, dim=0)) cos_theta = cos_theta.clamp(-1, 1) # for numerical stability if label is None: return cos_theta target_logit = cos_theta[torch.arange(0, input.size(0)), label].view(-1, 1) sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) mask = cos_theta > cos_theta_m final_target_logit = torch.where( target_logit > self.threshold, cos_theta_m, target_logit - self.mm ) hard_example = cos_theta[mask] with torch.no_grad(): self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t cos_theta[mask] = hard_example * (self.t + hard_example) cos_theta.scatter_(1, label.view(-1, 1).long(), final_target_logit) output = cos_theta * self.s return output
__all__ = ["CurricularFace"]