rxnn 0.1.41__py3-none-any.whl → 0.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,20 +83,16 @@ class GroupedMoeAttention(GroupedQueryAttention):
83
83
 
84
84
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
85
85
  skip_query_processing: bool = False):
86
- head_dim = d // self.num_heads
87
-
88
86
  # Process Query as in GQA
89
87
  q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
90
88
 
91
89
  # Key/Value MoE routing
92
90
  B, S, D = key.shape
93
- print('key/value type', key.dtype, value.dtype)
94
91
  key_flat = key.reshape(-1, D)
95
- print('key flat type', key_flat.dtype)
96
92
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
97
93
  weights = weights.view(B, S, self.num_groups, 1)
98
94
  indices = indices.view(B, S, self.num_groups)
99
- print('weights/indices type', weights.dtype, indices.dtype)
95
+
100
96
  # Compute all experts' projections
101
97
  # Shape: (B*S, num_experts, head_dim)
102
98
  k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
@@ -106,43 +102,23 @@ class GroupedMoeAttention(GroupedQueryAttention):
106
102
  k_all += self.bk
107
103
  v_all += self.bv
108
104
 
109
- print('k all/v all before get all')
110
- print(k_all.size(), k_all.dtype)
111
- print(v_all.size(), v_all.dtype)
112
-
113
105
  # Get results for all heads
114
106
  k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
115
107
  v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
116
108
 
117
- print('k all/v all get all')
118
- print(k_all.size(), k_all.dtype)
119
- print(v_all.size(), v_all.dtype)
120
-
121
109
  # Gather top-k experts using expanded indices
122
110
  expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
123
111
  selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
124
112
  selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
125
113
 
126
- print('selected k/selected v')
127
- print(selected_k.size(), selected_k.dtype)
128
- print(selected_v.size(), selected_v.dtype)
129
-
130
114
  # Weighted
131
- weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
132
- weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
133
-
134
- print('weighted')
135
- print(weighted_k.size(), weighted_k.dtype)
136
- print(weighted_v.size(), weighted_v.dtype)
115
+ weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
116
+ weighted_v = (selected_v * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
137
117
 
138
118
  # Reshape to GQA format
139
119
  k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
140
120
  v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
141
121
 
142
- print('out 1')
143
- print(k.size(), k.dtype)
144
- print(v.size(), v.dtype)
145
-
146
122
  if self.rel_embed:
147
123
  group_heads = self.num_heads // self.num_groups
148
124
 
@@ -152,10 +128,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
152
128
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
153
129
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
154
130
 
155
- print('out 2')
156
- print(k.size(), k.dtype)
157
- print(v.size(), v.dtype)
158
-
159
131
  return q, k, v
160
132
 
161
133
 
@@ -253,7 +225,7 @@ class DeepMoeAttention(GroupedMoeAttention):
253
225
  selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
254
226
 
255
227
  # Weighted sum
256
- q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
228
+ q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
257
229
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
258
230
 
259
231
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -106,9 +106,6 @@ class MultiHeadAttention(nn.Module):
106
106
  # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
107
107
  # return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
108
108
  from flash_attn import flash_attn_func
109
- print(q.size(), q.dtype)
110
- print(k.size(), k.dtype)
111
- print(v.size(), v.dtype)
112
109
  attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
113
110
  return self._transpose_output(attn_output, b, t, d)
114
111
 
@@ -137,10 +134,6 @@ class MultiHeadAttention(nn.Module):
137
134
  return self._calculate_output(attn_weights, v, b, t, d)
138
135
 
139
136
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
140
- print('MHA forward')
141
- print(query.size(), query.dtype)
142
- print(key.size(), key.dtype)
143
- print(value.size(), value.dtype)
144
137
  b, t, d = query.size()
145
138
  q, k, v = self._forward_qkv(query, key, value, b, t, d)
146
139
  if not self.rel_embed:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.41
3
+ Version: 0.1.43
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=tvyjyBWLh9Bcu3MXlo3yx1hA48La8Zks4wWYtcnOYc8,30365
3
+ rxnn/experimental/attention.py,sha256=tyeZ2xkZfFrlTWzqjKS13NnIRSLGe1-oR7J20hN6BgI,29602
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
17
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- rxnn/transformers/attention.py,sha256=ZDdu1plCLhh-c5x-ZSMD1Avl36LMuTJSq617hD_hLCg,15942
19
+ rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.41.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.41.dist-info/METADATA,sha256=B9AeelUbhWbH1uYB17WIEz4gjo_1c0V5jTM5S2QgulE,16627
30
- rxnn-0.1.41.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.41.dist-info/RECORD,,
28
+ rxnn-0.1.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.43.dist-info/METADATA,sha256=7iCT0KM5FyGKZwe9ZbxnULu9z-rwyko4Xbklok1ZsOs,16627
30
+ rxnn-0.1.43.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.43.dist-info/RECORD,,
File without changes
File without changes