rxnn 0.1.41__tar.gz → 0.1.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {rxnn-0.1.41 → rxnn-0.1.43}/PKG-INFO +1 -1
  2. {rxnn-0.1.41 → rxnn-0.1.43}/pyproject.toml +1 -1
  3. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/experimental/attention.py +4 -32
  4. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/attention.py +0 -7
  5. {rxnn-0.1.41 → rxnn-0.1.43}/LICENSE +0 -0
  6. {rxnn-0.1.41 → rxnn-0.1.43}/README.md +0 -0
  7. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/experimental/models.py +0 -0
  10. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/memory/norm.py +0 -0
  13. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/memory/stm.py +0 -0
  14. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/rxt/__init__.py +0 -0
  15. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/rxt/models.py +0 -0
  16. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/__init__.py +0 -0
  17. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/base.py +0 -0
  18. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/bml.py +0 -0
  19. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/callbacks.py +0 -0
  20. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/dataset.py +0 -0
  21. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/scheduler.py +0 -0
  22. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/training/tokenizer.py +0 -0
  23. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/__init__.py +0 -0
  24. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/ff.py +0 -0
  25. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/layers.py +0 -0
  26. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/mask.py +0 -0
  27. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/models.py +0 -0
  28. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/moe.py +0 -0
  29. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/positional.py +0 -0
  30. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/transformers/sampler.py +0 -0
  31. {rxnn-0.1.41 → rxnn-0.1.43}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.41
3
+ Version: 0.1.43
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.41"
7
+ version = "0.1.43"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -83,20 +83,16 @@ class GroupedMoeAttention(GroupedQueryAttention):
83
83
 
84
84
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
85
85
  skip_query_processing: bool = False):
86
- head_dim = d // self.num_heads
87
-
88
86
  # Process Query as in GQA
89
87
  q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
90
88
 
91
89
  # Key/Value MoE routing
92
90
  B, S, D = key.shape
93
- print('key/value type', key.dtype, value.dtype)
94
91
  key_flat = key.reshape(-1, D)
95
- print('key flat type', key_flat.dtype)
96
92
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
97
93
  weights = weights.view(B, S, self.num_groups, 1)
98
94
  indices = indices.view(B, S, self.num_groups)
99
- print('weights/indices type', weights.dtype, indices.dtype)
95
+
100
96
  # Compute all experts' projections
101
97
  # Shape: (B*S, num_experts, head_dim)
102
98
  k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
@@ -106,43 +102,23 @@ class GroupedMoeAttention(GroupedQueryAttention):
106
102
  k_all += self.bk
107
103
  v_all += self.bv
108
104
 
109
- print('k all/v all before get all')
110
- print(k_all.size(), k_all.dtype)
111
- print(v_all.size(), v_all.dtype)
112
-
113
105
  # Get results for all heads
114
106
  k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
115
107
  v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
116
108
 
117
- print('k all/v all get all')
118
- print(k_all.size(), k_all.dtype)
119
- print(v_all.size(), v_all.dtype)
120
-
121
109
  # Gather top-k experts using expanded indices
122
110
  expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
123
111
  selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
124
112
  selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
125
113
 
126
- print('selected k/selected v')
127
- print(selected_k.size(), selected_k.dtype)
128
- print(selected_v.size(), selected_v.dtype)
129
-
130
114
  # Weighted
131
- weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
132
- weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
133
-
134
- print('weighted')
135
- print(weighted_k.size(), weighted_k.dtype)
136
- print(weighted_v.size(), weighted_v.dtype)
115
+ weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
116
+ weighted_v = (selected_v * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
137
117
 
138
118
  # Reshape to GQA format
139
119
  k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
140
120
  v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
141
121
 
142
- print('out 1')
143
- print(k.size(), k.dtype)
144
- print(v.size(), v.dtype)
145
-
146
122
  if self.rel_embed:
147
123
  group_heads = self.num_heads // self.num_groups
148
124
 
@@ -152,10 +128,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
152
128
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
153
129
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
154
130
 
155
- print('out 2')
156
- print(k.size(), k.dtype)
157
- print(v.size(), v.dtype)
158
-
159
131
  return q, k, v
160
132
 
161
133
 
@@ -253,7 +225,7 @@ class DeepMoeAttention(GroupedMoeAttention):
253
225
  selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
254
226
 
255
227
  # Weighted sum
256
- q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
228
+ q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
257
229
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
258
230
 
259
231
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -106,9 +106,6 @@ class MultiHeadAttention(nn.Module):
106
106
  # with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
107
107
  # return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
108
108
  from flash_attn import flash_attn_func
109
- print(q.size(), q.dtype)
110
- print(k.size(), k.dtype)
111
- print(v.size(), v.dtype)
112
109
  attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
113
110
  return self._transpose_output(attn_output, b, t, d)
114
111
 
@@ -137,10 +134,6 @@ class MultiHeadAttention(nn.Module):
137
134
  return self._calculate_output(attn_weights, v, b, t, d)
138
135
 
139
136
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
140
- print('MHA forward')
141
- print(query.size(), query.dtype)
142
- print(key.size(), key.dtype)
143
- print(value.size(), value.dtype)
144
137
  b, t, d = query.size()
145
138
  q, k, v = self._forward_qkv(query, key, value, b, t, d)
146
139
  if not self.rel_embed:
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes