32 def forward(self, features, targets):
33
34 cos_theta = F.linear(F.normalize(features), F.normalize(self.weight))
35 cos_theta = cos_theta.clamp(-1, 1)
36
37 target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1)
38
39 sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
40 cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m
41 mask = cos_theta > cos_theta_m
42 final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
43
44 hard_example = cos_theta[mask]
45 with torch.no_grad():
46 self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
47 cos_theta[mask] = hard_example * (self.t + hard_example)
48 cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
49 pred_class_logits = cos_theta * self.s
50 return pred_class_logits
51