rxnn 0.1.40__py3-none-any.whl → 0.1.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +4 -3
- rxnn/transformers/attention.py +4 -0
- {rxnn-0.1.40.dist-info → rxnn-0.1.42.dist-info}/METADATA +1 -1
- {rxnn-0.1.40.dist-info → rxnn-0.1.42.dist-info}/RECORD +6 -6
- {rxnn-0.1.40.dist-info → rxnn-0.1.42.dist-info}/LICENSE +0 -0
- {rxnn-0.1.40.dist-info → rxnn-0.1.42.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -90,6 +90,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
90
90
|
|
91
91
|
# Key/Value MoE routing
|
92
92
|
B, S, D = key.shape
|
93
|
+
print('key/value type', key.dtype, value.dtype)
|
93
94
|
key_flat = key.reshape(-1, D)
|
94
95
|
print('key flat type', key_flat.dtype)
|
95
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
@@ -127,8 +128,8 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
127
128
|
print(selected_v.size(), selected_v.dtype)
|
128
129
|
|
129
130
|
# Weighted
|
130
|
-
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
131
|
-
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
131
|
+
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
132
|
+
weighted_v = (selected_v * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
132
133
|
|
133
134
|
print('weighted')
|
134
135
|
print(weighted_k.size(), weighted_k.dtype)
|
@@ -252,7 +253,7 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
252
253
|
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
253
254
|
|
254
255
|
# Weighted sum
|
255
|
-
q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
|
256
|
+
q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
256
257
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
257
258
|
|
258
259
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
rxnn/transformers/attention.py
CHANGED
@@ -137,6 +137,10 @@ class MultiHeadAttention(nn.Module):
|
|
137
137
|
return self._calculate_output(attn_weights, v, b, t, d)
|
138
138
|
|
139
139
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
140
|
+
print('MHA forward')
|
141
|
+
print(query.size(), query.dtype)
|
142
|
+
print(key.size(), key.dtype)
|
143
|
+
print(value.size(), value.dtype)
|
140
144
|
b, t, d = query.size()
|
141
145
|
q, k, v = self._forward_qkv(query, key, value, b, t, d)
|
142
146
|
if not self.rel_embed:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=mPDLI5lwujNTELdnVXDuIpagoQqHDP1GG6-ObCyM-Hw,30510
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=ZDdu1plCLhh-c5x-ZSMD1Avl36LMuTJSq617hD_hLCg,15942
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.42.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.42.dist-info/METADATA,sha256=38UxLA25RpEi1-how5kIxMtlyJSrqZWZNzw0sgtnoDs,16627
|
30
|
+
rxnn-0.1.42.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.42.dist-info/RECORD,,
|
File without changes
|
File without changes
|