rxnn 0.1.39__py3-none-any.whl → 0.1.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +30 -4
- rxnn/transformers/attention.py +4 -0
- {rxnn-0.1.39.dist-info → rxnn-0.1.41.dist-info}/METADATA +1 -1
- {rxnn-0.1.39.dist-info → rxnn-0.1.41.dist-info}/RECORD +6 -6
- {rxnn-0.1.39.dist-info → rxnn-0.1.41.dist-info}/LICENSE +0 -0
- {rxnn-0.1.39.dist-info → rxnn-0.1.41.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -90,11 +90,13 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
90
90
|
|
91
91
|
# Key/Value MoE routing
|
92
92
|
B, S, D = key.shape
|
93
|
+
print('key/value type', key.dtype, value.dtype)
|
93
94
|
key_flat = key.reshape(-1, D)
|
95
|
+
print('key flat type', key_flat.dtype)
|
94
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
95
97
|
weights = weights.view(B, S, self.num_groups, 1)
|
96
98
|
indices = indices.view(B, S, self.num_groups)
|
97
|
-
|
99
|
+
print('weights/indices type', weights.dtype, indices.dtype)
|
98
100
|
# Compute all experts' projections
|
99
101
|
# Shape: (B*S, num_experts, head_dim)
|
100
102
|
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
@@ -104,24 +106,44 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
104
106
|
k_all += self.bk
|
105
107
|
v_all += self.bv
|
106
108
|
|
109
|
+
print('k all/v all before get all')
|
110
|
+
print(k_all.size(), k_all.dtype)
|
111
|
+
print(v_all.size(), v_all.dtype)
|
112
|
+
|
107
113
|
# Get results for all heads
|
108
114
|
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
109
115
|
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
110
116
|
|
117
|
+
print('k all/v all get all')
|
118
|
+
print(k_all.size(), k_all.dtype)
|
119
|
+
print(v_all.size(), v_all.dtype)
|
120
|
+
|
111
121
|
# Gather top-k experts using expanded indices
|
112
122
|
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
113
123
|
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
114
124
|
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
115
125
|
|
126
|
+
print('selected k/selected v')
|
127
|
+
print(selected_k.size(), selected_k.dtype)
|
128
|
+
print(selected_v.size(), selected_v.dtype)
|
129
|
+
|
116
130
|
# Weighted
|
117
131
|
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
118
132
|
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
119
133
|
|
134
|
+
print('weighted')
|
135
|
+
print(weighted_k.size(), weighted_k.dtype)
|
136
|
+
print(weighted_v.size(), weighted_v.dtype)
|
137
|
+
|
120
138
|
# Reshape to GQA format
|
121
139
|
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
122
140
|
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
123
141
|
|
124
|
-
|
142
|
+
print('out 1')
|
143
|
+
print(k.size(), k.dtype)
|
144
|
+
print(v.size(), v.dtype)
|
145
|
+
|
146
|
+
if self.rel_embed:
|
125
147
|
group_heads = self.num_heads // self.num_groups
|
126
148
|
|
127
149
|
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
@@ -130,6 +152,10 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
130
152
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
131
153
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
132
154
|
|
155
|
+
print('out 2')
|
156
|
+
print(k.size(), k.dtype)
|
157
|
+
print(v.size(), v.dtype)
|
158
|
+
|
133
159
|
return q, k, v
|
134
160
|
|
135
161
|
|
@@ -220,10 +246,10 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
220
246
|
if self.use_bias:
|
221
247
|
q_all += self.bq
|
222
248
|
|
223
|
-
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
249
|
+
q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
|
224
250
|
|
225
251
|
# Gather top-k experts
|
226
|
-
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
|
252
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
|
227
253
|
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
228
254
|
|
229
255
|
# Weighted sum
|
rxnn/transformers/attention.py
CHANGED
@@ -137,6 +137,10 @@ class MultiHeadAttention(nn.Module):
|
|
137
137
|
return self._calculate_output(attn_weights, v, b, t, d)
|
138
138
|
|
139
139
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
140
|
+
print('MHA forward')
|
141
|
+
print(query.size(), query.dtype)
|
142
|
+
print(key.size(), key.dtype)
|
143
|
+
print(value.size(), value.dtype)
|
140
144
|
b, t, d = query.size()
|
141
145
|
q, k, v = self._forward_qkv(query, key, value, b, t, d)
|
142
146
|
if not self.rel_embed:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=tvyjyBWLh9Bcu3MXlo3yx1hA48La8Zks4wWYtcnOYc8,30365
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=ZDdu1plCLhh-c5x-ZSMD1Avl36LMuTJSq617hD_hLCg,15942
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.41.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.41.dist-info/METADATA,sha256=B9AeelUbhWbH1uYB17WIEz4gjo_1c0V5jTM5S2QgulE,16627
|
30
|
+
rxnn-0.1.41.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.41.dist-info/RECORD,,
|
File without changes
|
File without changes
|