rxnn 0.2.67__tar.gz → 0.2.68__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.67 → rxnn-0.2.68}/PKG-INFO +186 -83
- {rxnn-0.2.67 → rxnn-0.2.68}/README.md +185 -82
- {rxnn-0.2.67 → rxnn-0.2.68}/pyproject.toml +1 -1
- {rxnn-0.2.67 → rxnn-0.2.68}/LICENSE +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/mrl.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.67 → rxnn-0.2.68}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.68
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -57,10 +57,10 @@ sequence, it has to process it and save it in memory, but it could be done in ba
|
|
57
57
|
|
58
58
|
## Release plan
|
59
59
|
We are working on three new reactive architectures, that progressively advance from language models to awareness models:
|
60
|
-
- Reactive Transformer
|
61
|
-
- Preactor
|
60
|
+
- **Reactive Transformer**: Reactive Language Model (RLM) with Short-Term Memory. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md)
|
61
|
+
- **Preactor**: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
|
62
62
|
single message length is limited) and the ability to learn from interactions (Live Learning)
|
63
|
-
- Reactor
|
63
|
+
- **Reactor**: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
|
64
64
|
|
65
65
|
Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
|
66
66
|
released with next versions of **RxNN** framework:
|
@@ -207,11 +207,57 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
|
|
207
207
|
return self.model(x, attention_mask=attention_mask)
|
208
208
|
```
|
209
209
|
|
210
|
+
#### RxT-Alpha
|
211
|
+
`RxTAlphaEncoder` and `RxTAlphaDecoder` are ready to use **Reactive Transformer** components, compatible with Hugging Face
|
212
|
+
Hub (the above example is based on their code), so it could be used instead of creating custom class. Example usage could
|
213
|
+
be found in [pre-training docs](#pre-training)
|
214
|
+
|
210
215
|
### Memory
|
211
|
-
The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
|
212
|
-
include **Long-Term Memory**.
|
216
|
+
The _memory_ module includes **Short-Term Memory (STM)** and layers responsible for its update. In future versions it will also
|
217
|
+
include **Long-Term Memory (LTM)**.
|
218
|
+
|
219
|
+
#### Short Term Memory
|
220
|
+
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module. As described in [Reactive Transformer research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md),
|
221
|
+
each transformer (encoder and decoder) layer has its own **STM** layer of shape `[batch_size, stm_size, embed_dim]`. Initially,
|
222
|
+
for the first training stages (pre-training and supervised fine-tuning), **STM** is in "single/no batch" mode (`batch_size = 1`),
|
223
|
+
because it's not used. For reinforcement learning stages (**MRL/RxRLHF/BRL**), we have to switch short-term memory to batch
|
224
|
+
mode, because items in batches are independent. After training, we could switch back to "single/no batch" mode. Example:
|
225
|
+
```python
|
226
|
+
from rxnn.memory.stm import ShortTermMemory
|
227
|
+
|
228
|
+
num_layers = 10
|
229
|
+
stm_size = 256
|
230
|
+
embed_dim = 128
|
231
|
+
batch_size = 32
|
213
232
|
|
214
|
-
|
233
|
+
# 1. Init STM
|
234
|
+
stm = ShortTermMemory(
|
235
|
+
num_layers, embed_dim, stm_size,
|
236
|
+
init_type='normal' # memory init type, 'normal' is default and means normal distribution with 0.0 mean and 0.02 std
|
237
|
+
)
|
238
|
+
|
239
|
+
# 2. Set "batch" mode for MRL
|
240
|
+
stm.batched_memory(
|
241
|
+
batch_size,
|
242
|
+
init_type='standard' # init type could be changed for batch mode, 'standard' is normal distribution with 0.0 mean and 1.0 std
|
243
|
+
)
|
244
|
+
|
245
|
+
# 3. Reset STM with optional init type change
|
246
|
+
stm.reset(init_type='uniform') # init type could be also 'ones' or 'zeros', but it's not recommended
|
247
|
+
|
248
|
+
# 4. Back to "single" mode for inference (optionally using mean value from batch)
|
249
|
+
stm.single_memory(
|
250
|
+
init_type='standard', # we could change init type again
|
251
|
+
use_mean_from_batch=True # use mean values from batch as new memory
|
252
|
+
)
|
253
|
+
```
|
254
|
+
|
255
|
+
> ##### Other utils
|
256
|
+
> `ShortTermMemory` could be also resized with `stm.resize(new_stm_size, init_type)` method, detached and cloned
|
257
|
+
> with `stm.clone_detach_reset()` (used in MRL), or could be made trainable (experimental option):
|
258
|
+
> - could be initialized as trainable - `stm = ShortTermMemory(num_layers, embed_dim, stm_size, is_trainable=True)`
|
259
|
+
> - could be switched to trainable - `stm.make_trainable()`
|
260
|
+
> - and switched back to buffer - `stm.freeze()`
|
215
261
|
|
216
262
|
#### Memory Attention Network
|
217
263
|
**Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
|
@@ -299,6 +345,10 @@ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0")
|
|
299
345
|
>
|
300
346
|
> **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
|
301
347
|
|
348
|
+
##### RxT-Alpha Memory Attention
|
349
|
+
`RxTAlphaMemoryAttention` is ready to use Memory Attention network for **Reactive Transformer** Proof-of-Concept, that
|
350
|
+
could be used instead of creating custom class. Example is in [Memory Reinforcement Learning docs](#memory-reinforcement-learning)
|
351
|
+
|
302
352
|
### Training
|
303
353
|
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
304
354
|
|
@@ -368,16 +418,17 @@ config = {
|
|
368
418
|
'num_layers': 10,
|
369
419
|
'vocab_size': vocab_size,
|
370
420
|
'embed_dim': embed_dim,
|
371
|
-
'att_heads': 16,
|
372
|
-
'att_groups': 8,
|
421
|
+
'att_heads': 16, # attention heads, in SQA it's used only for dimension split
|
422
|
+
'att_groups': 8, # key/value groups for GQA/SQA
|
373
423
|
'seq_len': seq_len,
|
374
424
|
'stm_size': seq_len,
|
375
|
-
'use_flash_attention': False,
|
376
|
-
'use_gated': True,
|
425
|
+
'use_flash_attention': False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend) - not recommended
|
426
|
+
'use_gated': True, # use Gated Linear Units in feed forward, True by default
|
427
|
+
'ff_activation': 'silu', # feed forward activation, 'silu' is default for SwiGLU layers
|
377
428
|
'ff_dropout': 0.1,
|
378
|
-
'self_att_type': 'sqa',
|
379
|
-
'cross_att_type': 'sqa',
|
380
|
-
'att_query_groups': 8,
|
429
|
+
'self_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
|
430
|
+
'cross_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
|
431
|
+
'att_query_groups': 8, # query groups for SQA
|
381
432
|
}
|
382
433
|
|
383
434
|
encoder_config = {
|
@@ -387,9 +438,9 @@ encoder_config = {
|
|
387
438
|
|
388
439
|
decoder_config = {
|
389
440
|
'ff_dim': 256,
|
390
|
-
'use_moe': True,
|
391
|
-
'num_experts': 20,
|
392
|
-
'moe_top_k': 4,
|
441
|
+
'use_moe': True, # use Mixture-of-Experts feed forward
|
442
|
+
'num_experts': 20, # number of experts
|
443
|
+
'moe_top_k': 4, # number of activated experts (per token)
|
393
444
|
**config
|
394
445
|
}
|
395
446
|
|
@@ -647,11 +698,11 @@ mem_attn = RxTAlphaMemoryAttention(
|
|
647
698
|
att_heads=8,
|
648
699
|
seq_len=256,
|
649
700
|
stm_size=256,
|
650
|
-
use_flash_attention=False,
|
651
|
-
norm_type='classic-rms',
|
652
|
-
att_groups=4,
|
653
|
-
att_type='sqa',
|
654
|
-
att_query_groups=4,
|
701
|
+
use_flash_attention=False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend)
|
702
|
+
norm_type='classic-rms', # memory norm type
|
703
|
+
att_groups=4, # key/value groups for SQA/GQA
|
704
|
+
att_type='sqa', # attention type, could be 'sqa', 'gqa', 'mqa' or 'mha'
|
705
|
+
att_query_groups=4, # query groups for SQA
|
655
706
|
)
|
656
707
|
|
657
708
|
# 4. Load shared embedding and memory from encoder to other models
|
@@ -677,7 +728,7 @@ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_
|
|
677
728
|
# 1. Load tokenizer
|
678
729
|
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
|
679
730
|
|
680
|
-
# 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
|
731
|
+
# 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range, and pre-tokenize it
|
681
732
|
mrl_datasets = MrlDatasets.from_hf_hub(
|
682
733
|
'ReactiveAI/TinyStories-MRL',
|
683
734
|
tokenizer,
|
@@ -693,33 +744,40 @@ mrl_datasets = MrlDatasets.from_hf_hub(
|
|
693
744
|
max_seq_len=256,
|
694
745
|
)
|
695
746
|
|
747
|
+
mrl_datasets.pre_tokenize(verbose=True, log_interval=100)
|
748
|
+
|
696
749
|
# 3. Create curriculum stages config
|
697
750
|
curriculum_stages = [CurriculumConfig(
|
698
|
-
steps=item['steps'],
|
699
|
-
epochs=10 if item['steps'] == 4 else
|
751
|
+
steps=item['steps'], # number of steps in curriculum stage
|
752
|
+
epochs=10 if item['steps'] == 4 else 5, # number of epochs in curriculum stage
|
700
753
|
dataset=item['dataset'],
|
701
754
|
eval_dataset=item['eval_dataset'],
|
702
755
|
callbacks=[
|
703
|
-
MrlPrintCallback(),
|
756
|
+
MrlPrintCallback(), # Print loss/reward callback
|
704
757
|
MrlModelSaveCallback(
|
705
|
-
'./models',
|
706
|
-
|
707
|
-
|
708
|
-
|
758
|
+
'./models',
|
759
|
+
push_to_hub=True,
|
760
|
+
hub_model_critic='Your critic model hub id',
|
761
|
+
hub_model_decoder='Your MRL decoder model hub id',
|
762
|
+
hub_model_encoder='Your MRL encoder model hub id',
|
763
|
+
hub_model_memory_attention='Your memory-attention model hub id',
|
764
|
+
private_repo=True,
|
765
|
+
hf_token='HF_TOKEN',
|
766
|
+
final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
|
709
767
|
push_checkpoint_weights=True,
|
710
|
-
)
|
768
|
+
) # MRL Model save callback - save and push to hub critic model and actor components
|
711
769
|
],
|
712
|
-
strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
|
713
|
-
unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
|
714
|
-
random_resets=item['steps'] > 4,
|
715
|
-
random_resets_from=2,
|
716
|
-
random_resets_ratio=0.4 if item['steps'] != 4 else None,
|
717
|
-
separate_memory_lr=True,
|
718
|
-
memory_lr=6e-4 if item['steps'] == 4 else
|
719
|
-
lr=3e-4 if item['steps'] == 4 else
|
720
|
-
critic_lr=4e-4 if item['steps'] == 4 else None,
|
721
|
-
critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
|
722
|
-
teacher_forcing=
|
770
|
+
strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY, # strategy for curriculum stage
|
771
|
+
unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4), # unfreeze strategy config
|
772
|
+
random_resets=item['steps'] > 4, # enable random memory resets
|
773
|
+
random_resets_from=2, # epoch when random resets starts
|
774
|
+
random_resets_ratio=0.4 if item['steps'] != 4 else None, # probability of STM reset before episode
|
775
|
+
separate_memory_lr=True, # use separate memory LR in current curriculum stage
|
776
|
+
memory_lr=6e-4 if item['steps'] == 4 else None, # memory LR for curriculum stage, if None, use global config
|
777
|
+
lr=3e-4 if item['steps'] == 4 else None, # model LR for curriculum stage, if None, use global config
|
778
|
+
critic_lr=4e-4 if item['steps'] == 4 else None, # critic (head) LR for curriculum stage, if None, use global config
|
779
|
+
critic_encoder_lr=2e-4 if item['steps'] == 4 else None, # critic (encoder) LR for curriculum stage, if None, use global config
|
780
|
+
teacher_forcing=item['steps'] <= 8, # use teacher forcing - save reference answers from dataset in memory instead of generated ones
|
723
781
|
) for item in mrl_datasets]
|
724
782
|
```
|
725
783
|
|
@@ -735,30 +793,33 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
735
793
|
reward_model = MrlRewardModel(
|
736
794
|
encoder.model.embedding,
|
737
795
|
device,
|
738
|
-
bleu_with_saved_data=True,
|
739
|
-
reward_len=True,
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
796
|
+
bleu_with_saved_data=True, # use saved data (previous or first interaction) in BLEU calculation
|
797
|
+
reward_len=True, # use length reward in calculation (answer_len / target_len)
|
798
|
+
max_rewarded_len=None, # target length awarded as 1.0
|
799
|
+
neg_reward_len=True, # negative length reward - lower reward when answer is too long (target_len / answer_len)
|
800
|
+
target_len_as_ref=True, # use reference answer len as target
|
801
|
+
use_running_mean=True, # use running mean embedding of all previous answers in cosine similarity calculation
|
802
|
+
allow_not_summing_factors=False, # if True sum of reward factors could be different from 1.0, it's False by default
|
803
|
+
bleu_factor=0.4, # factor for BLEU score in standard reward
|
804
|
+
cos_factor=0.5, # factor for cosine similarity score in standard reward
|
805
|
+
len_factor=0.1, # factor for length reward score in standard reward
|
806
|
+
bleu_ref_factor=0.4, # factor for reference answer score in BLEU calculation (standard mode)
|
807
|
+
bleu_saved_factor=0.6, # factor for saved data score in BLEU calculation (standard mode)
|
808
|
+
cos_ref_factor=0.35, # factor for reference answer score in cosine sim calculation (standard mode)
|
809
|
+
cos_saved_factor=0.65, # factor for saved data score in cosine sim calculation (standard mode)
|
810
|
+
multi_cos_ref_factor=0.3, # factor for reference answer in multi-step cosine sim calculation
|
811
|
+
multi_cos_saved_factor= 0.5, # factor for saved data in multi-step cosine sim calculation
|
812
|
+
multi_cos_running_mean_factor = 0.2, # factor for previous answers running mean in multi-step cosine sim calculation
|
813
|
+
neg_bleu_factor=0.45, # factor for BLEU score in negative reward
|
814
|
+
neg_cos_factor=0.45, # factor for cosine similarity score in negative reward
|
815
|
+
neg_bleu_ref_factor=0.3, # factor for reference answer score in BLEU calculation (negative mode)
|
816
|
+
neg_bleu_saved_factor=0.7, # factor for saved data score in BLEU calculation (negative mode)
|
817
|
+
neg_cos_ref_factor=0.3, # factor for reference answer score in cosine sim calculation (negative mode)
|
818
|
+
neg_cos_saved_factor=0.7, # factor for saved data score in cosine sim calculation (negative mode)
|
819
|
+
bleu_ref_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for reference answers
|
820
|
+
bleu_saved_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for saved data
|
821
|
+
tanh_reward_scale=False, # scale rewards to -1.0 to 1.0 range, instead of standard 0.0-1.0
|
822
|
+
rewards_scale=1.0, # rewards scaling factor (reward * rewards_scale)
|
762
823
|
)
|
763
824
|
```
|
764
825
|
|
@@ -769,32 +830,74 @@ algorithm = PPOAlgorithm(
|
|
769
830
|
PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
|
770
831
|
)
|
771
832
|
|
772
|
-
# 2. Create config for MRLTrainer
|
833
|
+
# 2. Create config for MRLTrainer (most of MrlConfig fields could be overwritten in each curriculum stage)
|
773
834
|
mrl_config = MrlConfig(
|
774
|
-
lr=1e-4,
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
835
|
+
lr=1e-4, # main LR, used for decoder layers
|
836
|
+
encoder_lr=2e-4, # encoder LR, used for encoder layers (if None, lr is used)
|
837
|
+
critic_lr=2e-4, # critic LR, used for critic value head
|
838
|
+
critic_encoder_lr=1e-4, # critic encoder LR (if not set, critic_lr is used)
|
839
|
+
separate_memory_lr=True, # use separate LR for memory attention and memory cross-attention
|
840
|
+
encoder_memory_lr=5e-4, # LR for encoder memory cross-attention (if None, memory_lr is used)
|
841
|
+
memory_lr=3e-4, # memory LR, used for decoder memory cross-attention
|
842
|
+
memory_attn_lr=5e-4, # memory attention LR (if None, memory_lr is used)
|
843
|
+
max_seq_len=256, # maximum length of single interaction
|
844
|
+
critic_max_len=512, # maximum length of critic sequence (have to be longer than actor's context)
|
845
|
+
weight_decay=0.01, # weight decay for actor AdamW optimizer
|
846
|
+
critic_weight_decay=0.01, # weight decay for critic AdamW optimizer
|
847
|
+
update_epochs=10, # inner PPO update epochs
|
848
|
+
pad_token_id=0, # tokenizer padding token id
|
849
|
+
end_token_id=3, # tokenizer EOS token id
|
850
|
+
use_moe_aux_loss=True, # add Mixture-of-Experts Router auxiliary loss to policy loss
|
851
|
+
freeze_embeddings=False, # freeze pre-trained embeddings for MRL training
|
852
|
+
embedding_lr=5e-6, # LR for embeddings, if not frozen (if None, lr is used)
|
853
|
+
use_memory_warmup=False, # memory warmup - update memory with first interaction in no grad mode, before episode, for better initialization
|
789
854
|
)
|
790
855
|
|
791
856
|
# 3. Initialize MRL Trainer
|
792
|
-
trainer = MRLTrainer(
|
857
|
+
trainer = MRLTrainer(
|
858
|
+
actor, critic, reward_model, device, mrl_config, algorithm,
|
859
|
+
use_amp=True, # use autocast in MRL Training
|
860
|
+
dtype=torch.bfloat16, # data type for MRL
|
861
|
+
use_ddp=False, # use distributed training with DDP
|
862
|
+
)
|
793
863
|
|
794
864
|
# 4. Train with curriculum stages config
|
795
865
|
trainer(curriculum_stages, batch_size=batch_size)
|
796
866
|
```
|
797
867
|
|
868
|
+
## Experimental attention layers
|
869
|
+
While working on reactive architectures, we also developed several new types of attention layers, some of which achieve
|
870
|
+
very promising results. Even considering that reactive models, processing single interactions, have much lower computational
|
871
|
+
requirements, we need the most efficient attention mechanisms, consistent with memory requirements. Since memory is not a
|
872
|
+
sequence but a set, spatial sparsity is probably not a good solution here, so we were looking for an efficient alternative
|
873
|
+
to Flex Attention with full access to all memory positions. New attention layers are implemented in `rxnn.experimental.attention`
|
874
|
+
module:
|
875
|
+
- **Grouped Mixture-of-Experts Attention (GMA)** - use MoE routing to dynamically select K active key/value heads for each token, instead
|
876
|
+
of using static selection in **GQA**. While it's theoretically interesting, in practice, it achieved worse results than **GQA**,
|
877
|
+
and even **MQA**, in all test, and is a lot slower because of routing overhead, so we abandoned further research. More details
|
878
|
+
in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
|
879
|
+
- **Deep Mixture-of-Experts Attention (DMA)** - extends **GMA** with the same MoE routing for query heads. Like **GMA**,
|
880
|
+
it gives even worse results, and all the computational performance benefits from the sparse query heads (like in
|
881
|
+
**SQA**) are lost by routing overhead (lack of specialized kernels for heads selection), so the further research is also
|
882
|
+
abandoned. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
|
883
|
+
- **Hierarchical MoE Attention (HMA)** - extends **DMA/GMA**, using different number of query/key/value heads for tokens with
|
884
|
+
different priority. It's only the idea and is not implemented, because of poor results of GMA/DMA. [More info](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/hierarchical_moe_attention.md)
|
885
|
+
- **Sparse Query Attention (SQA)** - the most trivial extension to GQA, reducing not only the number of key/value heads, but
|
886
|
+
also the number of query heads. It results in even 2-3x faster model (for 32k/131k tokens). **SQA** is the fastest attention
|
887
|
+
mechanism for 0-131k sequence length, for longer sequences **Flex Attention** becomes faster. That's ideal for reactive models,
|
888
|
+
that doesn't need a million token context for single interaction processing. In tested cases **SQA** models results (loss/accuracy)
|
889
|
+
were close to GQA, differences were almost unnoticeable, but it still requires more tests. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md)
|
890
|
+
- **Flex Sparse Query Attention (Flex-SQA)** - **Flex Attention** combined with **SQA** - enable handling 4-8x longer sliding
|
891
|
+
windows, in shorter time, than base **Flex**, so it should result in better results. **Flex-SQA** should be the fastest
|
892
|
+
attention mechanism for sequences longer than 131k tokens and is made for classic transformers, or potentially self-attention
|
893
|
+
in bigger reactive models. Currently, it's viable only with symmetric variants of **SQA** (same number of used query
|
894
|
+
and key/value heads), because kernels aren't compatible with GQA in sliding windows and not symmetric variants is 2x slower,
|
895
|
+
than it should be. Docs and tests in progress
|
896
|
+
|
897
|
+
### Test usage
|
898
|
+
Experimental attention layers could be tested with `ExperimentalAttentionTransformer` model from `rxnn.experimental.models`,
|
899
|
+
Usage example could be found in our notebooks repository - [RxNN Notebooks](https://github.com/RxAI-dev/rxnn-notebooks)
|
900
|
+
|
798
901
|
Apache License
|
799
902
|
Version 2.0, January 2004
|
800
903
|
http://www.apache.org/licenses/
|
@@ -31,10 +31,10 @@ sequence, it has to process it and save it in memory, but it could be done in ba
|
|
31
31
|
|
32
32
|
## Release plan
|
33
33
|
We are working on three new reactive architectures, that progressively advance from language models to awareness models:
|
34
|
-
- Reactive Transformer
|
35
|
-
- Preactor
|
34
|
+
- **Reactive Transformer**: Reactive Language Model (RLM) with Short-Term Memory. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md)
|
35
|
+
- **Preactor**: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
|
36
36
|
single message length is limited) and the ability to learn from interactions (Live Learning)
|
37
|
-
- Reactor
|
37
|
+
- **Reactor**: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
|
38
38
|
|
39
39
|
Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
|
40
40
|
released with next versions of **RxNN** framework:
|
@@ -181,11 +181,57 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
|
|
181
181
|
return self.model(x, attention_mask=attention_mask)
|
182
182
|
```
|
183
183
|
|
184
|
+
#### RxT-Alpha
|
185
|
+
`RxTAlphaEncoder` and `RxTAlphaDecoder` are ready to use **Reactive Transformer** components, compatible with Hugging Face
|
186
|
+
Hub (the above example is based on their code), so it could be used instead of creating custom class. Example usage could
|
187
|
+
be found in [pre-training docs](#pre-training)
|
188
|
+
|
184
189
|
### Memory
|
185
|
-
The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
|
186
|
-
include **Long-Term Memory**.
|
190
|
+
The _memory_ module includes **Short-Term Memory (STM)** and layers responsible for its update. In future versions it will also
|
191
|
+
include **Long-Term Memory (LTM)**.
|
192
|
+
|
193
|
+
#### Short Term Memory
|
194
|
+
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module. As described in [Reactive Transformer research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md),
|
195
|
+
each transformer (encoder and decoder) layer has its own **STM** layer of shape `[batch_size, stm_size, embed_dim]`. Initially,
|
196
|
+
for the first training stages (pre-training and supervised fine-tuning), **STM** is in "single/no batch" mode (`batch_size = 1`),
|
197
|
+
because it's not used. For reinforcement learning stages (**MRL/RxRLHF/BRL**), we have to switch short-term memory to batch
|
198
|
+
mode, because items in batches are independent. After training, we could switch back to "single/no batch" mode. Example:
|
199
|
+
```python
|
200
|
+
from rxnn.memory.stm import ShortTermMemory
|
201
|
+
|
202
|
+
num_layers = 10
|
203
|
+
stm_size = 256
|
204
|
+
embed_dim = 128
|
205
|
+
batch_size = 32
|
187
206
|
|
188
|
-
|
207
|
+
# 1. Init STM
|
208
|
+
stm = ShortTermMemory(
|
209
|
+
num_layers, embed_dim, stm_size,
|
210
|
+
init_type='normal' # memory init type, 'normal' is default and means normal distribution with 0.0 mean and 0.02 std
|
211
|
+
)
|
212
|
+
|
213
|
+
# 2. Set "batch" mode for MRL
|
214
|
+
stm.batched_memory(
|
215
|
+
batch_size,
|
216
|
+
init_type='standard' # init type could be changed for batch mode, 'standard' is normal distribution with 0.0 mean and 1.0 std
|
217
|
+
)
|
218
|
+
|
219
|
+
# 3. Reset STM with optional init type change
|
220
|
+
stm.reset(init_type='uniform') # init type could be also 'ones' or 'zeros', but it's not recommended
|
221
|
+
|
222
|
+
# 4. Back to "single" mode for inference (optionally using mean value from batch)
|
223
|
+
stm.single_memory(
|
224
|
+
init_type='standard', # we could change init type again
|
225
|
+
use_mean_from_batch=True # use mean values from batch as new memory
|
226
|
+
)
|
227
|
+
```
|
228
|
+
|
229
|
+
> ##### Other utils
|
230
|
+
> `ShortTermMemory` could be also resized with `stm.resize(new_stm_size, init_type)` method, detached and cloned
|
231
|
+
> with `stm.clone_detach_reset()` (used in MRL), or could be made trainable (experimental option):
|
232
|
+
> - could be initialized as trainable - `stm = ShortTermMemory(num_layers, embed_dim, stm_size, is_trainable=True)`
|
233
|
+
> - could be switched to trainable - `stm.make_trainable()`
|
234
|
+
> - and switched back to buffer - `stm.freeze()`
|
189
235
|
|
190
236
|
#### Memory Attention Network
|
191
237
|
**Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
|
@@ -273,6 +319,10 @@ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0")
|
|
273
319
|
>
|
274
320
|
> **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
|
275
321
|
|
322
|
+
##### RxT-Alpha Memory Attention
|
323
|
+
`RxTAlphaMemoryAttention` is ready to use Memory Attention network for **Reactive Transformer** Proof-of-Concept, that
|
324
|
+
could be used instead of creating custom class. Example is in [Memory Reinforcement Learning docs](#memory-reinforcement-learning)
|
325
|
+
|
276
326
|
### Training
|
277
327
|
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
278
328
|
|
@@ -342,16 +392,17 @@ config = {
|
|
342
392
|
'num_layers': 10,
|
343
393
|
'vocab_size': vocab_size,
|
344
394
|
'embed_dim': embed_dim,
|
345
|
-
'att_heads': 16,
|
346
|
-
'att_groups': 8,
|
395
|
+
'att_heads': 16, # attention heads, in SQA it's used only for dimension split
|
396
|
+
'att_groups': 8, # key/value groups for GQA/SQA
|
347
397
|
'seq_len': seq_len,
|
348
398
|
'stm_size': seq_len,
|
349
|
-
'use_flash_attention': False,
|
350
|
-
'use_gated': True,
|
399
|
+
'use_flash_attention': False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend) - not recommended
|
400
|
+
'use_gated': True, # use Gated Linear Units in feed forward, True by default
|
401
|
+
'ff_activation': 'silu', # feed forward activation, 'silu' is default for SwiGLU layers
|
351
402
|
'ff_dropout': 0.1,
|
352
|
-
'self_att_type': 'sqa',
|
353
|
-
'cross_att_type': 'sqa',
|
354
|
-
'att_query_groups': 8,
|
403
|
+
'self_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
|
404
|
+
'cross_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
|
405
|
+
'att_query_groups': 8, # query groups for SQA
|
355
406
|
}
|
356
407
|
|
357
408
|
encoder_config = {
|
@@ -361,9 +412,9 @@ encoder_config = {
|
|
361
412
|
|
362
413
|
decoder_config = {
|
363
414
|
'ff_dim': 256,
|
364
|
-
'use_moe': True,
|
365
|
-
'num_experts': 20,
|
366
|
-
'moe_top_k': 4,
|
415
|
+
'use_moe': True, # use Mixture-of-Experts feed forward
|
416
|
+
'num_experts': 20, # number of experts
|
417
|
+
'moe_top_k': 4, # number of activated experts (per token)
|
367
418
|
**config
|
368
419
|
}
|
369
420
|
|
@@ -621,11 +672,11 @@ mem_attn = RxTAlphaMemoryAttention(
|
|
621
672
|
att_heads=8,
|
622
673
|
seq_len=256,
|
623
674
|
stm_size=256,
|
624
|
-
use_flash_attention=False,
|
625
|
-
norm_type='classic-rms',
|
626
|
-
att_groups=4,
|
627
|
-
att_type='sqa',
|
628
|
-
att_query_groups=4,
|
675
|
+
use_flash_attention=False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend)
|
676
|
+
norm_type='classic-rms', # memory norm type
|
677
|
+
att_groups=4, # key/value groups for SQA/GQA
|
678
|
+
att_type='sqa', # attention type, could be 'sqa', 'gqa', 'mqa' or 'mha'
|
679
|
+
att_query_groups=4, # query groups for SQA
|
629
680
|
)
|
630
681
|
|
631
682
|
# 4. Load shared embedding and memory from encoder to other models
|
@@ -651,7 +702,7 @@ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_
|
|
651
702
|
# 1. Load tokenizer
|
652
703
|
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
|
653
704
|
|
654
|
-
# 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
|
705
|
+
# 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range, and pre-tokenize it
|
655
706
|
mrl_datasets = MrlDatasets.from_hf_hub(
|
656
707
|
'ReactiveAI/TinyStories-MRL',
|
657
708
|
tokenizer,
|
@@ -667,33 +718,40 @@ mrl_datasets = MrlDatasets.from_hf_hub(
|
|
667
718
|
max_seq_len=256,
|
668
719
|
)
|
669
720
|
|
721
|
+
mrl_datasets.pre_tokenize(verbose=True, log_interval=100)
|
722
|
+
|
670
723
|
# 3. Create curriculum stages config
|
671
724
|
curriculum_stages = [CurriculumConfig(
|
672
|
-
steps=item['steps'],
|
673
|
-
epochs=10 if item['steps'] == 4 else
|
725
|
+
steps=item['steps'], # number of steps in curriculum stage
|
726
|
+
epochs=10 if item['steps'] == 4 else 5, # number of epochs in curriculum stage
|
674
727
|
dataset=item['dataset'],
|
675
728
|
eval_dataset=item['eval_dataset'],
|
676
729
|
callbacks=[
|
677
|
-
MrlPrintCallback(),
|
730
|
+
MrlPrintCallback(), # Print loss/reward callback
|
678
731
|
MrlModelSaveCallback(
|
679
|
-
'./models',
|
680
|
-
|
681
|
-
|
682
|
-
|
732
|
+
'./models',
|
733
|
+
push_to_hub=True,
|
734
|
+
hub_model_critic='Your critic model hub id',
|
735
|
+
hub_model_decoder='Your MRL decoder model hub id',
|
736
|
+
hub_model_encoder='Your MRL encoder model hub id',
|
737
|
+
hub_model_memory_attention='Your memory-attention model hub id',
|
738
|
+
private_repo=True,
|
739
|
+
hf_token='HF_TOKEN',
|
740
|
+
final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
|
683
741
|
push_checkpoint_weights=True,
|
684
|
-
)
|
742
|
+
) # MRL Model save callback - save and push to hub critic model and actor components
|
685
743
|
],
|
686
|
-
strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
|
687
|
-
unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
|
688
|
-
random_resets=item['steps'] > 4,
|
689
|
-
random_resets_from=2,
|
690
|
-
random_resets_ratio=0.4 if item['steps'] != 4 else None,
|
691
|
-
separate_memory_lr=True,
|
692
|
-
memory_lr=6e-4 if item['steps'] == 4 else
|
693
|
-
lr=3e-4 if item['steps'] == 4 else
|
694
|
-
critic_lr=4e-4 if item['steps'] == 4 else None,
|
695
|
-
critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
|
696
|
-
teacher_forcing=
|
744
|
+
strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY, # strategy for curriculum stage
|
745
|
+
unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4), # unfreeze strategy config
|
746
|
+
random_resets=item['steps'] > 4, # enable random memory resets
|
747
|
+
random_resets_from=2, # epoch when random resets starts
|
748
|
+
random_resets_ratio=0.4 if item['steps'] != 4 else None, # probability of STM reset before episode
|
749
|
+
separate_memory_lr=True, # use separate memory LR in current curriculum stage
|
750
|
+
memory_lr=6e-4 if item['steps'] == 4 else None, # memory LR for curriculum stage, if None, use global config
|
751
|
+
lr=3e-4 if item['steps'] == 4 else None, # model LR for curriculum stage, if None, use global config
|
752
|
+
critic_lr=4e-4 if item['steps'] == 4 else None, # critic (head) LR for curriculum stage, if None, use global config
|
753
|
+
critic_encoder_lr=2e-4 if item['steps'] == 4 else None, # critic (encoder) LR for curriculum stage, if None, use global config
|
754
|
+
teacher_forcing=item['steps'] <= 8, # use teacher forcing - save reference answers from dataset in memory instead of generated ones
|
697
755
|
) for item in mrl_datasets]
|
698
756
|
```
|
699
757
|
|
@@ -709,30 +767,33 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
709
767
|
reward_model = MrlRewardModel(
|
710
768
|
encoder.model.embedding,
|
711
769
|
device,
|
712
|
-
bleu_with_saved_data=True,
|
713
|
-
reward_len=True,
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
770
|
+
bleu_with_saved_data=True, # use saved data (previous or first interaction) in BLEU calculation
|
771
|
+
reward_len=True, # use length reward in calculation (answer_len / target_len)
|
772
|
+
max_rewarded_len=None, # target length awarded as 1.0
|
773
|
+
neg_reward_len=True, # negative length reward - lower reward when answer is too long (target_len / answer_len)
|
774
|
+
target_len_as_ref=True, # use reference answer len as target
|
775
|
+
use_running_mean=True, # use running mean embedding of all previous answers in cosine similarity calculation
|
776
|
+
allow_not_summing_factors=False, # if True sum of reward factors could be different from 1.0, it's False by default
|
777
|
+
bleu_factor=0.4, # factor for BLEU score in standard reward
|
778
|
+
cos_factor=0.5, # factor for cosine similarity score in standard reward
|
779
|
+
len_factor=0.1, # factor for length reward score in standard reward
|
780
|
+
bleu_ref_factor=0.4, # factor for reference answer score in BLEU calculation (standard mode)
|
781
|
+
bleu_saved_factor=0.6, # factor for saved data score in BLEU calculation (standard mode)
|
782
|
+
cos_ref_factor=0.35, # factor for reference answer score in cosine sim calculation (standard mode)
|
783
|
+
cos_saved_factor=0.65, # factor for saved data score in cosine sim calculation (standard mode)
|
784
|
+
multi_cos_ref_factor=0.3, # factor for reference answer in multi-step cosine sim calculation
|
785
|
+
multi_cos_saved_factor= 0.5, # factor for saved data in multi-step cosine sim calculation
|
786
|
+
multi_cos_running_mean_factor = 0.2, # factor for previous answers running mean in multi-step cosine sim calculation
|
787
|
+
neg_bleu_factor=0.45, # factor for BLEU score in negative reward
|
788
|
+
neg_cos_factor=0.45, # factor for cosine similarity score in negative reward
|
789
|
+
neg_bleu_ref_factor=0.3, # factor for reference answer score in BLEU calculation (negative mode)
|
790
|
+
neg_bleu_saved_factor=0.7, # factor for saved data score in BLEU calculation (negative mode)
|
791
|
+
neg_cos_ref_factor=0.3, # factor for reference answer score in cosine sim calculation (negative mode)
|
792
|
+
neg_cos_saved_factor=0.7, # factor for saved data score in cosine sim calculation (negative mode)
|
793
|
+
bleu_ref_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for reference answers
|
794
|
+
bleu_saved_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for saved data
|
795
|
+
tanh_reward_scale=False, # scale rewards to -1.0 to 1.0 range, instead of standard 0.0-1.0
|
796
|
+
rewards_scale=1.0, # rewards scaling factor (reward * rewards_scale)
|
736
797
|
)
|
737
798
|
```
|
738
799
|
|
@@ -743,28 +804,70 @@ algorithm = PPOAlgorithm(
|
|
743
804
|
PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
|
744
805
|
)
|
745
806
|
|
746
|
-
# 2. Create config for MRLTrainer
|
807
|
+
# 2. Create config for MRLTrainer (most of MrlConfig fields could be overwritten in each curriculum stage)
|
747
808
|
mrl_config = MrlConfig(
|
748
|
-
lr=1e-4,
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
809
|
+
lr=1e-4, # main LR, used for decoder layers
|
810
|
+
encoder_lr=2e-4, # encoder LR, used for encoder layers (if None, lr is used)
|
811
|
+
critic_lr=2e-4, # critic LR, used for critic value head
|
812
|
+
critic_encoder_lr=1e-4, # critic encoder LR (if not set, critic_lr is used)
|
813
|
+
separate_memory_lr=True, # use separate LR for memory attention and memory cross-attention
|
814
|
+
encoder_memory_lr=5e-4, # LR for encoder memory cross-attention (if None, memory_lr is used)
|
815
|
+
memory_lr=3e-4, # memory LR, used for decoder memory cross-attention
|
816
|
+
memory_attn_lr=5e-4, # memory attention LR (if None, memory_lr is used)
|
817
|
+
max_seq_len=256, # maximum length of single interaction
|
818
|
+
critic_max_len=512, # maximum length of critic sequence (have to be longer than actor's context)
|
819
|
+
weight_decay=0.01, # weight decay for actor AdamW optimizer
|
820
|
+
critic_weight_decay=0.01, # weight decay for critic AdamW optimizer
|
821
|
+
update_epochs=10, # inner PPO update epochs
|
822
|
+
pad_token_id=0, # tokenizer padding token id
|
823
|
+
end_token_id=3, # tokenizer EOS token id
|
824
|
+
use_moe_aux_loss=True, # add Mixture-of-Experts Router auxiliary loss to policy loss
|
825
|
+
freeze_embeddings=False, # freeze pre-trained embeddings for MRL training
|
826
|
+
embedding_lr=5e-6, # LR for embeddings, if not frozen (if None, lr is used)
|
827
|
+
use_memory_warmup=False, # memory warmup - update memory with first interaction in no grad mode, before episode, for better initialization
|
763
828
|
)
|
764
829
|
|
765
830
|
# 3. Initialize MRL Trainer
|
766
|
-
trainer = MRLTrainer(
|
831
|
+
trainer = MRLTrainer(
|
832
|
+
actor, critic, reward_model, device, mrl_config, algorithm,
|
833
|
+
use_amp=True, # use autocast in MRL Training
|
834
|
+
dtype=torch.bfloat16, # data type for MRL
|
835
|
+
use_ddp=False, # use distributed training with DDP
|
836
|
+
)
|
767
837
|
|
768
838
|
# 4. Train with curriculum stages config
|
769
839
|
trainer(curriculum_stages, batch_size=batch_size)
|
770
840
|
```
|
841
|
+
|
842
|
+
## Experimental attention layers
|
843
|
+
While working on reactive architectures, we also developed several new types of attention layers, some of which achieve
|
844
|
+
very promising results. Even considering that reactive models, processing single interactions, have much lower computational
|
845
|
+
requirements, we need the most efficient attention mechanisms, consistent with memory requirements. Since memory is not a
|
846
|
+
sequence but a set, spatial sparsity is probably not a good solution here, so we were looking for an efficient alternative
|
847
|
+
to Flex Attention with full access to all memory positions. New attention layers are implemented in `rxnn.experimental.attention`
|
848
|
+
module:
|
849
|
+
- **Grouped Mixture-of-Experts Attention (GMA)** - use MoE routing to dynamically select K active key/value heads for each token, instead
|
850
|
+
of using static selection in **GQA**. While it's theoretically interesting, in practice, it achieved worse results than **GQA**,
|
851
|
+
and even **MQA**, in all test, and is a lot slower because of routing overhead, so we abandoned further research. More details
|
852
|
+
in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
|
853
|
+
- **Deep Mixture-of-Experts Attention (DMA)** - extends **GMA** with the same MoE routing for query heads. Like **GMA**,
|
854
|
+
it gives even worse results, and all the computational performance benefits from the sparse query heads (like in
|
855
|
+
**SQA**) are lost by routing overhead (lack of specialized kernels for heads selection), so the further research is also
|
856
|
+
abandoned. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
|
857
|
+
- **Hierarchical MoE Attention (HMA)** - extends **DMA/GMA**, using different number of query/key/value heads for tokens with
|
858
|
+
different priority. It's only the idea and is not implemented, because of poor results of GMA/DMA. [More info](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/hierarchical_moe_attention.md)
|
859
|
+
- **Sparse Query Attention (SQA)** - the most trivial extension to GQA, reducing not only the number of key/value heads, but
|
860
|
+
also the number of query heads. It results in even 2-3x faster model (for 32k/131k tokens). **SQA** is the fastest attention
|
861
|
+
mechanism for 0-131k sequence length, for longer sequences **Flex Attention** becomes faster. That's ideal for reactive models,
|
862
|
+
that doesn't need a million token context for single interaction processing. In tested cases **SQA** models results (loss/accuracy)
|
863
|
+
were close to GQA, differences were almost unnoticeable, but it still requires more tests. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md)
|
864
|
+
- **Flex Sparse Query Attention (Flex-SQA)** - **Flex Attention** combined with **SQA** - enable handling 4-8x longer sliding
|
865
|
+
windows, in shorter time, than base **Flex**, so it should result in better results. **Flex-SQA** should be the fastest
|
866
|
+
attention mechanism for sequences longer than 131k tokens and is made for classic transformers, or potentially self-attention
|
867
|
+
in bigger reactive models. Currently, it's viable only with symmetric variants of **SQA** (same number of used query
|
868
|
+
and key/value heads), because kernels aren't compatible with GQA in sliding windows and not symmetric variants is 2x slower,
|
869
|
+
than it should be. Docs and tests in progress
|
870
|
+
|
871
|
+
### Test usage
|
872
|
+
Experimental attention layers could be tested with `ExperimentalAttentionTransformer` model from `rxnn.experimental.models`,
|
873
|
+
Usage example could be found in our notebooks repository - [RxNN Notebooks](https://github.com/RxAI-dev/rxnn-notebooks)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|