rxnn 0.1.38__py3-none-any.whl → 0.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -91,10 +91,11 @@ class GroupedMoeAttention(GroupedQueryAttention):
91
91
  # Key/Value MoE routing
92
92
  B, S, D = key.shape
93
93
  key_flat = key.reshape(-1, D)
94
+ print('key flat type', key_flat.dtype)
94
95
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
95
96
  weights = weights.view(B, S, self.num_groups, 1)
96
97
  indices = indices.view(B, S, self.num_groups)
97
-
98
+ print('weights/indices type', weights.dtype, indices.dtype)
98
99
  # Compute all experts' projections
99
100
  # Shape: (B*S, num_experts, head_dim)
100
101
  k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
@@ -104,24 +105,44 @@ class GroupedMoeAttention(GroupedQueryAttention):
104
105
  k_all += self.bk
105
106
  v_all += self.bv
106
107
 
108
+ print('k all/v all before get all')
109
+ print(k_all.size(), k_all.dtype)
110
+ print(v_all.size(), v_all.dtype)
111
+
107
112
  # Get results for all heads
108
113
  k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
109
114
  v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
110
115
 
116
+ print('k all/v all get all')
117
+ print(k_all.size(), k_all.dtype)
118
+ print(v_all.size(), v_all.dtype)
119
+
111
120
  # Gather top-k experts using expanded indices
112
121
  expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
113
122
  selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
114
123
  selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
115
124
 
125
+ print('selected k/selected v')
126
+ print(selected_k.size(), selected_k.dtype)
127
+ print(selected_v.size(), selected_v.dtype)
128
+
116
129
  # Weighted
117
130
  weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
118
131
  weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
119
132
 
133
+ print('weighted')
134
+ print(weighted_k.size(), weighted_k.dtype)
135
+ print(weighted_v.size(), weighted_v.dtype)
136
+
120
137
  # Reshape to GQA format
121
138
  k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
122
139
  v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
123
140
 
124
- if not self.rel_embed:
141
+ print('out 1')
142
+ print(k.size(), k.dtype)
143
+ print(v.size(), v.dtype)
144
+
145
+ if self.rel_embed:
125
146
  group_heads = self.num_heads // self.num_groups
126
147
 
127
148
  k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
@@ -130,6 +151,10 @@ class GroupedMoeAttention(GroupedQueryAttention):
130
151
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
131
152
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
132
153
 
154
+ print('out 2')
155
+ print(k.size(), k.dtype)
156
+ print(v.size(), v.dtype)
157
+
133
158
  return q, k, v
134
159
 
135
160
 
@@ -220,10 +245,10 @@ class DeepMoeAttention(GroupedMoeAttention):
220
245
  if self.use_bias:
221
246
  q_all += self.bq
222
247
 
223
- q_all = q_all.view(B, T, self.num_query_experts, -1)
248
+ q_all = q_all.view(B, T, self.num_query_experts, -1) # [B, T, num_query_experts, head_dim]
224
249
 
225
250
  # Gather top-k experts
226
- expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
251
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
227
252
  selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
228
253
 
229
254
  # Weighted sum
@@ -106,6 +106,9 @@ class MultiHeadAttention(nn.Module):
106
106
  # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
107
107
  # return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
108
108
  from flash_attn import flash_attn_func
109
+ print(q.size(), q.dtype)
110
+ print(k.size(), k.dtype)
111
+ print(v.size(), v.dtype)
109
112
  attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
110
113
  return self._transpose_output(attn_output, b, t, d)
111
114
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.38
3
+ Version: 0.1.40
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=PjmVwNeJXDy72LJr5cl9JD1oqjlwYK-Ahx1K1gLQgf8,29426
3
+ rxnn/experimental/attention.py,sha256=ZvLTfuTk05SzuRXNS91rc1TuJVPiOB8EvNfFpb4TLGw,30309
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
17
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
19
+ rxnn/transformers/attention.py,sha256=NjMjg4K4sIyQL4BjPhF1KeYe6yhUpOVUyc9AHS7rWEw,15794
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.38.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.38.dist-info/METADATA,sha256=_2XywHD_OlazI_xonQN5xhLGIaHFgYTAMWNi3HiQXnE,16627
30
- rxnn-0.1.38.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.38.dist-info/RECORD,,
28
+ rxnn-0.1.40.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.40.dist-info/METADATA,sha256=wiq3LTNOFaUDCN9Nsbxc_yjQX2eICtIRBVdtQTvL3Zc,16627
30
+ rxnn-0.1.40.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.40.dist-info/RECORD,,
File without changes
File without changes