12def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
13 """Log a histogram of trust ratio scalars in across layers."""
14 results = collections.defaultdict(list)
15 for group in optimizer.param_groups:
16 for p in group['params']:
17 state = optimizer.state[p]
18 for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
19 if i in state:
20 results[i].append(state[i])
21
22 for k, v in results.items():
23 event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
24
25