rxnn 0.2.70__tar.gz → 0.2.72__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.70 → rxnn-0.2.72}/PKG-INFO +1 -1
- {rxnn-0.2.70 → rxnn-0.2.72}/pyproject.toml +1 -1
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/memory/attention.py +3 -3
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/mrl.py +30 -13
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/layers.py +7 -7
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/models.py +4 -4
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/utils.py +5 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/LICENSE +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/README.md +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.70 → rxnn-0.2.72}/src/rxnn/transformers/sampler.py +0 -0
@@ -65,7 +65,7 @@ class StmMemoryAttention(nn.Module):
|
|
65
65
|
encoded_layer_data = x[i]
|
66
66
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
67
67
|
if torch.isnan(normalized_layer_stm).any():
|
68
|
-
print(f"NaN detected in {i} layer memory norm output")
|
68
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory norm output")
|
69
69
|
|
70
70
|
if self.debug_mode and self.training:
|
71
71
|
if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
|
@@ -75,11 +75,11 @@ class StmMemoryAttention(nn.Module):
|
|
75
75
|
self.debug_step += 1
|
76
76
|
|
77
77
|
if torch.isnan(encoded_layer_data).any():
|
78
|
-
print(f"NaN detected in {i} layer encoded data input")
|
78
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer encoded data input")
|
79
79
|
|
80
80
|
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
81
81
|
if torch.isnan(new_layer_stm).any():
|
82
|
-
print(f"NaN detected in {i} layer memory attention output")
|
82
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory attention output")
|
83
83
|
|
84
84
|
if self.use_gated_residual:
|
85
85
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
@@ -592,7 +592,7 @@ class MRLTrainer:
|
|
592
592
|
|
593
593
|
router_loss = actor.moe_router_loss()
|
594
594
|
if torch.isnan(router_loss).any():
|
595
|
-
print("NaN detected in router loss")
|
595
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
|
596
596
|
if router_loss is not None:
|
597
597
|
return main_loss + self.moe_aux_loss_scale * router_loss
|
598
598
|
else:
|
@@ -607,21 +607,38 @@ class MRLTrainer:
|
|
607
607
|
print(f"Encoder grad norm - total: {encoder_total:.6f}, mean: {encoder_mean:.6f}")
|
608
608
|
print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
|
609
609
|
print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
|
610
|
-
# decoder's cross att
|
611
|
-
dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
|
612
|
-
print(f"Decoder cross-att mean norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
|
613
610
|
|
611
|
+
dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
|
614
612
|
mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
|
615
|
-
print(f"Memory attention layers mean norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
|
616
|
-
|
617
613
|
enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
|
618
|
-
print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
|
619
|
-
|
620
614
|
enc_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.encoder.model.layers]
|
621
|
-
|
615
|
+
enc_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in
|
616
|
+
self.actor.encoder.model.layers]
|
617
|
+
|
618
|
+
calc_mean = lambda x: sum(x) / len(x)
|
619
|
+
|
620
|
+
dec_x_att_norms_mean = calc_mean(dec_x_att_norms)
|
621
|
+
mem_att_norms_mean = calc_mean(mem_att_norms)
|
622
|
+
enc_ff_norms_mean = calc_mean(enc_ff_norms)
|
623
|
+
enc_self_att_norms_mean = calc_mean(enc_self_att_norms)
|
624
|
+
enc_x_att_norms_mean = calc_mean(enc_x_att_norms)
|
625
|
+
|
626
|
+
print(f"Decoder cross-att mean norm: {dec_x_att_norms_mean:.6f}, all: {dec_x_att_norms}")
|
627
|
+
print(f"Memory attention layers mean norm: {mem_att_norms_mean:.6f}, all: {mem_att_norms}")
|
628
|
+
print(f"Encoder ff mean norm: {enc_ff_norms_mean:.6f}, all: {enc_ff_norms}")
|
629
|
+
print(f"Encoder self-att mean norm: {enc_self_att_norms_mean:.6f}, all: {enc_self_att_norms}")
|
630
|
+
print(f"Encoder cross-att mean norm: {enc_x_att_norms_mean:.6f}, all: {enc_x_att_norms}")
|
631
|
+
|
632
|
+
if self.writer is not None:
|
633
|
+
self.writer.add_scalar('Gradient/encoder', encoder_mean, self.global_step['train'])
|
634
|
+
self.writer.add_scalar('Gradient/decoder', decoder_mean, self.global_step['train'])
|
635
|
+
self.writer.add_scalar('Gradient/mem-att', mem_att_mean, self.global_step['train'])
|
636
|
+
self.writer.add_scalar('Gradient/decoder x-att', dec_x_att_norms_mean, self.global_step['train'])
|
637
|
+
self.writer.add_scalar('Gradient/mem-att layers', mem_att_norms_mean, self.global_step['train'])
|
638
|
+
self.writer.add_scalar('Gradient/encoder ff', enc_ff_norms_mean, self.global_step['train'])
|
639
|
+
self.writer.add_scalar('Gradient/encoder self-att', enc_self_att_norms_mean, self.global_step['train'])
|
640
|
+
self.writer.add_scalar('Gradient/encoder x-att', enc_x_att_norms_mean, self.global_step['train'])
|
622
641
|
|
623
|
-
enc_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
|
624
|
-
print(f"Encoder cross-att mean norm: {(sum(enc_att_norms) / len(enc_att_norms)):.6f}, all: {enc_att_norms}")
|
625
642
|
|
626
643
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
627
644
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
@@ -654,7 +671,7 @@ class MRLTrainer:
|
|
654
671
|
# 4.4 Unscale and clip gradient norms
|
655
672
|
self.scaler.unscale_(self.optimizer)
|
656
673
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
657
|
-
error_if_nonfinite=
|
674
|
+
error_if_nonfinite=False)
|
658
675
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
659
676
|
self._log_gradients(logits)
|
660
677
|
# 4.5 Run scaled optimization step
|
@@ -675,7 +692,7 @@ class MRLTrainer:
|
|
675
692
|
policy_loss.backward(retain_graph=True)
|
676
693
|
# 4.4 Clip gradient norms
|
677
694
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
678
|
-
error_if_nonfinite=
|
695
|
+
error_if_nonfinite=False)
|
679
696
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
680
697
|
self._log_gradients(logits)
|
681
698
|
# 4.5 Run scaled optimization step
|
@@ -103,10 +103,10 @@ class ReactiveTransformerLayer(nn.Module):
|
|
103
103
|
if not self.use_post_norm:
|
104
104
|
x = self.norm1(x)
|
105
105
|
if torch.isnan(x).any():
|
106
|
-
print("NaN detected in pre-norm (self-attention) output")
|
106
|
+
print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (self-attention) output")
|
107
107
|
x = self.attention(x, x, x, mask=mask)
|
108
108
|
if torch.isnan(x).any():
|
109
|
-
print("NaN detected in self-attention output")
|
109
|
+
print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in self-attention output")
|
110
110
|
x = residual + x
|
111
111
|
if self.use_post_norm:
|
112
112
|
x = self.norm1(x)
|
@@ -115,17 +115,17 @@ class ReactiveTransformerLayer(nn.Module):
|
|
115
115
|
if not self.use_post_norm:
|
116
116
|
x = self.norm2(x)
|
117
117
|
if torch.isnan(x).any():
|
118
|
-
print("NaN detected in pre-norm (cross-attention) output")
|
118
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (cross-attention) output")
|
119
119
|
|
120
120
|
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
|
121
121
|
if mask is not None else None
|
122
122
|
|
123
123
|
if torch.isnan(stm).any():
|
124
|
-
print("NaN detected in STM cross-attention input")
|
124
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in STM cross-attention input")
|
125
125
|
|
126
126
|
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
127
127
|
if torch.isnan(x).any():
|
128
|
-
print("NaN detected in cross-attention output")
|
128
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in cross-attention output")
|
129
129
|
x = residual + x
|
130
130
|
if self.use_post_norm:
|
131
131
|
x = self.norm2(x)
|
@@ -135,10 +135,10 @@ class ReactiveTransformerLayer(nn.Module):
|
|
135
135
|
if not self.use_post_norm:
|
136
136
|
x = self.norm3(x)
|
137
137
|
if torch.isnan(x).any():
|
138
|
-
print("NaN detected in pre-norm (ff) output")
|
138
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (ff) output")
|
139
139
|
x = self.ff(x)
|
140
140
|
if torch.isnan(x).any():
|
141
|
-
print("NaN detected in ff output")
|
141
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in ff output")
|
142
142
|
x = residual + x
|
143
143
|
if self.use_post_norm:
|
144
144
|
x = self.norm3(x)
|
@@ -94,7 +94,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
94
94
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
95
95
|
x = super().forward(x) # apply embeddings
|
96
96
|
if torch.isnan(x).any():
|
97
|
-
print("NaN detected in decoder embedding output")
|
97
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
|
98
98
|
seq_len = x.size(1)
|
99
99
|
if not self.use_flash_attention and self.use_relative_embedding:
|
100
100
|
mask = create_causal_mask(seq_len, device=x.device)
|
@@ -112,7 +112,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
112
112
|
for i in range(self.num_own_layers):
|
113
113
|
x = self._handle_layer(i, x, mask=mask)
|
114
114
|
if torch.isnan(x).any():
|
115
|
-
print(f"NaN detected in {i}. decoder layer output")
|
115
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
|
116
116
|
return self.head(self.head_norm(x) if self.use_head_norm else x)
|
117
117
|
|
118
118
|
|
@@ -122,7 +122,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
122
122
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
123
123
|
x = super().forward(x) # apply embeddings
|
124
124
|
if torch.isnan(x).any():
|
125
|
-
print("NaN detected in encoder embedding output")
|
125
|
+
print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
|
126
126
|
if attention_mask is not None:
|
127
127
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
128
128
|
|
@@ -136,7 +136,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
136
136
|
for i in range(self.num_own_layers):
|
137
137
|
x = self._handle_layer(i, x, mask=attention_mask)
|
138
138
|
if torch.isnan(x).any():
|
139
|
-
print(f"NaN detected in {i}. encoder layer output")
|
139
|
+
print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
|
140
140
|
hidden_states.append(x)
|
141
141
|
return x, torch.stack(hidden_states)
|
142
142
|
|
@@ -1,6 +1,11 @@
|
|
1
1
|
import random, gc
|
2
|
+
from typing import Optional, Union, List, Dict, Any
|
3
|
+
|
2
4
|
import torch
|
3
5
|
import numpy as np
|
6
|
+
from huggingface_hub import PyTorchModelHubMixin
|
7
|
+
from huggingface_hub.hub_mixin import DataclassInstance
|
8
|
+
|
4
9
|
|
5
10
|
def human_format(num: int):
|
6
11
|
"""Format numbers to human-readable format."""
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|