Gather tensors from all process, supporting backward propagation.
Definition at line 13 of file gather_layer.py.
◆ backward()
| fastreid.layers.gather_layer.GatherLayer.backward |
( |
| ctx, |
|
|
* | grads ) |
|
static |
Definition at line 26 of file gather_layer.py.
26 def backward(ctx, *grads):
27 input, = ctx.saved_tensors
28 grad_out = torch.zeros_like(input)
29 grad_out[:] = grads[dist.get_rank()]
30 return grad_out
◆ forward()
| fastreid.layers.gather_layer.GatherLayer.forward |
( |
| ctx, |
|
|
| input ) |
|
static |
Definition at line 18 of file gather_layer.py.
18 def forward(ctx, input):
19 ctx.save_for_backward(input)
20 output = [torch.zeros_like(input) \
21 for _ in range(dist.get_world_size())]
22 dist.all_gather(output, input)
23 return tuple(output)
24
The documentation for this class was generated from the following file: