Safemotion Lib
Loading...
Searching...
No Matches
circle_softmax.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import math
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12from torch.nn import Parameter
13
14
15class CircleSoftmax(nn.Module):
16 def __init__(self, cfg, in_feat, num_classes):
17 super().__init__()
18 self.in_feat = in_feat
19 self._num_classes = num_classes
20 self.s = cfg.MODEL.HEADS.SCALE
21 self.m = cfg.MODEL.HEADS.MARGIN
22
23 self.weight = Parameter(torch.Tensor(num_classes, in_feat))
24 nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
25
26 def forward(self, features, targets):
27 sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
28 alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.)
29 alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.)
30 delta_p = 1 - self.m
31 delta_n = self.m
32
33 s_p = self.s * alpha_p * (sim_mat - delta_p)
34 s_n = self.s * alpha_n * (sim_mat - delta_n)
35
36 targets = F.one_hot(targets, num_classes=self._num_classes)
37
38 pred_class_logits = targets * s_p + (1.0 - targets) * s_n
39
40 return pred_class_logits
41
42 def extra_repr(self):
43 return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
44 self.in_feat, self._num_classes, self.s, self.m
45 )
__init__(self, cfg, in_feat, num_classes)