rxnn 0.1.43__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,18 +68,22 @@ class GroupedMoeAttention(GroupedQueryAttention):
68
68
  def _init_kv(self, embed_dim: int):
69
69
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
70
70
  hidden_dim = embed_dim // self.num_heads
71
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
72
- self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
73
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
74
- self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
75
- self._init_experts()
71
+ moe_dim = hidden_dim * self.num_experts
72
+ self.k_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
73
+ self.v_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
74
+ # self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
75
+ # self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
76
+ # self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
77
+ # self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
78
+ # self._init_experts()
76
79
 
77
80
  def _init_experts(self):
78
- torch.nn.init.xavier_uniform_(self.wk)
79
- torch.nn.init.xavier_uniform_(self.wv)
80
- if self.use_bias:
81
- torch.nn.init.zeros_(self.bk)
82
- torch.nn.init.zeros_(self.bv)
81
+ pass
82
+ # torch.nn.init.xavier_uniform_(self.wk)
83
+ # torch.nn.init.xavier_uniform_(self.wv)
84
+ # if self.use_bias:
85
+ # torch.nn.init.zeros_(self.bk)
86
+ # torch.nn.init.zeros_(self.bv)
83
87
 
84
88
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
85
89
  skip_query_processing: bool = False):
@@ -94,22 +98,13 @@ class GroupedMoeAttention(GroupedQueryAttention):
94
98
  indices = indices.view(B, S, self.num_groups)
95
99
 
96
100
  # Compute all experts' projections
97
- # Shape: (B*S, num_experts, head_dim)
98
- k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
99
- v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
100
-
101
- if self.use_bias:
102
- k_all += self.bk
103
- v_all += self.bv
104
-
105
- # Get results for all heads
106
- k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
107
- v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
101
+ k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
102
+ v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
108
103
 
109
104
  # Gather top-k experts using expanded indices
110
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
111
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
112
- selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
105
+ expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, S, -1) # [B, num_groups, S, head_dim]
106
+ selected_k = torch.gather(k_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
107
+ selected_v = torch.gather(v_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
113
108
 
114
109
  # Weighted
115
110
  weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
@@ -187,14 +182,17 @@ class DeepMoeAttention(GroupedMoeAttention):
187
182
  def _init_q(self, embed_dim: int):
188
183
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
189
184
  hidden_dim = embed_dim // self.num_heads
190
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
191
- self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
192
- self._init_query_experts()
185
+ moe_dim = hidden_dim * self.num_query_experts
186
+ self.q_proj = nn.Linear(embed_dim, moe_dim)
187
+ # self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
188
+ # self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
189
+ # self._init_query_experts()
193
190
 
194
191
  def _init_query_experts(self):
195
- torch.nn.init.xavier_uniform_(self.wq)
196
- if self.use_bias:
197
- torch.nn.init.zeros_(self.bq)
192
+ pass
193
+ # torch.nn.init.xavier_uniform_(self.wq)
194
+ # if self.use_bias:
195
+ # torch.nn.init.zeros_(self.bq)
198
196
 
199
197
  def _init_out(self, embed_dim: int):
200
198
  """Initialize output projection"""
@@ -209,23 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
209
207
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
210
208
  B, T, D = query.shape
211
209
  query_flat = query.reshape(-1, D)
212
- weights_q, indices_q = self.query_router(query_flat)
213
- weights_q = weights_q.view(B, T, self.num_query_groups, 1)
214
- indices_q = indices_q.view(B, T, self.num_query_groups)
210
+ weights, indices = self.query_router(query_flat)
211
+ weights = weights.view(B, T, self.num_query_groups, 1)
212
+ indices = indices.view(B, T, self.num_query_groups)
215
213
 
216
- # Compute all query experts
217
- q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
218
- if self.use_bias:
219
- q_all += self.bq
220
-
221
- q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
214
+ q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
222
215
 
223
- # Gather top-k experts
224
- expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
225
- selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
216
+ # Gather top-k experts using expanded indices
217
+ expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
218
+ selected_q = torch.gather(q_all, 1, expanded_indices) # [B, num_query_groups, T, head_dim]
226
219
 
227
220
  # Weighted sum
228
- q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
221
+ q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
229
222
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
230
223
 
231
224
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.43
3
+ Version: 0.1.44
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=tyeZ2xkZfFrlTWzqjKS13NnIRSLGe1-oR7J20hN6BgI,29602
3
+ rxnn/experimental/attention.py,sha256=-AlMmfb7PVXpSIOY6VqNSMSrwbO3HhzGiuy1Jyrt_bk,29543
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.43.dist-info/METADATA,sha256=7iCT0KM5FyGKZwe9ZbxnULu9z-rwyko4Xbklok1ZsOs,16627
30
- rxnn-0.1.43.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.43.dist-info/RECORD,,
28
+ rxnn-0.1.44.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.44.dist-info/METADATA,sha256=2nJ6QfPr4pcgXVvBwG7XTMGoC4z5ku5f5RnT53B64Rs,16627
30
+ rxnn-0.1.44.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.44.dist-info/RECORD,,
File without changes
File without changes