Safemotion Lib
Loading...
Searching...
No Matches
am_softmax.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7import torch
8from torch import nn
9import torch.nn.functional as F
10from torch.nn import Parameter
11
12
13class AMSoftmax(nn.Module):
14 r"""Implement of large margin cosine distance:
15 Args:
16 in_feat: size of each input sample
17 num_classes: size of each output sample
18 """
19
20 def __init__(self, cfg, in_feat, num_classes):
21 super().__init__()
22 self.in_features = in_feat
23 self._num_classes = num_classes
24 self.s = cfg.MODEL.HEADS.SCALE
25 self.m = cfg.MODEL.HEADS.MARGIN
26 self.weight = Parameter(torch.Tensor(num_classes, in_feat))
27 nn.init.xavier_uniform_(self.weight)
28
29 def forward(self, features, targets):
30 # --------------------------- cos(theta) & phi(theta) ---------------------------
31 cosine = F.linear(F.normalize(features), F.normalize(self.weight))
32 phi = cosine - self.m
33 # --------------------------- convert label to one-hot ---------------------------
34 targets = F.one_hot(targets, num_classes=self._num_classes)
35 output = (targets * phi) + ((1.0 - targets) * cosine)
36 output *= self.s
37
38 return output
39
40 def extra_repr(self):
41 return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
42 self.in_feat, self._num_classes, self.s, self.m
43 )
forward(self, features, targets)
Definition am_softmax.py:29
__init__(self, cfg, in_feat, num_classes)
Definition am_softmax.py:20