rxnn 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,81 +68,52 @@ class GroupedMoeAttention(GroupedQueryAttention):
68
68
  def _init_kv(self, embed_dim: int):
69
69
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
70
70
  hidden_dim = embed_dim // self.num_heads
71
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
72
- self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
73
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
74
- self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
75
- self._init_experts()
71
+ moe_dim = hidden_dim * self.num_experts
72
+ self.k_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
73
+ self.v_proj = nn.Linear(embed_dim, moe_dim, bias=self.use_bias)
74
+ # self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
75
+ # self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
76
+ # self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
77
+ # self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
78
+ # self._init_experts()
76
79
 
77
80
  def _init_experts(self):
78
- torch.nn.init.xavier_uniform_(self.wk)
79
- torch.nn.init.xavier_uniform_(self.wv)
80
- if self.use_bias:
81
- torch.nn.init.zeros_(self.bk)
82
- torch.nn.init.zeros_(self.bv)
81
+ pass
82
+ # torch.nn.init.xavier_uniform_(self.wk)
83
+ # torch.nn.init.xavier_uniform_(self.wv)
84
+ # if self.use_bias:
85
+ # torch.nn.init.zeros_(self.bk)
86
+ # torch.nn.init.zeros_(self.bv)
83
87
 
84
88
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
85
89
  skip_query_processing: bool = False):
86
- head_dim = d // self.num_heads
87
-
88
90
  # Process Query as in GQA
89
91
  q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
90
92
 
91
93
  # Key/Value MoE routing
92
94
  B, S, D = key.shape
93
- print('key/value type', key.dtype, value.dtype)
94
95
  key_flat = key.reshape(-1, D)
95
- print('key flat type', key_flat.dtype)
96
96
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
97
97
  weights = weights.view(B, S, self.num_groups, 1)
98
98
  indices = indices.view(B, S, self.num_groups)
99
- print('weights/indices type', weights.dtype, indices.dtype)
100
- # Compute all experts' projections
101
- # Shape: (B*S, num_experts, head_dim)
102
- k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
103
- v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
104
-
105
- if self.use_bias:
106
- k_all += self.bk
107
- v_all += self.bv
108
-
109
- print('k all/v all before get all')
110
- print(k_all.size(), k_all.dtype)
111
- print(v_all.size(), v_all.dtype)
112
-
113
- # Get results for all heads
114
- k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
115
- v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
116
99
 
117
- print('k all/v all get all')
118
- print(k_all.size(), k_all.dtype)
119
- print(v_all.size(), v_all.dtype)
100
+ # Compute all experts' projections
101
+ k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
102
+ v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
120
103
 
121
104
  # Gather top-k experts using expanded indices
122
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
123
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
124
- selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
125
-
126
- print('selected k/selected v')
127
- print(selected_k.size(), selected_k.dtype)
128
- print(selected_v.size(), selected_v.dtype)
105
+ expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, S, -1) # [B, num_groups, S, head_dim]
106
+ selected_k = torch.gather(k_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
107
+ selected_v = torch.gather(v_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
129
108
 
130
109
  # Weighted
131
110
  weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
132
111
  weighted_v = (selected_v * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
133
112
 
134
- print('weighted')
135
- print(weighted_k.size(), weighted_k.dtype)
136
- print(weighted_v.size(), weighted_v.dtype)
137
-
138
113
  # Reshape to GQA format
139
114
  k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
140
115
  v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
141
116
 
142
- print('out 1')
143
- print(k.size(), k.dtype)
144
- print(v.size(), v.dtype)
145
-
146
117
  if self.rel_embed:
147
118
  group_heads = self.num_heads // self.num_groups
148
119
 
@@ -152,10 +123,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
152
123
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
153
124
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
154
125
 
155
- print('out 2')
156
- print(k.size(), k.dtype)
157
- print(v.size(), v.dtype)
158
-
159
126
  return q, k, v
160
127
 
161
128
 
@@ -215,14 +182,17 @@ class DeepMoeAttention(GroupedMoeAttention):
215
182
  def _init_q(self, embed_dim: int):
216
183
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
217
184
  hidden_dim = embed_dim // self.num_heads
218
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
219
- self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
220
- self._init_query_experts()
185
+ moe_dim = hidden_dim * self.num_query_experts
186
+ self.q_proj = nn.Linear(embed_dim, moe_dim)
187
+ # self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
188
+ # self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
189
+ # self._init_query_experts()
221
190
 
222
191
  def _init_query_experts(self):
223
- torch.nn.init.xavier_uniform_(self.wq)
224
- if self.use_bias:
225
- torch.nn.init.zeros_(self.bq)
192
+ pass
193
+ # torch.nn.init.xavier_uniform_(self.wq)
194
+ # if self.use_bias:
195
+ # torch.nn.init.zeros_(self.bq)
226
196
 
227
197
  def _init_out(self, embed_dim: int):
228
198
  """Initialize output projection"""
@@ -237,23 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
237
207
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
238
208
  B, T, D = query.shape
239
209
  query_flat = query.reshape(-1, D)
240
- weights_q, indices_q = self.query_router(query_flat)
241
- weights_q = weights_q.view(B, T, self.num_query_groups, 1)
242
- indices_q = indices_q.view(B, T, self.num_query_groups)
210
+ weights, indices = self.query_router(query_flat)
211
+ weights = weights.view(B, T, self.num_query_groups, 1)
212
+ indices = indices.view(B, T, self.num_query_groups)
243
213
 
244
- # Compute all query experts
245
- q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
246
- if self.use_bias:
247
- q_all += self.bq
214
+ q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
248
215
 
249
- q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
250
-
251
- # Gather top-k experts
252
- expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
253
- selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
216
+ # Gather top-k experts using expanded indices
217
+ expanded_indices = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
218
+ selected_q = torch.gather(q_all, 1, expanded_indices) # [B, num_query_groups, T, head_dim]
254
219
 
255
220
  # Weighted sum
256
- q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
221
+ q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
257
222
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
258
223
 
259
224
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -106,9 +106,6 @@ class MultiHeadAttention(nn.Module):
106
106
  # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
107
107
  # return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
108
108
  from flash_attn import flash_attn_func
109
- print(q.size(), q.dtype)
110
- print(k.size(), k.dtype)
111
- print(v.size(), v.dtype)
112
109
  attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
113
110
  return self._transpose_output(attn_output, b, t, d)
114
111
 
@@ -137,10 +134,6 @@ class MultiHeadAttention(nn.Module):
137
134
  return self._calculate_output(attn_weights, v, b, t, d)
138
135
 
139
136
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
140
- print('MHA forward')
141
- print(query.size(), query.dtype)
142
- print(key.size(), key.dtype)
143
- print(value.size(), value.dtype)
144
137
  b, t, d = query.size()
145
138
  q, k, v = self._forward_qkv(query, key, value, b, t, d)
146
139
  if not self.rel_embed:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.42
3
+ Version: 0.1.44
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=mPDLI5lwujNTELdnVXDuIpagoQqHDP1GG6-ObCyM-Hw,30510
3
+ rxnn/experimental/attention.py,sha256=-AlMmfb7PVXpSIOY6VqNSMSrwbO3HhzGiuy1Jyrt_bk,29543
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
17
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- rxnn/transformers/attention.py,sha256=ZDdu1plCLhh-c5x-ZSMD1Avl36LMuTJSq617hD_hLCg,15942
19
+ rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.42.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.42.dist-info/METADATA,sha256=38UxLA25RpEi1-how5kIxMtlyJSrqZWZNzw0sgtnoDs,16627
30
- rxnn-0.1.42.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.42.dist-info/RECORD,,
28
+ rxnn-0.1.44.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.44.dist-info/METADATA,sha256=2nJ6QfPr4pcgXVvBwG7XTMGoC4z5ku5f5RnT53B64Rs,16627
30
+ rxnn-0.1.44.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.44.dist-info/RECORD,,
File without changes
File without changes