rxnn 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +37 -72
- rxnn/transformers/attention.py +0 -7
- {rxnn-0.1.42.dist-info → rxnn-0.1.44.dist-info}/METADATA +1 -1
- {rxnn-0.1.42.dist-info → rxnn-0.1.44.dist-info}/RECORD +6 -6
- {rxnn-0.1.42.dist-info → rxnn-0.1.44.dist-info}/LICENSE +0 -0
- {rxnn-0.1.42.dist-info → rxnn-0.1.44.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -68,81 +68,52 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
68
68
|
def _init_kv(self, embed_dim: int):
|
69
69
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
70
70
|
hidden_dim = embed_dim // self.num_heads
|
71
|
-
|
72
|
-
self.
|
73
|
-
self.
|
74
|
-
self.
|
75
|
-
self.
|
71
|
+
moe_dim = hidden_dim * self.num_experts
|
72
|
+
self.k_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
|
73
|
+
self.v_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
|
74
|
+
# self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
75
|
+
# self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
76
|
+
# self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
77
|
+
# self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
78
|
+
# self._init_experts()
|
76
79
|
|
77
80
|
def _init_experts(self):
|
78
|
-
|
79
|
-
torch.nn.init.xavier_uniform_(self.
|
80
|
-
|
81
|
-
|
82
|
-
|
81
|
+
pass
|
82
|
+
# torch.nn.init.xavier_uniform_(self.wk)
|
83
|
+
# torch.nn.init.xavier_uniform_(self.wv)
|
84
|
+
# if self.use_bias:
|
85
|
+
# torch.nn.init.zeros_(self.bk)
|
86
|
+
# torch.nn.init.zeros_(self.bv)
|
83
87
|
|
84
88
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
85
89
|
skip_query_processing: bool = False):
|
86
|
-
head_dim = d // self.num_heads
|
87
|
-
|
88
90
|
# Process Query as in GQA
|
89
91
|
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
90
92
|
|
91
93
|
# Key/Value MoE routing
|
92
94
|
B, S, D = key.shape
|
93
|
-
print('key/value type', key.dtype, value.dtype)
|
94
95
|
key_flat = key.reshape(-1, D)
|
95
|
-
print('key flat type', key_flat.dtype)
|
96
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
97
|
weights = weights.view(B, S, self.num_groups, 1)
|
98
98
|
indices = indices.view(B, S, self.num_groups)
|
99
|
-
print('weights/indices type', weights.dtype, indices.dtype)
|
100
|
-
# Compute all experts' projections
|
101
|
-
# Shape: (B*S, num_experts, head_dim)
|
102
|
-
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
103
|
-
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
104
|
-
|
105
|
-
if self.use_bias:
|
106
|
-
k_all += self.bk
|
107
|
-
v_all += self.bv
|
108
|
-
|
109
|
-
print('k all/v all before get all')
|
110
|
-
print(k_all.size(), k_all.dtype)
|
111
|
-
print(v_all.size(), v_all.dtype)
|
112
|
-
|
113
|
-
# Get results for all heads
|
114
|
-
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
115
|
-
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
116
99
|
|
117
|
-
|
118
|
-
|
119
|
-
|
100
|
+
# Compute all experts' projections
|
101
|
+
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
102
|
+
v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
120
103
|
|
121
104
|
# Gather top-k experts using expanded indices
|
122
|
-
expanded_indices = indices.unsqueeze(-1).expand(-1, -1,
|
123
|
-
selected_k = torch.gather(k_all,
|
124
|
-
selected_v = torch.gather(v_all,
|
125
|
-
|
126
|
-
print('selected k/selected v')
|
127
|
-
print(selected_k.size(), selected_k.dtype)
|
128
|
-
print(selected_v.size(), selected_v.dtype)
|
105
|
+
expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, S, -1) # [B, num_groups, S, head_dim]
|
106
|
+
selected_k = torch.gather(k_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
|
107
|
+
selected_v = torch.gather(v_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
|
129
108
|
|
130
109
|
# Weighted
|
131
110
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
132
111
|
weighted_v = (selected_v * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
133
112
|
|
134
|
-
print('weighted')
|
135
|
-
print(weighted_k.size(), weighted_k.dtype)
|
136
|
-
print(weighted_v.size(), weighted_v.dtype)
|
137
|
-
|
138
113
|
# Reshape to GQA format
|
139
114
|
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
140
115
|
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
141
116
|
|
142
|
-
print('out 1')
|
143
|
-
print(k.size(), k.dtype)
|
144
|
-
print(v.size(), v.dtype)
|
145
|
-
|
146
117
|
if self.rel_embed:
|
147
118
|
group_heads = self.num_heads // self.num_groups
|
148
119
|
|
@@ -152,10 +123,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
152
123
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
153
124
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
154
125
|
|
155
|
-
print('out 2')
|
156
|
-
print(k.size(), k.dtype)
|
157
|
-
print(v.size(), v.dtype)
|
158
|
-
|
159
126
|
return q, k, v
|
160
127
|
|
161
128
|
|
@@ -215,14 +182,17 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
215
182
|
def _init_q(self, embed_dim: int):
|
216
183
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
217
184
|
hidden_dim = embed_dim // self.num_heads
|
218
|
-
|
219
|
-
self.
|
220
|
-
self.
|
185
|
+
moe_dim = hidden_dim * self.num_query_experts
|
186
|
+
self.q_proj = nn.Linear(embed_dim, moe_dim)
|
187
|
+
# self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
188
|
+
# self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
189
|
+
# self._init_query_experts()
|
221
190
|
|
222
191
|
def _init_query_experts(self):
|
223
|
-
|
224
|
-
|
225
|
-
|
192
|
+
pass
|
193
|
+
# torch.nn.init.xavier_uniform_(self.wq)
|
194
|
+
# if self.use_bias:
|
195
|
+
# torch.nn.init.zeros_(self.bq)
|
226
196
|
|
227
197
|
def _init_out(self, embed_dim: int):
|
228
198
|
"""Initialize output projection"""
|
@@ -237,23 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
237
207
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
238
208
|
B, T, D = query.shape
|
239
209
|
query_flat = query.reshape(-1, D)
|
240
|
-
|
241
|
-
|
242
|
-
|
210
|
+
weights, indices = self.query_router(query_flat)
|
211
|
+
weights = weights.view(B, T, self.num_query_groups, 1)
|
212
|
+
indices = indices.view(B, T, self.num_query_groups)
|
243
213
|
|
244
|
-
#
|
245
|
-
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
246
|
-
if self.use_bias:
|
247
|
-
q_all += self.bq
|
214
|
+
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
|
248
215
|
|
249
|
-
|
250
|
-
|
251
|
-
#
|
252
|
-
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
|
253
|
-
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
216
|
+
# Gather top-k experts using expanded indices
|
217
|
+
expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
|
218
|
+
selected_q = torch.gather(q_all, 1, expanded_indices) # [B, num_query_groups, T, head_dim]
|
254
219
|
|
255
220
|
# Weighted sum
|
256
|
-
q = (selected_q *
|
221
|
+
q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
257
222
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
258
223
|
|
259
224
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
rxnn/transformers/attention.py
CHANGED
@@ -106,9 +106,6 @@ class MultiHeadAttention(nn.Module):
|
|
106
106
|
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
107
107
|
# return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
|
108
108
|
from flash_attn import flash_attn_func
|
109
|
-
print(q.size(), q.dtype)
|
110
|
-
print(k.size(), k.dtype)
|
111
|
-
print(v.size(), v.dtype)
|
112
109
|
attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
|
113
110
|
return self._transpose_output(attn_output, b, t, d)
|
114
111
|
|
@@ -137,10 +134,6 @@ class MultiHeadAttention(nn.Module):
|
|
137
134
|
return self._calculate_output(attn_weights, v, b, t, d)
|
138
135
|
|
139
136
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
140
|
-
print('MHA forward')
|
141
|
-
print(query.size(), query.dtype)
|
142
|
-
print(key.size(), key.dtype)
|
143
|
-
print(value.size(), value.dtype)
|
144
137
|
b, t, d = query.size()
|
145
138
|
q, k, v = self._forward_qkv(query, key, value, b, t, d)
|
146
139
|
if not self.rel_embed:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=-AlMmfb7PVXpSIOY6VqNSMSrwbO3HhzGiuy1Jyrt_bk,29543
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.44.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.44.dist-info/METADATA,sha256=2nJ6QfPr4pcgXVvBwG7XTMGoC4z5ku5f5RnT53B64Rs,16627
|
30
|
+
rxnn-0.1.44.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.44.dist-info/RECORD,,
|
File without changes
|
File without changes
|