rxnn 0.2.65__py3-none-any.whl → 0.2.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/utils.py CHANGED
@@ -151,5 +151,9 @@ def get_gradient_norms(model: nn.Module):
151
151
  param_norm = p.grad.data.norm(2)
152
152
  total_norm += param_norm.item() ** 2
153
153
  total_norm = total_norm ** 0.5
154
- mean_norm = total_norm / len(grad_params)
154
+ params_len = len(grad_params)
155
+ if params_len != 0:
156
+ mean_norm = total_norm / params_len
157
+ else:
158
+ mean_norm = 0.0
155
159
  return total_norm, mean_norm
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.65
3
+ Version: 0.2.67
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -113,17 +113,17 @@ and [TensorBoard](https://github.com/tensorflow/tensorboard).
113
113
  > if it's set to `False`, when `flash-attn` library is installed, **PyTorch** will try to use it implicitly through _SDPA backend_. It's better to set it
114
114
  > to `False` and use automatically, because of better compatibility. Explicit options could be used for research
115
115
 
116
- ### Modules
116
+ ## Modules
117
117
  **RxNN** framework has multiple modules with models, layers, training and inference tools, made for complete development
118
118
  of _reactive models_, and could be also used for regular **Transformers**.
119
119
 
120
- #### Transformers
120
+ ### Transformers
121
121
  Transformers module includes classes for models and layers. It includes **Reactive Transformers** as well as **Classic Transformers**
122
122
 
123
123
  Submodules:
124
124
  - `rxnn.transformers.attention` - basic, most common attention layers - `MultiHeadAttention`, `GroupedQueryAttention` and `MultiQueryAttention`
125
125
  - additional attention layers, especially `SparseQueryAttention` could be found in `rxnn.experimental.attention` module
126
- - `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.2.x version
126
+ - `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.3.x version
127
127
  - `rxnn.transformers.positional` - positional encoding layers - `RotaryPositionalEmbedding` and legacy ones - `AbsolutePositionalEmbedding`/`RelativePositionalEmbedding`
128
128
  - `rxnn.transformers.ff` - dense feed forward layers, including gated layers (_SwiGLU_, etc.) - `FeedForward` & `GatedFeedForward` (recommended)
129
129
  - `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
@@ -133,7 +133,6 @@ Submodules:
133
133
 
134
134
  In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
135
135
  to be compatible with HuggingFace **JSON** config. In example:
136
-
137
136
  ```python
138
137
  from typing import TypedDict
139
138
  import torch
@@ -193,11 +192,11 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
193
192
  memory_cross_attention=GroupedQueryAttention(
194
193
  config['embed_dim'],
195
194
  config['att_heads'],
196
- config['att_groups'],
195
+ config['cross_att_groups'] if 'cross_att_groups' in config else config['att_groups'],
197
196
  rope=rope,
198
197
  dropout=0.1,
199
198
  max_seq_len=config['seq_len'],
200
- is_causal=True,
199
+ is_causal=False,
201
200
  rope_only_for_query=True
202
201
  ),
203
202
  ) for _ in range(config['num_layers'])
@@ -208,15 +207,99 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
208
207
  return self.model(x, attention_mask=attention_mask)
209
208
  ```
210
209
 
211
- #### Memory
210
+ ### Memory
212
211
  The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
213
212
  include **Long-Term Memory**.
214
213
 
215
214
  The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
216
215
 
217
- > 0.2.x Memory modules docs in progress - will be released soon
216
+ #### Memory Attention Network
217
+ **Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
218
+ and residual connection (with optional gated residual). **Memory Attention Network** should have the same number of layers
219
+ as other components (encoder & decoder). It takes the results from each encoder layer (hidden states), and combine them
220
+ with actual memory state.
221
+
222
+ You can create your own **Memory Attention Network**, integrated with **HuggingFace Hub**, same way as reactive transformers:
223
+ ```python
224
+ from typing import TypedDict
225
+ import torch
226
+ import torch.nn as nn
227
+ from huggingface_hub import PyTorchModelHubMixin
228
+ from rxnn.transformers.attention import GroupedQueryAttention
229
+ from rxnn.transformers.positional import RotaryPositionalEmbedding
230
+ from rxnn.memory.stm import ShortTermMemory
231
+ from rxnn.memory.attention import StmMemoryAttention
232
+
233
+ class YourMemoryAttentionConfig(TypedDict):
234
+ num_layers: int
235
+ vocab_size: int
236
+ embed_dim: int
237
+ ff_dim: int
238
+ att_heads: int
239
+ seq_len: int
240
+ stm_size: int
241
+ att_groups: int
242
+
243
+ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
244
+ """RxT-Alpha (Reactive Transformer) memory attention model"""
245
+
246
+ def __init__(
247
+ self,
248
+ config: YourMemoryAttentionConfig,
249
+ **kwargs,
250
+ ):
251
+ super(YourMemoryAttention, self).__init__(**kwargs)
252
+
253
+ rope = RotaryPositionalEmbedding(config['embed_dim'] // config['att_heads'], config['seq_len'])
254
+ # This separately initialized STM will be replaced by shared instance with `load_shared_memory` call
255
+ stm = ShortTermMemory(config['num_layers'], config['embed_dim'], config['stm_size'])
256
+
257
+ self.model = StmMemoryAttention(
258
+ stm,
259
+ attention_layers=nn.ModuleList([
260
+ GroupedQueryAttention(
261
+ config['embed_dim'],
262
+ config['att_heads'],
263
+ config['att_groups'],
264
+ rope=rope,
265
+ dropout=0.1,
266
+ is_causal=False,
267
+ rope_only_for_keys=True
268
+ ) for _ in range(config['num_layers'])
269
+ ]),
270
+ memory_norm_layers=nn.ModuleList([
271
+ nn.RMSNorm(config['embed_dim']) for _ in range(config['num_layers'])
272
+ ]),
273
+ use_gated_residual=False, # memory attention residual gate
274
+ per_slot_gate=False, # gate per memory slot, otherwise it's per layer
275
+ init_gate=None, # initial value for gate weights
276
+ use_dynamic_gate=False, # dynamic gate calculated from weights and memory state, otherwise it's calculated only from weights
277
+ use_tanh_gate=False, # use tanh gate, otherwise it's sigmoid
278
+ )
279
+
280
+ def load_shared_memory(self, stm: ShortTermMemory):
281
+ self.model.stm = stm
282
+
283
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
284
+ return self.model(x, attention_mask=attention_mask)
285
+ ```
286
+
287
+ > #### Gated residual
288
+ > Optional gated residual could be used to improve Memory Attention expressiveness. It's using gate (sigmoid or tanh)
289
+ > with trainable weights, to decide how much information from old and new updated memory state should be stored. Depending
290
+ > on params weights are declared per layer or per memory slot (more expressive). It could work in two modes, that could
291
+ > be switched, because they are using the same weights shape:
292
+ > - static - gate values calculated only from weights (`gate = torch.sigmoid(weights)`) - enable explicit control with `init_gate` param
293
+ > - dynamic - gate values calculated from weights and updated memory state (`gate = torch.sigmoid(weights * (new_layer_stm + layer_stm).mean(dim=-1, keepdim=True))`)
294
+ >
295
+ > Depending on `use_tanh_gate` param, final gated residual connection is calculated with different formulas:
296
+ > - sigmoid gate - `stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm`
297
+ > - tanh gate - `stm[i] = (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm`
298
+ > - tanh gate preserve residual connection scale, while sigmoid gate result is equivalent to `(new_layer_stm + layer_stm) / 2`
299
+ >
300
+ > **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
218
301
 
219
- #### Training
302
+ ### Training
220
303
  Training module includes **Trainers** for different training stages of reactive models and shared training utils.
221
304
 
222
305
  Submodules:
@@ -233,6 +316,9 @@ Submodules:
233
316
  returns new subset and modifying existing one - i.e. `valid_dataset = train_dataset.get_subset(0.1)`
234
317
  - for concatenated datasets, validation/test split could be created with `concat_from_hf_hub_with_subset` - it cuts the
235
318
  same percentage of each loaded dataset
319
+ - each dataset has `pre_tokenize` method, to tokenize all items before the training (otherwise they are tokenized on
320
+ dynamically on item access). It's recommended for smaller datasets (fine-tuning, MRL, etc.) and not recommended for
321
+ very big datasets (pre-training), because it's using a lot of RAM (CPU)
236
322
  - `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
237
323
  - `rxnn.training.scheduler` includes learning rate scheduler for training
238
324
  - `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
@@ -240,9 +326,474 @@ Submodules:
240
326
  - `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
241
327
  - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
242
328
 
243
- ##### Base Model Learning
244
- Docs in progress
329
+ #### Base Model Learning
330
+ **Base Model Learning (BML)** module is made for both pre-training and fine-tuning base models, that will be used as components
331
+ in reactive models. Generally the only two differences between pre-training and supervised fine-tuning are different dataset
332
+ classes and trainer/callbacks hyperparams config.
333
+
334
+ Reactive models are based on transformer decoder and encoder, with shared embeddings and memory layers - it require special
335
+ handling in first training stages:
336
+ - layers connected with memory - **Memory Cross-Attention** are frozen during pre-training and fine-tuning, and they are
337
+ skipped by residual connections
338
+ - as encoder is able to learn little better embeddings, because of bidirectional modelling, it's pre-trained first, then
339
+ decoder is trained with frozen embeddings from encoder
340
+ - in **Reactive Transformer** fine-tuning, both encoder and decoder are fit to interaction format (single query and answer), the
341
+ training order is the same as for pre-training
342
+ - in **Preactor** architecture there are 2 encoders and single decoder. Encoders are fine-tuned from single pre-trained
343
+ encoder - first one is processing only queries and second one only the answers. More info soon
344
+ - in **Reactor** architecture there are 2 encoders and 2 decoders. Both encoders and decoders are fine-tuned from single
345
+ pre-trained encoder and decoder. Each component is fine-tuned to their specific task. More info soon
346
+
347
+ ##### Pre-training
348
+ We have to start with importing required modules/libraries, initializing the models and loading the tokenized - I will
349
+ use _RxT-Alpha-Micro-Plus_ models as example:
350
+ ```python
351
+ import torch
352
+ from rxnn.rxt.models import RxTAlphaDecoder, RxTAlphaEncoder
353
+ from rxnn.training.dataset import AutoregressiveLMDataset, MaskedLMDataset
354
+ from rxnn.training.bml import AutoregressiveTrainer, MLMTrainer
355
+ from rxnn.training.models import MLMHead, MLMTrainingModel
356
+ from rxnn.training.scheduler import get_transformer_lr_scheduler, calculate_steps
357
+ from rxnn.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback, JointModelSaveCallback
358
+ from rxnn.training.tokenizer import load_tokenizer_from_hf_hub
359
+ from rxnn.utils import set_random_seed, cache_clean
360
+
361
+ embed_dim = 128
362
+ vocab_size = 7_500
363
+ seq_len = 256
364
+
365
+ set_random_seed(42)
366
+
367
+ config = {
368
+ 'num_layers': 10,
369
+ 'vocab_size': vocab_size,
370
+ 'embed_dim': embed_dim,
371
+ 'att_heads': 16,
372
+ 'att_groups': 8,
373
+ 'seq_len': seq_len,
374
+ 'stm_size': seq_len,
375
+ 'use_flash_attention': False,
376
+ 'use_gated': True,
377
+ 'ff_dropout': 0.1,
378
+ 'self_att_type': 'sqa',
379
+ 'cross_att_type': 'sqa',
380
+ 'att_query_groups': 8,
381
+ }
382
+
383
+ encoder_config = {
384
+ 'ff_dim': 384,
385
+ **config
386
+ }
387
+
388
+ decoder_config = {
389
+ 'ff_dim': 256,
390
+ 'use_moe': True,
391
+ 'num_experts': 20,
392
+ 'moe_top_k': 4,
393
+ **config
394
+ }
395
+
396
+ encoder = RxTAlphaEncoder(**encoder_config)
397
+ decoder = RxTAlphaDecoder(**decoder_config)
398
+ head = MLMHead(embed_dim, vocab_size)
399
+
400
+ # Tokenizer is the same for encoder and decoder
401
+ tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder', token='HF_TOKEN')
402
+ ```
403
+ Then, we have to load MLM datasets, set callbacks and run encoder training:
404
+ ```python
405
+ # 1. Load datasets
406
+ load_kwargs = {
407
+ 'trust_remote_code': True
408
+ }
409
+
410
+ train_dataset = MaskedLMDataset.from_hf_hub('roneneldan/TinyStories', load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
411
+ valid_dataset = MaskedLMDataset.from_hf_hub('roneneldan/TinyStories', split="validation", load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
412
+
413
+ # 2. Select device
414
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
415
+
416
+ # 3. Clean GPU cache (optional)
417
+ cache_clean()
418
+
419
+ # 4. Set training config variables
420
+ batch_size = 256
421
+ epochs = 8
422
+ gradient_acc_steps = 1
423
+ peak_lr = 1e-3 * gradient_acc_steps
424
+
425
+ # 5. Get number of steps for scheduler
426
+ steps_config = calculate_steps(len(train_dataset), epochs, batch_size, warmup_ratio=0.05, verbose=True)
427
+ steps_per_epoch, total_steps, warmup_steps = steps_config['epoch'], steps_config['total'], steps_config['warmup']
428
+
429
+ # 6. Freeze memory cross-attention layers
430
+ encoder.freeze_memory()
431
+
432
+ # 7. Select directory for TensorBoard logs
433
+ logs_dir = './micro/tensorboard_logs/encoder-plus-sft'
434
+
435
+ # 8. Basic callbacks - print loss, accuracy and number of processed tokens
436
+ print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
437
+ count_cb = TokenCounterCallback(3_000_000_000)
438
+ acc_cb = PrintAccuracyCallback()
439
+
440
+ # 9. Joint model save callback - used to save encoder and MLM head, and push them to HuggingFace Hub
441
+ save_cb = JointModelSaveCallback(
442
+ './micro/encoder-plus-sft',
443
+ push_to_hub=True,
444
+ hub_model_decoder=None,
445
+ hub_model_encoder='Your encoder model id',
446
+ hub_model_head='Your mlm model id',
447
+ push_checkpoint_weights=True, # push epoch checkpoints to hub
448
+ final_commit_message='Final commit message',
449
+ private_repo=False, # use HF private repository
450
+ save_checkpoint_after_n_batches=1000, # save model after N batches in epoch (batch checkpoint)
451
+ push_batch_checkpoint=True, # push batch checkpoints to HF Hub
452
+ mlm_mode=True, # use MLM mode
453
+ hf_token='HF_TOKEN',
454
+ use_ddp=False, # use distributed training mode
455
+ )
456
+
457
+ # 10. Init training model - encoder + head
458
+ model = MLMTrainingModel(encoder, head)
459
+
460
+ # 11. Init MLM Trainer
461
+ trainer = MLMTrainer(
462
+ model,
463
+ device,
464
+ dataset=train_dataset,
465
+ validation_dataset=valid_dataset,
466
+ vocab_size=vocab_size,
467
+ callbacks=[print_cb, acc_cb, count_cb, save_cb],
468
+ use_amp=True, # use autocast
469
+ dtype=torch.bfloat16, # data type for training
470
+ log_dir=logs_dir,
471
+ use_ddp=False, # use distributed training mode
472
+ )
473
+
474
+ # 12. Init optimizer and cosine annealing scheduler
475
+ optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.02)
476
+ scheduler = get_transformer_lr_scheduler(
477
+ optimizer,
478
+ warmup_steps=warmup_steps,
479
+ num_training_steps=total_steps
480
+ )
481
+
482
+ # 13. Run the training for the selected number of epochs
483
+ trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
484
+ ```
485
+ After the encoder's training, we have to train decoder:
486
+ ```python
487
+ # 1. Load datasets
488
+ load_kwargs = {
489
+ 'trust_remote_code': True
490
+ }
491
+
492
+ train_dataset = AutoregressiveLMDataset.from_hf_hub('roneneldan/TinyStories', load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
493
+ valid_dataset = AutoregressiveLMDataset.from_hf_hub('roneneldan/TinyStories', split="validation", load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
494
+
495
+ # 2. Load shared embedding and memory, then freeze embedding and memory cross-attention
496
+ decoder.load_shared_embedding(encoder.model.embedding)
497
+ decoder.load_shared_memory(encoder.model.stm)
498
+
499
+ decoder.model.embedding.requires_grad_(False)
500
+ decoder.freeze_memory()
501
+
502
+ # 3. Clean GPU cache (optional)
503
+ cache_clean()
504
+
505
+ # 4. Set training config variables
506
+ batch_size = 256
507
+ epochs = 8
508
+ gradient_acc_steps = 1
509
+ peak_lr = 1e-3 * gradient_acc_steps
510
+
511
+ # 5. Get number of steps for scheduler
512
+ steps_config = calculate_steps(len(train_dataset), epochs, batch_size, warmup_ratio=0.05, verbose=True)
513
+ steps_per_epoch, total_steps, warmup_steps = steps_config['epoch'], steps_config['total'], steps_config['warmup']
514
+
515
+ # 6. Select directory for TensorBoard logs
516
+ logs_dir = './micro/tensorboard_logs/decoder-plus-sft'
517
+
518
+ # 7. Basic callbacks - print loss, accuracy and number of processed tokens
519
+ print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
520
+ count_cb = TokenCounterCallback(5_000_000_000)
521
+ acc_cb = PrintAccuracyCallback()
522
+
523
+ # 8. Model save callback - used to save decoder and push it to HuggingFace Hub
524
+ save_cb = ModelSaveCallback(
525
+ './micro/decoder-plus-sft',
526
+ push_to_hub=True,
527
+ hub_model_id='Your decoder model id',
528
+ push_checkpoint_weights=True, # push epoch checkpoints to hub
529
+ final_commit_message='Final commit message',
530
+ private_repo=False, # use HF private repository
531
+ save_checkpoint_after_n_batches=1000, # save model after N batches in epoch (batch checkpoint)
532
+ push_batch_checkpoint=True, # push batch checkpoints to HF Hub
533
+ hf_token='HF_TOKEN',
534
+ use_ddp=False, # use distributed training mode
535
+ )
536
+
537
+ # 9. Init Autoregressive Trainer
538
+ trainer = AutoregressiveTrainer(
539
+ decoder,
540
+ device,
541
+ dataset=train_dataset,
542
+ validation_dataset=valid_dataset,
543
+ vocab_size=vocab_size,
544
+ callbacks=[print_cb, acc_cb, count_cb, save_cb],
545
+ use_amp=True,
546
+ dtype=torch.bfloat16,
547
+ log_dir=logs_dir,
548
+ use_moe_aux_loss=True, # Add MoE Router auxiliary loss to main loss
549
+ moe_aux_loss_scale=0.02, # MoE Router aux loss scale
550
+ use_ddp=False, # use distributed training mode
551
+ )
552
+
553
+ # 10. Init optimizer and cosine annealing scheduler
554
+ optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.02)
555
+ scheduler = get_transformer_lr_scheduler(
556
+ optimizer,
557
+ warmup_steps=warmup_steps,
558
+ num_training_steps=total_steps
559
+ )
560
+
561
+ # 11. Run the training for the selected number of epochs
562
+ trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
563
+ ```
564
+
565
+ ##### Fine-tuning
566
+ For _**Interaction Supervised Fine-Tuning**_, the code is almost the same as for pre-training, with some small changes.
567
+
568
+ First, we have to load pre-trained models, instead of initializing them with configs:
569
+ ```python
570
+ encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder', token='HF_TOKEN')
571
+ decoder = RxTAlphaDecoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
572
+ head = MLMHead.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-MLM', token='HF_TOKEN')
573
+ ```
574
+
575
+ Then, we have to change the datasets loading part. For encoder:
576
+ ```python
577
+ # 1. Load datasets
578
+ train_dataset = EncoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', tokenizer=tokenizer, max_seq_len=seq_len)
579
+ valid_dataset = EncoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', split="validation", tokenizer=tokenizer, max_seq_len=seq_len)
580
+
581
+ # 2. Pre-tokenize dataset with verbose logging (optional)
582
+ train_dataset.pre_tokenize(verbose=True, log_interval=5000)
583
+ valid_dataset.pre_tokenize(verbose=True, log_interval=1000)
584
+ ```
585
+ And the same for decoder:
586
+ ```python
587
+ # 1. Load datasets
588
+ train_dataset = DecoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', tokenizer=tokenizer, max_seq_len=seq_len)
589
+ valid_dataset = DecoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', split="validation", tokenizer=tokenizer, max_seq_len=seq_len)
590
+
591
+ # 2. Pre-tokenize dataset with verbose logging (optional)
592
+ train_dataset.pre_tokenize(verbose=True, log_interval=5000)
593
+ valid_dataset.pre_tokenize(verbose=True, log_interval=1000)
594
+ ```
595
+
596
+ We could also add early stoppage callback:
597
+ ```python
598
+ from rxnn.training.callbacks import EarlyStoppageCallback
245
599
 
600
+ stop_cb = EarlyStoppageCallback(num_plateau_epochs=5)
601
+ ```
602
+
603
+ Additionally, in fine-tuning we will rather use different config for number of epochs, steps, learning rate, etc.
604
+
605
+ > #### Classic Transformer Training
606
+ > The same code could be used also to train classic decoder-only or encoder-only transformers, the only difference is
607
+ > that they don't require memory cross-attention freezing.
608
+
609
+ ##### Joint Training
610
+ There are also `JointLMDataset` and `JointLMTrainer` classes to train encoder and decoder at once. In that case, embeddings
611
+ are updated from both encoder and decoder optimization. However, I noticed some issues with balancing training in that mode,
612
+ so it's **not recommended** now, until it will be tested and fixed
613
+
614
+ #### Memory Reinforcement Learning
615
+ **Memory Reinforcement Learning (MRL)** is the most important training stage for reactive model's **Attention-Based Memory System**.
616
+ In this stage we are training model to remember information between multiple interactions, with different curriculum stage
617
+ configs. Theoretical foundations are described in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/mrl.md).
618
+
619
+ > **MRL** algorithm is currently in tests and still a lot of things could be changed!
620
+
621
+ In practice, algorithm has over 50 hyperparams, so it require careful handling. We start from importing modules, loading
622
+ pre-trained models from SFT stage, initializing new Memory Attention, and actor and critic models:
623
+ ```python
624
+ import torch
625
+ from rxnn.rxt.models import RxTAlphaDecoder, RxTAlphaEncoder, RxTAlphaMemoryAttention
626
+ from rxnn.training.tokenizer import load_tokenizer_from_hf_hub
627
+ from rxnn.training.dataset import MrlDatasets
628
+ from rxnn.training.models import MrlActorModel, MrlCriticModel
629
+ from rxnn.training.reward import MrlRewardModel
630
+ from rxnn.training.mrl import MRLTrainer, CurriculumConfig, MrlStrategy, MrlConfig
631
+ from rxnn.training.rl import PPOAlgorithm, PPOConfig
632
+ from rxnn.training.callbacks import MrlPrintCallback, MrlEarlyStoppageCallback, MrlModelSaveCallback, MrlGeneratedTokensCallback
633
+ from rxnn.utils import set_random_seed
634
+
635
+ # 1. Set random seed, batch size and embed dim
636
+ set_random_seed(42)
637
+ batch_size = 64
638
+ embed_dim = 128
639
+
640
+ # 2. Get pre-trained microscale PoC models
641
+ decoder = RxTAlphaDecoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder-SFT', token='HF_TOKEN')
642
+ encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder-SFT', token='HF_TOKEN')
643
+ # 3. Init Memory Attention Network
644
+ mem_attn = RxTAlphaMemoryAttention(
645
+ num_layers=10,
646
+ embed_dim=embed_dim,
647
+ att_heads=8,
648
+ seq_len=256,
649
+ stm_size=256,
650
+ use_flash_attention=False,
651
+ norm_type='classic-rms',
652
+ att_groups=4,
653
+ att_type='sqa',
654
+ att_query_groups=4,
655
+ )
656
+
657
+ # 4. Load shared embedding and memory from encoder to other models
658
+ decoder.load_shared_embedding(encoder.model.embedding)
659
+ encoder.model.stm.batched_memory(batch_size=batch_size, init_type='standard')
660
+ decoder.load_shared_memory(encoder.model.stm)
661
+ mem_attn.load_shared_memory(encoder.model.stm)
662
+
663
+ # 5. Init Actor model
664
+ actor = MrlActorModel(encoder, decoder, mem_attn)
665
+
666
+ # 6. Get pre-trained encoder, extend its context size, freeze memory and use as a body for Critic model
667
+ critic_encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder-SFT', token='HF_TOKEN')
668
+
669
+ critic_encoder.update_max_len(512)
670
+ critic_encoder.freeze_memory()
671
+ # 7. Init Critic model
672
+ critic = MrlCriticModel(critic_encoder, embed_dim)
673
+ ```
674
+
675
+ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_:
676
+ ```python
677
+ # 1. Load tokenizer
678
+ tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
679
+
680
+ # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
681
+ mrl_datasets = MrlDatasets.from_hf_hub(
682
+ 'ReactiveAI/TinyStories-MRL',
683
+ tokenizer,
684
+ mrl_curriculum_steps=[
685
+ { 'subset_name': 'steps-4', 'steps': 4, 'is_long_range': False },
686
+ { 'subset_name': 'steps-6', 'steps': 6, 'is_long_range': False },
687
+ { 'subset_name': 'steps-8', 'steps': 8, 'is_long_range': False },
688
+ { 'subset_name': 'steps-8-lr', 'steps': 8, 'is_long_range': True },
689
+ { 'subset_name': 'steps-12', 'steps': 12, 'is_long_range': True },
690
+ { 'subset_name': 'steps-16', 'steps': 16, 'is_long_range': True },
691
+ ],
692
+ eval_split='validation',
693
+ max_seq_len=256,
694
+ )
695
+
696
+ # 3. Create curriculum stages config
697
+ curriculum_stages = [CurriculumConfig(
698
+ steps=item['steps'],
699
+ epochs=10 if item['steps'] == 4 else 8 if item['steps'] == 8 and item['is_long_range'] else 5,
700
+ dataset=item['dataset'],
701
+ eval_dataset=item['eval_dataset'],
702
+ callbacks=[
703
+ MrlPrintCallback(),
704
+ MrlModelSaveCallback(
705
+ './models', push_to_hub=True, hub_model_critic='ReactiveAI/RxT-Alpha-Micro-Critic-MRL',
706
+ hub_model_decoder='ReactiveAI/RxT-Alpha-Micro-Decoder-MRL', hub_model_encoder='ReactiveAI/RxT-Alpha-Micro-Encoder-MRL',
707
+ hub_model_memory_attention='ReactiveAI/RxT-Alpha-Micro-MemAtt-MRL', private_repo=True,
708
+ hf_token='HF_TOKEN', final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
709
+ push_checkpoint_weights=True,
710
+ )
711
+ ],
712
+ strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
713
+ unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
714
+ random_resets=item['steps'] > 4,
715
+ random_resets_from=2,
716
+ random_resets_ratio=0.4 if item['steps'] != 4 else None,
717
+ separate_memory_lr=True,
718
+ memory_lr=6e-4 if item['steps'] == 4 else 4e-4 if item['steps'] == 8 and item['is_long_range'] else None,
719
+ lr=3e-4 if item['steps'] == 4 else 2e-4 if item['steps'] == 8 and item['is_long_range'] else None,
720
+ critic_lr=4e-4 if item['steps'] == 4 else None,
721
+ critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
722
+ teacher_forcing=True if item['steps'] <= 8 else False,
723
+ ) for item in mrl_datasets]
724
+ ```
725
+
726
+ After that, we have to configure reward model. It's based on BLEU scores and cosine similarity between generated answers
727
+ and saved data from previous steps and reference answers from dataset. Cosine similarity is also calculated from running
728
+ mean embedding of previous steps. Reward model also includes optional length reward. It's config includes a lot of option
729
+ to set different factors for different reward parts.
730
+ ```python
731
+ # 1. Init GPU device
732
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
733
+
734
+ # 2. Create reward model
735
+ reward_model = MrlRewardModel(
736
+ encoder.model.embedding,
737
+ device,
738
+ bleu_with_saved_data=True,
739
+ reward_len=True,
740
+ neg_reward_len=True,
741
+ target_len_as_ref=True,
742
+ bleu_factor=0.4,
743
+ cos_factor=0.5,
744
+ len_factor=0.1,
745
+ bleu_ref_factor=0.4,
746
+ bleu_saved_factor=0.6,
747
+ cos_ref_factor=0.35,
748
+ cos_saved_factor=0.65,
749
+ neg_bleu_factor=0.45,
750
+ neg_cos_factor=0.45,
751
+ neg_cos_ref_factor=0.3,
752
+ neg_cos_saved_factor=0.7,
753
+ neg_bleu_ref_factor=0.3,
754
+ neg_bleu_saved_factor=0.7,
755
+ multi_cos_ref_factor=0.3,
756
+ multi_cos_saved_factor= 0.5,
757
+ multi_cos_running_mean_factor = 0.2,
758
+ bleu_ref_weights=(0.2, 0.2, 0.3, 0.3),
759
+ bleu_saved_weights=(0.2, 0.2, 0.3, 0.3),
760
+ tanh_reward_scale=False,
761
+ rewards_scale=1.0,
762
+ )
763
+ ```
764
+
765
+ And finally, we could create the MRL Trainer with RL algorithm (currently only PPO available) and start the training:
766
+ ```python
767
+ # 1. Init PPO Algorithm
768
+ algorithm = PPOAlgorithm(
769
+ PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
770
+ )
771
+
772
+ # 2. Create config for MRLTrainer
773
+ mrl_config = MrlConfig(
774
+ lr=1e-4,
775
+ critic_lr=2e-4,
776
+ critic_encoder_lr=1e-4,
777
+ separate_memory_lr=True,
778
+ memory_lr=3e-4,
779
+ max_seq_len=256,
780
+ critic_max_len=512,
781
+ weight_decay=0.01,
782
+ critic_weight_decay=0.01,
783
+ update_epochs=10,
784
+ pad_token_id=0,
785
+ end_token_id=3,
786
+ use_moe_aux_loss=True,
787
+ embedding_lr=5e-6,
788
+ use_memory_warmup=False,
789
+ )
790
+
791
+ # 3. Initialize MRL Trainer
792
+ trainer = MRLTrainer(actor, critic, reward_model, device, mrl_config, algorithm, use_amp=True, dtype=torch.bfloat16)
793
+
794
+ # 4. Train with curriculum stages config
795
+ trainer(curriculum_stages, batch_size=batch_size)
796
+ ```
246
797
 
247
798
  Apache License
248
799
  Version 2.0, January 2004
@@ -22,7 +22,7 @@ rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
- rxnn/training/utils.py,sha256=QMNkJPQBY04DX9WN7GHnI2EZTBbAzWkjt2W-798oUII,6129
25
+ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
26
26
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
28
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.65.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.65.dist-info/METADATA,sha256=tIYHsYRYeZlVc4c7evGHH0pVqIh0jWoaJp3kEIUmL8c,25997
38
- rxnn-0.2.65.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.65.dist-info/RECORD,,
36
+ rxnn-0.2.67.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.67.dist-info/METADATA,sha256=LEIwAXp3Eau7DrEUCeJ5etTC6nl-rNzsQfJxiRXD7xI,49548
38
+ rxnn-0.2.67.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.67.dist-info/RECORD,,
File without changes
File without changes