Safemotion Lib
Loading...
Searching...
No Matches
activation.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: xingyu liao
4@contact: sherlockliao01@gmail.com
5"""
6
7import math
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12
13__all__ = [
14 'Mish',
15 'Swish',
16 'MemoryEfficientSwish',
17 'GELU']
18
19
20class Mish(nn.Module):
21 def __init__(self):
22 super().__init__()
23
24 def forward(self, x):
25 # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
26 return x * (torch.tanh(F.softplus(x)))
27
28
29class Swish(nn.Module):
30 def forward(self, x):
31 return x * torch.sigmoid(x)
32
33
34class SwishImplementation(torch.autograd.Function):
35 @staticmethod
36 def forward(ctx, i):
37 result = i * torch.sigmoid(i)
38 ctx.save_for_backward(i)
39 return result
40
41 @staticmethod
42 def backward(ctx, grad_output):
43 i = ctx.saved_variables[0]
44 sigmoid_i = torch.sigmoid(i)
45 return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
46
47
48class MemoryEfficientSwish(nn.Module):
49 def forward(self, x):
50 return SwishImplementation.apply(x)
51
52
53class GELU(nn.Module):
54 """
55 Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
56 """
57
58 def forward(self, x):
59 return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))