rxnn 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +29 -4
- rxnn/transformers/attention.py +3 -0
- {rxnn-0.1.38.dist-info → rxnn-0.1.40.dist-info}/METADATA +1 -1
- {rxnn-0.1.38.dist-info → rxnn-0.1.40.dist-info}/RECORD +6 -6
- {rxnn-0.1.38.dist-info → rxnn-0.1.40.dist-info}/LICENSE +0 -0
- {rxnn-0.1.38.dist-info → rxnn-0.1.40.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -91,10 +91,11 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
91
91
|
# Key/Value MoE routing
|
92
92
|
B, S, D = key.shape
|
93
93
|
key_flat = key.reshape(-1, D)
|
94
|
+
print('key flat type', key_flat.dtype)
|
94
95
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
95
96
|
weights = weights.view(B, S, self.num_groups, 1)
|
96
97
|
indices = indices.view(B, S, self.num_groups)
|
97
|
-
|
98
|
+
print('weights/indices type', weights.dtype, indices.dtype)
|
98
99
|
# Compute all experts' projections
|
99
100
|
# Shape: (B*S, num_experts, head_dim)
|
100
101
|
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
@@ -104,24 +105,44 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
104
105
|
k_all += self.bk
|
105
106
|
v_all += self.bv
|
106
107
|
|
108
|
+
print('k all/v all before get all')
|
109
|
+
print(k_all.size(), k_all.dtype)
|
110
|
+
print(v_all.size(), v_all.dtype)
|
111
|
+
|
107
112
|
# Get results for all heads
|
108
113
|
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
109
114
|
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
110
115
|
|
116
|
+
print('k all/v all get all')
|
117
|
+
print(k_all.size(), k_all.dtype)
|
118
|
+
print(v_all.size(), v_all.dtype)
|
119
|
+
|
111
120
|
# Gather top-k experts using expanded indices
|
112
121
|
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
113
122
|
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
114
123
|
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
115
124
|
|
125
|
+
print('selected k/selected v')
|
126
|
+
print(selected_k.size(), selected_k.dtype)
|
127
|
+
print(selected_v.size(), selected_v.dtype)
|
128
|
+
|
116
129
|
# Weighted
|
117
130
|
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
118
131
|
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
119
132
|
|
133
|
+
print('weighted')
|
134
|
+
print(weighted_k.size(), weighted_k.dtype)
|
135
|
+
print(weighted_v.size(), weighted_v.dtype)
|
136
|
+
|
120
137
|
# Reshape to GQA format
|
121
138
|
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
122
139
|
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
123
140
|
|
124
|
-
|
141
|
+
print('out 1')
|
142
|
+
print(k.size(), k.dtype)
|
143
|
+
print(v.size(), v.dtype)
|
144
|
+
|
145
|
+
if self.rel_embed:
|
125
146
|
group_heads = self.num_heads // self.num_groups
|
126
147
|
|
127
148
|
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
@@ -130,6 +151,10 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
130
151
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
131
152
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
132
153
|
|
154
|
+
print('out 2')
|
155
|
+
print(k.size(), k.dtype)
|
156
|
+
print(v.size(), v.dtype)
|
157
|
+
|
133
158
|
return q, k, v
|
134
159
|
|
135
160
|
|
@@ -220,10 +245,10 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
220
245
|
if self.use_bias:
|
221
246
|
q_all += self.bq
|
222
247
|
|
223
|
-
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
248
|
+
q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
|
224
249
|
|
225
250
|
# Gather top-k experts
|
226
|
-
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
|
251
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
|
227
252
|
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
228
253
|
|
229
254
|
# Weighted sum
|
rxnn/transformers/attention.py
CHANGED
@@ -106,6 +106,9 @@ class MultiHeadAttention(nn.Module):
|
|
106
106
|
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
107
107
|
# return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
|
108
108
|
from flash_attn import flash_attn_func
|
109
|
+
print(q.size(), q.dtype)
|
110
|
+
print(k.size(), k.dtype)
|
111
|
+
print(v.size(), v.dtype)
|
109
112
|
attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
|
110
113
|
return self._transpose_output(attn_output, b, t, d)
|
111
114
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=ZvLTfuTk05SzuRXNS91rc1TuJVPiOB8EvNfFpb4TLGw,30309
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=NjMjg4K4sIyQL4BjPhF1KeYe6yhUpOVUyc9AHS7rWEw,15794
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.40.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.40.dist-info/METADATA,sha256=wiq3LTNOFaUDCN9Nsbxc_yjQX2eICtIRBVdtQTvL3Zc,16627
|
30
|
+
rxnn-0.1.40.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.40.dist-info/RECORD,,
|
File without changes
|
File without changes
|