rxnn 0.2.52__tar.gz → 0.2.53__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.52 → rxnn-0.2.53}/PKG-INFO +1 -1
  2. {rxnn-0.2.52 → rxnn-0.2.53}/pyproject.toml +1 -1
  3. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/experimental/attention.py +103 -60
  4. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/experimental/models.py +7 -3
  5. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/attention.py +3 -1
  6. {rxnn-0.2.52 → rxnn-0.2.53}/LICENSE +0 -0
  7. {rxnn-0.2.52 → rxnn-0.2.53}/README.md +0 -0
  8. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/experimental/moe.py +0 -0
  12. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/memory/__init__.py +0 -0
  13. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/memory/attention.py +0 -0
  14. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/memory/norm.py +0 -0
  15. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/memory/stm.py +0 -0
  16. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/rxt/__init__.py +0 -0
  17. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/rxt/models.py +0 -0
  18. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.52 → rxnn-0.2.53}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.52
3
+ Version: 0.2.53
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.52"
7
+ version = "0.2.53"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -318,76 +318,102 @@ class SparseQueryAttention(MultiHeadAttention):
318
318
 
319
319
 
320
320
  # Others
321
-
322
321
  class FlexAttention(MultiHeadAttention):
323
322
  def __init__(
324
- self,
325
- embed_dim: int,
326
- num_heads: int,
327
- num_global_tokens: int = 16,
328
- window_size: int = 128,
329
- **kwargs
323
+ self,
324
+ embed_dim: int,
325
+ num_heads: int,
326
+ dropout: float = 0.0,
327
+ rope: RotaryPositionalEmbedding = None,
328
+ rope_only_for_query: bool = False,
329
+ rope_only_for_keys: bool = False,
330
+ use_relative_embeddings: bool = False,
331
+ max_seq_len: int = 1024,
332
+ use_flash_attention: bool = True,
333
+ is_causal: bool = False,
334
+ use_bias: bool = False,
335
+ num_global_tokens: int = 16,
336
+ window_size: int = 128,
330
337
  ):
331
- super().__init__(embed_dim, num_heads, **kwargs)
338
+ super(FlexAttention, self).__init__(
339
+ embed_dim,
340
+ num_heads,
341
+ dropout=dropout,
342
+ rope=rope,
343
+ rope_only_for_query=rope_only_for_query,
344
+ rope_only_for_keys=rope_only_for_keys,
345
+ use_relative_embeddings=use_relative_embeddings,
346
+ max_seq_len=max_seq_len,
347
+ use_flash_attention=use_flash_attention,
348
+ is_causal=is_causal,
349
+ use_bias=use_bias,
350
+ )
351
+ self.head_dim = embed_dim // num_heads
332
352
  self.num_global_tokens = num_global_tokens
333
353
  self.window_size = window_size
334
- self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
335
354
 
336
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None):
355
+ # Learnable global tokens
356
+ self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
357
+
358
+
359
+ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
360
+ b, t, d = x.size()
361
+ return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
362
+
363
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
364
+ b, h, t, d = x.size()
365
+ return self._transpose_output(x, b, t, h * d)
366
+
367
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
337
368
  b, t, d = query.size()
338
- head_dim = d // self.num_heads
339
369
 
340
- # Split into global and local
341
- x = torch.cat([self.global_tokens.expand(b, -1, -1), query], dim=1)
342
- seq_len = x.size(1)
343
- num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
370
+ # Prepend global tokens to the input query
371
+ global_tokens = self.global_tokens.expand(b, -1, -1)
372
+ x = torch.cat([global_tokens, query], dim=1)
344
373
 
345
374
  # Project Q, K, V
346
- q, k, v = self._forward_qkv(x, key, value, b, seq_len, d)
347
-
348
- # Process Global-to-Global Attention
349
- global_q = q[:, :, :self.num_global_tokens] # [B, H, G, head_dim]
350
- global_k = k[:, :, :self.num_global_tokens]
351
- global_v = v[:, :, :self.num_global_tokens]
352
- global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
353
-
354
- # Process Global-to-Local Attention
355
- local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
356
- local_v = v[:, :, self.num_global_tokens:]
357
- # Apply RoPE to local_k if needed
375
+ q = self._split_heads(self.q_proj(x))
376
+ k = self._split_heads(self.k_proj(key))
377
+ v = self._split_heads(self.v_proj(value))
378
+
379
+ # Apply RoPE
358
380
  if self.rope:
359
- # Compute frequencies for entire local sequence
360
- local_k = self.rope.forward_one(local_k)
361
-
362
- global_local_attn = self._calculate_attn_weights(global_q, local_k, d) @ local_v
363
-
364
- # Process Local-to-Local Attention (per window)
365
- local_q = q[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
366
- local_q = local_q.view(b, self.num_heads, num_windows, self.window_size, head_dim)
367
- local_k = local_k.view(b, self.num_heads, num_windows, self.window_size, head_dim)
368
- local_v = local_v.view(b, self.num_heads, num_windows, self.window_size, head_dim)
369
-
370
- local_attn = []
371
- for i in range(num_windows):
372
- window_q = local_q[:, :, i] # [B, H, window_size, head_dim]
373
- window_k = local_k[:, :, i]
374
- window_v = local_v[:, :, i]
375
-
376
- # Apply RoPE to window_q and window_k
377
- if self.rope:
378
- # Compute frequencies for this window
379
- window_q, window_k = self.rope(window_q, window_k)
380
-
381
- # Calculate attention for this window
382
- attn = self._calculate_attn_weights(window_q, window_k, d)
383
- attn_i = torch.einsum('bhij, bhjd -> bhid', attn, window_v)
384
- local_attn.append(attn_i)
385
- local_attn = torch.cat(local_attn, dim=2).view(b, self.num_heads, -1, head_dim)
386
-
387
- # Combine all attention outputs
388
- combined_attn = torch.cat([global_attn, global_local_attn, local_attn], dim=2)
389
- output = self._calculate_output(combined_attn, v, b, t, d)
390
- return self.out_proj(output)
381
+ q, k = self._apply_rope(q, k, separate=True)
382
+
383
+ # Split Q into global and local parts
384
+ global_q = q[:, :, :self.num_global_tokens] # (B, H, G, head_dim)
385
+ local_q = q[:, :, self.num_global_tokens:] # (B, H, L, head_dim)
386
+ L = local_q.size(2)
387
+ S = k.size(2)
388
+
389
+ # Global attention: global_q attends to all K/V
390
+ global_attn = F.scaled_dot_product_attention(
391
+ global_q, k, v, attn_mask=mask if not self.is_causal else None,
392
+ dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
393
+ )
394
+
395
+ # Local attention: local_q attends to windowed K/V
396
+ # Vectorized window mask
397
+ indices = torch.arange(S, device=local_q.device)
398
+ local_pos = torch.arange(L, device=local_q.device)
399
+ local_window = (local_pos // self.window_size).unsqueeze(-1) # (L, 1)
400
+ key_window = (indices // self.window_size).expand(L, -1) # (L, S)
401
+ window_mask = (local_window == key_window).to(device=local_q.device)
402
+
403
+ local_attn = F.scaled_dot_product_attention(
404
+ local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
405
+ dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
406
+ )
407
+
408
+ # Combine global and local attention outputs
409
+ attn = torch.cat([global_attn, local_attn], dim=2) # (B, H, G+L, head_dim)
410
+
411
+ # Merge heads and project back
412
+ output = self._merge_heads(attn) # (B, G+L, D)
413
+ output = self.out_proj(output) # (B, G+L, D)
414
+
415
+ # Return only the local tokens (original query tokens)
416
+ return output[:, self.num_global_tokens:, :]
391
417
 
392
418
 
393
419
  class InfiniteAttention(MultiHeadAttention):
@@ -467,8 +493,10 @@ def init_experimental_attention(
467
493
  num_experts: int = None,
468
494
  num_query_experts: int = None,
469
495
  num_query_groups: int = None,
496
+ num_global_tokens: int = 16,
497
+ window_size: int = 128,
470
498
  ) -> MultiHeadAttention:
471
- assert attention_type in ['gma', 'dma', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa'"
499
+ assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
472
500
 
473
501
  if attention_type == "gma":
474
502
  return GroupedMoeAttention(
@@ -504,6 +532,21 @@ def init_experimental_attention(
504
532
  num_query_experts=num_query_experts,
505
533
  num_query_groups=num_query_groups,
506
534
  )
535
+ elif attention_type == "flex":
536
+ return FlexAttention(
537
+ embed_dim,
538
+ num_heads,
539
+ dropout=dropout,
540
+ rope=rope,
541
+ max_seq_len=max_seq_len,
542
+ rope_only_for_query=rope_only_for_query,
543
+ rope_only_for_keys=rope_only_for_keys,
544
+ use_flash_attention=use_flash_attention,
545
+ is_causal=is_causal,
546
+ use_bias=use_bias,
547
+ num_global_tokens=num_global_tokens,
548
+ window_size=window_size,
549
+ )
507
550
  else:
508
551
  return SparseQueryAttention(
509
552
  embed_dim,
@@ -32,6 +32,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
32
32
  att_num_experts: int
33
33
  att_num_query_experts: int
34
34
  att_num_query_groups: int
35
+ att_num_global_tokens: int
36
+ att_window_size: int
35
37
 
36
38
 
37
39
  class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
@@ -63,13 +65,15 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
63
65
  att_num_experts: int = None,
64
66
  att_num_query_experts: int = None,
65
67
  att_num_query_groups: int = None,
68
+ att_num_global_tokens: int = 16,
69
+ att_window_size: int = 128,
66
70
  **kwargs
67
71
  ):
68
72
  super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
69
73
  assert ff_activation in ['relu', 'gelu',
70
74
  'swish', 'silu', 'linear',
71
75
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
72
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
76
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
73
77
 
74
78
  embedding = nn.Embedding(vocab_size, embed_dim)
75
79
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -84,8 +88,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
84
88
  att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
85
89
  use_flash_attention=use_flash_attention, dropout=att_dropout,
86
90
  max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
87
- num_query_experts=att_num_query_experts,
88
- num_query_groups=att_num_query_groups)
91
+ num_query_experts=att_num_query_experts, num_query_groups=att_num_query_groups,
92
+ num_global_tokens=att_num_global_tokens, window_size=att_window_size)
89
93
 
90
94
  use_moe_att = att_type in ['gma', 'dma']
91
95
 
@@ -69,12 +69,14 @@ class MultiHeadAttention(nn.Module):
69
69
  v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
70
70
  return q, k, v
71
71
 
72
- def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
72
+ def _apply_rope(self, q: torch.Tensor, k: torch.Tensor, separate: bool = False):
73
73
  if self.rope is not None:
74
74
  if self.rope_only_for_query:
75
75
  q = self.rope.forward_one(q)
76
76
  elif self.rope_only_for_keys:
77
77
  k = self.rope.forward_one(k)
78
+ elif separate:
79
+ q, k = self.rope.forward_one(q), self.rope.forward_one(k)
78
80
  else:
79
81
  q, k = self.rope(q, k)
80
82
  return q, k
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes