rxnn 0.2.62__py3-none-any.whl → 0.2.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/models.py +6 -0
- rxnn/rxt/models.py +18 -2
- rxnn/training/mrl.py +23 -9
- rxnn/training/utils.py +5 -5
- rxnn/transformers/models.py +21 -5
- {rxnn-0.2.62.dist-info → rxnn-0.2.64.dist-info}/METADATA +1 -1
- {rxnn-0.2.62.dist-info → rxnn-0.2.64.dist-info}/RECORD +9 -9
- {rxnn-0.2.62.dist-info → rxnn-0.2.64.dist-info}/LICENSE +0 -0
- {rxnn-0.2.62.dist-info → rxnn-0.2.64.dist-info}/WHEEL +0 -0
rxnn/experimental/models.py
CHANGED
@@ -34,6 +34,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
|
|
34
34
|
att_num_query_groups: int
|
35
35
|
att_num_global_tokens: int
|
36
36
|
att_window_size: int
|
37
|
+
use_head_norm: bool
|
38
|
+
init_identity_norm: bool
|
37
39
|
|
38
40
|
|
39
41
|
class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
@@ -67,6 +69,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
67
69
|
att_num_query_groups: int = None,
|
68
70
|
att_num_global_tokens: int = 16,
|
69
71
|
att_window_size: int = 128,
|
72
|
+
use_head_norm: bool = False,
|
73
|
+
init_identity_norm: bool = False,
|
70
74
|
**kwargs
|
71
75
|
):
|
72
76
|
super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
|
@@ -110,6 +114,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
110
114
|
use_rms_norm=use_rms_norm,
|
111
115
|
self_attention=att_init(),
|
112
116
|
use_moe_att=use_moe_att,
|
117
|
+
use_head_norm=use_head_norm,
|
118
|
+
init_identity_norm=init_identity_norm,
|
113
119
|
) for _ in range(num_layers)
|
114
120
|
]),
|
115
121
|
use_flash_attention=use_flash_attention,
|
rxnn/rxt/models.py
CHANGED
@@ -39,6 +39,8 @@ class RxTAlphaComponentConfig(TypedDict):
|
|
39
39
|
att_query_groups: int
|
40
40
|
cross_att_groups: int
|
41
41
|
cross_att_query_groups: int
|
42
|
+
use_head_norm: bool
|
43
|
+
init_identity_norm: bool
|
42
44
|
|
43
45
|
|
44
46
|
class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
@@ -71,6 +73,8 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
71
73
|
att_query_groups: int = None,
|
72
74
|
cross_att_groups: int = None,
|
73
75
|
cross_att_query_groups: int = None,
|
76
|
+
use_head_norm: bool = False,
|
77
|
+
init_identity_norm: bool = False,
|
74
78
|
**kwargs
|
75
79
|
):
|
76
80
|
super(RxTAlphaComponentBase, self).__init__(**kwargs)
|
@@ -130,10 +134,14 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
130
134
|
memory_cross_attention=cross_att_init(),
|
131
135
|
) for _ in range(num_layers)
|
132
136
|
])
|
133
|
-
self.model = self._init_model(
|
137
|
+
self.model = self._init_model(
|
138
|
+
stm, layers, embedding, use_flash_attention, embed_dim, vocab_size, use_moe,
|
139
|
+
use_head_norm=use_head_norm, init_identity_norm=init_identity_norm,
|
140
|
+
)
|
134
141
|
|
135
142
|
def _init_model(self, stm: ShortTermMemory, layers: nn.ModuleList, embedding: nn.Embedding,
|
136
|
-
use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool
|
143
|
+
use_flash_attention: bool, embed_dim: int, vocab_size: int, use_moe: bool,
|
144
|
+
use_head_norm: bool = False, init_identity_norm: bool = False) -> ReactiveTransformerBase:
|
137
145
|
pass
|
138
146
|
|
139
147
|
def params_count(self):
|
@@ -187,6 +195,8 @@ class RxTAlphaEncoder(RxTAlphaComponentBase, pipeline_tag="fill-mask", license="
|
|
187
195
|
embed_dim: int,
|
188
196
|
vocab_size: int,
|
189
197
|
use_moe: bool,
|
198
|
+
use_head_norm: bool = False,
|
199
|
+
init_identity_norm: bool = False,
|
190
200
|
) -> ReactiveTransformerEncoder:
|
191
201
|
return ReactiveTransformerEncoder(
|
192
202
|
stm=stm,
|
@@ -214,6 +224,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
|
|
214
224
|
embed_dim: int,
|
215
225
|
vocab_size: int,
|
216
226
|
use_moe: bool,
|
227
|
+
use_head_norm: bool = False,
|
228
|
+
init_identity_norm: bool = False,
|
217
229
|
) -> ReactiveTransformerDecoder:
|
218
230
|
return ReactiveTransformerDecoder(
|
219
231
|
embed_dim,
|
@@ -223,6 +235,8 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
|
|
223
235
|
own_layers=layers,
|
224
236
|
use_flash_attention=use_flash_attention,
|
225
237
|
use_moe=use_moe,
|
238
|
+
use_head_norm=use_head_norm,
|
239
|
+
init_identity_norm=init_identity_norm,
|
226
240
|
)
|
227
241
|
|
228
242
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
@@ -327,6 +341,8 @@ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classifica
|
|
327
341
|
embed_dim: int,
|
328
342
|
vocab_size: int,
|
329
343
|
use_moe: bool = False,
|
344
|
+
use_head_norm: bool = False,
|
345
|
+
init_identity_norm: bool = False,
|
330
346
|
) -> ReactiveTransformerEncoderDetachStm:
|
331
347
|
return ReactiveTransformerEncoderDetachStm(
|
332
348
|
stm=stm,
|
rxnn/training/mrl.py
CHANGED
@@ -243,20 +243,31 @@ class MRLTrainer:
|
|
243
243
|
critic_weight_decay: float,
|
244
244
|
critic_encoder_lr: float,
|
245
245
|
embedding_lr: float,
|
246
|
+
encoder_lr: float,
|
246
247
|
memory_lr: Optional[float] = None,
|
248
|
+
encoder_memory_lr: Optional[float] = None,
|
249
|
+
memory_attn_lr: Optional[float] = None,
|
247
250
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
248
251
|
if memory_lr is not None:
|
249
252
|
optimizer = torch.optim.AdamW([
|
250
253
|
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
251
|
-
{'params': self.actor.not_memory_parameters(), 'lr':
|
252
|
-
{'params': self.actor.memory_parameters(), 'lr':
|
254
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
|
255
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': encoder_memory_lr},
|
256
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': memory_attn_lr},
|
257
|
+
{'params': self.actor.decoder.memory_parameters(), 'lr': memory_lr},
|
258
|
+
{'params': self.actor.decoder.not_memory_parameters(), 'lr': lr},
|
253
259
|
],
|
254
260
|
weight_decay=weight_decay,
|
255
261
|
)
|
256
262
|
else:
|
257
263
|
optimizer = torch.optim.AdamW([
|
258
264
|
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
259
|
-
{'params': self.actor.
|
265
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
266
|
+
{'params': self.actor.encoder.not_memory_parameters(), 'lr': encoder_lr},
|
267
|
+
{'params': self.actor.encoder.memory_parameters(), 'lr': encoder_lr},
|
268
|
+
{'params': self.actor.memory_attention_parameters(), 'lr': lr},
|
269
|
+
{'params': self.actor.decoder.memory_parameters(), 'lr': lr},
|
270
|
+
{'params': self.actor.decoder.not_memory_parameters(), 'lr': lr},
|
260
271
|
],
|
261
272
|
weight_decay=weight_decay,
|
262
273
|
)
|
@@ -591,14 +602,17 @@ class MRLTrainer:
|
|
591
602
|
print(f"Decoder grad norm - total: {decoder_total:.6f}, mean: {decoder_mean:.6f}")
|
592
603
|
print(f"Memory attention grad norm - total: {mem_att_total:.6f}, mean: {mem_att_mean:.6f}")
|
593
604
|
# decoder's cross att
|
594
|
-
dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[
|
595
|
-
print(f"Decoder cross-att mean
|
605
|
+
dec_x_att_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.decoder.model.layers]
|
606
|
+
print(f"Decoder cross-att mean norm: {(sum(dec_x_att_norms) / len(dec_x_att_norms)):.6f}, all: {dec_x_att_norms}")
|
607
|
+
|
608
|
+
mem_att_norms = [get_gradient_norms(layer)[1] for layer in self.actor.memory_attention.model.attention_layers]
|
609
|
+
print(f"Memory attention layers mean norm: {(sum(mem_att_norms) / len(mem_att_norms)):.6f}, all: {mem_att_norms}")
|
596
610
|
|
597
|
-
|
598
|
-
print(f"
|
611
|
+
enc_ff_norms = [get_gradient_norms(layer.ff)[1] for layer in self.actor.encoder.model.layers]
|
612
|
+
print(f"Encoder ff mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
|
599
613
|
|
600
|
-
enc_ff_norms = [get_gradient_norms(layer.
|
601
|
-
print(f"Encoder
|
614
|
+
enc_ff_norms = [get_gradient_norms(layer.memory_cross_attention)[1] for layer in self.actor.encoder.model.layers]
|
615
|
+
print(f"Encoder cross-att mean norm: {(sum(enc_ff_norms) / len(enc_ff_norms)):.6f}, all: {enc_ff_norms}")
|
602
616
|
|
603
617
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
604
618
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
rxnn/training/utils.py
CHANGED
@@ -146,10 +146,10 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
146
146
|
|
147
147
|
def get_gradient_norms(model: nn.Module):
|
148
148
|
total_norm = 0
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
149
|
+
grad_params = list(filter(lambda p: p.requires_grad and p.grad is not None, model.parameters()))
|
150
|
+
for p in grad_params:
|
151
|
+
param_norm = p.grad.data.norm(2)
|
152
|
+
total_norm += param_norm.item() ** 2
|
153
153
|
total_norm = total_norm ** 0.5
|
154
|
-
mean_norm = total_norm / len(
|
154
|
+
mean_norm = total_norm / len(grad_params)
|
155
155
|
return total_norm, mean_norm
|
rxnn/transformers/models.py
CHANGED
@@ -69,9 +69,18 @@ class ReactiveTransformerBase(nn.Module):
|
|
69
69
|
class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
70
70
|
"""Reactive Transformer decoder - extending the classic Transformer decoder with Memory Cross-Attention"""
|
71
71
|
|
72
|
-
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
72
|
+
def __init__(self, embed_dim: int, vocab_size: int, use_head_norm: bool = False, init_identity_norm: bool = False, *args, **kwargs):
|
73
73
|
super(ReactiveTransformerDecoder, self).__init__(*args, **kwargs)
|
74
|
+
|
74
75
|
self.head = nn.Linear(embed_dim, vocab_size)
|
76
|
+
self.use_head_norm = use_head_norm
|
77
|
+
if use_head_norm:
|
78
|
+
self.head_norm = nn.LayerNorm(embed_dim)
|
79
|
+
if init_identity_norm:
|
80
|
+
self.head_norm.weight.data.fill_(1.0)
|
81
|
+
self.head_norm.bias.data.fill_(0.0)
|
82
|
+
else:
|
83
|
+
self.head_norm = None
|
75
84
|
|
76
85
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
77
86
|
x = super().forward(x) # apply embeddings
|
@@ -99,7 +108,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
99
108
|
if layer_stm.size(0) == 1:
|
100
109
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
101
110
|
x = self.layers[i](x, layer_stm, mask=mask)
|
102
|
-
return self.head(x)
|
111
|
+
return self.head(self.head_norm(x) if self.use_head_norm else x)
|
103
112
|
|
104
113
|
|
105
114
|
class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
@@ -201,9 +210,17 @@ class ClassicTransformerBase(nn.Module):
|
|
201
210
|
class ClassicTransformerDecoder(ClassicTransformerBase):
|
202
211
|
"""Classic Transformer decoder - for decoder-only Transformer models"""
|
203
212
|
|
204
|
-
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
213
|
+
def __init__(self, embed_dim: int, vocab_size: int, use_head_norm: bool = False, init_identity_norm: bool = False, *args, **kwargs):
|
205
214
|
super(ClassicTransformerDecoder, self).__init__(*args, **kwargs)
|
206
215
|
self.head = nn.Linear(embed_dim, vocab_size)
|
216
|
+
self.use_head_norm = use_head_norm
|
217
|
+
if use_head_norm:
|
218
|
+
self.head_norm = nn.LayerNorm(embed_dim)
|
219
|
+
if init_identity_norm:
|
220
|
+
self.head_norm.weight.data.fill_(1.0)
|
221
|
+
self.head_norm.bias.data.fill_(0.0)
|
222
|
+
else:
|
223
|
+
self.head_norm = None
|
207
224
|
|
208
225
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
209
226
|
x = super().forward(x) # apply embeddings
|
@@ -213,7 +230,6 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
|
|
213
230
|
if attention_mask is not None:
|
214
231
|
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
215
232
|
elif attention_mask is not None:
|
216
|
-
print(attention_mask.size())
|
217
233
|
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
218
234
|
else:
|
219
235
|
mask = None
|
@@ -221,7 +237,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
|
|
221
237
|
# Process layers
|
222
238
|
for i in range(self.num_layers):
|
223
239
|
x = self.layers[i](x, mask=mask)
|
224
|
-
return self.head(x)
|
240
|
+
return self.head(self.head_norm(x) if self.use_head_norm else x)
|
225
241
|
|
226
242
|
|
227
243
|
class ClassicTransformerEncoder(ClassicTransformerBase):
|
@@ -2,14 +2,14 @@ rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhXg,28455
|
5
|
-
rxnn/experimental/models.py,sha256=
|
5
|
+
rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=lSniKrf_skiM1V1zbfmV84PbKoQ-t_fVcKfwNKW3_OY,3844
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=JrZQ78F4HGGklAy6mML4fbqdsMOcGSDRZpjhX55VXb8,15486
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
@@ -17,23 +17,23 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
|
|
17
17
|
rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
|
20
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=2J6Wh4xtsVoE6duEevmovDpmSsMkEoH39Ru0bE8lhFo,65481
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
22
|
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
|
-
rxnn/training/utils.py,sha256=
|
25
|
+
rxnn/training/utils.py,sha256=QMNkJPQBY04DX9WN7GHnI2EZTBbAzWkjt2W-798oUII,6129
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
|
-
rxnn/transformers/models.py,sha256=
|
31
|
+
rxnn/transformers/models.py,sha256=7xKixlBN5uZckcOXfukZeA4f7R34A35Gk98lNzew37o,11559
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.64.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.64.dist-info/METADATA,sha256=QaDWd-8W0vs3povCgRAUXXSmTNN8gkEJ1dY6mA7n9kQ,25997
|
38
|
+
rxnn-0.2.64.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.64.dist-info/RECORD,,
|
File without changes
|
File without changes
|