Safemotion Lib
Loading...
Searching...
No Matches
focal_loss.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7import torch
8import torch.nn.functional as F
9
10
11# based on:
12# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
13
14def focal_loss(
15 input: torch.Tensor,
16 target: torch.Tensor,
17 alpha: float,
18 gamma: float = 2.0,
19 reduction: str = 'mean') -> torch.Tensor:
20 r"""Criterion that computes Focal loss.
21 See :class:`fastreid.modeling.losses.FocalLoss` for details.
22 According to [1], the Focal loss is computed as follows:
23 .. math::
24 \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
25 where:
26 - :math:`p_t` is the model's estimated probability for each class.
27 Arguments:
28 alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
29 gamma (float): Focusing parameter :math:`\gamma >= 0`.
30 reduction (str, optional): Specifies the reduction to apply to the
31 output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
32 ‘mean’: the sum of the output will be divided by the number of elements
33 in the output, ‘sum’: the output will be summed. Default: ‘none’.
34 Shape:
35 - Input: :math:`(N, C, *)` where C = number of classes.
36 - Target: :math:`(N, *)` where each value is
37 :math:`0 ≤ targets[i] ≤ C−1`.
38 Examples:
39 >>> N = 5 # num_classes
40 >>> loss = FocalLoss(cfg)
41 >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
42 >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
43 >>> output = loss(input, target)
44 >>> output.backward()
45 References:
46 [1] https://arxiv.org/abs/1708.02002
47 """
48 if not torch.is_tensor(input):
49 raise TypeError("Input type is not a torch.Tensor. Got {}"
50 .format(type(input)))
51
52 if not len(input.shape) >= 2:
53 raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
54 .format(input.shape))
55
56 if input.size(0) != target.size(0):
57 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
58 .format(input.size(0), target.size(0)))
59
60 n = input.size(0)
61 out_size = (n,) + input.size()[2:]
62 if target.size()[1:] != input.size()[2:]:
63 raise ValueError('Expected target size {}, got {}'.format(
64 out_size, target.size()))
65
66 if not input.device == target.device:
67 raise ValueError(
68 "input and target must be in the same device. Got: {}".format(
69 input.device, target.device))
70
71 # compute softmax over the classes axis
72 input_soft = F.softmax(input, dim=1)
73
74 # create the labels one hot tensor
75 target_one_hot = F.one_hot(target, num_classes=input.shape[1])
76
77 # compute the actual focal loss
78 weight = torch.pow(-input_soft + 1., gamma)
79
80 focal = -alpha * weight * torch.log(input_soft)
81 loss_tmp = torch.sum(target_one_hot * focal, dim=1)
82
83 if reduction == 'none':
84 loss = loss_tmp
85 elif reduction == 'mean':
86 loss = torch.mean(loss_tmp)
87 elif reduction == 'sum':
88 loss = torch.sum(loss_tmp)
89 else:
90 raise NotImplementedError("Invalid reduction mode: {}"
91 .format(reduction))
92 return loss