Safemotion Lib
Loading...
Searching...
No Matches
smreid
fastreid
layers
gather_layer.py
Go to the documentation of this file.
1
# encoding: utf-8
2
"""
3
@author: xingyu liao
4
@contact: sherlockliao01@gmail.com
5
"""
6
7
# based on: https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py
8
9
import
torch
10
import
torch.distributed
as
dist
11
12
13
class
GatherLayer
(torch.autograd.Function):
14
"""Gather tensors from all process, supporting backward propagation.
15
"""
16
17
@staticmethod
18
def
forward
(ctx, input):
19
ctx.save_for_backward(input)
20
output = [torch.zeros_like(input) \
21
for
_
in
range(dist.get_world_size())]
22
dist.all_gather(output, input)
23
return
tuple(output)
24
25
@staticmethod
26
def
backward
(ctx, *grads):
27
input, = ctx.saved_tensors
28
grad_out = torch.zeros_like(input)
29
grad_out[:] = grads[dist.get_rank()]
30
return
grad_out
fastreid.layers.gather_layer.GatherLayer
Definition
gather_layer.py:13
fastreid.layers.gather_layer.GatherLayer.forward
forward(ctx, input)
Definition
gather_layer.py:18
fastreid.layers.gather_layer.GatherLayer.backward
backward(ctx, *grads)
Definition
gather_layer.py:26
Generated by
1.10.0