rxnn 0.2.74__tar.gz → 0.2.75__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.74 → rxnn-0.2.75}/PKG-INFO +1 -1
- {rxnn-0.2.74 → rxnn-0.2.75}/pyproject.toml +1 -1
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/mrl.py +33 -7
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/models.py +0 -8
- {rxnn-0.2.74 → rxnn-0.2.75}/LICENSE +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/README.md +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/memory/gate.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.74 → rxnn-0.2.75}/src/rxnn/utils.py +0 -0
@@ -42,6 +42,7 @@ class MrlConfig(TypedDict):
|
|
42
42
|
debug_mode: Optional[bool]
|
43
43
|
debug_interval: Optional[int]
|
44
44
|
clamp_logits: Optional[bool]
|
45
|
+
max_grad_norm: Optional[float]
|
45
46
|
|
46
47
|
|
47
48
|
class MrlStrategy(Enum):
|
@@ -154,6 +155,7 @@ class MRLTrainer:
|
|
154
155
|
self.debug_mode = config.get('debug_mode', False)
|
155
156
|
self.debug_interval = config.get('debug_interval', 10)
|
156
157
|
self.clamp_logits = config.get('clamp_logits', False)
|
158
|
+
self.max_grad_norm = config.get('max_grad_norm', 1.0)
|
157
159
|
# Internal update epochs config
|
158
160
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
159
161
|
self.update_epochs = self.shared_update_epochs
|
@@ -591,13 +593,28 @@ class MRLTrainer:
|
|
591
593
|
actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
|
592
594
|
|
593
595
|
router_loss = actor.moe_router_loss()
|
594
|
-
if torch.isnan(router_loss).any():
|
595
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
|
596
596
|
if router_loss is not None:
|
597
597
|
return main_loss + self.moe_aux_loss_scale * router_loss
|
598
598
|
else:
|
599
599
|
return main_loss
|
600
600
|
|
601
|
+
def _clip_actor_grad_norms(self):
|
602
|
+
# Encoder with embedding
|
603
|
+
torch.nn.utils.clip_grad_norm_(
|
604
|
+
self.actor.encoder.parameters(),
|
605
|
+
max_norm=self.max_grad_norm, error_if_nonfinite=False
|
606
|
+
)
|
607
|
+
# Decoder
|
608
|
+
torch.nn.utils.clip_grad_norm_(
|
609
|
+
self.actor.decoder.memory_parameters() + self.actor.decoder.not_memory_parameters(),
|
610
|
+
max_norm=self.max_grad_norm, error_if_nonfinite=False
|
611
|
+
)
|
612
|
+
# Memory attention
|
613
|
+
torch.nn.utils.clip_grad_norm_(
|
614
|
+
self.actor.memory_attention.parameters(),
|
615
|
+
max_norm=self.max_grad_norm, error_if_nonfinite=False
|
616
|
+
)
|
617
|
+
|
601
618
|
def _log_gradients(self, logits: torch.Tensor):
|
602
619
|
print(
|
603
620
|
f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
|
@@ -608,8 +625,13 @@ class MRLTrainer:
|
|
608
625
|
print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
|
609
626
|
print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
|
610
627
|
|
628
|
+
dec_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.decoder.model.layers]
|
629
|
+
dec_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.decoder.model.layers]
|
611
630
|
dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
|
631
|
+
|
632
|
+
|
612
633
|
mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
|
634
|
+
|
613
635
|
enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
|
614
636
|
enc_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.encoder.model.layers]
|
615
637
|
enc_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in
|
@@ -617,22 +639,28 @@ class MRLTrainer:
|
|
617
639
|
|
618
640
|
calc_mean = lambda x: sum(x) / len(x)
|
619
641
|
|
642
|
+
dec_ff_norms_mean = calc_mean(dec_ff_norms)
|
643
|
+
dec_self_att_norms_mean = calc_mean(dec_self_att_norms)
|
620
644
|
dec_x_att_norms_mean = calc_mean(dec_x_att_norms)
|
621
645
|
mem_att_norms_mean = calc_mean(mem_att_norms)
|
622
646
|
enc_ff_norms_mean = calc_mean(enc_ff_norms)
|
623
647
|
enc_self_att_norms_mean = calc_mean(enc_self_att_norms)
|
624
648
|
enc_x_att_norms_mean = calc_mean(enc_x_att_norms)
|
625
649
|
|
650
|
+
print(f"Decoder ff mean norm: {dec_ff_norms_mean:.6f}, all: {dec_ff_norms}")
|
651
|
+
print(f"Decoder self-att mean norm: {dec_self_att_norms_mean:.6f}, all: {dec_self_att_norms}")
|
626
652
|
print(f"Decoder cross-att mean norm: {dec_x_att_norms_mean:.6f}, all: {dec_x_att_norms}")
|
627
|
-
print(f"Memory attention layers mean norm: {mem_att_norms_mean:.6f}, all: {mem_att_norms}")
|
628
653
|
print(f"Encoder ff mean norm: {enc_ff_norms_mean:.6f}, all: {enc_ff_norms}")
|
629
654
|
print(f"Encoder self-att mean norm: {enc_self_att_norms_mean:.6f}, all: {enc_self_att_norms}")
|
630
655
|
print(f"Encoder cross-att mean norm: {enc_x_att_norms_mean:.6f}, all: {enc_x_att_norms}")
|
656
|
+
print(f"Memory attention layers mean norm: {mem_att_norms_mean:.6f}, all: {mem_att_norms}")
|
631
657
|
|
632
658
|
if self.writer is not None:
|
633
659
|
self.writer.add_scalar('Gradient/encoder', encoder_mean, self.global_step['train'])
|
634
660
|
self.writer.add_scalar('Gradient/decoder', decoder_mean, self.global_step['train'])
|
635
661
|
self.writer.add_scalar('Gradient/mem-att', mem_att_mean, self.global_step['train'])
|
662
|
+
self.writer.add_scalar('Gradient/decoder ff', dec_ff_norms_mean, self.global_step['train'])
|
663
|
+
self.writer.add_scalar('Gradient/decoder self-att', dec_self_att_norms_mean, self.global_step['train'])
|
636
664
|
self.writer.add_scalar('Gradient/decoder x-att', dec_x_att_norms_mean, self.global_step['train'])
|
637
665
|
self.writer.add_scalar('Gradient/mem-att layers', mem_att_norms_mean, self.global_step['train'])
|
638
666
|
self.writer.add_scalar('Gradient/encoder ff', enc_ff_norms_mean, self.global_step['train'])
|
@@ -670,8 +698,7 @@ class MRLTrainer:
|
|
670
698
|
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
671
699
|
# 4.4 Unscale and clip gradient norms
|
672
700
|
self.scaler.unscale_(self.optimizer)
|
673
|
-
|
674
|
-
error_if_nonfinite=False)
|
701
|
+
self._clip_actor_grad_norms()
|
675
702
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
676
703
|
self._log_gradients(logits)
|
677
704
|
# 4.5 Run scaled optimization step
|
@@ -691,8 +718,7 @@ class MRLTrainer:
|
|
691
718
|
# 4.3 Run backpropagation
|
692
719
|
policy_loss.backward(retain_graph=True)
|
693
720
|
# 4.4 Clip gradient norms
|
694
|
-
|
695
|
-
error_if_nonfinite=False)
|
721
|
+
self._clip_actor_grad_norms()
|
696
722
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
697
723
|
self._log_gradients(logits)
|
698
724
|
# 4.5 Run scaled optimization step
|
@@ -93,8 +93,6 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
93
93
|
|
94
94
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
95
95
|
x = super().forward(x) # apply embeddings
|
96
|
-
if torch.isnan(x).any():
|
97
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
|
98
96
|
seq_len = x.size(1)
|
99
97
|
if not self.use_flash_attention and self.use_relative_embedding:
|
100
98
|
mask = create_causal_mask(seq_len, device=x.device)
|
@@ -111,8 +109,6 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
111
109
|
# Process own layers
|
112
110
|
for i in range(self.num_own_layers):
|
113
111
|
x = self._handle_layer(i, x, mask=mask)
|
114
|
-
if torch.isnan(x).any():
|
115
|
-
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
|
116
112
|
return self.head(self.head_norm(x) if self.use_head_norm else x)
|
117
113
|
|
118
114
|
|
@@ -121,8 +117,6 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
121
117
|
|
122
118
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
123
119
|
x = super().forward(x) # apply embeddings
|
124
|
-
if torch.isnan(x).any():
|
125
|
-
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
|
126
120
|
if attention_mask is not None:
|
127
121
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
128
122
|
|
@@ -135,8 +129,6 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
135
129
|
# Process own layers
|
136
130
|
for i in range(self.num_own_layers):
|
137
131
|
x = self._handle_layer(i, x, mask=attention_mask)
|
138
|
-
if torch.isnan(x).any():
|
139
|
-
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
|
140
132
|
hidden_states.append(x)
|
141
133
|
return x, torch.stack(hidden_states)
|
142
134
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|