rxnn 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -68,18 +68,22 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
68
68
|
def _init_kv(self, embed_dim: int):
|
69
69
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
70
70
|
hidden_dim = embed_dim // self.num_heads
|
71
|
-
|
72
|
-
self.
|
73
|
-
self.
|
74
|
-
self.
|
75
|
-
self.
|
71
|
+
moe_dim = hidden_dim * self.num_experts
|
72
|
+
self.k_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
|
73
|
+
self.v_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
|
74
|
+
# self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
75
|
+
# self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
76
|
+
# self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
77
|
+
# self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
78
|
+
# self._init_experts()
|
76
79
|
|
77
80
|
def _init_experts(self):
|
78
|
-
|
79
|
-
torch.nn.init.xavier_uniform_(self.
|
80
|
-
|
81
|
-
|
82
|
-
|
81
|
+
pass
|
82
|
+
# torch.nn.init.xavier_uniform_(self.wk)
|
83
|
+
# torch.nn.init.xavier_uniform_(self.wv)
|
84
|
+
# if self.use_bias:
|
85
|
+
# torch.nn.init.zeros_(self.bk)
|
86
|
+
# torch.nn.init.zeros_(self.bv)
|
83
87
|
|
84
88
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
85
89
|
skip_query_processing: bool = False):
|
@@ -91,25 +95,15 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
91
95
|
key_flat = key.reshape(-1, D)
|
92
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
93
97
|
weights = weights.view(B, S, self.num_groups, 1)
|
94
|
-
indices = indices.view(B, S, self.num_groups)
|
98
|
+
indices = indices.view(B, S, self.num_groups).unsqueeze(-1).expand(-1, -1, S, -1)
|
95
99
|
|
96
100
|
# Compute all experts' projections
|
97
|
-
|
98
|
-
|
99
|
-
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
100
|
-
|
101
|
-
if self.use_bias:
|
102
|
-
k_all += self.bk
|
103
|
-
v_all += self.bv
|
104
|
-
|
105
|
-
# Get results for all heads
|
106
|
-
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
107
|
-
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
101
|
+
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
102
|
+
v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
108
103
|
|
109
104
|
# Gather top-k experts using expanded indices
|
110
|
-
|
111
|
-
|
112
|
-
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
105
|
+
selected_k = torch.gather(k_all, 1, indices) # [B, num_groups, S, head_dim]
|
106
|
+
selected_v = torch.gather(v_all, 1, indices) # [B, num_groups, S, head_dim]
|
113
107
|
|
114
108
|
# Weighted
|
115
109
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
@@ -187,14 +181,17 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
187
181
|
def _init_q(self, embed_dim: int):
|
188
182
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
189
183
|
hidden_dim = embed_dim // self.num_heads
|
190
|
-
|
191
|
-
self.
|
192
|
-
self.
|
184
|
+
moe_dim = hidden_dim * self.num_query_experts
|
185
|
+
self.q_proj = nn.Linear(embed_dim, moe_dim)
|
186
|
+
# self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
187
|
+
# self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
188
|
+
# self._init_query_experts()
|
193
189
|
|
194
190
|
def _init_query_experts(self):
|
195
|
-
|
196
|
-
|
197
|
-
|
191
|
+
pass
|
192
|
+
# torch.nn.init.xavier_uniform_(self.wq)
|
193
|
+
# if self.use_bias:
|
194
|
+
# torch.nn.init.zeros_(self.bq)
|
198
195
|
|
199
196
|
def _init_out(self, embed_dim: int):
|
200
197
|
"""Initialize output projection"""
|
@@ -209,23 +206,17 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
209
206
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
210
207
|
B, T, D = query.shape
|
211
208
|
query_flat = query.reshape(-1, D)
|
212
|
-
|
213
|
-
|
214
|
-
|
209
|
+
weights, indices = self.query_router(query_flat)
|
210
|
+
weights = weights.view(B, T, self.num_query_groups, 1)
|
211
|
+
indices = indices.view(B, T, self.num_query_groups).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
|
215
212
|
|
216
|
-
#
|
217
|
-
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
218
|
-
if self.use_bias:
|
219
|
-
q_all += self.bq
|
220
|
-
|
221
|
-
q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
|
213
|
+
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
|
222
214
|
|
223
|
-
# Gather top-k experts
|
224
|
-
|
225
|
-
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
215
|
+
# Gather top-k experts using expanded indices
|
216
|
+
selected_q = torch.gather(q_all, 1, indices) # [B, num_query_groups, T, head_dim]
|
226
217
|
|
227
218
|
# Weighted sum
|
228
|
-
q = (selected_q *
|
219
|
+
q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
229
220
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
230
221
|
|
231
222
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=gj1Zi7GGxn2DYJiBjlbMaNWwfmuUxVxi7j94AaJlm44,29387
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.45.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.45.dist-info/METADATA,sha256=E-oJ9C52lOKVC4ZCnYQRSfqh9zhGFjLfc8OF_BWlUfc,16627
|
30
|
+
rxnn-0.1.45.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.45.dist-info/RECORD,,
|
File without changes
|
File without changes
|