rxnn 0.2.57__py3-none-any.whl → 0.2.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +5 -2
- rxnn/rxt/models.py +4 -4
- rxnn/training/models.py +1 -1
- rxnn/transformers/layers.py +5 -1
- rxnn/transformers/models.py +2 -0
- {rxnn-0.2.57.dist-info → rxnn-0.2.59.dist-info}/METADATA +1 -1
- {rxnn-0.2.57.dist-info → rxnn-0.2.59.dist-info}/RECORD +9 -9
- {rxnn-0.2.57.dist-info → rxnn-0.2.59.dist-info}/LICENSE +0 -0
- {rxnn-0.2.57.dist-info → rxnn-0.2.59.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -47,7 +47,10 @@ class StmMemoryAttention(nn.Module):
|
|
47
47
|
else:
|
48
48
|
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
49
49
|
|
50
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
50
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
51
|
+
if attention_mask is not None:
|
52
|
+
print(attention_mask.size())
|
53
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
51
54
|
new_stm = torch.zeros_like(self.stm.memory)
|
52
55
|
for i in range(self.num_layers):
|
53
56
|
layer_stm = self.stm(i)
|
@@ -56,7 +59,7 @@ class StmMemoryAttention(nn.Module):
|
|
56
59
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
57
60
|
encoded_layer_data = x[i]
|
58
61
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
59
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
|
62
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
60
63
|
if self.use_gated_residual:
|
61
64
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
62
65
|
else:
|
rxnn/rxt/models.py
CHANGED
@@ -103,13 +103,13 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
103
103
|
if cross_att_type in ['mha', 'gqa', 'mqa']:
|
104
104
|
cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
105
105
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
106
|
-
max_seq_len=seq_len, is_causal=
|
106
|
+
max_seq_len=seq_len, is_causal=False, rope_only_for_query=True)
|
107
107
|
else:
|
108
108
|
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
|
109
109
|
cross_att_groups or att_groups, rope=rope,
|
110
110
|
use_flash_attention=use_flash_attention,
|
111
111
|
dropout=att_dropout,
|
112
|
-
max_seq_len=seq_len, is_causal=
|
112
|
+
max_seq_len=seq_len, is_causal=False,
|
113
113
|
num_experts=att_experts,
|
114
114
|
num_query_experts=att_query_experts,
|
115
115
|
num_query_groups=cross_att_query_groups or att_query_groups,
|
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
306
306
|
def clone_reset_memory(self):
|
307
307
|
self.model.stm.clone_detach_reset()
|
308
308
|
|
309
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
310
|
-
return self.model(x)
|
309
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
310
|
+
return self.model(x, attention_mask=attention_mask)
|
311
311
|
|
312
312
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
313
313
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
rxnn/training/models.py
CHANGED
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
|
|
204
204
|
return self.decoder(x, attention_mask=attention_mask)
|
205
205
|
else:
|
206
206
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
207
|
-
return self.memory_attention(ed)
|
207
|
+
return self.memory_attention(ed, attention_mask=attention_mask)
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
rxnn/transformers/layers.py
CHANGED
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
110
110
|
residual = x
|
111
111
|
if not self.use_post_norm:
|
112
112
|
x = self.norm2(x)
|
113
|
-
|
113
|
+
|
114
|
+
if mask is not None:
|
115
|
+
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
|
116
|
+
|
117
|
+
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
114
118
|
x = residual + x
|
115
119
|
if self.use_post_norm:
|
116
120
|
x = self.norm2(x)
|
rxnn/transformers/models.py
CHANGED
@@ -108,6 +108,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
108
108
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
109
109
|
x = super().forward(x) # apply embeddings
|
110
110
|
if attention_mask is not None:
|
111
|
+
print(attention_mask.size())
|
111
112
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
112
113
|
|
113
114
|
hidden_states = []
|
@@ -213,6 +214,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
|
|
213
214
|
if attention_mask is not None:
|
214
215
|
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
215
216
|
elif attention_mask is not None:
|
217
|
+
print(attention_mask.size())
|
216
218
|
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
217
219
|
else:
|
218
220
|
mask = None
|
@@ -5,18 +5,18 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
|
|
5
5
|
rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=Mt9LC0DK5asjN4pTqWOLDAoGzAgg3FhQoUw6baDJ3NI,3309
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=jd1UVBUWJzWdw7Rjcvo9k5BXCJriQ0khuVszqEyfD7M,14665
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
|
17
17
|
rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
|
20
20
|
rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
22
|
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
@@ -26,14 +26,14 @@ rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
29
|
+
rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
|
-
rxnn/transformers/models.py,sha256=
|
31
|
+
rxnn/transformers/models.py,sha256=bHR5gy74aV20Sl0371vVCvb3Z2pcqYTdOSnrLSkAIiI,10802
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.59.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.59.dist-info/METADATA,sha256=yj3LmcYzHTe7O5Y3LuDkCbjoTcY01WRQxaPQ9G4a8GA,25997
|
38
|
+
rxnn-0.2.59.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.59.dist-info/RECORD,,
|
File without changes
|
File without changes
|