rxnn 0.2.53__tar.gz → 0.2.55__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.53 → rxnn-0.2.55}/PKG-INFO +1 -1
  2. {rxnn-0.2.53 → rxnn-0.2.55}/pyproject.toml +1 -1
  3. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/experimental/attention.py +105 -30
  4. {rxnn-0.2.53 → rxnn-0.2.55}/LICENSE +0 -0
  5. {rxnn-0.2.53 → rxnn-0.2.55}/README.md +0 -0
  6. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/.DS_Store +0 -0
  7. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/experimental/models.py +0 -0
  10. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/memory/attention.py +0 -0
  13. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/memory/norm.py +0 -0
  14. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/memory/stm.py +0 -0
  15. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/rxt/__init__.py +0 -0
  16. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/rxt/models.py +0 -0
  17. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/__init__.py +0 -0
  18. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/base.py +0 -0
  19. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/bml.py +0 -0
  20. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/callbacks.py +0 -0
  21. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/dataset.py +0 -0
  22. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/ddp.py +0 -0
  23. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/models.py +0 -0
  24. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/mrl.py +0 -0
  25. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.53 → rxnn-0.2.55}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.53
3
+ Version: 0.2.55
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.53"
7
+ version = "0.2.55"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -319,6 +319,8 @@ class SparseQueryAttention(MultiHeadAttention):
319
319
 
320
320
  # Others
321
321
  class FlexAttention(MultiHeadAttention):
322
+ """Flex attention layer, with RoPE support"""
323
+
322
324
  def __init__(
323
325
  self,
324
326
  embed_dim: int,
@@ -355,8 +357,7 @@ class FlexAttention(MultiHeadAttention):
355
357
  # Learnable global tokens
356
358
  self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
357
359
 
358
-
359
- def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
360
+ def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
360
361
  b, t, d = x.size()
361
362
  return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
362
363
 
@@ -364,15 +365,19 @@ class FlexAttention(MultiHeadAttention):
364
365
  b, h, t, d = x.size()
365
366
  return self._transpose_output(x, b, t, h * d)
366
367
 
368
+ def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
369
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
370
+ is_causal=self.is_causal)
371
+
367
372
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
368
373
  b, t, d = query.size()
369
374
 
370
- # Prepend global tokens to the input query
375
+ # Prepend global tokens
371
376
  global_tokens = self.global_tokens.expand(b, -1, -1)
372
377
  x = torch.cat([global_tokens, query], dim=1)
373
378
 
374
379
  # Project Q, K, V
375
- q = self._split_heads(self.q_proj(x))
380
+ q = self._split_heads(self.q_proj(x), is_query=True)
376
381
  k = self._split_heads(self.k_proj(key))
377
382
  v = self._split_heads(self.v_proj(value))
378
383
 
@@ -380,42 +385,112 @@ class FlexAttention(MultiHeadAttention):
380
385
  if self.rope:
381
386
  q, k = self._apply_rope(q, k, separate=True)
382
387
 
383
- # Split Q into global and local parts
384
- global_q = q[:, :, :self.num_global_tokens] # (B, H, G, head_dim)
385
- local_q = q[:, :, self.num_global_tokens:] # (B, H, L, head_dim)
388
+ # Split Q into global and local
389
+ global_q = q[:, :, :self.num_global_tokens] # (B, H, G, D)
390
+ local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
391
+
392
+ # Global attention
393
+ global_attn = self._sdpa(global_q, k, v, mask=mask)
394
+
395
+ # Local attention with windowed slicing (no large masks)
386
396
  L = local_q.size(2)
387
397
  S = k.size(2)
398
+ windowed_attn = []
388
399
 
389
- # Global attention: global_q attends to all K/V
390
- global_attn = F.scaled_dot_product_attention(
391
- global_q, k, v, attn_mask=mask if not self.is_causal else None,
392
- dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
393
- )
400
+ for i in range(0, L, self.window_size):
401
+ start = i
402
+ end = min(i + self.window_size, L)
403
+ window_q = local_q[:, :, start:end] # (B, H, W, D)
394
404
 
395
- # Local attention: local_q attends to windowed K/V
396
- # Vectorized window mask
397
- indices = torch.arange(S, device=local_q.device)
398
- local_pos = torch.arange(L, device=local_q.device)
399
- local_window = (local_pos // self.window_size).unsqueeze(-1) # (L, 1)
400
- key_window = (indices // self.window_size).expand(L, -1) # (L, S)
401
- window_mask = (local_window == key_window).to(device=local_q.device)
402
-
403
- local_attn = F.scaled_dot_product_attention(
404
- local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
405
- dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
406
- )
405
+ # Use only relevant keys/values (same window)
406
+ k_window_start = max(0, start - self.window_size)
407
+ k_window_end = min(S, end + self.window_size)
408
+ window_k = k[:, :, k_window_start:k_window_end]
409
+ window_v = v[:, :, k_window_start:k_window_end]
410
+
411
+ window_attn = self._sdpa(window_q, window_k, window_v, mask=None)
412
+
413
+ windowed_attn.append(window_attn)
407
414
 
408
- # Combine global and local attention outputs
409
- attn = torch.cat([global_attn, local_attn], dim=2) # (B, H, G+L, head_dim)
415
+ # Concat local attention
416
+ local_attn = torch.cat(windowed_attn, dim=2)
410
417
 
411
- # Merge heads and project back
412
- output = self._merge_heads(attn) # (B, G+L, D)
413
- output = self.out_proj(output) # (B, G+L, D)
418
+ # Combine global and local
419
+ attn = torch.cat([global_attn, local_attn], dim=2)
420
+ output = self._merge_heads(attn)
421
+ output = self.out_proj(output)
414
422
 
415
- # Return only the local tokens (original query tokens)
416
423
  return output[:, self.num_global_tokens:, :]
417
424
 
418
425
 
426
+ class FlexSparseQueryAttention(FlexAttention):
427
+ """Combined Flex and Sparse Query attention layer, with RoPE support"""
428
+ def __init__(
429
+ self,
430
+ embed_dim: int,
431
+ num_heads: int,
432
+ num_groups: int,
433
+ num_query_groups: int,
434
+ dropout: float = 0.0,
435
+ rope: RotaryPositionalEmbedding = None,
436
+ rope_only_for_query: bool = False,
437
+ rope_only_for_keys: bool = False,
438
+ use_relative_embeddings: bool = False,
439
+ max_seq_len: int = 1024,
440
+ use_flash_attention: bool = True,
441
+ is_causal: bool = False,
442
+ use_bias: bool = False,
443
+ num_global_tokens: int = 16,
444
+ window_size: int = 128,
445
+ ):
446
+ self.num_groups = num_groups
447
+ self.num_query_groups = num_query_groups
448
+ super(FlexSparseQueryAttention, self).__init__(
449
+ embed_dim,
450
+ num_heads,
451
+ dropout=dropout,
452
+ rope=rope,
453
+ rope_only_for_query=rope_only_for_query,
454
+ rope_only_for_keys=rope_only_for_keys,
455
+ use_relative_embeddings=use_relative_embeddings,
456
+ max_seq_len=max_seq_len,
457
+ use_flash_attention=use_flash_attention,
458
+ is_causal=is_causal,
459
+ use_bias=use_bias,
460
+ num_global_tokens=num_global_tokens,
461
+ window_size=window_size,
462
+ )
463
+ assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
464
+
465
+ def _init_kv(self, embed_dim: int):
466
+ self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
467
+ self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
468
+
469
+ def _init_q(self, embed_dim: int):
470
+ self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups), bias=self.use_bias)
471
+
472
+ def _init_out(self, embed_dim: int):
473
+ """Initialize output projection"""
474
+ self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_query_groups), embed_dim)
475
+
476
+ def _split_heads(self, x: torch.Tensor, is_query: bool = False) -> torch.Tensor:
477
+ b, t, d = x.size()
478
+ return x.view(b, t, self.num_query_groups if is_query else self.num_groups, self.head_dim).transpose(1, 2)
479
+
480
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
481
+ """Transpose attention output back to (B, T, D) shape"""
482
+ return attn_output.transpose(1, 2).contiguous().view(b, t, d // (self.num_heads // self.num_query_groups))
483
+
484
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
485
+ b, h, t, d = x.size()
486
+ return self._transpose_output(x, b, t, self.embed_dim)
487
+
488
+ def _sdpa(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
489
+ is_gqa = self.num_query_groups != self.num_groups
490
+ return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0,
491
+ is_causal=self.is_causal, enable_gqa=is_gqa)
492
+
493
+
419
494
  class InfiniteAttention(MultiHeadAttention):
420
495
  def __init__(
421
496
  self,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes