rxnn 0.2.62__py3-none-any.whl → 0.2.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,6 +34,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
34
34
  att_num_query_groups: int
35
35
  att_num_global_tokens: int
36
36
  att_window_size: int
37
+ use_head_norm: bool
38
+ init_identity_norm: bool
37
39
 
38
40
 
39
41
  class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
@@ -67,6 +69,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
67
69
  att_num_query_groups: int = None,
68
70
  att_num_global_tokens: int = 16,
69
71
  att_window_size: int = 128,
72
+ use_head_norm: bool = False,
73
+ init_identity_norm: bool = False,
70
74
  **kwargs
71
75
  ):
72
76
  super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
@@ -110,6 +114,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
110
114
  use_rms_norm=use_rms_norm,
111
115
  self_attention=att_init(),
112
116
  use_moe_att=use_moe_att,
117
+ use_head_norm=use_head_norm,
118
+ init_identity_norm=init_identity_norm,
113
119
  ) for _ in range(num_layers)
114
120
  ]),
115
121
  use_flash_attention=use_flash_attention,
rxnn/rxt/models.py CHANGED
@@ -39,6 +39,8 @@ class RxTAlphaComponentConfig(TypedDict):
39
39
  att_query_groups: int
40
40
  cross_att_groups: int
41
41
  cross_att_query_groups: int
42
+ use_head_norm: bool
43
+ init_identity_norm: bool
42
44
 
43
45
 
44
46
  class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
@@ -71,6 +73,8 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
71
73
  att_query_groups: int = None,
72
74
  cross_att_groups: int = None,
73
75
  cross_att_query_groups: int = None,
76
+ use_head_norm: bool = False,
77
+ init_identity_norm: bool = False,
74
78
  **kwargs
75
79
  ):
76
80
  super(RxTAlphaComponentBase, self).__init__(**kwargs)
@@ -130,10 +134,14 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
130
134
  memory_cross_attention=cross_att_init(),
131
135
  ) for _ in range(num_layers)
132
136
  ])
133
- self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe)
137
+ self.model = self._init_model(
138
+ stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe,
139
+ use_head_norm=use_head_norm, init_identity_norm=init_identity_norm,
140
+ )
134
141
 
135
142
  def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
136
- use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool) -> ReactiveTransformerBase:
143
+ use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool,
144
+ use_head_norm: bool = False, init_identity_norm: bool = False) -> ReactiveTransformerBase:
137
145
  pass
138
146
 
139
147
  def params_count(self):
@@ -187,6 +195,8 @@ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="
187
195
  embed_dim: int,
188
196
  vocab_size: int,
189
197
  use_moe: bool,
198
+ use_head_norm: bool = False,
199
+ init_identity_norm: bool = False,
190
200
  ) -> ReactiveTransformerEncoder:
191
201
  return ReactiveTransformerEncoder(
192
202
  stm=stm,
@@ -214,6 +224,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
214
224
  embed_dim: int,
215
225
  vocab_size: int,
216
226
  use_moe: bool,
227
+ use_head_norm: bool = False,
228
+ init_identity_norm: bool = False,
217
229
  ) -> ReactiveTransformerDecoder:
218
230
  return ReactiveTransformerDecoder(
219
231
  embed_dim,
@@ -223,6 +235,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
223
235
  own_layers=layers,
224
236
  use_flash_attention=use_flash_attention,
225
237
  use_moe=use_moe,
238
+ use_head_norm=use_head_norm,
239
+ init_identity_norm=init_identity_norm,
226
240
  )
227
241
 
228
242
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
@@ -327,6 +341,8 @@ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classifica
327
341
  embed_dim: int,
328
342
  vocab_size: int,
329
343
  use_moe: bool = False,
344
+ use_head_norm: bool = False,
345
+ init_identity_norm: bool = False,
330
346
  ) -> ReactiveTransformerEncoderDetachStm:
331
347
  return ReactiveTransformerEncoderDetachStm(
332
348
  stm=stm,
rxnn/training/mrl.py CHANGED
@@ -243,20 +243,31 @@ class MRLTrainer:
243
243
  critic_weight_decay: float,
244
244
  critic_encoder_lr: float,
245
245
  embedding_lr: float,
246
+ encoder_lr: float,
246
247
  memory_lr: Optional[float] = None,
248
+ encoder_memory_lr: Optional[float] = None,
249
+ memory_attn_lr: Optional[float] = None,
247
250
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
248
251
  if memory_lr is not None:
249
252
  optimizer = torch.optim.AdamW([
250
253
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
251
- {'params': self.actor.not_memory_parameters(), 'lr': lr},
252
- {'params': self.actor.memory_parameters(), 'lr': memory_lr},
254
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
255
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
256
+ {'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
257
+ {'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
258
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': lr},
253
259
  ],
254
260
  weight_decay=weight_decay,
255
261
  )
256
262
  else:
257
263
  optimizer = torch.optim.AdamW([
258
264
  {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
259
- {'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
265
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
266
+ {'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
267
+ {'params': self.actor.encoder.memory_parameters(), 'lr': encoder_lr},
268
+ {'params': self.actor.memory_attention_parameters(), 'lr': lr},
269
+ {'params': self.actor.decoder.memory_parameters(), 'lr': lr},
270
+ {'params': self.actor.decoder.not_memory_parameters(), 'lr': lr},
260
271
  ],
261
272
  weight_decay=weight_decay,
262
273
  )
@@ -591,14 +602,17 @@ class MRLTrainer:
591
602
  print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
592
603
  print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
593
604
  # decoder's cross att
594
- dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[0] for layer in self.actor.decoder.model.layers]
595
- print(f"Decoder cross-att mean total norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
605
+ dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
606
+ print(f"Decoder cross-att mean norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
607
+
608
+ mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
609
+ print(f"Memory attention layers mean norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
596
610
 
597
- mem_att_norms = [get_gradient_norms(layer)[0] for layer in self.actor.memory_attention.model.attention_layers]
598
- print(f"Memory attention layers mean total norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
611
+ enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
612
+ print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
599
613
 
600
- enc_ff_norms = [get_gradient_norms(layer.ff)[0] for layer in self.actor.encoder.model.layers]
601
- print(f"Encoder ff mean total norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
614
+ enc_ff_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
615
+ print(f"Encoder cross-att mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
602
616
 
603
617
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
604
618
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
rxnn/training/utils.py CHANGED
@@ -146,10 +146,10 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
146
146
 
147
147
  def get_gradient_norms(model: nn.Module):
148
148
  total_norm = 0
149
- for p in model.parameters():
150
- if p.grad is not None:
151
- param_norm = p.grad.data.norm(2)
152
- total_norm += param_norm.item() ** 2
149
+ grad_params = list(filter(lambda p: p.requires_grad and p.grad is not None, model.parameters()))
150
+ for p in grad_params:
151
+ param_norm = p.grad.data.norm(2)
152
+ total_norm += param_norm.item() ** 2
153
153
  total_norm = total_norm ** 0.5
154
- mean_norm = total_norm / len(list(model.parameters()))
154
+ mean_norm = total_norm / len(grad_params)
155
155
  return total_norm, mean_norm
@@ -69,9 +69,18 @@ class ReactiveTransformerBase(nn.Module):
69
69
  class ReactiveTransformerDecoder(ReactiveTransformerBase):
70
70
  """Reactive Transformer decoder - extending the classic Transformer decoder with Memory Cross-Attention"""
71
71
 
72
- def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
72
+ def __init__(self, embed_dim: int, vocab_size: int, use_head_norm: bool = False, init_identity_norm: bool = False, *args, **kwargs):
73
73
  super(ReactiveTransformerDecoder, self).__init__(*args, **kwargs)
74
+
74
75
  self.head = nn.Linear(embed_dim, vocab_size)
76
+ self.use_head_norm = use_head_norm
77
+ if use_head_norm:
78
+ self.head_norm = nn.LayerNorm(embed_dim)
79
+ if init_identity_norm:
80
+ self.head_norm.weight.data.fill_(1.0)
81
+ self.head_norm.bias.data.fill_(0.0)
82
+ else:
83
+ self.head_norm = None
75
84
 
76
85
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
77
86
  x = super().forward(x) # apply embeddings
@@ -99,7 +108,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
99
108
  if layer_stm.size(0) == 1:
100
109
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
101
110
  x = self.layers[i](x, layer_stm, mask=mask)
102
- return self.head(x)
111
+ return self.head(self.head_norm(x) if self.use_head_norm else x)
103
112
 
104
113
 
105
114
  class ReactiveTransformerEncoder(ReactiveTransformerBase):
@@ -201,9 +210,17 @@ class ClassicTransformerBase(nn.Module):
201
210
  class ClassicTransformerDecoder(ClassicTransformerBase):
202
211
  """Classic Transformer decoder - for decoder-only Transformer models"""
203
212
 
204
- def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
213
+ def __init__(self, embed_dim: int, vocab_size: int, use_head_norm: bool = False, init_identity_norm: bool = False, *args, **kwargs):
205
214
  super(ClassicTransformerDecoder, self).__init__(*args, **kwargs)
206
215
  self.head = nn.Linear(embed_dim, vocab_size)
216
+ self.use_head_norm = use_head_norm
217
+ if use_head_norm:
218
+ self.head_norm = nn.LayerNorm(embed_dim)
219
+ if init_identity_norm:
220
+ self.head_norm.weight.data.fill_(1.0)
221
+ self.head_norm.bias.data.fill_(0.0)
222
+ else:
223
+ self.head_norm = None
207
224
 
208
225
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
209
226
  x = super().forward(x) # apply embeddings
@@ -213,7 +230,6 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
213
230
  if attention_mask is not None:
214
231
  mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
215
232
  elif attention_mask is not None:
216
- print(attention_mask.size())
217
233
  mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
218
234
  else:
219
235
  mask = None
@@ -221,7 +237,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
221
237
  # Process layers
222
238
  for i in range(self.num_layers):
223
239
  x = self.layers[i](x, mask=mask)
224
- return self.head(x)
240
+ return self.head(self.head_norm(x) if self.use_head_norm else x)
225
241
 
226
242
 
227
243
  class ClassicTransformerEncoder(ClassicTransformerBase):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.62
3
+ Version: 0.2.64
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -2,14 +2,14 @@ rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
2
2
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhXg,28455
5
- rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
5
+ rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=lSniKrf_skiM1V1zbfmV84PbKoQ-t_fVcKfwNKW3_OY,3844
9
9
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
10
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=mLHvb3ablQK9UtupuOHmLlG440Q_NW-OuLWcxGMfGuY,14807
12
+ rxnn/rxt/models.py,sha256=JrZQ78F4HGGklAy6mML4fbqdsMOcGSDRZpjhX55VXb8,15486
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
@@ -17,23 +17,23 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
17
  rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
20
- rxnn/training/mrl.py,sha256=BWp87Lj4epjTlROmrQK8RnS_83IucqS7XWI1cBae7BM,64424
20
+ rxnn/training/mrl.py,sha256=2J6Wh4xtsVoE6duEevmovDpmSsMkEoH39Ru0bE8lhFo,65481
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
- rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
25
+ rxnn/training/utils.py,sha256=QMNkJPQBY04DX9WN7GHnI2EZTBbAzWkjt2W-798oUII,6129
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
29
  rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
30
30
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
- rxnn/transformers/models.py,sha256=tT0W5inG4EtjEHNutG77Wcws2fJzLJs-iFDP3hX3D2Q,10761
31
+ rxnn/transformers/models.py,sha256=7xKixlBN5uZckcOXfukZeA4f7R34A35Gk98lNzew37o,11559
32
32
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.62.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.62.dist-info/METADATA,sha256=A40GBcyyy0ZxkHxFDXQVc7Ghrz9pvlYFDwPHEAbLuFI,25997
38
- rxnn-0.2.62.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.62.dist-info/RECORD,,
36
+ rxnn-0.2.64.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.64.dist-info/METADATA,sha256=QaDWd-8W0vs3povCgRAUXXSmTNN8gkEJ1dY6mA7n9kQ,25997
38
+ rxnn-0.2.64.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.64.dist-info/RECORD,,
File without changes
File without changes