Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Protected Attributes | List of all members
fastreid.layers.arc_softmax.ArcSoftmax Class Reference
Inheritance diagram for fastreid.layers.arc_softmax.ArcSoftmax:

Public Member Functions

 __init__ (self, cfg, in_feat, num_classes)
 
 forward (self, features, targets)
 
 extra_repr (self)
 

Public Attributes

 in_feat
 
 s
 
 m
 
 cos_m
 
 sin_m
 
 threshold
 
 mm
 
 weight
 
 t
 

Protected Attributes

 _num_classes
 

Detailed Description

Definition at line 15 of file arc_softmax.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.layers.arc_softmax.ArcSoftmax.__init__ ( self,
cfg,
in_feat,
num_classes )

Definition at line 16 of file arc_softmax.py.

16 def __init__(self, cfg, in_feat, num_classes):
17 super().__init__()
18 self.in_feat = in_feat
19 self._num_classes = num_classes
20 self.s = cfg.MODEL.HEADS.SCALE
21 self.m = cfg.MODEL.HEADS.MARGIN
22
23 self.cos_m = math.cos(self.m)
24 self.sin_m = math.sin(self.m)
25 self.threshold = math.cos(math.pi - self.m)
26 self.mm = math.sin(math.pi - self.m) * self.m
27
28 self.weight = Parameter(torch.Tensor(num_classes, in_feat))
29 nn.init.xavier_uniform_(self.weight)
30 self.register_buffer('t', torch.zeros(1))
31

Member Function Documentation

◆ extra_repr()

fastreid.layers.arc_softmax.ArcSoftmax.extra_repr ( self)

Definition at line 52 of file arc_softmax.py.

52 def extra_repr(self):
53 return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
54 self.in_feat, self._num_classes, self.s, self.m
55 )

◆ forward()

fastreid.layers.arc_softmax.ArcSoftmax.forward ( self,
features,
targets )

Definition at line 32 of file arc_softmax.py.

32 def forward(self, features, targets):
33 # get cos(theta)
34 cos_theta = F.linear(F.normalize(features), F.normalize(self.weight))
35 cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
36
37 target_logit = cos_theta[torch.arange(0, features.size(0)), targets].view(-1, 1)
38
39 sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
40 cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
41 mask = cos_theta > cos_theta_m
42 final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
43
44 hard_example = cos_theta[mask]
45 with torch.no_grad():
46 self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
47 cos_theta[mask] = hard_example * (self.t + hard_example)
48 cos_theta.scatter_(1, targets.view(-1, 1).long(), final_target_logit)
49 pred_class_logits = cos_theta * self.s
50 return pred_class_logits
51

Member Data Documentation

◆ _num_classes

fastreid.layers.arc_softmax.ArcSoftmax._num_classes
protected

Definition at line 19 of file arc_softmax.py.

◆ cos_m

fastreid.layers.arc_softmax.ArcSoftmax.cos_m

Definition at line 23 of file arc_softmax.py.

◆ in_feat

fastreid.layers.arc_softmax.ArcSoftmax.in_feat

Definition at line 18 of file arc_softmax.py.

◆ m

fastreid.layers.arc_softmax.ArcSoftmax.m

Definition at line 21 of file arc_softmax.py.

◆ mm

fastreid.layers.arc_softmax.ArcSoftmax.mm

Definition at line 26 of file arc_softmax.py.

◆ s

fastreid.layers.arc_softmax.ArcSoftmax.s

Definition at line 20 of file arc_softmax.py.

◆ sin_m

fastreid.layers.arc_softmax.ArcSoftmax.sin_m

Definition at line 24 of file arc_softmax.py.

◆ t

fastreid.layers.arc_softmax.ArcSoftmax.t

Definition at line 46 of file arc_softmax.py.

◆ threshold

fastreid.layers.arc_softmax.ArcSoftmax.threshold

Definition at line 25 of file arc_softmax.py.

◆ weight

fastreid.layers.arc_softmax.ArcSoftmax.weight

Definition at line 28 of file arc_softmax.py.


The documentation for this class was generated from the following file: