rxnn 0.2.58__py3-none-any.whl → 0.2.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/attention.py +4 -2
- rxnn/rxt/models.py +2 -2
- rxnn/transformers/models.py +2 -0
- {rxnn-0.2.58.dist-info → rxnn-0.2.59.dist-info}/METADATA +1 -1
- {rxnn-0.2.58.dist-info → rxnn-0.2.59.dist-info}/RECORD +7 -7
- {rxnn-0.2.58.dist-info → rxnn-0.2.59.dist-info}/LICENSE +0 -0
- {rxnn-0.2.58.dist-info → rxnn-0.2.59.dist-info}/WHEEL +0 -0
rxnn/memory/attention.py
CHANGED
@@ -48,7 +48,9 @@ class StmMemoryAttention(nn.Module):
|
|
48
48
|
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
49
49
|
|
50
50
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
51
|
-
|
51
|
+
if attention_mask is not None:
|
52
|
+
print(attention_mask.size())
|
53
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
52
54
|
new_stm = torch.zeros_like(self.stm.memory)
|
53
55
|
for i in range(self.num_layers):
|
54
56
|
layer_stm = self.stm(i)
|
@@ -57,7 +59,7 @@ class StmMemoryAttention(nn.Module):
|
|
57
59
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
58
60
|
encoded_layer_data = x[i]
|
59
61
|
normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
|
60
|
-
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=
|
62
|
+
new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
|
61
63
|
if self.use_gated_residual:
|
62
64
|
new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
|
63
65
|
else:
|
rxnn/rxt/models.py
CHANGED
@@ -103,13 +103,13 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
103
103
|
if cross_att_type in ['mha', 'gqa', 'mqa']:
|
104
104
|
cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
105
105
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
106
|
-
max_seq_len=seq_len, is_causal=
|
106
|
+
max_seq_len=seq_len, is_causal=False, rope_only_for_query=True)
|
107
107
|
else:
|
108
108
|
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
|
109
109
|
cross_att_groups or att_groups, rope=rope,
|
110
110
|
use_flash_attention=use_flash_attention,
|
111
111
|
dropout=att_dropout,
|
112
|
-
max_seq_len=seq_len, is_causal=
|
112
|
+
max_seq_len=seq_len, is_causal=False,
|
113
113
|
num_experts=att_experts,
|
114
114
|
num_query_experts=att_query_experts,
|
115
115
|
num_query_groups=cross_att_query_groups or att_query_groups,
|
rxnn/transformers/models.py
CHANGED
@@ -108,6 +108,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
108
108
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
109
109
|
x = super().forward(x) # apply embeddings
|
110
110
|
if attention_mask is not None:
|
111
|
+
print(attention_mask.size())
|
111
112
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
112
113
|
|
113
114
|
hidden_states = []
|
@@ -213,6 +214,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
|
|
213
214
|
if attention_mask is not None:
|
214
215
|
mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
215
216
|
elif attention_mask is not None:
|
217
|
+
print(attention_mask.size())
|
216
218
|
mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
217
219
|
else:
|
218
220
|
mask = None
|
@@ -5,11 +5,11 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
|
|
5
5
|
rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=Mt9LC0DK5asjN4pTqWOLDAoGzAgg3FhQoUw6baDJ3NI,3309
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=jd1UVBUWJzWdw7Rjcvo9k5BXCJriQ0khuVszqEyfD7M,14665
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
@@ -28,12 +28,12 @@ rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHm
|
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=OlbqD5kKygn5WZziLbU3jZjhr8hBrxLpqlCjJ_BNCW0,8119
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
|
-
rxnn/transformers/models.py,sha256=
|
31
|
+
rxnn/transformers/models.py,sha256=bHR5gy74aV20Sl0371vVCvb3Z2pcqYTdOSnrLSkAIiI,10802
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.59.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.59.dist-info/METADATA,sha256=yj3LmcYzHTe7O5Y3LuDkCbjoTcY01WRQxaPQ9G4a8GA,25997
|
38
|
+
rxnn-0.2.59.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.59.dist-info/RECORD,,
|
File without changes
|
File without changes
|