Safemotion Lib
Loading...
Searching...
No Matches
batch_drop.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7import random
8
9from torch import nn
10
11
12class BatchDrop(nn.Module):
13 """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
14 batch drop mask
15 """
16
17 def __init__(self, h_ratio, w_ratio):
18 super(BatchDrop, self).__init__()
19 self.h_ratio = h_ratio
20 self.w_ratio = w_ratio
21
22 def forward(self, x):
23 if self.training:
24 h, w = x.size()[-2:]
25 rh = round(self.h_ratio * h)
26 rw = round(self.w_ratio * w)
27 sx = random.randint(0, h - rh)
28 sy = random.randint(0, w - rw)
29 mask = x.new_ones(x.size())
30 mask[:, :, sx:sx + rh, sy:sy + rw] = 0
31 x = x * mask
32 return x
__init__(self, h_ratio, w_ratio)
Definition batch_drop.py:17