rxnn 0.1.43__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,18 +68,22 @@ class GroupedMoeAttention(GroupedQueryAttention):
68
68
  def _init_kv(self, embed_dim: int):
69
69
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
70
70
  hidden_dim = embed_dim // self.num_heads
71
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
72
- self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
73
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
74
- self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
75
- self._init_experts()
71
+ moe_dim = hidden_dim * self.num_experts
72
+ self.k_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
73
+ self.v_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
74
+ # self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
75
+ # self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
76
+ # self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
77
+ # self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
78
+ # self._init_experts()
76
79
 
77
80
  def _init_experts(self):
78
- torch.nn.init.xavier_uniform_(self.wk)
79
- torch.nn.init.xavier_uniform_(self.wv)
80
- if self.use_bias:
81
- torch.nn.init.zeros_(self.bk)
82
- torch.nn.init.zeros_(self.bv)
81
+ pass
82
+ # torch.nn.init.xavier_uniform_(self.wk)
83
+ # torch.nn.init.xavier_uniform_(self.wv)
84
+ # if self.use_bias:
85
+ # torch.nn.init.zeros_(self.bk)
86
+ # torch.nn.init.zeros_(self.bv)
83
87
 
84
88
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
85
89
  skip_query_processing: bool = False):
@@ -91,25 +95,15 @@ class GroupedMoeAttention(GroupedQueryAttention):
91
95
  key_flat = key.reshape(-1, D)
92
96
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
93
97
  weights = weights.view(B, S, self.num_groups, 1)
94
- indices = indices.view(B, S, self.num_groups)
98
+ indices = indices.view(B, S, self.num_groups).unsqueeze(-1).expand(-1, -1, S, -1)
95
99
 
96
100
  # Compute all experts' projections
97
- # Shape: (B*S, num_experts, head_dim)
98
- k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
99
- v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
100
-
101
- if self.use_bias:
102
- k_all += self.bk
103
- v_all += self.bv
104
-
105
- # Get results for all heads
106
- k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
107
- v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
101
+ k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
102
+ v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
108
103
 
109
104
  # Gather top-k experts using expanded indices
110
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
111
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
112
- selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
105
+ selected_k = torch.gather(k_all, 1, indices) # [B, num_groups, S, head_dim]
106
+ selected_v = torch.gather(v_all, 1, indices) # [B, num_groups, S, head_dim]
113
107
 
114
108
  # Weighted
115
109
  weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
@@ -187,14 +181,17 @@ class DeepMoeAttention(GroupedMoeAttention):
187
181
  def _init_q(self, embed_dim: int):
188
182
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
189
183
  hidden_dim = embed_dim // self.num_heads
190
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
191
- self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
192
- self._init_query_experts()
184
+ moe_dim = hidden_dim * self.num_query_experts
185
+ self.q_proj = nn.Linear(embed_dim, moe_dim)
186
+ # self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
187
+ # self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
188
+ # self._init_query_experts()
193
189
 
194
190
  def _init_query_experts(self):
195
- torch.nn.init.xavier_uniform_(self.wq)
196
- if self.use_bias:
197
- torch.nn.init.zeros_(self.bq)
191
+ pass
192
+ # torch.nn.init.xavier_uniform_(self.wq)
193
+ # if self.use_bias:
194
+ # torch.nn.init.zeros_(self.bq)
198
195
 
199
196
  def _init_out(self, embed_dim: int):
200
197
  """Initialize output projection"""
@@ -209,23 +206,17 @@ class DeepMoeAttention(GroupedMoeAttention):
209
206
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
210
207
  B, T, D = query.shape
211
208
  query_flat = query.reshape(-1, D)
212
- weights_q, indices_q = self.query_router(query_flat)
213
- weights_q = weights_q.view(B, T, self.num_query_groups, 1)
214
- indices_q = indices_q.view(B, T, self.num_query_groups)
209
+ weights, indices = self.query_router(query_flat)
210
+ weights = weights.view(B, T, self.num_query_groups, 1)
211
+ indices = indices.view(B, T, self.num_query_groups).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
215
212
 
216
- # Compute all query experts
217
- q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
218
- if self.use_bias:
219
- q_all += self.bq
220
-
221
- q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
213
+ q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
222
214
 
223
- # Gather top-k experts
224
- expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
225
- selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
215
+ # Gather top-k experts using expanded indices
216
+ selected_q = torch.gather(q_all, 1, indices) # [B, num_query_groups, T, head_dim]
226
217
 
227
218
  # Weighted sum
228
- q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
219
+ q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
229
220
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
230
221
 
231
222
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.43
3
+ Version: 0.1.45
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=tyeZ2xkZfFrlTWzqjKS13NnIRSLGe1-oR7J20hN6BgI,29602
3
+ rxnn/experimental/attention.py,sha256=gj1Zi7GGxn2DYJiBjlbMaNWwfmuUxVxi7j94AaJlm44,29387
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.43.dist-info/METADATA,sha256=7iCT0KM5FyGKZwe9ZbxnULu9z-rwyko4Xbklok1ZsOs,16627
30
- rxnn-0.1.43.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.43.dist-info/RECORD,,
28
+ rxnn-0.1.45.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.45.dist-info/METADATA,sha256=E-oJ9C52lOKVC4ZCnYQRSfqh9zhGFjLfc8OF_BWlUfc,16627
30
+ rxnn-0.1.45.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.45.dist-info/RECORD,,
File without changes
File without changes