rxnn 0.2.53__tar.gz → 0.2.54__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.53 → rxnn-0.2.54}/PKG-INFO +1 -1
  2. {rxnn-0.2.53 → rxnn-0.2.54}/pyproject.toml +1 -1
  3. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/experimental/attention.py +30 -27
  4. {rxnn-0.2.53 → rxnn-0.2.54}/LICENSE +0 -0
  5. {rxnn-0.2.53 → rxnn-0.2.54}/README.md +0 -0
  6. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/.DS_Store +0 -0
  7. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/experimental/models.py +0 -0
  10. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/memory/attention.py +0 -0
  13. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/memory/norm.py +0 -0
  14. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/memory/stm.py +0 -0
  15. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/rxt/__init__.py +0 -0
  16. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/rxt/models.py +0 -0
  17. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/__init__.py +0 -0
  18. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/base.py +0 -0
  19. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/bml.py +0 -0
  20. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/callbacks.py +0 -0
  21. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/dataset.py +0 -0
  22. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/ddp.py +0 -0
  23. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/models.py +0 -0
  24. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/mrl.py +0 -0
  25. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.53 → rxnn-0.2.54}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.53
3
+ Version: 0.2.54
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.53"
7
+ version = "0.2.54"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -367,7 +367,7 @@ class FlexAttention(MultiHeadAttention):
367
367
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
368
368
  b, t, d = query.size()
369
369
 
370
- # Prepend global tokens to the input query
370
+ # Prepend global tokens
371
371
  global_tokens = self.global_tokens.expand(b, -1, -1)
372
372
  x = torch.cat([global_tokens, query], dim=1)
373
373
 
@@ -380,39 +380,42 @@ class FlexAttention(MultiHeadAttention):
380
380
  if self.rope:
381
381
  q, k = self._apply_rope(q, k, separate=True)
382
382
 
383
- # Split Q into global and local parts
384
- global_q = q[:, :, :self.num_global_tokens] # (B, H, G, head_dim)
385
- local_q = q[:, :, self.num_global_tokens:] # (B, H, L, head_dim)
383
+ # Split Q into global and local
384
+ global_q = q[:, :, :self.num_global_tokens] # (B, H, G, D)
385
+ local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
386
+
387
+ # Global attention
388
+ global_attn = F.scaled_dot_product_attention(global_q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
389
+
390
+ # Local attention with windowed slicing (no large masks)
386
391
  L = local_q.size(2)
387
392
  S = k.size(2)
393
+ windowed_attn = []
388
394
 
389
- # Global attention: global_q attends to all K/V
390
- global_attn = F.scaled_dot_product_attention(
391
- global_q, k, v, attn_mask=mask if not self.is_causal else None,
392
- dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
393
- )
395
+ for i in range(0, L, self.window_size):
396
+ start = i
397
+ end = min(i + self.window_size, L)
398
+ window_q = local_q[:, :, start:end] # (B, H, W, D)
394
399
 
395
- # Local attention: local_q attends to windowed K/V
396
- # Vectorized window mask
397
- indices = torch.arange(S, device=local_q.device)
398
- local_pos = torch.arange(L, device=local_q.device)
399
- local_window = (local_pos // self.window_size).unsqueeze(-1) # (L, 1)
400
- key_window = (indices // self.window_size).expand(L, -1) # (L, S)
401
- window_mask = (local_window == key_window).to(device=local_q.device)
402
-
403
- local_attn = F.scaled_dot_product_attention(
404
- local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
405
- dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
406
- )
400
+ # Use only relevant keys/values (same window)
401
+ k_window_start = max(0, start - self.window_size)
402
+ k_window_end = min(S, end + self.window_size)
403
+ window_k = k[:, :, k_window_start:k_window_end]
404
+ window_v = v[:, :, k_window_start:k_window_end]
405
+
406
+
407
+ window_attn = F.scaled_dot_product_attention(window_q, window_k, window_v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
408
+
409
+ windowed_attn.append(window_attn)
407
410
 
408
- # Combine global and local attention outputs
409
- attn = torch.cat([global_attn, local_attn], dim=2) # (B, H, G+L, head_dim)
411
+ # Concat local attention
412
+ local_attn = torch.cat(windowed_attn, dim=2)
410
413
 
411
- # Merge heads and project back
412
- output = self._merge_heads(attn) # (B, G+L, D)
413
- output = self.out_proj(output) # (B, G+L, D)
414
+ # Combine global and local
415
+ attn = torch.cat([global_attn, local_attn], dim=2)
416
+ output = self._merge_heads(attn)
417
+ output = self.out_proj(output)
414
418
 
415
- # Return only the local tokens (original query tokens)
416
419
  return output[:, self.num_global_tokens:, :]
417
420
 
418
421
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes