rxnn 0.2.67__py3-none-any.whl → 0.2.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py CHANGED
@@ -41,6 +41,7 @@ class MrlConfig(TypedDict):
41
41
  use_memory_warmup: Optional[bool]
42
42
  debug_mode: Optional[bool]
43
43
  debug_interval: Optional[int]
44
+ clamp_logits: Optional[bool]
44
45
 
45
46
 
46
47
  class MrlStrategy(Enum):
@@ -152,6 +153,7 @@ class MRLTrainer:
152
153
  self.use_memory_warmup = config.get('use_memory_warmup', False)
153
154
  self.debug_mode = config.get('debug_mode', False)
154
155
  self.debug_interval = config.get('debug_interval', 10)
156
+ self.clamp_logits = config.get('clamp_logits', False)
155
157
  # Internal update epochs config
156
158
  self.shared_update_epochs = config.get('update_epochs', 10)
157
159
  self.update_epochs = self.shared_update_epochs
@@ -594,7 +596,9 @@ class MRLTrainer:
594
596
  else:
595
597
  return main_loss
596
598
 
597
- def _log_gradients(self):
599
+ def _log_gradients(self, logits: torch.Tensor):
600
+ print(
601
+ f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
598
602
  encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
599
603
  decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
600
604
  mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
@@ -633,6 +637,8 @@ class MRLTrainer:
633
637
  pad_token_id=self.pad_token_id)
634
638
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
635
639
  action=MrlActorAction.DECODE)
640
+ if self.clamp_logits:
641
+ logits = logits.clamp(min=-20.0, max=20.0)
636
642
  # 4.2 Calculate policy loss with selected algorithm
637
643
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
638
644
  advantages)
@@ -645,7 +651,7 @@ class MRLTrainer:
645
651
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
646
652
  error_if_nonfinite=False)
647
653
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
648
- self._log_gradients()
654
+ self._log_gradients(logits)
649
655
  # 4.5 Run scaled optimization step
650
656
  self.scaler.step(self.optimizer)
651
657
  self.scaler.update()
@@ -655,6 +661,8 @@ class MRLTrainer:
655
661
  pad_token_id=self.pad_token_id)
656
662
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
657
663
  action=MrlActorAction.DECODE)
664
+ if self.clamp_logits:
665
+ logits = logits.clamp(min=-20.0, max=20.0)
658
666
  # 4.2 Calculate policy loss with selected algorithm
659
667
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
660
668
  policy_loss = self._moe_aux_loss(policy_loss)
@@ -664,7 +672,7 @@ class MRLTrainer:
664
672
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
665
673
  error_if_nonfinite=False)
666
674
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
667
- self._log_gradients()
675
+ self._log_gradients(logits)
668
676
  # 4.5 Run scaled optimization step
669
677
  self.optimizer.step()
670
678
  # 5. Get float loss value for callbacks/writer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.67
3
+ Version: 0.2.69
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -57,10 +57,10 @@ sequence, it has to process it and save it in memory, but it could be done in ba
57
57
 
58
58
  ## Release plan
59
59
  We are working on three new reactive architectures, that progressively advance from language models to awareness models:
60
- - Reactive Transformer: Reactive Language Model (RLM) with Short-Term Memory
61
- - Preactor: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
60
+ - **Reactive Transformer**: Reactive Language Model (RLM) with Short-Term Memory. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md)
61
+ - **Preactor**: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
62
62
  single message length is limited) and the ability to learn from interactions (Live Learning)
63
- - Reactor: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
63
+ - **Reactor**: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
64
64
 
65
65
  Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
66
66
  released with next versions of **RxNN** framework:
@@ -207,11 +207,57 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
207
207
  return self.model(x, attention_mask=attention_mask)
208
208
  ```
209
209
 
210
+ #### RxT-Alpha
211
+ `RxTAlphaEncoder` and `RxTAlphaDecoder` are ready to use **Reactive Transformer** components, compatible with Hugging Face
212
+ Hub (the above example is based on their code), so it could be used instead of creating custom class. Example usage could
213
+ be found in [pre-training docs](#pre-training)
214
+
210
215
  ### Memory
211
- The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
212
- include **Long-Term Memory**.
216
+ The _memory_ module includes **Short-Term Memory (STM)** and layers responsible for its update. In future versions it will also
217
+ include **Long-Term Memory (LTM)**.
218
+
219
+ #### Short Term Memory
220
+ The main `ShortTermMemory` class is located in `rxnn.memory.stm` module. As described in [Reactive Transformer research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md),
221
+ each transformer (encoder and decoder) layer has its own **STM** layer of shape `[batch_size, stm_size, embed_dim]`. Initially,
222
+ for the first training stages (pre-training and supervised fine-tuning), **STM** is in "single/no batch" mode (`batch_size = 1`),
223
+ because it's not used. For reinforcement learning stages (**MRL/RxRLHF/BRL**), we have to switch short-term memory to batch
224
+ mode, because items in batches are independent. After training, we could switch back to "single/no batch" mode. Example:
225
+ ```python
226
+ from rxnn.memory.stm import ShortTermMemory
227
+
228
+ num_layers = 10
229
+ stm_size = 256
230
+ embed_dim = 128
231
+ batch_size = 32
213
232
 
214
- The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
233
+ # 1. Init STM
234
+ stm = ShortTermMemory(
235
+ num_layers, embed_dim, stm_size,
236
+ init_type='normal' # memory init type, 'normal' is default and means normal distribution with 0.0 mean and 0.02 std
237
+ )
238
+
239
+ # 2. Set "batch" mode for MRL
240
+ stm.batched_memory(
241
+ batch_size,
242
+ init_type='standard' # init type could be changed for batch mode, 'standard' is normal distribution with 0.0 mean and 1.0 std
243
+ )
244
+
245
+ # 3. Reset STM with optional init type change
246
+ stm.reset(init_type='uniform') # init type could be also 'ones' or 'zeros', but it's not recommended
247
+
248
+ # 4. Back to "single" mode for inference (optionally using mean value from batch)
249
+ stm.single_memory(
250
+ init_type='standard', # we could change init type again
251
+ use_mean_from_batch=True # use mean values from batch as new memory
252
+ )
253
+ ```
254
+
255
+ > ##### Other utils
256
+ > `ShortTermMemory` could be also resized with `stm.resize(new_stm_size, init_type)` method, detached and cloned
257
+ > with `stm.clone_detach_reset()` (used in MRL), or could be made trainable (experimental option):
258
+ > - could be initialized as trainable - `stm = ShortTermMemory(num_layers, embed_dim, stm_size, is_trainable=True)`
259
+ > - could be switched to trainable - `stm.make_trainable()`
260
+ > - and switched back to buffer - `stm.freeze()`
215
261
 
216
262
  #### Memory Attention Network
217
263
  **Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
@@ -299,6 +345,10 @@ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0")
299
345
  >
300
346
  > **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
301
347
 
348
+ ##### RxT-Alpha Memory Attention
349
+ `RxTAlphaMemoryAttention` is ready to use Memory Attention network for **Reactive Transformer** Proof-of-Concept, that
350
+ could be used instead of creating custom class. Example is in [Memory Reinforcement Learning docs](#memory-reinforcement-learning)
351
+
302
352
  ### Training
303
353
  Training module includes **Trainers** for different training stages of reactive models and shared training utils.
304
354
 
@@ -368,16 +418,17 @@ config = {
368
418
  'num_layers': 10,
369
419
  'vocab_size': vocab_size,
370
420
  'embed_dim': embed_dim,
371
- 'att_heads': 16,
372
- 'att_groups': 8,
421
+ 'att_heads': 16, # attention heads, in SQA it's used only for dimension split
422
+ 'att_groups': 8, # key/value groups for GQA/SQA
373
423
  'seq_len': seq_len,
374
424
  'stm_size': seq_len,
375
- 'use_flash_attention': False,
376
- 'use_gated': True,
425
+ 'use_flash_attention': False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend) - not recommended
426
+ 'use_gated': True, # use Gated Linear Units in feed forward, True by default
427
+ 'ff_activation': 'silu', # feed forward activation, 'silu' is default for SwiGLU layers
377
428
  'ff_dropout': 0.1,
378
- 'self_att_type': 'sqa',
379
- 'cross_att_type': 'sqa',
380
- 'att_query_groups': 8,
429
+ 'self_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
430
+ 'cross_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
431
+ 'att_query_groups': 8, # query groups for SQA
381
432
  }
382
433
 
383
434
  encoder_config = {
@@ -387,9 +438,9 @@ encoder_config = {
387
438
 
388
439
  decoder_config = {
389
440
  'ff_dim': 256,
390
- 'use_moe': True,
391
- 'num_experts': 20,
392
- 'moe_top_k': 4,
441
+ 'use_moe': True, # use Mixture-of-Experts feed forward
442
+ 'num_experts': 20, # number of experts
443
+ 'moe_top_k': 4, # number of activated experts (per token)
393
444
  **config
394
445
  }
395
446
 
@@ -647,11 +698,11 @@ mem_attn = RxTAlphaMemoryAttention(
647
698
  att_heads=8,
648
699
  seq_len=256,
649
700
  stm_size=256,
650
- use_flash_attention=False,
651
- norm_type='classic-rms',
652
- att_groups=4,
653
- att_type='sqa',
654
- att_query_groups=4,
701
+ use_flash_attention=False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend)
702
+ norm_type='classic-rms', # memory norm type
703
+ att_groups=4, # key/value groups for SQA/GQA
704
+ att_type='sqa', # attention type, could be 'sqa', 'gqa', 'mqa' or 'mha'
705
+ att_query_groups=4, # query groups for SQA
655
706
  )
656
707
 
657
708
  # 4. Load shared embedding and memory from encoder to other models
@@ -677,7 +728,7 @@ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_
677
728
  # 1. Load tokenizer
678
729
  tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
679
730
 
680
- # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
731
+ # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range, and pre-tokenize it
681
732
  mrl_datasets = MrlDatasets.from_hf_hub(
682
733
  'ReactiveAI/TinyStories-MRL',
683
734
  tokenizer,
@@ -693,33 +744,40 @@ mrl_datasets = MrlDatasets.from_hf_hub(
693
744
  max_seq_len=256,
694
745
  )
695
746
 
747
+ mrl_datasets.pre_tokenize(verbose=True, log_interval=100)
748
+
696
749
  # 3. Create curriculum stages config
697
750
  curriculum_stages = [CurriculumConfig(
698
- steps=item['steps'],
699
- epochs=10 if item['steps'] == 4 else 8 if item['steps'] == 8 and item['is_long_range'] else 5,
751
+ steps=item['steps'], # number of steps in curriculum stage
752
+ epochs=10 if item['steps'] == 4 else 5, # number of epochs in curriculum stage
700
753
  dataset=item['dataset'],
701
754
  eval_dataset=item['eval_dataset'],
702
755
  callbacks=[
703
- MrlPrintCallback(),
756
+ MrlPrintCallback(), # Print loss/reward callback
704
757
  MrlModelSaveCallback(
705
- './models', push_to_hub=True, hub_model_critic='ReactiveAI/RxT-Alpha-Micro-Critic-MRL',
706
- hub_model_decoder='ReactiveAI/RxT-Alpha-Micro-Decoder-MRL', hub_model_encoder='ReactiveAI/RxT-Alpha-Micro-Encoder-MRL',
707
- hub_model_memory_attention='ReactiveAI/RxT-Alpha-Micro-MemAtt-MRL', private_repo=True,
708
- hf_token='HF_TOKEN', final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
758
+ './models',
759
+ push_to_hub=True,
760
+ hub_model_critic='Your critic model hub id',
761
+ hub_model_decoder='Your MRL decoder model hub id',
762
+ hub_model_encoder='Your MRL encoder model hub id',
763
+ hub_model_memory_attention='Your memory-attention model hub id',
764
+ private_repo=True,
765
+ hf_token='HF_TOKEN',
766
+ final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
709
767
  push_checkpoint_weights=True,
710
- )
768
+ ) # MRL Model save callback - save and push to hub critic model and actor components
711
769
  ],
712
- strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
713
- unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
714
- random_resets=item['steps'] > 4,
715
- random_resets_from=2,
716
- random_resets_ratio=0.4 if item['steps'] != 4 else None,
717
- separate_memory_lr=True,
718
- memory_lr=6e-4 if item['steps'] == 4 else 4e-4 if item['steps'] == 8 and item['is_long_range'] else None,
719
- lr=3e-4 if item['steps'] == 4 else 2e-4 if item['steps'] == 8 and item['is_long_range'] else None,
720
- critic_lr=4e-4 if item['steps'] == 4 else None,
721
- critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
722
- teacher_forcing=True if item['steps'] <= 8 else False,
770
+ strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY, # strategy for curriculum stage
771
+ unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4), # unfreeze strategy config
772
+ random_resets=item['steps'] > 4, # enable random memory resets
773
+ random_resets_from=2, # epoch when random resets starts
774
+ random_resets_ratio=0.4 if item['steps'] != 4 else None, # probability of STM reset before episode
775
+ separate_memory_lr=True, # use separate memory LR in current curriculum stage
776
+ memory_lr=6e-4 if item['steps'] == 4 else None, # memory LR for curriculum stage, if None, use global config
777
+ lr=3e-4 if item['steps'] == 4 else None, # model LR for curriculum stage, if None, use global config
778
+ critic_lr=4e-4 if item['steps'] == 4 else None, # critic (head) LR for curriculum stage, if None, use global config
779
+ critic_encoder_lr=2e-4 if item['steps'] == 4 else None, # critic (encoder) LR for curriculum stage, if None, use global config
780
+ teacher_forcing=item['steps'] <= 8, # use teacher forcing - save reference answers from dataset in memory instead of generated ones
723
781
  ) for item in mrl_datasets]
724
782
  ```
725
783
 
@@ -735,30 +793,33 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
735
793
  reward_model = MrlRewardModel(
736
794
  encoder.model.embedding,
737
795
  device,
738
- bleu_with_saved_data=True,
739
- reward_len=True,
740
- neg_reward_len=True,
741
- target_len_as_ref=True,
742
- bleu_factor=0.4,
743
- cos_factor=0.5,
744
- len_factor=0.1,
745
- bleu_ref_factor=0.4,
746
- bleu_saved_factor=0.6,
747
- cos_ref_factor=0.35,
748
- cos_saved_factor=0.65,
749
- neg_bleu_factor=0.45,
750
- neg_cos_factor=0.45,
751
- neg_cos_ref_factor=0.3,
752
- neg_cos_saved_factor=0.7,
753
- neg_bleu_ref_factor=0.3,
754
- neg_bleu_saved_factor=0.7,
755
- multi_cos_ref_factor=0.3,
756
- multi_cos_saved_factor= 0.5,
757
- multi_cos_running_mean_factor = 0.2,
758
- bleu_ref_weights=(0.2, 0.2, 0.3, 0.3),
759
- bleu_saved_weights=(0.2, 0.2, 0.3, 0.3),
760
- tanh_reward_scale=False,
761
- rewards_scale=1.0,
796
+ bleu_with_saved_data=True, # use saved data (previous or first interaction) in BLEU calculation
797
+ reward_len=True, # use length reward in calculation (answer_len / target_len)
798
+ max_rewarded_len=None, # target length awarded as 1.0
799
+ neg_reward_len=True, # negative length reward - lower reward when answer is too long (target_len / answer_len)
800
+ target_len_as_ref=True, # use reference answer len as target
801
+ use_running_mean=True, # use running mean embedding of all previous answers in cosine similarity calculation
802
+ allow_not_summing_factors=False, # if True sum of reward factors could be different from 1.0, it's False by default
803
+ bleu_factor=0.4, # factor for BLEU score in standard reward
804
+ cos_factor=0.5, # factor for cosine similarity score in standard reward
805
+ len_factor=0.1, # factor for length reward score in standard reward
806
+ bleu_ref_factor=0.4, # factor for reference answer score in BLEU calculation (standard mode)
807
+ bleu_saved_factor=0.6, # factor for saved data score in BLEU calculation (standard mode)
808
+ cos_ref_factor=0.35, # factor for reference answer score in cosine sim calculation (standard mode)
809
+ cos_saved_factor=0.65, # factor for saved data score in cosine sim calculation (standard mode)
810
+ multi_cos_ref_factor=0.3, # factor for reference answer in multi-step cosine sim calculation
811
+ multi_cos_saved_factor= 0.5, # factor for saved data in multi-step cosine sim calculation
812
+ multi_cos_running_mean_factor = 0.2, # factor for previous answers running mean in multi-step cosine sim calculation
813
+ neg_bleu_factor=0.45, # factor for BLEU score in negative reward
814
+ neg_cos_factor=0.45, # factor for cosine similarity score in negative reward
815
+ neg_bleu_ref_factor=0.3, # factor for reference answer score in BLEU calculation (negative mode)
816
+ neg_bleu_saved_factor=0.7, # factor for saved data score in BLEU calculation (negative mode)
817
+ neg_cos_ref_factor=0.3, # factor for reference answer score in cosine sim calculation (negative mode)
818
+ neg_cos_saved_factor=0.7, # factor for saved data score in cosine sim calculation (negative mode)
819
+ bleu_ref_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for reference answers
820
+ bleu_saved_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for saved data
821
+ tanh_reward_scale=False, # scale rewards to -1.0 to 1.0 range, instead of standard 0.0-1.0
822
+ rewards_scale=1.0, # rewards scaling factor (reward * rewards_scale)
762
823
  )
763
824
  ```
764
825
 
@@ -769,32 +830,74 @@ algorithm = PPOAlgorithm(
769
830
  PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
770
831
  )
771
832
 
772
- # 2. Create config for MRLTrainer
833
+ # 2. Create config for MRLTrainer (most of MrlConfig fields could be overwritten in each curriculum stage)
773
834
  mrl_config = MrlConfig(
774
- lr=1e-4,
775
- critic_lr=2e-4,
776
- critic_encoder_lr=1e-4,
777
- separate_memory_lr=True,
778
- memory_lr=3e-4,
779
- max_seq_len=256,
780
- critic_max_len=512,
781
- weight_decay=0.01,
782
- critic_weight_decay=0.01,
783
- update_epochs=10,
784
- pad_token_id=0,
785
- end_token_id=3,
786
- use_moe_aux_loss=True,
787
- embedding_lr=5e-6,
788
- use_memory_warmup=False,
835
+ lr=1e-4, # main LR, used for decoder layers
836
+ encoder_lr=2e-4, # encoder LR, used for encoder layers (if None, lr is used)
837
+ critic_lr=2e-4, # critic LR, used for critic value head
838
+ critic_encoder_lr=1e-4, # critic encoder LR (if not set, critic_lr is used)
839
+ separate_memory_lr=True, # use separate LR for memory attention and memory cross-attention
840
+ encoder_memory_lr=5e-4, # LR for encoder memory cross-attention (if None, memory_lr is used)
841
+ memory_lr=3e-4, # memory LR, used for decoder memory cross-attention
842
+ memory_attn_lr=5e-4, # memory attention LR (if None, memory_lr is used)
843
+ max_seq_len=256, # maximum length of single interaction
844
+ critic_max_len=512, # maximum length of critic sequence (have to be longer than actor's context)
845
+ weight_decay=0.01, # weight decay for actor AdamW optimizer
846
+ critic_weight_decay=0.01, # weight decay for critic AdamW optimizer
847
+ update_epochs=10, # inner PPO update epochs
848
+ pad_token_id=0, # tokenizer padding token id
849
+ end_token_id=3, # tokenizer EOS token id
850
+ use_moe_aux_loss=True, # add Mixture-of-Experts Router auxiliary loss to policy loss
851
+ freeze_embeddings=False, # freeze pre-trained embeddings for MRL training
852
+ embedding_lr=5e-6, # LR for embeddings, if not frozen (if None, lr is used)
853
+ use_memory_warmup=False, # memory warmup - update memory with first interaction in no grad mode, before episode, for better initialization
789
854
  )
790
855
 
791
856
  # 3. Initialize MRL Trainer
792
- trainer = MRLTrainer(actor, critic, reward_model, device, mrl_config, algorithm, use_amp=True, dtype=torch.bfloat16)
857
+ trainer = MRLTrainer(
858
+ actor, critic, reward_model, device, mrl_config, algorithm,
859
+ use_amp=True, # use autocast in MRL Training
860
+ dtype=torch.bfloat16, # data type for MRL
861
+ use_ddp=False, # use distributed training with DDP
862
+ )
793
863
 
794
864
  # 4. Train with curriculum stages config
795
865
  trainer(curriculum_stages, batch_size=batch_size)
796
866
  ```
797
867
 
868
+ ## Experimental attention layers
869
+ While working on reactive architectures, we also developed several new types of attention layers, some of which achieve
870
+ very promising results. Even considering that reactive models, processing single interactions, have much lower computational
871
+ requirements, we need the most efficient attention mechanisms, consistent with memory requirements. Since memory is not a
872
+ sequence but a set, spatial sparsity is probably not a good solution here, so we were looking for an efficient alternative
873
+ to Flex Attention with full access to all memory positions. New attention layers are implemented in `rxnn.experimental.attention`
874
+ module:
875
+ - **Grouped Mixture-of-Experts Attention (GMA)** - use MoE routing to dynamically select K active key/value heads for each token, instead
876
+ of using static selection in **GQA**. While it's theoretically interesting, in practice, it achieved worse results than **GQA**,
877
+ and even **MQA**, in all test, and is a lot slower because of routing overhead, so we abandoned further research. More details
878
+ in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
879
+ - **Deep Mixture-of-Experts Attention (DMA)** - extends **GMA** with the same MoE routing for query heads. Like **GMA**,
880
+ it gives even worse results, and all the computational performance benefits from the sparse query heads (like in
881
+ **SQA**) are lost by routing overhead (lack of specialized kernels for heads selection), so the further research is also
882
+ abandoned. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
883
+ - **Hierarchical MoE Attention (HMA)** - extends **DMA/GMA**, using different number of query/key/value heads for tokens with
884
+ different priority. It's only the idea and is not implemented, because of poor results of GMA/DMA. [More info](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/hierarchical_moe_attention.md)
885
+ - **Sparse Query Attention (SQA)** - the most trivial extension to GQA, reducing not only the number of key/value heads, but
886
+ also the number of query heads. It results in even 2-3x faster model (for 32k/131k tokens). **SQA** is the fastest attention
887
+ mechanism for 0-131k sequence length, for longer sequences **Flex Attention** becomes faster. That's ideal for reactive models,
888
+ that doesn't need a million token context for single interaction processing. In tested cases **SQA** models results (loss/accuracy)
889
+ were close to GQA, differences were almost unnoticeable, but it still requires more tests. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md)
890
+ - **Flex Sparse Query Attention (Flex-SQA)** - **Flex Attention** combined with **SQA** - enable handling 4-8x longer sliding
891
+ windows, in shorter time, than base **Flex**, so it should result in better results. **Flex-SQA** should be the fastest
892
+ attention mechanism for sequences longer than 131k tokens and is made for classic transformers, or potentially self-attention
893
+ in bigger reactive models. Currently, it's viable only with symmetric variants of **SQA** (same number of used query
894
+ and key/value heads), because kernels aren't compatible with GQA in sliding windows and not symmetric variants is 2x slower,
895
+ than it should be. Docs and tests in progress
896
+
897
+ ### Test usage
898
+ Experimental attention layers could be tested with `ExperimentalAttentionTransformer` model from `rxnn.experimental.models`,
899
+ Usage example could be found in our notebooks repository - [RxNN Notebooks](https://github.com/RxAI-dev/rxnn-notebooks)
900
+
798
901
  Apache License
799
902
  Version 2.0, January 2004
800
903
  http://www.apache.org/licenses/
@@ -17,7 +17,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
17
  rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
20
- rxnn/training/mrl.py,sha256=2J6Wh4xtsVoE6duEevmovDpmSsMkEoH39Ru0bE8lhFo,65481
20
+ rxnn/training/mrl.py,sha256=c_7P_DhroK3pQLubfmlVryWBSwlZ0BssU8zZ6UhjOaI,65919
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.67.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.67.dist-info/METADATA,sha256=LEIwAXp3Eau7DrEUCeJ5etTC6nl-rNzsQfJxiRXD7xI,49548
38
- rxnn-0.2.67.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.67.dist-info/RECORD,,
36
+ rxnn-0.2.69.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.69.dist-info/METADATA,sha256=YcmghdF8ypeyOCmglJaws18cDtTqSIE8P-gReGIMzsU,60420
38
+ rxnn-0.2.69.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.69.dist-info/RECORD,,
File without changes
File without changes