rxnn 0.2.55__py3-none-any.whl → 0.2.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -571,7 +571,7 @@ def init_experimental_attention(
571
571
  num_global_tokens: int = 16,
572
572
  window_size: int = 128,
573
573
  ) -> MultiHeadAttention:
574
- assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
574
+ assert attention_type in ['gma', 'dma', 'sqa', 'flex', 'flex-sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex', 'flex-sqa"
575
575
 
576
576
  if attention_type == "gma":
577
577
  return GroupedMoeAttention(
@@ -622,6 +622,23 @@ def init_experimental_attention(
622
622
  num_global_tokens=num_global_tokens,
623
623
  window_size=window_size,
624
624
  )
625
+ elif attention_type == "flex-sqa":
626
+ return FlexSparseQueryAttention(
627
+ embed_dim,
628
+ num_heads,
629
+ gqa_groups,
630
+ num_query_groups,
631
+ dropout=dropout,
632
+ rope=rope,
633
+ max_seq_len=max_seq_len,
634
+ rope_only_for_query=rope_only_for_query,
635
+ rope_only_for_keys=rope_only_for_keys,
636
+ use_flash_attention=use_flash_attention,
637
+ is_causal=is_causal,
638
+ use_bias=use_bias,
639
+ num_global_tokens=num_global_tokens,
640
+ window_size=window_size,
641
+ )
625
642
  else:
626
643
  return SparseQueryAttention(
627
644
  embed_dim,
@@ -73,7 +73,7 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
73
73
  assert ff_activation in ['relu', 'gelu',
74
74
  'swish', 'silu', 'linear',
75
75
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
76
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
76
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex', 'flex-sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex", "flex-sqa".'
77
77
 
78
78
  embedding = nn.Embedding(vocab_size, embed_dim)
79
79
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
rxnn/training/dataset.py CHANGED
@@ -943,7 +943,7 @@ class MrlCurriculumDataset(Dataset):
943
943
  return self.get_tokenized_item(idx)
944
944
 
945
945
  def __len__(self) -> int:
946
- return len(self.episodes)
946
+ return len(self.inputs if self.is_pre_tokenized else self.episodes)
947
947
 
948
948
  def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
949
949
  split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
rxnn/training/mrl.py CHANGED
@@ -36,6 +36,8 @@ class MrlConfig(TypedDict):
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
38
  use_memory_warmup: Optional[bool]
39
+ debug_mode: Optional[bool]
40
+ debug_interval: Optional[int]
39
41
 
40
42
 
41
43
  class MrlStrategy(Enum):
@@ -109,7 +111,6 @@ class MRLTrainer:
109
111
  use_ddp: bool = False,
110
112
  use_amp: bool = False,
111
113
  dtype: torch.dtype = torch.float32,
112
- debug_mode: bool = False,
113
114
  ):
114
115
  """
115
116
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -140,7 +141,8 @@ class MRLTrainer:
140
141
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
141
142
  self.freeze_embeddings = self.shared_freeze_embeddings
142
143
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
- self.debug_mode = debug_mode
144
+ self.debug_mode = config.get('debug_mode', False)
145
+ self.debug_interval = config.get('debug_interval', 10)
144
146
  # Internal update epochs config
145
147
  self.shared_update_epochs = config.get('update_epochs', 10)
146
148
  self.update_epochs = self.shared_update_epochs
@@ -606,7 +608,7 @@ class MRLTrainer:
606
608
  self.scaler.unscale_(self.optimizer)
607
609
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
608
610
  error_if_nonfinite=False)
609
- if self.debug_mode:
611
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
610
612
  self._log_gradients()
611
613
  # 4.5 Run scaled optimization step
612
614
  self.scaler.step(self.optimizer)
@@ -625,7 +627,7 @@ class MRLTrainer:
625
627
  # 4.4 Clip gradient norms
626
628
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
627
629
  error_if_nonfinite=False)
628
- if self.debug_mode:
630
+ if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
629
631
  self._log_gradients()
630
632
  # 4.5 Run scaled optimization step
631
633
  self.optimizer.step()
rxnn/training/rl.py CHANGED
@@ -33,10 +33,12 @@ class PPOConfig(TypedDict):
33
33
  use_distributed_advantage_norm: Optional[bool]
34
34
  clip_critic_values: Optional[bool]
35
35
  critic_value_clip: Optional[float]
36
+ debug_mode: Optional[bool]
37
+ debug_interval: Optional[int]
36
38
 
37
39
 
38
40
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
41
+ def __init__(self, config: Optional[PPOConfig] = None):
40
42
  super(PPOAlgorithm, self).__init__()
41
43
 
42
44
  if config is None:
@@ -50,7 +52,9 @@ class PPOAlgorithm(RlAlgorithm):
50
52
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
53
  self.clip_critic_values = config.get('clip_critic_values', True)
52
54
  self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
- self.debug_mode = debug_mode
55
+ self.debug_mode = config.get('debug_mode', False)
56
+ self.debug_interval = config.get('debug_interval', 10)
57
+ self.debug_step = 0
54
58
 
55
59
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
56
60
  # Critic loss with clipped values
@@ -98,12 +102,16 @@ class PPOAlgorithm(RlAlgorithm):
98
102
  advantages = advantages.unsqueeze(-1)
99
103
 
100
104
  if self.debug_mode:
101
- print(
102
- f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
- print(
104
- f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
- print(
106
- f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
105
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
106
+ self.debug_step = 0
107
+ print(
108
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
109
+ print(
110
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
111
+ print(
112
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
113
+ else:
114
+ self.debug_step += 1
107
115
 
108
116
  # c) Clipped surrogate loss
109
117
  surr1 = ratio * advantages
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.55
3
+ Version: 0.2.57
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,8 +1,8 @@
1
1
  rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
2
2
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- rxnn/experimental/attention.py,sha256=JMs6Wr2rRe5J5m0ULhudmhBrzPicGuOOyg5hO8aLFiQ,27846
5
- rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
4
+ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhXg,28455
5
+ rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
@@ -14,12 +14,12 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
17
- rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
17
+ rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
20
- rxnn/training/mrl.py,sha256=185rZsaFVaAt4mYksuflzKPiDuEaHpjsc3vPzxt9ax0,61862
20
+ rxnn/training/mrl.py,sha256=H2JcamaJv19vKqOgdoyhcCBwu1lb_aKfCmR_MuuvmS0,62085
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
- rxnn/training/rl.py,sha256=iBEPC_gfydXtWkVORO3REMWvOtx60-0xB7MFzfghUK8,6825
22
+ rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.55.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.55.dist-info/METADATA,sha256=4XCpsJFv9dpetex6uDRLrzKlMYlQZFLK2H2j---WZmA,25997
38
- rxnn-0.2.55.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.55.dist-info/RECORD,,
36
+ rxnn-0.2.57.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.57.dist-info/METADATA,sha256=xZ60cC1PUzse2BBDKPSOrFc_oVVLUdl6qKmZWUMaUa4,25997
38
+ rxnn-0.2.57.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.57.dist-info/RECORD,,
File without changes
File without changes