Safemotion Lib
Loading...
Searching...
No Matches
Static Public Member Functions | List of all members
fastreid.layers.gather_layer.GatherLayer Class Reference
Inheritance diagram for fastreid.layers.gather_layer.GatherLayer:

Static Public Member Functions

 forward (ctx, input)
 
 backward (ctx, *grads)
 

Detailed Description

Gather tensors from all process, supporting backward propagation.

Definition at line 13 of file gather_layer.py.

Member Function Documentation

◆ backward()

fastreid.layers.gather_layer.GatherLayer.backward ( ctx,
* grads )
static

Definition at line 26 of file gather_layer.py.

26 def backward(ctx, *grads):
27 input, = ctx.saved_tensors
28 grad_out = torch.zeros_like(input)
29 grad_out[:] = grads[dist.get_rank()]
30 return grad_out

◆ forward()

fastreid.layers.gather_layer.GatherLayer.forward ( ctx,
input )
static

Definition at line 18 of file gather_layer.py.

18 def forward(ctx, input):
19 ctx.save_for_backward(input)
20 output = [torch.zeros_like(input) \
21 for _ in range(dist.get_world_size())]
22 dist.all_gather(output, input)
23 return tuple(output)
24

The documentation for this class was generated from the following file: