rxnn 0.2.63__tar.gz → 0.2.64__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.63 → rxnn-0.2.64}/PKG-INFO +1 -1
  2. {rxnn-0.2.63 → rxnn-0.2.64}/pyproject.toml +1 -1
  3. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/experimental/models.py +6 -0
  4. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/rxt/models.py +18 -2
  5. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/mrl.py +9 -6
  6. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/utils.py +5 -5
  7. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/models.py +21 -5
  8. {rxnn-0.2.63 → rxnn-0.2.64}/LICENSE +0 -0
  9. {rxnn-0.2.63 → rxnn-0.2.64}/README.md +0 -0
  10. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/.DS_Store +0 -0
  11. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/__init__.py +0 -0
  12. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/experimental/__init__.py +0 -0
  13. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/experimental/attention.py +0 -0
  14. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/memory/attention.py +0 -0
  17. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/memory/stm.py +0 -0
  19. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/rxt/__init__.py +0 -0
  20. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/base.py +0 -0
  22. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/bml.py +0 -0
  23. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/ddp.py +0 -0
  26. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/models.py +0 -0
  27. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/reward.py +0 -0
  28. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/rl.py +0 -0
  29. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/scheduler.py +0 -0
  30. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/training/tokenizer.py +0 -0
  31. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.63 → rxnn-0.2.64}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.63
3
+ Version: 0.2.64
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.63"
7
+ version = "0.2.64"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -34,6 +34,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
34
34
  att_num_query_groups: int
35
35
  att_num_global_tokens: int
36
36
  att_window_size: int
37
+ use_head_norm: bool
38
+ init_identity_norm: bool
37
39
 
38
40
 
39
41
  class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
@@ -67,6 +69,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
67
69
  att_num_query_groups: int = None,
68
70
  att_num_global_tokens: int = 16,
69
71
  att_window_size: int = 128,
72
+ use_head_norm: bool = False,
73
+ init_identity_norm: bool = False,
70
74
  **kwargs
71
75
  ):
72
76
  super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
@@ -110,6 +114,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
110
114
  use_rms_norm=use_rms_norm,
111
115
  self_attention=att_init(),
112
116
  use_moe_att=use_moe_att,
117
+ use_head_norm=use_head_norm,
118
+ init_identity_norm=init_identity_norm,
113
119
  ) for _ in range(num_layers)
114
120
  ]),
115
121
  use_flash_attention=use_flash_attention,
@@ -39,6 +39,8 @@ class RxTAlphaComponentConfig(TypedDict):
39
39
  att_query_groups: int
40
40
  cross_att_groups: int
41
41
  cross_att_query_groups: int
42
+ use_head_norm: bool
43
+ init_identity_norm: bool
42
44
 
43
45
 
44
46
  class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
@@ -71,6 +73,8 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
71
73
  att_query_groups: int = None,
72
74
  cross_att_groups: int = None,
73
75
  cross_att_query_groups: int = None,
76
+ use_head_norm: bool = False,
77
+ init_identity_norm: bool = False,
74
78
  **kwargs
75
79
  ):
76
80
  super(RxTAlphaComponentBase, self).__init__(**kwargs)
@@ -130,10 +134,14 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
130
134
  memory_cross_attention=cross_att_init(),
131
135
  ) for _ in range(num_layers)
132
136
  ])
133
- self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe)
137
+ self.model = self._init_model(
138
+ stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe,
139
+ use_head_norm=use_head_norm, init_identity_norm=init_identity_norm,
140
+ )
134
141
 
135
142
  def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
136
- use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool) -> ReactiveTransformerBase:
143
+ use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool,
144
+ use_head_norm: bool = False, init_identity_norm: bool = False) -> ReactiveTransformerBase:
137
145
  pass
138
146
 
139
147
  def params_count(self):
@@ -187,6 +195,8 @@ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="
187
195
  embed_dim: int,
188
196
  vocab_size: int,
189
197
  use_moe: bool,
198
+ use_head_norm: bool = False,
199
+ init_identity_norm: bool = False,
190
200
  ) -> ReactiveTransformerEncoder:
191
201
  return ReactiveTransformerEncoder(
192
202
  stm=stm,
@@ -214,6 +224,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
214
224
  embed_dim: int,
215
225
  vocab_size: int,
216
226
  use_moe: bool,
227
+ use_head_norm: bool = False,
228
+ init_identity_norm: bool = False,
217
229
  ) -> ReactiveTransformerDecoder:
218
230
  return ReactiveTransformerDecoder(
219
231
  embed_dim,
@@ -223,6 +235,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
223
235
  own_layers=layers,
224
236
  use_flash_attention=use_flash_attention,
225
237
  use_moe=use_moe,
238
+ use_head_norm=use_head_norm,
239
+ init_identity_norm=init_identity_norm,
226
240
  )
227
241
 
228
242
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -327,6 +341,8 @@ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classifica
327
341
  embed_dim: int,
328
342
  vocab_size: int,
329
343
  use_moe: bool = False,
344
+ use_head_norm: bool = False,
345
+ init_identity_norm: bool = False,
330
346
  ) -> ReactiveTransformerEncoderDetachStm:
331
347
  return ReactiveTransformerEncoderDetachStm(
332
348
  stm=stm,
@@ -602,14 +602,17 @@ class MRLTrainer:
602
602
  print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
603
603
  print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
604
604
  # decoder's cross att
605
- dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[0] for layer in self.actor.decoder.model.layers]
606
- print(f"Decoder cross-att mean total norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
605
+ dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
606
+ print(f"Decoder cross-att mean norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
607
607
 
608
- mem_att_norms = [get_gradient_norms(layer)[0] for layer in self.actor.memory_attention.model.attention_layers]
609
- print(f"Memory attention layers mean total norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
608
+ mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
609
+ print(f"Memory attention layers mean norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
610
610
 
611
- enc_ff_norms = [get_gradient_norms(layer.ff)[0] for layer in self.actor.encoder.model.layers]
612
- print(f"Encoder ff mean total norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
611
+ enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
612
+ print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
613
+
614
+ enc_ff_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
615
+ print(f"Encoder cross-att mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
613
616
 
614
617
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
615
618
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
@@ -146,10 +146,10 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
146
146
 
147
147
  def get_gradient_norms(model: nn.Module):
148
148
  total_norm = 0
149
- for p in model.parameters():
150
- if p.grad is not None:
151
- param_norm = p.grad.data.norm(2)
152
- total_norm += param_norm.item() ** 2
149
+ grad_params = list(filter(lambda p: p.requires_grad and p.grad is not None, model.parameters()))
150
+ for p in grad_params:
151
+ param_norm = p.grad.data.norm(2)
152
+ total_norm += param_norm.item() ** 2
153
153
  total_norm = total_norm ** 0.5
154
- mean_norm = total_norm / len(list(model.parameters()))
154
+ mean_norm = total_norm / len(grad_params)
155
155
  return total_norm, mean_norm
@@ -69,9 +69,18 @@ class ReactiveTransformerBase(nn.Module):
69
69
  class ReactiveTransformerDecoder(ReactiveTransformerBase):
70
70
  """Reactive Transformer decoder - extending the classic Transformer decoder with Memory Cross-Attention"""
71
71
 
72
- def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
72
+ def __init__(self, embed_dim: int, vocab_size: int, use_head_norm: bool = False, init_identity_norm: bool = False, *args, **kwargs):
73
73
  super(ReactiveTransformerDecoder, self).__init__(*args, **kwargs)
74
+
74
75
  self.head = nn.Linear(embed_dim, vocab_size)
76
+ self.use_head_norm = use_head_norm
77
+ if use_head_norm:
78
+ self.head_norm = nn.LayerNorm(embed_dim)
79
+ if init_identity_norm:
80
+ self.head_norm.weight.data.fill_(1.0)
81
+ self.head_norm.bias.data.fill_(0.0)
82
+ else:
83
+ self.head_norm = None
75
84
 
76
85
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
77
86
  x = super().forward(x) # apply embeddings
@@ -99,7 +108,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
99
108
  if layer_stm.size(0) == 1:
100
109
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
101
110
  x = self.layers[i](x, layer_stm, mask=mask)
102
- return self.head(x)
111
+ return self.head(self.head_norm(x) if self.use_head_norm else x)
103
112
 
104
113
 
105
114
  class ReactiveTransformerEncoder(ReactiveTransformerBase):
@@ -201,9 +210,17 @@ class ClassicTransformerBase(nn.Module):
201
210
  class ClassicTransformerDecoder(ClassicTransformerBase):
202
211
  """Classic Transformer decoder - for decoder-only Transformer models"""
203
212
 
204
- def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
213
+ def __init__(self, embed_dim: int, vocab_size: int, use_head_norm: bool = False, init_identity_norm: bool = False, *args, **kwargs):
205
214
  super(ClassicTransformerDecoder, self).__init__(*args, **kwargs)
206
215
  self.head = nn.Linear(embed_dim, vocab_size)
216
+ self.use_head_norm = use_head_norm
217
+ if use_head_norm:
218
+ self.head_norm = nn.LayerNorm(embed_dim)
219
+ if init_identity_norm:
220
+ self.head_norm.weight.data.fill_(1.0)
221
+ self.head_norm.bias.data.fill_(0.0)
222
+ else:
223
+ self.head_norm = None
207
224
 
208
225
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
209
226
  x = super().forward(x) # apply embeddings
@@ -213,7 +230,6 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
213
230
  if attention_mask is not None:
214
231
  mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
215
232
  elif attention_mask is not None:
216
- print(attention_mask.size())
217
233
  mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
218
234
  else:
219
235
  mask = None
@@ -221,7 +237,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
221
237
  # Process layers
222
238
  for i in range(self.num_layers):
223
239
  x = self.layers[i](x, mask=mask)
224
- return self.head(x)
240
+ return self.head(self.head_norm(x) if self.use_head_norm else x)
225
241
 
226
242
 
227
243
  class ClassicTransformerEncoder(ClassicTransformerBase):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes