rxnn 0.2.57__py3-none-any.whl → 0.2.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +3 -2
- rxnn/rxt/models.py +2 -2
- rxnn/training/models.py +1 -1
- rxnn/transformers/layers.py +5 -1
- {rxnn-0.2.57.dist-info → rxnn-0.2.58.dist-info}/METADATA +1 -1
- {rxnn-0.2.57.dist-info → rxnn-0.2.58.dist-info}/RECORD +8 -8
- {rxnn-0.2.57.dist-info → rxnn-0.2.58.dist-info}/LICENSE +0 -0
- {rxnn-0.2.57.dist-info → rxnn-0.2.58.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -47,7 +47,8 @@ class StmMemoryAttention(nn.Module):
|
|
47
47
|
else:
|
48
48
|
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
49
49
|
|
50
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
50
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
51
|
+
mem_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask else None
|
51
52
|
new_stm = torch.zeros_like(self.stm.memory)
|
52
53
|
for i in range(self.num_layers):
|
53
54
|
layer_stm = self.stm(i)
|
@@ -56,7 +57,7 @@ class StmMemoryAttention(nn.Module):
|
|
56
57
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
57
58
|
encoded_layer_data = x[i]
|
58
59
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
59
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
|
60
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mem_mask)
|
60
61
|
if self.use_gated_residual:
|
61
62
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
62
63
|
else:
|
rxnn/rxt/models.py
CHANGED
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
306
306
|
def clone_reset_memory(self):
|
307
307
|
self.model.stm.clone_detach_reset()
|
308
308
|
|
309
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
310
|
-
return self.model(x)
|
309
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
310
|
+
return self.model(x, attention_mask=attention_mask)
|
311
311
|
|
312
312
|
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
313
313
|
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
rxnn/training/models.py
CHANGED
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
|
|
204
204
|
return self.decoder(x, attention_mask=attention_mask)
|
205
205
|
else:
|
206
206
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
207
|
-
return self.memory_attention(ed)
|
207
|
+
return self.memory_attention(ed, attention_mask=attention_mask)
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
rxnn/transformers/layers.py
CHANGED
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
110
110
|
residual = x
|
111
111
|
if not self.use_post_norm:
|
112
112
|
x = self.norm2(x)
|
113
|
-
|
113
|
+
|
114
|
+
if mask is not None:
|
115
|
+
mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
|
116
|
+
|
117
|
+
x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
|
114
118
|
x = residual + x
|
115
119
|
if self.use_post_norm:
|
116
120
|
x = self.norm2(x)
|
@@ -5,18 +5,18 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
|
|
5
5
|
rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=N78kzcqXXfcq5v43LUsLheVUs4JrKPeQissl-_XKXdk,3241
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=4MbCL4xGY3ceewZQmopjmwAyLQS92L6KLOPqaW7-Fho,14673
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
|
17
17
|
rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
|
20
20
|
rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
22
|
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
@@ -26,14 +26,14 @@ rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
29
|
+
rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
31
|
rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.58.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.58.dist-info/METADATA,sha256=_NKVaBMYEbJadkBcUvDX1UprMlUu92JqgnBYx7R1J1c,25997
|
38
|
+
rxnn-0.2.58.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.58.dist-info/RECORD,,
|
File without changes
|
File without changes
|