rxnn 0.1.22__py3-none-any.whl → 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -125,9 +125,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
125
125
|
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
126
126
|
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
127
127
|
|
128
|
-
print('processed k', k.size())
|
129
|
-
print('processed v', v.size())
|
130
|
-
|
131
128
|
# Expand to GQA format
|
132
129
|
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
133
130
|
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
@@ -141,10 +138,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
141
138
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
142
139
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
143
140
|
|
144
|
-
print('q', q.size())
|
145
|
-
print('k', k.size())
|
146
|
-
print('v', v.size())
|
147
|
-
|
148
141
|
return q, k, v
|
149
142
|
|
150
143
|
|
@@ -229,13 +222,8 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
229
222
|
weights_q = weights_q_flat.view(B, T, -1)
|
230
223
|
indices_q = indices_q_flat.view(B, T, -1)
|
231
224
|
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
232
|
-
print('processed q', q.size())
|
233
|
-
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
234
|
-
|
235
|
-
# Expand query groups to match head count
|
236
|
-
group_heads = self.num_heads // self.num_query_groups
|
237
|
-
q = q.unsqueeze(2).expand(-1, -1, group_heads, -1, -1).flatten(1, 2)
|
238
225
|
|
226
|
+
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
239
227
|
# Key/Value processing
|
240
228
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
241
229
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=gMEcFJHGOkz8R_s4dGEJB5cb2K3pbXZi4XBwyhEdB4s,31967
|
4
4
|
rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=msspVdefdt2ekIN8aT-V8DolK4taESQL_NVsSGOepIs,4739
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.23.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.23.dist-info/METADATA,sha256=rZSBuoIgf8jKB11LKgMg7U42Wx7VNT_4EU3FVyED2YQ,16627
|
30
|
+
rxnn-0.1.23.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|