rxnn 0.2.70__py3-none-any.whl → 0.2.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -65,7 +65,7 @@ class StmMemoryAttention(nn.Module):
65
65
  encoded_layer_data = x[i]
66
66
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
67
  if torch.isnan(normalized_layer_stm).any():
68
- print(f"NaN detected in {i} layer memory norm output")
68
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory norm output")
69
69
 
70
70
  if self.debug_mode and self.training:
71
71
  if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
@@ -75,11 +75,11 @@ class StmMemoryAttention(nn.Module):
75
75
  self.debug_step += 1
76
76
 
77
77
  if torch.isnan(encoded_layer_data).any():
78
- print(f"NaN detected in {i} layer encoded data input")
78
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer encoded data input")
79
79
 
80
80
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
81
81
  if torch.isnan(new_layer_stm).any():
82
- print(f"NaN detected in {i} layer memory attention output")
82
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory attention output")
83
83
 
84
84
  if self.use_gated_residual:
85
85
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
rxnn/training/mrl.py CHANGED
@@ -592,7 +592,7 @@ class MRLTrainer:
592
592
 
593
593
  router_loss = actor.moe_router_loss()
594
594
  if torch.isnan(router_loss).any():
595
- print("NaN detected in router loss")
595
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
596
596
  if router_loss is not None:
597
597
  return main_loss + self.moe_aux_loss_scale * router_loss
598
598
  else:
@@ -607,21 +607,38 @@ class MRLTrainer:
607
607
  print(f"Encoder grad norm - total: {encoder_total:.6f}, mean: {encoder_mean:.6f}")
608
608
  print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
609
609
  print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
610
- # decoder's cross att
611
- dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
612
- print(f"Decoder cross-att mean norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
613
610
 
611
+ dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
614
612
  mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
615
- print(f"Memory attention layers mean norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
616
-
617
613
  enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
618
- print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
619
-
620
614
  enc_self_att_norms = [get_gradient_norms(layer.attention)[1] for layer in self.actor.encoder.model.layers]
621
- print(f"Encoder self-att mean norm: {(sum(enc_self_att_norms) / len(enc_self_att_norms)):.6f}, all: {enc_self_att_norms}")
615
+ enc_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in
616
+ self.actor.encoder.model.layers]
617
+
618
+ calc_mean = lambda x: sum(x) / len(x)
619
+
620
+ dec_x_att_norms_mean = calc_mean(dec_x_att_norms)
621
+ mem_att_norms_mean = calc_mean(mem_att_norms)
622
+ enc_ff_norms_mean = calc_mean(enc_ff_norms)
623
+ enc_self_att_norms_mean = calc_mean(enc_self_att_norms)
624
+ enc_x_att_norms_mean = calc_mean(enc_x_att_norms)
625
+
626
+ print(f"Decoder cross-att mean norm: {dec_x_att_norms_mean:.6f}, all: {dec_x_att_norms}")
627
+ print(f"Memory attention layers mean norm: {mem_att_norms_mean:.6f}, all: {mem_att_norms}")
628
+ print(f"Encoder ff mean norm: {enc_ff_norms_mean:.6f}, all: {enc_ff_norms}")
629
+ print(f"Encoder self-att mean norm: {enc_self_att_norms_mean:.6f}, all: {enc_self_att_norms}")
630
+ print(f"Encoder cross-att mean norm: {enc_x_att_norms_mean:.6f}, all: {enc_x_att_norms}")
631
+
632
+ if self.writer is not None:
633
+ self.writer.add_scalar('Gradient/encoder', encoder_mean, self.global_step['train'])
634
+ self.writer.add_scalar('Gradient/decoder', decoder_mean, self.global_step['train'])
635
+ self.writer.add_scalar('Gradient/mem-att', mem_att_mean, self.global_step['train'])
636
+ self.writer.add_scalar('Gradient/decoder x-att', dec_x_att_norms_mean, self.global_step['train'])
637
+ self.writer.add_scalar('Gradient/mem-att layers', mem_att_norms_mean, self.global_step['train'])
638
+ self.writer.add_scalar('Gradient/encoder ff', enc_ff_norms_mean, self.global_step['train'])
639
+ self.writer.add_scalar('Gradient/encoder self-att', enc_self_att_norms_mean, self.global_step['train'])
640
+ self.writer.add_scalar('Gradient/encoder x-att', enc_x_att_norms_mean, self.global_step['train'])
622
641
 
623
- enc_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
624
- print(f"Encoder cross-att mean norm: {(sum(enc_att_norms) / len(enc_att_norms)):.6f}, all: {enc_att_norms}")
625
642
 
626
643
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
627
644
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
@@ -654,7 +671,7 @@ class MRLTrainer:
654
671
  # 4.4 Unscale and clip gradient norms
655
672
  self.scaler.unscale_(self.optimizer)
656
673
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
657
- error_if_nonfinite=self.debug_mode)
674
+ error_if_nonfinite=False)
658
675
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
659
676
  self._log_gradients(logits)
660
677
  # 4.5 Run scaled optimization step
@@ -675,7 +692,7 @@ class MRLTrainer:
675
692
  policy_loss.backward(retain_graph=True)
676
693
  # 4.4 Clip gradient norms
677
694
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
678
- error_if_nonfinite=self.debug_mode)
695
+ error_if_nonfinite=False)
679
696
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
680
697
  self._log_gradients(logits)
681
698
  # 4.5 Run scaled optimization step
@@ -103,10 +103,10 @@ class ReactiveTransformerLayer(nn.Module):
103
103
  if not self.use_post_norm:
104
104
  x = self.norm1(x)
105
105
  if torch.isnan(x).any():
106
- print("NaN detected in pre-norm (self-attention) output")
106
+ print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (self-attention) output")
107
107
  x = self.attention(x, x, x, mask=mask)
108
108
  if torch.isnan(x).any():
109
- print("NaN detected in self-attention output")
109
+ print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in self-attention output")
110
110
  x = residual + x
111
111
  if self.use_post_norm:
112
112
  x = self.norm1(x)
@@ -115,17 +115,17 @@ class ReactiveTransformerLayer(nn.Module):
115
115
  if not self.use_post_norm:
116
116
  x = self.norm2(x)
117
117
  if torch.isnan(x).any():
118
- print("NaN detected in pre-norm (cross-attention) output")
118
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (cross-attention) output")
119
119
 
120
120
  mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
121
121
  if mask is not None else None
122
122
 
123
123
  if torch.isnan(stm).any():
124
- print("NaN detected in STM cross-attention input")
124
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in STM cross-attention input")
125
125
 
126
126
  x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
127
127
  if torch.isnan(x).any():
128
- print("NaN detected in cross-attention output")
128
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in cross-attention output")
129
129
  x = residual + x
130
130
  if self.use_post_norm:
131
131
  x = self.norm2(x)
@@ -135,10 +135,10 @@ class ReactiveTransformerLayer(nn.Module):
135
135
  if not self.use_post_norm:
136
136
  x = self.norm3(x)
137
137
  if torch.isnan(x).any():
138
- print("NaN detected in pre-norm (ff) output")
138
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (ff) output")
139
139
  x = self.ff(x)
140
140
  if torch.isnan(x).any():
141
- print("NaN detected in ff output")
141
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in ff output")
142
142
  x = residual + x
143
143
  if self.use_post_norm:
144
144
  x = self.norm3(x)
@@ -94,7 +94,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
94
94
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
95
95
  x = super().forward(x) # apply embeddings
96
96
  if torch.isnan(x).any():
97
- print("NaN detected in decoder embedding output")
97
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
98
98
  seq_len = x.size(1)
99
99
  if not self.use_flash_attention and self.use_relative_embedding:
100
100
  mask = create_causal_mask(seq_len, device=x.device)
@@ -112,7 +112,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
112
112
  for i in range(self.num_own_layers):
113
113
  x = self._handle_layer(i, x, mask=mask)
114
114
  if torch.isnan(x).any():
115
- print(f"NaN detected in {i}. decoder layer output")
115
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
116
116
  return self.head(self.head_norm(x) if self.use_head_norm else x)
117
117
 
118
118
 
@@ -122,7 +122,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
122
122
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
123
123
  x = super().forward(x) # apply embeddings
124
124
  if torch.isnan(x).any():
125
- print("NaN detected in encoder embedding output")
125
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
126
126
  if attention_mask is not None:
127
127
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
128
128
 
@@ -136,7 +136,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
136
136
  for i in range(self.num_own_layers):
137
137
  x = self._handle_layer(i, x, mask=attention_mask)
138
138
  if torch.isnan(x).any():
139
- print(f"NaN detected in {i}. encoder layer output")
139
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
140
140
  hidden_states.append(x)
141
141
  return x, torch.stack(hidden_states)
142
142
 
rxnn/utils.py CHANGED
@@ -1,6 +1,11 @@
1
1
  import random, gc
2
+ from typing import Optional, Union, List, Dict, Any
3
+
2
4
  import torch
3
5
  import numpy as np
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from huggingface_hub.hub_mixin import DataclassInstance
8
+
4
9
 
5
10
  def human_format(num: int):
6
11
  """Format numbers to human-readable format."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.70
3
+ Version: 0.2.72
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,7 +5,7 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
5
5
  rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=O4ycW3KKP5hFYadgVh47LvGWJn9zNHz8vh9E9okC0h8,4223
8
+ rxnn/memory/attention.py,sha256=el-vlkA7OYFNisdYbaQMxSphSG7Px6oDx1aO_3lFIs4,4316
9
9
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -17,7 +17,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
17
  rxnn/training/dataset.py,sha256=ruU6k33pQmpTqhxpjLFNdDJnCjcrBcGeFOzJqFahJDM,51880
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=ILkcqBV1MImnULnq-YDSSEf8cUdEbUgQaH0FRTsa4LA,9069
20
- rxnn/training/mrl.py,sha256=eIMfR0Rp7d_nvrgP9E2F_o7h2Suc0IWTUP0AXHqp-6Q,66282
20
+ rxnn/training/mrl.py,sha256=KUJAdUznquhf5UlcpV-QF5oKHDBEsDecMEVmMLQZw7w,67380
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -26,14 +26,14 @@ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
- rxnn/transformers/layers.py,sha256=bcDP8vZ5dpTWWqMCkzrPG8yQA0D0G5VjnV2Nq9IO8Dc,8816
29
+ rxnn/transformers/layers.py,sha256=VOKCURq9HQu0Uf0uk1cmyvj3u-rnoyJoZZ9Y-kSSihQ,9095
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
- rxnn/transformers/models.py,sha256=r4vNldYqCIpwMpXkFZvYbw0UBK3NE75qH7bc6OZ8YjE,11587
31
+ rxnn/transformers/models.py,sha256=TP0H9do53Z0vd8kpHMISBzMpHE5X9QIHcy0B-iJHuNQ,11711
32
32
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
- rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.70.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.70.dist-info/METADATA,sha256=jrbxT7UcwiXy63xX3TDBD1V84INrbArX362nafwkp98,60420
38
- rxnn-0.2.70.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.70.dist-info/RECORD,,
35
+ rxnn/utils.py,sha256=jnPmhehnRojRolgDxgRA_XPdcx_nUNT5tuDmrV0b-w0,1155
36
+ rxnn-0.2.72.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.72.dist-info/METADATA,sha256=FyIccoN8UysI4TKozaOrjckG0rxSelStVz7Yi3y8wXM,60420
38
+ rxnn-0.2.72.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.72.dist-info/RECORD,,
File without changes
File without changes