rxnn 0.1.43__py3-none-any.whl → 0.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -68,18 +68,22 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
68
68
|
def _init_kv(self, embed_dim: int):
|
69
69
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
70
70
|
hidden_dim = embed_dim // self.num_heads
|
71
|
-
|
72
|
-
self.
|
73
|
-
self.
|
74
|
-
self.
|
75
|
-
self.
|
71
|
+
moe_dim = hidden_dim * self.num_experts
|
72
|
+
self.k_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
|
73
|
+
self.v_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
|
74
|
+
# self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
75
|
+
# self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
76
|
+
# self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
77
|
+
# self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
78
|
+
# self._init_experts()
|
76
79
|
|
77
80
|
def _init_experts(self):
|
78
|
-
|
79
|
-
torch.nn.init.xavier_uniform_(self.
|
80
|
-
|
81
|
-
|
82
|
-
|
81
|
+
pass
|
82
|
+
# torch.nn.init.xavier_uniform_(self.wk)
|
83
|
+
# torch.nn.init.xavier_uniform_(self.wv)
|
84
|
+
# if self.use_bias:
|
85
|
+
# torch.nn.init.zeros_(self.bk)
|
86
|
+
# torch.nn.init.zeros_(self.bv)
|
83
87
|
|
84
88
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
85
89
|
skip_query_processing: bool = False):
|
@@ -94,22 +98,13 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
94
98
|
indices = indices.view(B, S, self.num_groups)
|
95
99
|
|
96
100
|
# Compute all experts' projections
|
97
|
-
|
98
|
-
|
99
|
-
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
100
|
-
|
101
|
-
if self.use_bias:
|
102
|
-
k_all += self.bk
|
103
|
-
v_all += self.bv
|
104
|
-
|
105
|
-
# Get results for all heads
|
106
|
-
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
107
|
-
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
101
|
+
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
102
|
+
v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
108
103
|
|
109
104
|
# Gather top-k experts using expanded indices
|
110
|
-
expanded_indices = indices.unsqueeze(-1).expand(-1, -1,
|
111
|
-
selected_k = torch.gather(k_all,
|
112
|
-
selected_v = torch.gather(v_all,
|
105
|
+
expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, S, -1) # [B, num_groups, S, head_dim]
|
106
|
+
selected_k = torch.gather(k_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
|
107
|
+
selected_v = torch.gather(v_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
|
113
108
|
|
114
109
|
# Weighted
|
115
110
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
@@ -187,14 +182,17 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
187
182
|
def _init_q(self, embed_dim: int):
|
188
183
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
189
184
|
hidden_dim = embed_dim // self.num_heads
|
190
|
-
|
191
|
-
self.
|
192
|
-
self.
|
185
|
+
moe_dim = hidden_dim * self.num_query_experts
|
186
|
+
self.q_proj = nn.Linear(embed_dim, moe_dim)
|
187
|
+
# self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
188
|
+
# self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
189
|
+
# self._init_query_experts()
|
193
190
|
|
194
191
|
def _init_query_experts(self):
|
195
|
-
|
196
|
-
|
197
|
-
|
192
|
+
pass
|
193
|
+
# torch.nn.init.xavier_uniform_(self.wq)
|
194
|
+
# if self.use_bias:
|
195
|
+
# torch.nn.init.zeros_(self.bq)
|
198
196
|
|
199
197
|
def _init_out(self, embed_dim: int):
|
200
198
|
"""Initialize output projection"""
|
@@ -209,23 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
209
207
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
210
208
|
B, T, D = query.shape
|
211
209
|
query_flat = query.reshape(-1, D)
|
212
|
-
|
213
|
-
|
214
|
-
|
210
|
+
weights, indices = self.query_router(query_flat)
|
211
|
+
weights = weights.view(B, T, self.num_query_groups, 1)
|
212
|
+
indices = indices.view(B, T, self.num_query_groups)
|
215
213
|
|
216
|
-
#
|
217
|
-
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
218
|
-
if self.use_bias:
|
219
|
-
q_all += self.bq
|
220
|
-
|
221
|
-
q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
|
214
|
+
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
|
222
215
|
|
223
|
-
# Gather top-k experts
|
224
|
-
expanded_indices =
|
225
|
-
selected_q = torch.gather(q_all,
|
216
|
+
# Gather top-k experts using expanded indices
|
217
|
+
expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
|
218
|
+
selected_q = torch.gather(q_all, 1, expanded_indices) # [B, num_query_groups, T, head_dim]
|
226
219
|
|
227
220
|
# Weighted sum
|
228
|
-
q = (selected_q *
|
221
|
+
q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
229
222
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
230
223
|
|
231
224
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=-AlMmfb7PVXpSIOY6VqNSMSrwbO3HhzGiuy1Jyrt_bk,29543
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.44.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.44.dist-info/METADATA,sha256=2nJ6QfPr4pcgXVvBwG7XTMGoC4z5ku5f5RnT53B64Rs,16627
|
30
|
+
rxnn-0.1.44.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.44.dist-info/RECORD,,
|
File without changes
|
File without changes
|