rxnn 0.2.67__tar.gz → 0.2.68__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.67 → rxnn-0.2.68}/PKG-INFO +186 -83
  2. {rxnn-0.2.67 → rxnn-0.2.68}/README.md +185 -82
  3. {rxnn-0.2.67 → rxnn-0.2.68}/pyproject.toml +1 -1
  4. {rxnn-0.2.67 → rxnn-0.2.68}/LICENSE +0 -0
  5. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/.DS_Store +0 -0
  6. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/__init__.py +0 -0
  7. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/__init__.py +0 -0
  8. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/attention.py +0 -0
  9. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/models.py +0 -0
  10. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/attention.py +0 -0
  13. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/norm.py +0 -0
  14. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/stm.py +0 -0
  15. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/rxt/__init__.py +0 -0
  16. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/rxt/models.py +0 -0
  17. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/__init__.py +0 -0
  18. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/base.py +0 -0
  19. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/bml.py +0 -0
  20. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/callbacks.py +0 -0
  21. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/dataset.py +0 -0
  22. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/ddp.py +0 -0
  23. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/models.py +0 -0
  24. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/mrl.py +0 -0
  25. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.67
3
+ Version: 0.2.68
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -57,10 +57,10 @@ sequence, it has to process it and save it in memory, but it could be done in ba
57
57
 
58
58
  ## Release plan
59
59
  We are working on three new reactive architectures, that progressively advance from language models to awareness models:
60
- - Reactive Transformer: Reactive Language Model (RLM) with Short-Term Memory
61
- - Preactor: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
60
+ - **Reactive Transformer**: Reactive Language Model (RLM) with Short-Term Memory. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md)
61
+ - **Preactor**: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
62
62
  single message length is limited) and the ability to learn from interactions (Live Learning)
63
- - Reactor: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
63
+ - **Reactor**: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
64
64
 
65
65
  Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
66
66
  released with next versions of **RxNN** framework:
@@ -207,11 +207,57 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
207
207
  return self.model(x, attention_mask=attention_mask)
208
208
  ```
209
209
 
210
+ #### RxT-Alpha
211
+ `RxTAlphaEncoder` and `RxTAlphaDecoder` are ready to use **Reactive Transformer** components, compatible with Hugging Face
212
+ Hub (the above example is based on their code), so it could be used instead of creating custom class. Example usage could
213
+ be found in [pre-training docs](#pre-training)
214
+
210
215
  ### Memory
211
- The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
212
- include **Long-Term Memory**.
216
+ The _memory_ module includes **Short-Term Memory (STM)** and layers responsible for its update. In future versions it will also
217
+ include **Long-Term Memory (LTM)**.
218
+
219
+ #### Short Term Memory
220
+ The main `ShortTermMemory` class is located in `rxnn.memory.stm` module. As described in [Reactive Transformer research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md),
221
+ each transformer (encoder and decoder) layer has its own **STM** layer of shape `[batch_size, stm_size, embed_dim]`. Initially,
222
+ for the first training stages (pre-training and supervised fine-tuning), **STM** is in "single/no batch" mode (`batch_size = 1`),
223
+ because it's not used. For reinforcement learning stages (**MRL/RxRLHF/BRL**), we have to switch short-term memory to batch
224
+ mode, because items in batches are independent. After training, we could switch back to "single/no batch" mode. Example:
225
+ ```python
226
+ from rxnn.memory.stm import ShortTermMemory
227
+
228
+ num_layers = 10
229
+ stm_size = 256
230
+ embed_dim = 128
231
+ batch_size = 32
213
232
 
214
- The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
233
+ # 1. Init STM
234
+ stm = ShortTermMemory(
235
+ num_layers, embed_dim, stm_size,
236
+ init_type='normal' # memory init type, 'normal' is default and means normal distribution with 0.0 mean and 0.02 std
237
+ )
238
+
239
+ # 2. Set "batch" mode for MRL
240
+ stm.batched_memory(
241
+ batch_size,
242
+ init_type='standard' # init type could be changed for batch mode, 'standard' is normal distribution with 0.0 mean and 1.0 std
243
+ )
244
+
245
+ # 3. Reset STM with optional init type change
246
+ stm.reset(init_type='uniform') # init type could be also 'ones' or 'zeros', but it's not recommended
247
+
248
+ # 4. Back to "single" mode for inference (optionally using mean value from batch)
249
+ stm.single_memory(
250
+ init_type='standard', # we could change init type again
251
+ use_mean_from_batch=True # use mean values from batch as new memory
252
+ )
253
+ ```
254
+
255
+ > ##### Other utils
256
+ > `ShortTermMemory` could be also resized with `stm.resize(new_stm_size, init_type)` method, detached and cloned
257
+ > with `stm.clone_detach_reset()` (used in MRL), or could be made trainable (experimental option):
258
+ > - could be initialized as trainable - `stm = ShortTermMemory(num_layers, embed_dim, stm_size, is_trainable=True)`
259
+ > - could be switched to trainable - `stm.make_trainable()`
260
+ > - and switched back to buffer - `stm.freeze()`
215
261
 
216
262
  #### Memory Attention Network
217
263
  **Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
@@ -299,6 +345,10 @@ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0")
299
345
  >
300
346
  > **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
301
347
 
348
+ ##### RxT-Alpha Memory Attention
349
+ `RxTAlphaMemoryAttention` is ready to use Memory Attention network for **Reactive Transformer** Proof-of-Concept, that
350
+ could be used instead of creating custom class. Example is in [Memory Reinforcement Learning docs](#memory-reinforcement-learning)
351
+
302
352
  ### Training
303
353
  Training module includes **Trainers** for different training stages of reactive models and shared training utils.
304
354
 
@@ -368,16 +418,17 @@ config = {
368
418
  'num_layers': 10,
369
419
  'vocab_size': vocab_size,
370
420
  'embed_dim': embed_dim,
371
- 'att_heads': 16,
372
- 'att_groups': 8,
421
+ 'att_heads': 16, # attention heads, in SQA it's used only for dimension split
422
+ 'att_groups': 8, # key/value groups for GQA/SQA
373
423
  'seq_len': seq_len,
374
424
  'stm_size': seq_len,
375
- 'use_flash_attention': False,
376
- 'use_gated': True,
425
+ 'use_flash_attention': False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend) - not recommended
426
+ 'use_gated': True, # use Gated Linear Units in feed forward, True by default
427
+ 'ff_activation': 'silu', # feed forward activation, 'silu' is default for SwiGLU layers
377
428
  'ff_dropout': 0.1,
378
- 'self_att_type': 'sqa',
379
- 'cross_att_type': 'sqa',
380
- 'att_query_groups': 8,
429
+ 'self_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
430
+ 'cross_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
431
+ 'att_query_groups': 8, # query groups for SQA
381
432
  }
382
433
 
383
434
  encoder_config = {
@@ -387,9 +438,9 @@ encoder_config = {
387
438
 
388
439
  decoder_config = {
389
440
  'ff_dim': 256,
390
- 'use_moe': True,
391
- 'num_experts': 20,
392
- 'moe_top_k': 4,
441
+ 'use_moe': True, # use Mixture-of-Experts feed forward
442
+ 'num_experts': 20, # number of experts
443
+ 'moe_top_k': 4, # number of activated experts (per token)
393
444
  **config
394
445
  }
395
446
 
@@ -647,11 +698,11 @@ mem_attn = RxTAlphaMemoryAttention(
647
698
  att_heads=8,
648
699
  seq_len=256,
649
700
  stm_size=256,
650
- use_flash_attention=False,
651
- norm_type='classic-rms',
652
- att_groups=4,
653
- att_type='sqa',
654
- att_query_groups=4,
701
+ use_flash_attention=False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend)
702
+ norm_type='classic-rms', # memory norm type
703
+ att_groups=4, # key/value groups for SQA/GQA
704
+ att_type='sqa', # attention type, could be 'sqa', 'gqa', 'mqa' or 'mha'
705
+ att_query_groups=4, # query groups for SQA
655
706
  )
656
707
 
657
708
  # 4. Load shared embedding and memory from encoder to other models
@@ -677,7 +728,7 @@ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_
677
728
  # 1. Load tokenizer
678
729
  tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
679
730
 
680
- # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
731
+ # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range, and pre-tokenize it
681
732
  mrl_datasets = MrlDatasets.from_hf_hub(
682
733
  'ReactiveAI/TinyStories-MRL',
683
734
  tokenizer,
@@ -693,33 +744,40 @@ mrl_datasets = MrlDatasets.from_hf_hub(
693
744
  max_seq_len=256,
694
745
  )
695
746
 
747
+ mrl_datasets.pre_tokenize(verbose=True, log_interval=100)
748
+
696
749
  # 3. Create curriculum stages config
697
750
  curriculum_stages = [CurriculumConfig(
698
- steps=item['steps'],
699
- epochs=10 if item['steps'] == 4 else 8 if item['steps'] == 8 and item['is_long_range'] else 5,
751
+ steps=item['steps'], # number of steps in curriculum stage
752
+ epochs=10 if item['steps'] == 4 else 5, # number of epochs in curriculum stage
700
753
  dataset=item['dataset'],
701
754
  eval_dataset=item['eval_dataset'],
702
755
  callbacks=[
703
- MrlPrintCallback(),
756
+ MrlPrintCallback(), # Print loss/reward callback
704
757
  MrlModelSaveCallback(
705
- './models', push_to_hub=True, hub_model_critic='ReactiveAI/RxT-Alpha-Micro-Critic-MRL',
706
- hub_model_decoder='ReactiveAI/RxT-Alpha-Micro-Decoder-MRL', hub_model_encoder='ReactiveAI/RxT-Alpha-Micro-Encoder-MRL',
707
- hub_model_memory_attention='ReactiveAI/RxT-Alpha-Micro-MemAtt-MRL', private_repo=True,
708
- hf_token='HF_TOKEN', final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
758
+ './models',
759
+ push_to_hub=True,
760
+ hub_model_critic='Your critic model hub id',
761
+ hub_model_decoder='Your MRL decoder model hub id',
762
+ hub_model_encoder='Your MRL encoder model hub id',
763
+ hub_model_memory_attention='Your memory-attention model hub id',
764
+ private_repo=True,
765
+ hf_token='HF_TOKEN',
766
+ final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
709
767
  push_checkpoint_weights=True,
710
- )
768
+ ) # MRL Model save callback - save and push to hub critic model and actor components
711
769
  ],
712
- strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
713
- unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
714
- random_resets=item['steps'] > 4,
715
- random_resets_from=2,
716
- random_resets_ratio=0.4 if item['steps'] != 4 else None,
717
- separate_memory_lr=True,
718
- memory_lr=6e-4 if item['steps'] == 4 else 4e-4 if item['steps'] == 8 and item['is_long_range'] else None,
719
- lr=3e-4 if item['steps'] == 4 else 2e-4 if item['steps'] == 8 and item['is_long_range'] else None,
720
- critic_lr=4e-4 if item['steps'] == 4 else None,
721
- critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
722
- teacher_forcing=True if item['steps'] <= 8 else False,
770
+ strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY, # strategy for curriculum stage
771
+ unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4), # unfreeze strategy config
772
+ random_resets=item['steps'] > 4, # enable random memory resets
773
+ random_resets_from=2, # epoch when random resets starts
774
+ random_resets_ratio=0.4 if item['steps'] != 4 else None, # probability of STM reset before episode
775
+ separate_memory_lr=True, # use separate memory LR in current curriculum stage
776
+ memory_lr=6e-4 if item['steps'] == 4 else None, # memory LR for curriculum stage, if None, use global config
777
+ lr=3e-4 if item['steps'] == 4 else None, # model LR for curriculum stage, if None, use global config
778
+ critic_lr=4e-4 if item['steps'] == 4 else None, # critic (head) LR for curriculum stage, if None, use global config
779
+ critic_encoder_lr=2e-4 if item['steps'] == 4 else None, # critic (encoder) LR for curriculum stage, if None, use global config
780
+ teacher_forcing=item['steps'] <= 8, # use teacher forcing - save reference answers from dataset in memory instead of generated ones
723
781
  ) for item in mrl_datasets]
724
782
  ```
725
783
 
@@ -735,30 +793,33 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
735
793
  reward_model = MrlRewardModel(
736
794
  encoder.model.embedding,
737
795
  device,
738
- bleu_with_saved_data=True,
739
- reward_len=True,
740
- neg_reward_len=True,
741
- target_len_as_ref=True,
742
- bleu_factor=0.4,
743
- cos_factor=0.5,
744
- len_factor=0.1,
745
- bleu_ref_factor=0.4,
746
- bleu_saved_factor=0.6,
747
- cos_ref_factor=0.35,
748
- cos_saved_factor=0.65,
749
- neg_bleu_factor=0.45,
750
- neg_cos_factor=0.45,
751
- neg_cos_ref_factor=0.3,
752
- neg_cos_saved_factor=0.7,
753
- neg_bleu_ref_factor=0.3,
754
- neg_bleu_saved_factor=0.7,
755
- multi_cos_ref_factor=0.3,
756
- multi_cos_saved_factor= 0.5,
757
- multi_cos_running_mean_factor = 0.2,
758
- bleu_ref_weights=(0.2, 0.2, 0.3, 0.3),
759
- bleu_saved_weights=(0.2, 0.2, 0.3, 0.3),
760
- tanh_reward_scale=False,
761
- rewards_scale=1.0,
796
+ bleu_with_saved_data=True, # use saved data (previous or first interaction) in BLEU calculation
797
+ reward_len=True, # use length reward in calculation (answer_len / target_len)
798
+ max_rewarded_len=None, # target length awarded as 1.0
799
+ neg_reward_len=True, # negative length reward - lower reward when answer is too long (target_len / answer_len)
800
+ target_len_as_ref=True, # use reference answer len as target
801
+ use_running_mean=True, # use running mean embedding of all previous answers in cosine similarity calculation
802
+ allow_not_summing_factors=False, # if True sum of reward factors could be different from 1.0, it's False by default
803
+ bleu_factor=0.4, # factor for BLEU score in standard reward
804
+ cos_factor=0.5, # factor for cosine similarity score in standard reward
805
+ len_factor=0.1, # factor for length reward score in standard reward
806
+ bleu_ref_factor=0.4, # factor for reference answer score in BLEU calculation (standard mode)
807
+ bleu_saved_factor=0.6, # factor for saved data score in BLEU calculation (standard mode)
808
+ cos_ref_factor=0.35, # factor for reference answer score in cosine sim calculation (standard mode)
809
+ cos_saved_factor=0.65, # factor for saved data score in cosine sim calculation (standard mode)
810
+ multi_cos_ref_factor=0.3, # factor for reference answer in multi-step cosine sim calculation
811
+ multi_cos_saved_factor= 0.5, # factor for saved data in multi-step cosine sim calculation
812
+ multi_cos_running_mean_factor = 0.2, # factor for previous answers running mean in multi-step cosine sim calculation
813
+ neg_bleu_factor=0.45, # factor for BLEU score in negative reward
814
+ neg_cos_factor=0.45, # factor for cosine similarity score in negative reward
815
+ neg_bleu_ref_factor=0.3, # factor for reference answer score in BLEU calculation (negative mode)
816
+ neg_bleu_saved_factor=0.7, # factor for saved data score in BLEU calculation (negative mode)
817
+ neg_cos_ref_factor=0.3, # factor for reference answer score in cosine sim calculation (negative mode)
818
+ neg_cos_saved_factor=0.7, # factor for saved data score in cosine sim calculation (negative mode)
819
+ bleu_ref_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for reference answers
820
+ bleu_saved_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for saved data
821
+ tanh_reward_scale=False, # scale rewards to -1.0 to 1.0 range, instead of standard 0.0-1.0
822
+ rewards_scale=1.0, # rewards scaling factor (reward * rewards_scale)
762
823
  )
763
824
  ```
764
825
 
@@ -769,32 +830,74 @@ algorithm = PPOAlgorithm(
769
830
  PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
770
831
  )
771
832
 
772
- # 2. Create config for MRLTrainer
833
+ # 2. Create config for MRLTrainer (most of MrlConfig fields could be overwritten in each curriculum stage)
773
834
  mrl_config = MrlConfig(
774
- lr=1e-4,
775
- critic_lr=2e-4,
776
- critic_encoder_lr=1e-4,
777
- separate_memory_lr=True,
778
- memory_lr=3e-4,
779
- max_seq_len=256,
780
- critic_max_len=512,
781
- weight_decay=0.01,
782
- critic_weight_decay=0.01,
783
- update_epochs=10,
784
- pad_token_id=0,
785
- end_token_id=3,
786
- use_moe_aux_loss=True,
787
- embedding_lr=5e-6,
788
- use_memory_warmup=False,
835
+ lr=1e-4, # main LR, used for decoder layers
836
+ encoder_lr=2e-4, # encoder LR, used for encoder layers (if None, lr is used)
837
+ critic_lr=2e-4, # critic LR, used for critic value head
838
+ critic_encoder_lr=1e-4, # critic encoder LR (if not set, critic_lr is used)
839
+ separate_memory_lr=True, # use separate LR for memory attention and memory cross-attention
840
+ encoder_memory_lr=5e-4, # LR for encoder memory cross-attention (if None, memory_lr is used)
841
+ memory_lr=3e-4, # memory LR, used for decoder memory cross-attention
842
+ memory_attn_lr=5e-4, # memory attention LR (if None, memory_lr is used)
843
+ max_seq_len=256, # maximum length of single interaction
844
+ critic_max_len=512, # maximum length of critic sequence (have to be longer than actor's context)
845
+ weight_decay=0.01, # weight decay for actor AdamW optimizer
846
+ critic_weight_decay=0.01, # weight decay for critic AdamW optimizer
847
+ update_epochs=10, # inner PPO update epochs
848
+ pad_token_id=0, # tokenizer padding token id
849
+ end_token_id=3, # tokenizer EOS token id
850
+ use_moe_aux_loss=True, # add Mixture-of-Experts Router auxiliary loss to policy loss
851
+ freeze_embeddings=False, # freeze pre-trained embeddings for MRL training
852
+ embedding_lr=5e-6, # LR for embeddings, if not frozen (if None, lr is used)
853
+ use_memory_warmup=False, # memory warmup - update memory with first interaction in no grad mode, before episode, for better initialization
789
854
  )
790
855
 
791
856
  # 3. Initialize MRL Trainer
792
- trainer = MRLTrainer(actor, critic, reward_model, device, mrl_config, algorithm, use_amp=True, dtype=torch.bfloat16)
857
+ trainer = MRLTrainer(
858
+ actor, critic, reward_model, device, mrl_config, algorithm,
859
+ use_amp=True, # use autocast in MRL Training
860
+ dtype=torch.bfloat16, # data type for MRL
861
+ use_ddp=False, # use distributed training with DDP
862
+ )
793
863
 
794
864
  # 4. Train with curriculum stages config
795
865
  trainer(curriculum_stages, batch_size=batch_size)
796
866
  ```
797
867
 
868
+ ## Experimental attention layers
869
+ While working on reactive architectures, we also developed several new types of attention layers, some of which achieve
870
+ very promising results. Even considering that reactive models, processing single interactions, have much lower computational
871
+ requirements, we need the most efficient attention mechanisms, consistent with memory requirements. Since memory is not a
872
+ sequence but a set, spatial sparsity is probably not a good solution here, so we were looking for an efficient alternative
873
+ to Flex Attention with full access to all memory positions. New attention layers are implemented in `rxnn.experimental.attention`
874
+ module:
875
+ - **Grouped Mixture-of-Experts Attention (GMA)** - use MoE routing to dynamically select K active key/value heads for each token, instead
876
+ of using static selection in **GQA**. While it's theoretically interesting, in practice, it achieved worse results than **GQA**,
877
+ and even **MQA**, in all test, and is a lot slower because of routing overhead, so we abandoned further research. More details
878
+ in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
879
+ - **Deep Mixture-of-Experts Attention (DMA)** - extends **GMA** with the same MoE routing for query heads. Like **GMA**,
880
+ it gives even worse results, and all the computational performance benefits from the sparse query heads (like in
881
+ **SQA**) are lost by routing overhead (lack of specialized kernels for heads selection), so the further research is also
882
+ abandoned. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
883
+ - **Hierarchical MoE Attention (HMA)** - extends **DMA/GMA**, using different number of query/key/value heads for tokens with
884
+ different priority. It's only the idea and is not implemented, because of poor results of GMA/DMA. [More info](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/hierarchical_moe_attention.md)
885
+ - **Sparse Query Attention (SQA)** - the most trivial extension to GQA, reducing not only the number of key/value heads, but
886
+ also the number of query heads. It results in even 2-3x faster model (for 32k/131k tokens). **SQA** is the fastest attention
887
+ mechanism for 0-131k sequence length, for longer sequences **Flex Attention** becomes faster. That's ideal for reactive models,
888
+ that doesn't need a million token context for single interaction processing. In tested cases **SQA** models results (loss/accuracy)
889
+ were close to GQA, differences were almost unnoticeable, but it still requires more tests. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md)
890
+ - **Flex Sparse Query Attention (Flex-SQA)** - **Flex Attention** combined with **SQA** - enable handling 4-8x longer sliding
891
+ windows, in shorter time, than base **Flex**, so it should result in better results. **Flex-SQA** should be the fastest
892
+ attention mechanism for sequences longer than 131k tokens and is made for classic transformers, or potentially self-attention
893
+ in bigger reactive models. Currently, it's viable only with symmetric variants of **SQA** (same number of used query
894
+ and key/value heads), because kernels aren't compatible with GQA in sliding windows and not symmetric variants is 2x slower,
895
+ than it should be. Docs and tests in progress
896
+
897
+ ### Test usage
898
+ Experimental attention layers could be tested with `ExperimentalAttentionTransformer` model from `rxnn.experimental.models`,
899
+ Usage example could be found in our notebooks repository - [RxNN Notebooks](https://github.com/RxAI-dev/rxnn-notebooks)
900
+
798
901
  Apache License
799
902
  Version 2.0, January 2004
800
903
  http://www.apache.org/licenses/
@@ -31,10 +31,10 @@ sequence, it has to process it and save it in memory, but it could be done in ba
31
31
 
32
32
  ## Release plan
33
33
  We are working on three new reactive architectures, that progressively advance from language models to awareness models:
34
- - Reactive Transformer: Reactive Language Model (RLM) with Short-Term Memory
35
- - Preactor: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
34
+ - **Reactive Transformer**: Reactive Language Model (RLM) with Short-Term Memory. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md)
35
+ - **Preactor**: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
36
36
  single message length is limited) and the ability to learn from interactions (Live Learning)
37
- - Reactor: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
37
+ - **Reactor**: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
38
38
 
39
39
  Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
40
40
  released with next versions of **RxNN** framework:
@@ -181,11 +181,57 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
181
181
  return self.model(x, attention_mask=attention_mask)
182
182
  ```
183
183
 
184
+ #### RxT-Alpha
185
+ `RxTAlphaEncoder` and `RxTAlphaDecoder` are ready to use **Reactive Transformer** components, compatible with Hugging Face
186
+ Hub (the above example is based on their code), so it could be used instead of creating custom class. Example usage could
187
+ be found in [pre-training docs](#pre-training)
188
+
184
189
  ### Memory
185
- The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
186
- include **Long-Term Memory**.
190
+ The _memory_ module includes **Short-Term Memory (STM)** and layers responsible for its update. In future versions it will also
191
+ include **Long-Term Memory (LTM)**.
192
+
193
+ #### Short Term Memory
194
+ The main `ShortTermMemory` class is located in `rxnn.memory.stm` module. As described in [Reactive Transformer research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md),
195
+ each transformer (encoder and decoder) layer has its own **STM** layer of shape `[batch_size, stm_size, embed_dim]`. Initially,
196
+ for the first training stages (pre-training and supervised fine-tuning), **STM** is in "single/no batch" mode (`batch_size = 1`),
197
+ because it's not used. For reinforcement learning stages (**MRL/RxRLHF/BRL**), we have to switch short-term memory to batch
198
+ mode, because items in batches are independent. After training, we could switch back to "single/no batch" mode. Example:
199
+ ```python
200
+ from rxnn.memory.stm import ShortTermMemory
201
+
202
+ num_layers = 10
203
+ stm_size = 256
204
+ embed_dim = 128
205
+ batch_size = 32
187
206
 
188
- The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
207
+ # 1. Init STM
208
+ stm = ShortTermMemory(
209
+ num_layers, embed_dim, stm_size,
210
+ init_type='normal' # memory init type, 'normal' is default and means normal distribution with 0.0 mean and 0.02 std
211
+ )
212
+
213
+ # 2. Set "batch" mode for MRL
214
+ stm.batched_memory(
215
+ batch_size,
216
+ init_type='standard' # init type could be changed for batch mode, 'standard' is normal distribution with 0.0 mean and 1.0 std
217
+ )
218
+
219
+ # 3. Reset STM with optional init type change
220
+ stm.reset(init_type='uniform') # init type could be also 'ones' or 'zeros', but it's not recommended
221
+
222
+ # 4. Back to "single" mode for inference (optionally using mean value from batch)
223
+ stm.single_memory(
224
+ init_type='standard', # we could change init type again
225
+ use_mean_from_batch=True # use mean values from batch as new memory
226
+ )
227
+ ```
228
+
229
+ > ##### Other utils
230
+ > `ShortTermMemory` could be also resized with `stm.resize(new_stm_size, init_type)` method, detached and cloned
231
+ > with `stm.clone_detach_reset()` (used in MRL), or could be made trainable (experimental option):
232
+ > - could be initialized as trainable - `stm = ShortTermMemory(num_layers, embed_dim, stm_size, is_trainable=True)`
233
+ > - could be switched to trainable - `stm.make_trainable()`
234
+ > - and switched back to buffer - `stm.freeze()`
189
235
 
190
236
  #### Memory Attention Network
191
237
  **Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
@@ -273,6 +319,10 @@ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0")
273
319
  >
274
320
  > **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
275
321
 
322
+ ##### RxT-Alpha Memory Attention
323
+ `RxTAlphaMemoryAttention` is ready to use Memory Attention network for **Reactive Transformer** Proof-of-Concept, that
324
+ could be used instead of creating custom class. Example is in [Memory Reinforcement Learning docs](#memory-reinforcement-learning)
325
+
276
326
  ### Training
277
327
  Training module includes **Trainers** for different training stages of reactive models and shared training utils.
278
328
 
@@ -342,16 +392,17 @@ config = {
342
392
  'num_layers': 10,
343
393
  'vocab_size': vocab_size,
344
394
  'embed_dim': embed_dim,
345
- 'att_heads': 16,
346
- 'att_groups': 8,
395
+ 'att_heads': 16, # attention heads, in SQA it's used only for dimension split
396
+ 'att_groups': 8, # key/value groups for GQA/SQA
347
397
  'seq_len': seq_len,
348
398
  'stm_size': seq_len,
349
- 'use_flash_attention': False,
350
- 'use_gated': True,
399
+ 'use_flash_attention': False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend) - not recommended
400
+ 'use_gated': True, # use Gated Linear Units in feed forward, True by default
401
+ 'ff_activation': 'silu', # feed forward activation, 'silu' is default for SwiGLU layers
351
402
  'ff_dropout': 0.1,
352
- 'self_att_type': 'sqa',
353
- 'cross_att_type': 'sqa',
354
- 'att_query_groups': 8,
403
+ 'self_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
404
+ 'cross_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
405
+ 'att_query_groups': 8, # query groups for SQA
355
406
  }
356
407
 
357
408
  encoder_config = {
@@ -361,9 +412,9 @@ encoder_config = {
361
412
 
362
413
  decoder_config = {
363
414
  'ff_dim': 256,
364
- 'use_moe': True,
365
- 'num_experts': 20,
366
- 'moe_top_k': 4,
415
+ 'use_moe': True, # use Mixture-of-Experts feed forward
416
+ 'num_experts': 20, # number of experts
417
+ 'moe_top_k': 4, # number of activated experts (per token)
367
418
  **config
368
419
  }
369
420
 
@@ -621,11 +672,11 @@ mem_attn = RxTAlphaMemoryAttention(
621
672
  att_heads=8,
622
673
  seq_len=256,
623
674
  stm_size=256,
624
- use_flash_attention=False,
625
- norm_type='classic-rms',
626
- att_groups=4,
627
- att_type='sqa',
628
- att_query_groups=4,
675
+ use_flash_attention=False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend)
676
+ norm_type='classic-rms', # memory norm type
677
+ att_groups=4, # key/value groups for SQA/GQA
678
+ att_type='sqa', # attention type, could be 'sqa', 'gqa', 'mqa' or 'mha'
679
+ att_query_groups=4, # query groups for SQA
629
680
  )
630
681
 
631
682
  # 4. Load shared embedding and memory from encoder to other models
@@ -651,7 +702,7 @@ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_
651
702
  # 1. Load tokenizer
652
703
  tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
653
704
 
654
- # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
705
+ # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range, and pre-tokenize it
655
706
  mrl_datasets = MrlDatasets.from_hf_hub(
656
707
  'ReactiveAI/TinyStories-MRL',
657
708
  tokenizer,
@@ -667,33 +718,40 @@ mrl_datasets = MrlDatasets.from_hf_hub(
667
718
  max_seq_len=256,
668
719
  )
669
720
 
721
+ mrl_datasets.pre_tokenize(verbose=True, log_interval=100)
722
+
670
723
  # 3. Create curriculum stages config
671
724
  curriculum_stages = [CurriculumConfig(
672
- steps=item['steps'],
673
- epochs=10 if item['steps'] == 4 else 8 if item['steps'] == 8 and item['is_long_range'] else 5,
725
+ steps=item['steps'], # number of steps in curriculum stage
726
+ epochs=10 if item['steps'] == 4 else 5, # number of epochs in curriculum stage
674
727
  dataset=item['dataset'],
675
728
  eval_dataset=item['eval_dataset'],
676
729
  callbacks=[
677
- MrlPrintCallback(),
730
+ MrlPrintCallback(), # Print loss/reward callback
678
731
  MrlModelSaveCallback(
679
- './models', push_to_hub=True, hub_model_critic='ReactiveAI/RxT-Alpha-Micro-Critic-MRL',
680
- hub_model_decoder='ReactiveAI/RxT-Alpha-Micro-Decoder-MRL', hub_model_encoder='ReactiveAI/RxT-Alpha-Micro-Encoder-MRL',
681
- hub_model_memory_attention='ReactiveAI/RxT-Alpha-Micro-MemAtt-MRL', private_repo=True,
682
- hf_token='HF_TOKEN', final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
732
+ './models',
733
+ push_to_hub=True,
734
+ hub_model_critic='Your critic model hub id',
735
+ hub_model_decoder='Your MRL decoder model hub id',
736
+ hub_model_encoder='Your MRL encoder model hub id',
737
+ hub_model_memory_attention='Your memory-attention model hub id',
738
+ private_repo=True,
739
+ hf_token='HF_TOKEN',
740
+ final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
683
741
  push_checkpoint_weights=True,
684
- )
742
+ ) # MRL Model save callback - save and push to hub critic model and actor components
685
743
  ],
686
- strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
687
- unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
688
- random_resets=item['steps'] > 4,
689
- random_resets_from=2,
690
- random_resets_ratio=0.4 if item['steps'] != 4 else None,
691
- separate_memory_lr=True,
692
- memory_lr=6e-4 if item['steps'] == 4 else 4e-4 if item['steps'] == 8 and item['is_long_range'] else None,
693
- lr=3e-4 if item['steps'] == 4 else 2e-4 if item['steps'] == 8 and item['is_long_range'] else None,
694
- critic_lr=4e-4 if item['steps'] == 4 else None,
695
- critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
696
- teacher_forcing=True if item['steps'] <= 8 else False,
744
+ strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY, # strategy for curriculum stage
745
+ unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4), # unfreeze strategy config
746
+ random_resets=item['steps'] > 4, # enable random memory resets
747
+ random_resets_from=2, # epoch when random resets starts
748
+ random_resets_ratio=0.4 if item['steps'] != 4 else None, # probability of STM reset before episode
749
+ separate_memory_lr=True, # use separate memory LR in current curriculum stage
750
+ memory_lr=6e-4 if item['steps'] == 4 else None, # memory LR for curriculum stage, if None, use global config
751
+ lr=3e-4 if item['steps'] == 4 else None, # model LR for curriculum stage, if None, use global config
752
+ critic_lr=4e-4 if item['steps'] == 4 else None, # critic (head) LR for curriculum stage, if None, use global config
753
+ critic_encoder_lr=2e-4 if item['steps'] == 4 else None, # critic (encoder) LR for curriculum stage, if None, use global config
754
+ teacher_forcing=item['steps'] <= 8, # use teacher forcing - save reference answers from dataset in memory instead of generated ones
697
755
  ) for item in mrl_datasets]
698
756
  ```
699
757
 
@@ -709,30 +767,33 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
709
767
  reward_model = MrlRewardModel(
710
768
  encoder.model.embedding,
711
769
  device,
712
- bleu_with_saved_data=True,
713
- reward_len=True,
714
- neg_reward_len=True,
715
- target_len_as_ref=True,
716
- bleu_factor=0.4,
717
- cos_factor=0.5,
718
- len_factor=0.1,
719
- bleu_ref_factor=0.4,
720
- bleu_saved_factor=0.6,
721
- cos_ref_factor=0.35,
722
- cos_saved_factor=0.65,
723
- neg_bleu_factor=0.45,
724
- neg_cos_factor=0.45,
725
- neg_cos_ref_factor=0.3,
726
- neg_cos_saved_factor=0.7,
727
- neg_bleu_ref_factor=0.3,
728
- neg_bleu_saved_factor=0.7,
729
- multi_cos_ref_factor=0.3,
730
- multi_cos_saved_factor= 0.5,
731
- multi_cos_running_mean_factor = 0.2,
732
- bleu_ref_weights=(0.2, 0.2, 0.3, 0.3),
733
- bleu_saved_weights=(0.2, 0.2, 0.3, 0.3),
734
- tanh_reward_scale=False,
735
- rewards_scale=1.0,
770
+ bleu_with_saved_data=True, # use saved data (previous or first interaction) in BLEU calculation
771
+ reward_len=True, # use length reward in calculation (answer_len / target_len)
772
+ max_rewarded_len=None, # target length awarded as 1.0
773
+ neg_reward_len=True, # negative length reward - lower reward when answer is too long (target_len / answer_len)
774
+ target_len_as_ref=True, # use reference answer len as target
775
+ use_running_mean=True, # use running mean embedding of all previous answers in cosine similarity calculation
776
+ allow_not_summing_factors=False, # if True sum of reward factors could be different from 1.0, it's False by default
777
+ bleu_factor=0.4, # factor for BLEU score in standard reward
778
+ cos_factor=0.5, # factor for cosine similarity score in standard reward
779
+ len_factor=0.1, # factor for length reward score in standard reward
780
+ bleu_ref_factor=0.4, # factor for reference answer score in BLEU calculation (standard mode)
781
+ bleu_saved_factor=0.6, # factor for saved data score in BLEU calculation (standard mode)
782
+ cos_ref_factor=0.35, # factor for reference answer score in cosine sim calculation (standard mode)
783
+ cos_saved_factor=0.65, # factor for saved data score in cosine sim calculation (standard mode)
784
+ multi_cos_ref_factor=0.3, # factor for reference answer in multi-step cosine sim calculation
785
+ multi_cos_saved_factor= 0.5, # factor for saved data in multi-step cosine sim calculation
786
+ multi_cos_running_mean_factor = 0.2, # factor for previous answers running mean in multi-step cosine sim calculation
787
+ neg_bleu_factor=0.45, # factor for BLEU score in negative reward
788
+ neg_cos_factor=0.45, # factor for cosine similarity score in negative reward
789
+ neg_bleu_ref_factor=0.3, # factor for reference answer score in BLEU calculation (negative mode)
790
+ neg_bleu_saved_factor=0.7, # factor for saved data score in BLEU calculation (negative mode)
791
+ neg_cos_ref_factor=0.3, # factor for reference answer score in cosine sim calculation (negative mode)
792
+ neg_cos_saved_factor=0.7, # factor for saved data score in cosine sim calculation (negative mode)
793
+ bleu_ref_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for reference answers
794
+ bleu_saved_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for saved data
795
+ tanh_reward_scale=False, # scale rewards to -1.0 to 1.0 range, instead of standard 0.0-1.0
796
+ rewards_scale=1.0, # rewards scaling factor (reward * rewards_scale)
736
797
  )
737
798
  ```
738
799
 
@@ -743,28 +804,70 @@ algorithm = PPOAlgorithm(
743
804
  PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
744
805
  )
745
806
 
746
- # 2. Create config for MRLTrainer
807
+ # 2. Create config for MRLTrainer (most of MrlConfig fields could be overwritten in each curriculum stage)
747
808
  mrl_config = MrlConfig(
748
- lr=1e-4,
749
- critic_lr=2e-4,
750
- critic_encoder_lr=1e-4,
751
- separate_memory_lr=True,
752
- memory_lr=3e-4,
753
- max_seq_len=256,
754
- critic_max_len=512,
755
- weight_decay=0.01,
756
- critic_weight_decay=0.01,
757
- update_epochs=10,
758
- pad_token_id=0,
759
- end_token_id=3,
760
- use_moe_aux_loss=True,
761
- embedding_lr=5e-6,
762
- use_memory_warmup=False,
809
+ lr=1e-4, # main LR, used for decoder layers
810
+ encoder_lr=2e-4, # encoder LR, used for encoder layers (if None, lr is used)
811
+ critic_lr=2e-4, # critic LR, used for critic value head
812
+ critic_encoder_lr=1e-4, # critic encoder LR (if not set, critic_lr is used)
813
+ separate_memory_lr=True, # use separate LR for memory attention and memory cross-attention
814
+ encoder_memory_lr=5e-4, # LR for encoder memory cross-attention (if None, memory_lr is used)
815
+ memory_lr=3e-4, # memory LR, used for decoder memory cross-attention
816
+ memory_attn_lr=5e-4, # memory attention LR (if None, memory_lr is used)
817
+ max_seq_len=256, # maximum length of single interaction
818
+ critic_max_len=512, # maximum length of critic sequence (have to be longer than actor's context)
819
+ weight_decay=0.01, # weight decay for actor AdamW optimizer
820
+ critic_weight_decay=0.01, # weight decay for critic AdamW optimizer
821
+ update_epochs=10, # inner PPO update epochs
822
+ pad_token_id=0, # tokenizer padding token id
823
+ end_token_id=3, # tokenizer EOS token id
824
+ use_moe_aux_loss=True, # add Mixture-of-Experts Router auxiliary loss to policy loss
825
+ freeze_embeddings=False, # freeze pre-trained embeddings for MRL training
826
+ embedding_lr=5e-6, # LR for embeddings, if not frozen (if None, lr is used)
827
+ use_memory_warmup=False, # memory warmup - update memory with first interaction in no grad mode, before episode, for better initialization
763
828
  )
764
829
 
765
830
  # 3. Initialize MRL Trainer
766
- trainer = MRLTrainer(actor, critic, reward_model, device, mrl_config, algorithm, use_amp=True, dtype=torch.bfloat16)
831
+ trainer = MRLTrainer(
832
+ actor, critic, reward_model, device, mrl_config, algorithm,
833
+ use_amp=True, # use autocast in MRL Training
834
+ dtype=torch.bfloat16, # data type for MRL
835
+ use_ddp=False, # use distributed training with DDP
836
+ )
767
837
 
768
838
  # 4. Train with curriculum stages config
769
839
  trainer(curriculum_stages, batch_size=batch_size)
770
840
  ```
841
+
842
+ ## Experimental attention layers
843
+ While working on reactive architectures, we also developed several new types of attention layers, some of which achieve
844
+ very promising results. Even considering that reactive models, processing single interactions, have much lower computational
845
+ requirements, we need the most efficient attention mechanisms, consistent with memory requirements. Since memory is not a
846
+ sequence but a set, spatial sparsity is probably not a good solution here, so we were looking for an efficient alternative
847
+ to Flex Attention with full access to all memory positions. New attention layers are implemented in `rxnn.experimental.attention`
848
+ module:
849
+ - **Grouped Mixture-of-Experts Attention (GMA)** - use MoE routing to dynamically select K active key/value heads for each token, instead
850
+ of using static selection in **GQA**. While it's theoretically interesting, in practice, it achieved worse results than **GQA**,
851
+ and even **MQA**, in all test, and is a lot slower because of routing overhead, so we abandoned further research. More details
852
+ in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
853
+ - **Deep Mixture-of-Experts Attention (DMA)** - extends **GMA** with the same MoE routing for query heads. Like **GMA**,
854
+ it gives even worse results, and all the computational performance benefits from the sparse query heads (like in
855
+ **SQA**) are lost by routing overhead (lack of specialized kernels for heads selection), so the further research is also
856
+ abandoned. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
857
+ - **Hierarchical MoE Attention (HMA)** - extends **DMA/GMA**, using different number of query/key/value heads for tokens with
858
+ different priority. It's only the idea and is not implemented, because of poor results of GMA/DMA. [More info](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/hierarchical_moe_attention.md)
859
+ - **Sparse Query Attention (SQA)** - the most trivial extension to GQA, reducing not only the number of key/value heads, but
860
+ also the number of query heads. It results in even 2-3x faster model (for 32k/131k tokens). **SQA** is the fastest attention
861
+ mechanism for 0-131k sequence length, for longer sequences **Flex Attention** becomes faster. That's ideal for reactive models,
862
+ that doesn't need a million token context for single interaction processing. In tested cases **SQA** models results (loss/accuracy)
863
+ were close to GQA, differences were almost unnoticeable, but it still requires more tests. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md)
864
+ - **Flex Sparse Query Attention (Flex-SQA)** - **Flex Attention** combined with **SQA** - enable handling 4-8x longer sliding
865
+ windows, in shorter time, than base **Flex**, so it should result in better results. **Flex-SQA** should be the fastest
866
+ attention mechanism for sequences longer than 131k tokens and is made for classic transformers, or potentially self-attention
867
+ in bigger reactive models. Currently, it's viable only with symmetric variants of **SQA** (same number of used query
868
+ and key/value heads), because kernels aren't compatible with GQA in sliding windows and not symmetric variants is 2x slower,
869
+ than it should be. Docs and tests in progress
870
+
871
+ ### Test usage
872
+ Experimental attention layers could be tested with `ExperimentalAttentionTransformer` model from `rxnn.experimental.models`,
873
+ Usage example could be found in our notebooks repository - [RxNN Notebooks](https://github.com/RxAI-dev/rxnn-notebooks)
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.67"
7
+ version = "0.2.68"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes