rxnn 0.2.73__py3-none-any.whl → 0.2.75__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/rxt/models.py CHANGED
@@ -15,7 +15,6 @@ from ..utils import get_model_size
15
15
  from ..experimental.attention import init_experimental_attention
16
16
 
17
17
 
18
-
19
18
  class RxTAlphaComponentConfig(TypedDict):
20
19
  num_layers: int
21
20
  vocab_size: int
rxnn/training/mrl.py CHANGED
@@ -42,6 +42,7 @@ class MrlConfig(TypedDict):
42
42
  debug_mode: Optional[bool]
43
43
  debug_interval: Optional[int]
44
44
  clamp_logits: Optional[bool]
45
+ max_grad_norm: Optional[float]
45
46
 
46
47
 
47
48
  class MrlStrategy(Enum):
@@ -154,6 +155,7 @@ class MRLTrainer:
154
155
  self.debug_mode = config.get('debug_mode', False)
155
156
  self.debug_interval = config.get('debug_interval', 10)
156
157
  self.clamp_logits = config.get('clamp_logits', False)
158
+ self.max_grad_norm = config.get('max_grad_norm', 1.0)
157
159
  # Internal update epochs config
158
160
  self.shared_update_epochs = config.get('update_epochs', 10)
159
161
  self.update_epochs = self.shared_update_epochs
@@ -591,13 +593,28 @@ class MRLTrainer:
591
593
  actor = next(self.actor.children()) if isinstance(self.actor, DistributedDataParallel) else self.actor
592
594
 
593
595
  router_loss = actor.moe_router_loss()
594
- if torch.isnan(router_loss).any():
595
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
596
596
  if router_loss is not None:
597
597
  return main_loss + self.moe_aux_loss_scale * router_loss
598
598
  else:
599
599
  return main_loss
600
600
 
601
+ def _clip_actor_grad_norms(self):
602
+ # Encoder with embedding
603
+ torch.nn.utils.clip_grad_norm_(
604
+ self.actor.encoder.parameters(),
605
+ max_norm=self.max_grad_norm, error_if_nonfinite=False
606
+ )
607
+ # Decoder
608
+ torch.nn.utils.clip_grad_norm_(
609
+ self.actor.decoder.memory_parameters() + self.actor.decoder.not_memory_parameters(),
610
+ max_norm=self.max_grad_norm, error_if_nonfinite=False
611
+ )
612
+ # Memory attention
613
+ torch.nn.utils.clip_grad_norm_(
614
+ self.actor.memory_attention.parameters(),
615
+ max_norm=self.max_grad_norm, error_if_nonfinite=False
616
+ )
617
+
601
618
  def _log_gradients(self, logits: torch.Tensor):
602
619
  print(
603
620
  f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
@@ -608,8 +625,13 @@ class MRLTrainer:
608
625
  print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
609
626
  print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
610
627
 
628
+ dec_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.decoder.model.layers]
629
+ dec_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.decoder.model.layers]
611
630
  dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
631
+
632
+
612
633
  mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
634
+
613
635
  enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
614
636
  enc_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.encoder.model.layers]
615
637
  enc_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in
@@ -617,22 +639,28 @@ class MRLTrainer:
617
639
 
618
640
  calc_mean = lambda x: sum(x) / len(x)
619
641
 
642
+ dec_ff_norms_mean = calc_mean(dec_ff_norms)
643
+ dec_self_att_norms_mean = calc_mean(dec_self_att_norms)
620
644
  dec_x_att_norms_mean = calc_mean(dec_x_att_norms)
621
645
  mem_att_norms_mean = calc_mean(mem_att_norms)
622
646
  enc_ff_norms_mean = calc_mean(enc_ff_norms)
623
647
  enc_self_att_norms_mean = calc_mean(enc_self_att_norms)
624
648
  enc_x_att_norms_mean = calc_mean(enc_x_att_norms)
625
649
 
650
+ print(f"Decoder ff mean norm: {dec_ff_norms_mean:.6f}, all: {dec_ff_norms}")
651
+ print(f"Decoder self-att mean norm: {dec_self_att_norms_mean:.6f}, all: {dec_self_att_norms}")
626
652
  print(f"Decoder cross-att mean norm: {dec_x_att_norms_mean:.6f}, all: {dec_x_att_norms}")
627
- print(f"Memory attention layers mean norm: {mem_att_norms_mean:.6f}, all: {mem_att_norms}")
628
653
  print(f"Encoder ff mean norm: {enc_ff_norms_mean:.6f}, all: {enc_ff_norms}")
629
654
  print(f"Encoder self-att mean norm: {enc_self_att_norms_mean:.6f}, all: {enc_self_att_norms}")
630
655
  print(f"Encoder cross-att mean norm: {enc_x_att_norms_mean:.6f}, all: {enc_x_att_norms}")
656
+ print(f"Memory attention layers mean norm: {mem_att_norms_mean:.6f}, all: {mem_att_norms}")
631
657
 
632
658
  if self.writer is not None:
633
659
  self.writer.add_scalar('Gradient/encoder', encoder_mean, self.global_step['train'])
634
660
  self.writer.add_scalar('Gradient/decoder', decoder_mean, self.global_step['train'])
635
661
  self.writer.add_scalar('Gradient/mem-att', mem_att_mean, self.global_step['train'])
662
+ self.writer.add_scalar('Gradient/decoder ff', dec_ff_norms_mean, self.global_step['train'])
663
+ self.writer.add_scalar('Gradient/decoder self-att', dec_self_att_norms_mean, self.global_step['train'])
636
664
  self.writer.add_scalar('Gradient/decoder x-att', dec_x_att_norms_mean, self.global_step['train'])
637
665
  self.writer.add_scalar('Gradient/mem-att layers', mem_att_norms_mean, self.global_step['train'])
638
666
  self.writer.add_scalar('Gradient/encoder ff', enc_ff_norms_mean, self.global_step['train'])
@@ -670,8 +698,7 @@ class MRLTrainer:
670
698
  self.scaler.scale(policy_loss).backward(retain_graph=True)
671
699
  # 4.4 Unscale and clip gradient norms
672
700
  self.scaler.unscale_(self.optimizer)
673
- torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
674
- error_if_nonfinite=False)
701
+ self._clip_actor_grad_norms()
675
702
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
676
703
  self._log_gradients(logits)
677
704
  # 4.5 Run scaled optimization step
@@ -691,8 +718,7 @@ class MRLTrainer:
691
718
  # 4.3 Run backpropagation
692
719
  policy_loss.backward(retain_graph=True)
693
720
  # 4.4 Clip gradient norms
694
- torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
695
- error_if_nonfinite=False)
721
+ self._clip_actor_grad_norms()
696
722
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
697
723
  self._log_gradients(logits)
698
724
  # 4.5 Run scaled optimization step
@@ -1,7 +1,5 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
- from poetry.console.commands import self
4
-
5
3
  from .attention import MultiHeadAttention
6
4
  from .ff import FeedForward, GatedFeedForward
7
5
  from .moe import MoeFeedForward, GatedMoeFeedForward
@@ -93,8 +93,6 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
93
93
 
94
94
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
95
95
  x = super().forward(x) # apply embeddings
96
- if torch.isnan(x).any():
97
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
98
96
  seq_len = x.size(1)
99
97
  if not self.use_flash_attention and self.use_relative_embedding:
100
98
  mask = create_causal_mask(seq_len, device=x.device)
@@ -111,8 +109,6 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
111
109
  # Process own layers
112
110
  for i in range(self.num_own_layers):
113
111
  x = self._handle_layer(i, x, mask=mask)
114
- if torch.isnan(x).any():
115
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
116
112
  return self.head(self.head_norm(x) if self.use_head_norm else x)
117
113
 
118
114
 
@@ -121,8 +117,6 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
121
117
 
122
118
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
123
119
  x = super().forward(x) # apply embeddings
124
- if torch.isnan(x).any():
125
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
126
120
  if attention_mask is not None:
127
121
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
128
122
 
@@ -135,8 +129,6 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
135
129
  # Process own layers
136
130
  for i in range(self.num_own_layers):
137
131
  x = self._handle_layer(i, x, mask=attention_mask)
138
- if torch.isnan(x).any():
139
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
140
132
  hidden_states.append(x)
141
133
  return x, torch.stack(hidden_states)
142
134
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.73
3
+ Version: 0.2.75
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -10,7 +10,7 @@ rxnn/memory/gate.py,sha256=pR_H2y9C7S02QskoFAEC9Tmluut0k4GGlHgvZGiw6m4,2332
10
10
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
11
11
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
12
12
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- rxnn/rxt/models.py,sha256=Pb48Frl6HV4Wb9CZgYtmzch3k_4Jess3rhs7dY1I96k,22209
13
+ rxnn/rxt/models.py,sha256=zgRgNUVYYuniiB1xt7HdQYgmhep6e5ybxv3PU0lcfoU,22208
14
14
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
16
16
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
@@ -18,7 +18,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
18
18
  rxnn/training/dataset.py,sha256=ruU6k33pQmpTqhxpjLFNdDJnCjcrBcGeFOzJqFahJDM,51880
19
19
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
20
20
  rxnn/training/models.py,sha256=ILkcqBV1MImnULnq-YDSSEf8cUdEbUgQaH0FRTsa4LA,9069
21
- rxnn/training/mrl.py,sha256=KUJAdUznquhf5UlcpV-QF5oKHDBEsDecMEVmMLQZw7w,67380
21
+ rxnn/training/mrl.py,sha256=XxO3LZ6Mxae5RBQJ-k2esU4oKS-BpGXC9mxVrg8MYaE,68527
22
22
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
23
23
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
24
24
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -27,14 +27,14 @@ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
27
27
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
29
29
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
30
- rxnn/transformers/layers.py,sha256=fxjlbQG6cwxq-b2ei4DnohSQGH5gwy4GkfP9duTUvjw,8492
30
+ rxnn/transformers/layers.py,sha256=7iwLZ4De4kw3-YA5p2-adCvTgeqeLC-lXcFAlhN_-AA,8450
31
31
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
32
- rxnn/transformers/models.py,sha256=TP0H9do53Z0vd8kpHMISBzMpHE5X9QIHcy0B-iJHuNQ,11711
32
+ rxnn/transformers/models.py,sha256=IfBsxmtj3val0AOAZV0cJBi1v2Q68HJoyD1LyqRYljs,11187
33
33
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
34
34
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
35
35
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
36
36
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
37
- rxnn-0.2.73.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
38
- rxnn-0.2.73.dist-info/METADATA,sha256=gtoRMeFgBuOZs4lRKl9JGUxZ2X4C9K78Ee-NHLMqW4E,60420
39
- rxnn-0.2.73.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
40
- rxnn-0.2.73.dist-info/RECORD,,
37
+ rxnn-0.2.75.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
38
+ rxnn-0.2.75.dist-info/METADATA,sha256=4re6SGM_SP2BQVHtEEKrpl5l9adUYz2wq6KsBW92rag,60420
39
+ rxnn-0.2.75.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
40
+ rxnn-0.2.75.dist-info/RECORD,,
File without changes
File without changes