rxnn 0.2.65__py3-none-any.whl → 0.2.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/utils.py
CHANGED
@@ -151,5 +151,9 @@ def get_gradient_norms(model: nn.Module):
|
|
151
151
|
param_norm = p.grad.data.norm(2)
|
152
152
|
total_norm += param_norm.item() ** 2
|
153
153
|
total_norm = total_norm ** 0.5
|
154
|
-
|
154
|
+
params_len = len(grad_params)
|
155
|
+
if params_len != 0:
|
156
|
+
mean_norm = total_norm / params_len
|
157
|
+
else:
|
158
|
+
mean_norm = 0.0
|
155
159
|
return total_norm, mean_norm
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.67
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -113,17 +113,17 @@ and [TensorBoard](https://github.com/tensorflow/tensorboard).
|
|
113
113
|
> if it's set to `False`, when `flash-attn` library is installed, **PyTorch** will try to use it implicitly through _SDPA backend_. It's better to set it
|
114
114
|
> to `False` and use automatically, because of better compatibility. Explicit options could be used for research
|
115
115
|
|
116
|
-
|
116
|
+
## Modules
|
117
117
|
**RxNN** framework has multiple modules with models, layers, training and inference tools, made for complete development
|
118
118
|
of _reactive models_, and could be also used for regular **Transformers**.
|
119
119
|
|
120
|
-
|
120
|
+
### Transformers
|
121
121
|
Transformers module includes classes for models and layers. It includes **Reactive Transformers** as well as **Classic Transformers**
|
122
122
|
|
123
123
|
Submodules:
|
124
124
|
- `rxnn.transformers.attention` - basic, most common attention layers - `MultiHeadAttention`, `GroupedQueryAttention` and `MultiQueryAttention`
|
125
125
|
- additional attention layers, especially `SparseQueryAttention` could be found in `rxnn.experimental.attention` module
|
126
|
-
- `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.
|
126
|
+
- `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.3.x version
|
127
127
|
- `rxnn.transformers.positional` - positional encoding layers - `RotaryPositionalEmbedding` and legacy ones - `AbsolutePositionalEmbedding`/`RelativePositionalEmbedding`
|
128
128
|
- `rxnn.transformers.ff` - dense feed forward layers, including gated layers (_SwiGLU_, etc.) - `FeedForward` & `GatedFeedForward` (recommended)
|
129
129
|
- `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
|
@@ -133,7 +133,6 @@ Submodules:
|
|
133
133
|
|
134
134
|
In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
|
135
135
|
to be compatible with HuggingFace **JSON** config. In example:
|
136
|
-
|
137
136
|
```python
|
138
137
|
from typing import TypedDict
|
139
138
|
import torch
|
@@ -193,11 +192,11 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
|
|
193
192
|
memory_cross_attention=GroupedQueryAttention(
|
194
193
|
config['embed_dim'],
|
195
194
|
config['att_heads'],
|
196
|
-
config['att_groups'],
|
195
|
+
config['cross_att_groups'] if 'cross_att_groups' in config else config['att_groups'],
|
197
196
|
rope=rope,
|
198
197
|
dropout=0.1,
|
199
198
|
max_seq_len=config['seq_len'],
|
200
|
-
is_causal=
|
199
|
+
is_causal=False,
|
201
200
|
rope_only_for_query=True
|
202
201
|
),
|
203
202
|
) for _ in range(config['num_layers'])
|
@@ -208,15 +207,99 @@ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
|
|
208
207
|
return self.model(x, attention_mask=attention_mask)
|
209
208
|
```
|
210
209
|
|
211
|
-
|
210
|
+
### Memory
|
212
211
|
The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
|
213
212
|
include **Long-Term Memory**.
|
214
213
|
|
215
214
|
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
|
216
215
|
|
217
|
-
|
216
|
+
#### Memory Attention Network
|
217
|
+
**Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
|
218
|
+
and residual connection (with optional gated residual). **Memory Attention Network** should have the same number of layers
|
219
|
+
as other components (encoder & decoder). It takes the results from each encoder layer (hidden states), and combine them
|
220
|
+
with actual memory state.
|
221
|
+
|
222
|
+
You can create your own **Memory Attention Network**, integrated with **HuggingFace Hub**, same way as reactive transformers:
|
223
|
+
```python
|
224
|
+
from typing import TypedDict
|
225
|
+
import torch
|
226
|
+
import torch.nn as nn
|
227
|
+
from huggingface_hub import PyTorchModelHubMixin
|
228
|
+
from rxnn.transformers.attention import GroupedQueryAttention
|
229
|
+
from rxnn.transformers.positional import RotaryPositionalEmbedding
|
230
|
+
from rxnn.memory.stm import ShortTermMemory
|
231
|
+
from rxnn.memory.attention import StmMemoryAttention
|
232
|
+
|
233
|
+
class YourMemoryAttentionConfig(TypedDict):
|
234
|
+
num_layers: int
|
235
|
+
vocab_size: int
|
236
|
+
embed_dim: int
|
237
|
+
ff_dim: int
|
238
|
+
att_heads: int
|
239
|
+
seq_len: int
|
240
|
+
stm_size: int
|
241
|
+
att_groups: int
|
242
|
+
|
243
|
+
class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
244
|
+
"""RxT-Alpha (Reactive Transformer) memory attention model"""
|
245
|
+
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
config: YourMemoryAttentionConfig,
|
249
|
+
**kwargs,
|
250
|
+
):
|
251
|
+
super(YourMemoryAttention, self).__init__(**kwargs)
|
252
|
+
|
253
|
+
rope = RotaryPositionalEmbedding(config['embed_dim'] // config['att_heads'], config['seq_len'])
|
254
|
+
# This separately initialized STM will be replaced by shared instance with `load_shared_memory` call
|
255
|
+
stm = ShortTermMemory(config['num_layers'], config['embed_dim'], config['stm_size'])
|
256
|
+
|
257
|
+
self.model = StmMemoryAttention(
|
258
|
+
stm,
|
259
|
+
attention_layers=nn.ModuleList([
|
260
|
+
GroupedQueryAttention(
|
261
|
+
config['embed_dim'],
|
262
|
+
config['att_heads'],
|
263
|
+
config['att_groups'],
|
264
|
+
rope=rope,
|
265
|
+
dropout=0.1,
|
266
|
+
is_causal=False,
|
267
|
+
rope_only_for_keys=True
|
268
|
+
) for _ in range(config['num_layers'])
|
269
|
+
]),
|
270
|
+
memory_norm_layers=nn.ModuleList([
|
271
|
+
nn.RMSNorm(config['embed_dim']) for _ in range(config['num_layers'])
|
272
|
+
]),
|
273
|
+
use_gated_residual=False, # memory attention residual gate
|
274
|
+
per_slot_gate=False, # gate per memory slot, otherwise it's per layer
|
275
|
+
init_gate=None, # initial value for gate weights
|
276
|
+
use_dynamic_gate=False, # dynamic gate calculated from weights and memory state, otherwise it's calculated only from weights
|
277
|
+
use_tanh_gate=False, # use tanh gate, otherwise it's sigmoid
|
278
|
+
)
|
279
|
+
|
280
|
+
def load_shared_memory(self, stm: ShortTermMemory):
|
281
|
+
self.model.stm = stm
|
282
|
+
|
283
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
284
|
+
return self.model(x, attention_mask=attention_mask)
|
285
|
+
```
|
286
|
+
|
287
|
+
> #### Gated residual
|
288
|
+
> Optional gated residual could be used to improve Memory Attention expressiveness. It's using gate (sigmoid or tanh)
|
289
|
+
> with trainable weights, to decide how much information from old and new updated memory state should be stored. Depending
|
290
|
+
> on params weights are declared per layer or per memory slot (more expressive). It could work in two modes, that could
|
291
|
+
> be switched, because they are using the same weights shape:
|
292
|
+
> - static - gate values calculated only from weights (`gate = torch.sigmoid(weights)`) - enable explicit control with `init_gate` param
|
293
|
+
> - dynamic - gate values calculated from weights and updated memory state (`gate = torch.sigmoid(weights * (new_layer_stm + layer_stm).mean(dim=-1, keepdim=True))`)
|
294
|
+
>
|
295
|
+
> Depending on `use_tanh_gate` param, final gated residual connection is calculated with different formulas:
|
296
|
+
> - sigmoid gate - `stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm`
|
297
|
+
> - tanh gate - `stm[i] = (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm`
|
298
|
+
> - tanh gate preserve residual connection scale, while sigmoid gate result is equivalent to `(new_layer_stm + layer_stm) / 2`
|
299
|
+
>
|
300
|
+
> **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
|
218
301
|
|
219
|
-
|
302
|
+
### Training
|
220
303
|
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
221
304
|
|
222
305
|
Submodules:
|
@@ -233,6 +316,9 @@ Submodules:
|
|
233
316
|
returns new subset and modifying existing one - i.e. `valid_dataset = train_dataset.get_subset(0.1)`
|
234
317
|
- for concatenated datasets, validation/test split could be created with `concat_from_hf_hub_with_subset` - it cuts the
|
235
318
|
same percentage of each loaded dataset
|
319
|
+
- each dataset has `pre_tokenize` method, to tokenize all items before the training (otherwise they are tokenized on
|
320
|
+
dynamically on item access). It's recommended for smaller datasets (fine-tuning, MRL, etc.) and not recommended for
|
321
|
+
very big datasets (pre-training), because it's using a lot of RAM (CPU)
|
236
322
|
- `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
|
237
323
|
- `rxnn.training.scheduler` includes learning rate scheduler for training
|
238
324
|
- `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
|
@@ -240,9 +326,474 @@ Submodules:
|
|
240
326
|
- `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
|
241
327
|
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
|
242
328
|
|
243
|
-
|
244
|
-
|
329
|
+
#### Base Model Learning
|
330
|
+
**Base Model Learning (BML)** module is made for both pre-training and fine-tuning base models, that will be used as components
|
331
|
+
in reactive models. Generally the only two differences between pre-training and supervised fine-tuning are different dataset
|
332
|
+
classes and trainer/callbacks hyperparams config.
|
333
|
+
|
334
|
+
Reactive models are based on transformer decoder and encoder, with shared embeddings and memory layers - it require special
|
335
|
+
handling in first training stages:
|
336
|
+
- layers connected with memory - **Memory Cross-Attention** are frozen during pre-training and fine-tuning, and they are
|
337
|
+
skipped by residual connections
|
338
|
+
- as encoder is able to learn little better embeddings, because of bidirectional modelling, it's pre-trained first, then
|
339
|
+
decoder is trained with frozen embeddings from encoder
|
340
|
+
- in **Reactive Transformer** fine-tuning, both encoder and decoder are fit to interaction format (single query and answer), the
|
341
|
+
training order is the same as for pre-training
|
342
|
+
- in **Preactor** architecture there are 2 encoders and single decoder. Encoders are fine-tuned from single pre-trained
|
343
|
+
encoder - first one is processing only queries and second one only the answers. More info soon
|
344
|
+
- in **Reactor** architecture there are 2 encoders and 2 decoders. Both encoders and decoders are fine-tuned from single
|
345
|
+
pre-trained encoder and decoder. Each component is fine-tuned to their specific task. More info soon
|
346
|
+
|
347
|
+
##### Pre-training
|
348
|
+
We have to start with importing required modules/libraries, initializing the models and loading the tokenized - I will
|
349
|
+
use _RxT-Alpha-Micro-Plus_ models as example:
|
350
|
+
```python
|
351
|
+
import torch
|
352
|
+
from rxnn.rxt.models import RxTAlphaDecoder, RxTAlphaEncoder
|
353
|
+
from rxnn.training.dataset import AutoregressiveLMDataset, MaskedLMDataset
|
354
|
+
from rxnn.training.bml import AutoregressiveTrainer, MLMTrainer
|
355
|
+
from rxnn.training.models import MLMHead, MLMTrainingModel
|
356
|
+
from rxnn.training.scheduler import get_transformer_lr_scheduler, calculate_steps
|
357
|
+
from rxnn.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback, JointModelSaveCallback
|
358
|
+
from rxnn.training.tokenizer import load_tokenizer_from_hf_hub
|
359
|
+
from rxnn.utils import set_random_seed, cache_clean
|
360
|
+
|
361
|
+
embed_dim = 128
|
362
|
+
vocab_size = 7_500
|
363
|
+
seq_len = 256
|
364
|
+
|
365
|
+
set_random_seed(42)
|
366
|
+
|
367
|
+
config = {
|
368
|
+
'num_layers': 10,
|
369
|
+
'vocab_size': vocab_size,
|
370
|
+
'embed_dim': embed_dim,
|
371
|
+
'att_heads': 16,
|
372
|
+
'att_groups': 8,
|
373
|
+
'seq_len': seq_len,
|
374
|
+
'stm_size': seq_len,
|
375
|
+
'use_flash_attention': False,
|
376
|
+
'use_gated': True,
|
377
|
+
'ff_dropout': 0.1,
|
378
|
+
'self_att_type': 'sqa',
|
379
|
+
'cross_att_type': 'sqa',
|
380
|
+
'att_query_groups': 8,
|
381
|
+
}
|
382
|
+
|
383
|
+
encoder_config = {
|
384
|
+
'ff_dim': 384,
|
385
|
+
**config
|
386
|
+
}
|
387
|
+
|
388
|
+
decoder_config = {
|
389
|
+
'ff_dim': 256,
|
390
|
+
'use_moe': True,
|
391
|
+
'num_experts': 20,
|
392
|
+
'moe_top_k': 4,
|
393
|
+
**config
|
394
|
+
}
|
395
|
+
|
396
|
+
encoder = RxTAlphaEncoder(**encoder_config)
|
397
|
+
decoder = RxTAlphaDecoder(**decoder_config)
|
398
|
+
head = MLMHead(embed_dim, vocab_size)
|
399
|
+
|
400
|
+
# Tokenizer is the same for encoder and decoder
|
401
|
+
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder', token='HF_TOKEN')
|
402
|
+
```
|
403
|
+
Then, we have to load MLM datasets, set callbacks and run encoder training:
|
404
|
+
```python
|
405
|
+
# 1. Load datasets
|
406
|
+
load_kwargs = {
|
407
|
+
'trust_remote_code': True
|
408
|
+
}
|
409
|
+
|
410
|
+
train_dataset = MaskedLMDataset.from_hf_hub('roneneldan/TinyStories', load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
|
411
|
+
valid_dataset = MaskedLMDataset.from_hf_hub('roneneldan/TinyStories', split="validation", load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
|
412
|
+
|
413
|
+
# 2. Select device
|
414
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
415
|
+
|
416
|
+
# 3. Clean GPU cache (optional)
|
417
|
+
cache_clean()
|
418
|
+
|
419
|
+
# 4. Set training config variables
|
420
|
+
batch_size = 256
|
421
|
+
epochs = 8
|
422
|
+
gradient_acc_steps = 1
|
423
|
+
peak_lr = 1e-3 * gradient_acc_steps
|
424
|
+
|
425
|
+
# 5. Get number of steps for scheduler
|
426
|
+
steps_config = calculate_steps(len(train_dataset), epochs, batch_size, warmup_ratio=0.05, verbose=True)
|
427
|
+
steps_per_epoch, total_steps, warmup_steps = steps_config['epoch'], steps_config['total'], steps_config['warmup']
|
428
|
+
|
429
|
+
# 6. Freeze memory cross-attention layers
|
430
|
+
encoder.freeze_memory()
|
431
|
+
|
432
|
+
# 7. Select directory for TensorBoard logs
|
433
|
+
logs_dir = './micro/tensorboard_logs/encoder-plus-sft'
|
434
|
+
|
435
|
+
# 8. Basic callbacks - print loss, accuracy and number of processed tokens
|
436
|
+
print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
|
437
|
+
count_cb = TokenCounterCallback(3_000_000_000)
|
438
|
+
acc_cb = PrintAccuracyCallback()
|
439
|
+
|
440
|
+
# 9. Joint model save callback - used to save encoder and MLM head, and push them to HuggingFace Hub
|
441
|
+
save_cb = JointModelSaveCallback(
|
442
|
+
'./micro/encoder-plus-sft',
|
443
|
+
push_to_hub=True,
|
444
|
+
hub_model_decoder=None,
|
445
|
+
hub_model_encoder='Your encoder model id',
|
446
|
+
hub_model_head='Your mlm model id',
|
447
|
+
push_checkpoint_weights=True, # push epoch checkpoints to hub
|
448
|
+
final_commit_message='Final commit message',
|
449
|
+
private_repo=False, # use HF private repository
|
450
|
+
save_checkpoint_after_n_batches=1000, # save model after N batches in epoch (batch checkpoint)
|
451
|
+
push_batch_checkpoint=True, # push batch checkpoints to HF Hub
|
452
|
+
mlm_mode=True, # use MLM mode
|
453
|
+
hf_token='HF_TOKEN',
|
454
|
+
use_ddp=False, # use distributed training mode
|
455
|
+
)
|
456
|
+
|
457
|
+
# 10. Init training model - encoder + head
|
458
|
+
model = MLMTrainingModel(encoder, head)
|
459
|
+
|
460
|
+
# 11. Init MLM Trainer
|
461
|
+
trainer = MLMTrainer(
|
462
|
+
model,
|
463
|
+
device,
|
464
|
+
dataset=train_dataset,
|
465
|
+
validation_dataset=valid_dataset,
|
466
|
+
vocab_size=vocab_size,
|
467
|
+
callbacks=[print_cb, acc_cb, count_cb, save_cb],
|
468
|
+
use_amp=True, # use autocast
|
469
|
+
dtype=torch.bfloat16, # data type for training
|
470
|
+
log_dir=logs_dir,
|
471
|
+
use_ddp=False, # use distributed training mode
|
472
|
+
)
|
473
|
+
|
474
|
+
# 12. Init optimizer and cosine annealing scheduler
|
475
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.02)
|
476
|
+
scheduler = get_transformer_lr_scheduler(
|
477
|
+
optimizer,
|
478
|
+
warmup_steps=warmup_steps,
|
479
|
+
num_training_steps=total_steps
|
480
|
+
)
|
481
|
+
|
482
|
+
# 13. Run the training for the selected number of epochs
|
483
|
+
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
|
484
|
+
```
|
485
|
+
After the encoder's training, we have to train decoder:
|
486
|
+
```python
|
487
|
+
# 1. Load datasets
|
488
|
+
load_kwargs = {
|
489
|
+
'trust_remote_code': True
|
490
|
+
}
|
491
|
+
|
492
|
+
train_dataset = AutoregressiveLMDataset.from_hf_hub('roneneldan/TinyStories', load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
|
493
|
+
valid_dataset = AutoregressiveLMDataset.from_hf_hub('roneneldan/TinyStories', split="validation", load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
|
494
|
+
|
495
|
+
# 2. Load shared embedding and memory, then freeze embedding and memory cross-attention
|
496
|
+
decoder.load_shared_embedding(encoder.model.embedding)
|
497
|
+
decoder.load_shared_memory(encoder.model.stm)
|
498
|
+
|
499
|
+
decoder.model.embedding.requires_grad_(False)
|
500
|
+
decoder.freeze_memory()
|
501
|
+
|
502
|
+
# 3. Clean GPU cache (optional)
|
503
|
+
cache_clean()
|
504
|
+
|
505
|
+
# 4. Set training config variables
|
506
|
+
batch_size = 256
|
507
|
+
epochs = 8
|
508
|
+
gradient_acc_steps = 1
|
509
|
+
peak_lr = 1e-3 * gradient_acc_steps
|
510
|
+
|
511
|
+
# 5. Get number of steps for scheduler
|
512
|
+
steps_config = calculate_steps(len(train_dataset), epochs, batch_size, warmup_ratio=0.05, verbose=True)
|
513
|
+
steps_per_epoch, total_steps, warmup_steps = steps_config['epoch'], steps_config['total'], steps_config['warmup']
|
514
|
+
|
515
|
+
# 6. Select directory for TensorBoard logs
|
516
|
+
logs_dir = './micro/tensorboard_logs/decoder-plus-sft'
|
517
|
+
|
518
|
+
# 7. Basic callbacks - print loss, accuracy and number of processed tokens
|
519
|
+
print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
|
520
|
+
count_cb = TokenCounterCallback(5_000_000_000)
|
521
|
+
acc_cb = PrintAccuracyCallback()
|
522
|
+
|
523
|
+
# 8. Model save callback - used to save decoder and push it to HuggingFace Hub
|
524
|
+
save_cb = ModelSaveCallback(
|
525
|
+
'./micro/decoder-plus-sft',
|
526
|
+
push_to_hub=True,
|
527
|
+
hub_model_id='Your decoder model id',
|
528
|
+
push_checkpoint_weights=True, # push epoch checkpoints to hub
|
529
|
+
final_commit_message='Final commit message',
|
530
|
+
private_repo=False, # use HF private repository
|
531
|
+
save_checkpoint_after_n_batches=1000, # save model after N batches in epoch (batch checkpoint)
|
532
|
+
push_batch_checkpoint=True, # push batch checkpoints to HF Hub
|
533
|
+
hf_token='HF_TOKEN',
|
534
|
+
use_ddp=False, # use distributed training mode
|
535
|
+
)
|
536
|
+
|
537
|
+
# 9. Init Autoregressive Trainer
|
538
|
+
trainer = AutoregressiveTrainer(
|
539
|
+
decoder,
|
540
|
+
device,
|
541
|
+
dataset=train_dataset,
|
542
|
+
validation_dataset=valid_dataset,
|
543
|
+
vocab_size=vocab_size,
|
544
|
+
callbacks=[print_cb, acc_cb, count_cb, save_cb],
|
545
|
+
use_amp=True,
|
546
|
+
dtype=torch.bfloat16,
|
547
|
+
log_dir=logs_dir,
|
548
|
+
use_moe_aux_loss=True, # Add MoE Router auxiliary loss to main loss
|
549
|
+
moe_aux_loss_scale=0.02, # MoE Router aux loss scale
|
550
|
+
use_ddp=False, # use distributed training mode
|
551
|
+
)
|
552
|
+
|
553
|
+
# 10. Init optimizer and cosine annealing scheduler
|
554
|
+
optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.02)
|
555
|
+
scheduler = get_transformer_lr_scheduler(
|
556
|
+
optimizer,
|
557
|
+
warmup_steps=warmup_steps,
|
558
|
+
num_training_steps=total_steps
|
559
|
+
)
|
560
|
+
|
561
|
+
# 11. Run the training for the selected number of epochs
|
562
|
+
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
|
563
|
+
```
|
564
|
+
|
565
|
+
##### Fine-tuning
|
566
|
+
For _**Interaction Supervised Fine-Tuning**_, the code is almost the same as for pre-training, with some small changes.
|
567
|
+
|
568
|
+
First, we have to load pre-trained models, instead of initializing them with configs:
|
569
|
+
```python
|
570
|
+
encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder', token='HF_TOKEN')
|
571
|
+
decoder = RxTAlphaDecoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
|
572
|
+
head = MLMHead.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-MLM', token='HF_TOKEN')
|
573
|
+
```
|
574
|
+
|
575
|
+
Then, we have to change the datasets loading part. For encoder:
|
576
|
+
```python
|
577
|
+
# 1. Load datasets
|
578
|
+
train_dataset = EncoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', tokenizer=tokenizer, max_seq_len=seq_len)
|
579
|
+
valid_dataset = EncoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', split="validation", tokenizer=tokenizer, max_seq_len=seq_len)
|
580
|
+
|
581
|
+
# 2. Pre-tokenize dataset with verbose logging (optional)
|
582
|
+
train_dataset.pre_tokenize(verbose=True, log_interval=5000)
|
583
|
+
valid_dataset.pre_tokenize(verbose=True, log_interval=1000)
|
584
|
+
```
|
585
|
+
And the same for decoder:
|
586
|
+
```python
|
587
|
+
# 1. Load datasets
|
588
|
+
train_dataset = DecoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', tokenizer=tokenizer, max_seq_len=seq_len)
|
589
|
+
valid_dataset = DecoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', split="validation", tokenizer=tokenizer, max_seq_len=seq_len)
|
590
|
+
|
591
|
+
# 2. Pre-tokenize dataset with verbose logging (optional)
|
592
|
+
train_dataset.pre_tokenize(verbose=True, log_interval=5000)
|
593
|
+
valid_dataset.pre_tokenize(verbose=True, log_interval=1000)
|
594
|
+
```
|
595
|
+
|
596
|
+
We could also add early stoppage callback:
|
597
|
+
```python
|
598
|
+
from rxnn.training.callbacks import EarlyStoppageCallback
|
245
599
|
|
600
|
+
stop_cb = EarlyStoppageCallback(num_plateau_epochs=5)
|
601
|
+
```
|
602
|
+
|
603
|
+
Additionally, in fine-tuning we will rather use different config for number of epochs, steps, learning rate, etc.
|
604
|
+
|
605
|
+
> #### Classic Transformer Training
|
606
|
+
> The same code could be used also to train classic decoder-only or encoder-only transformers, the only difference is
|
607
|
+
> that they don't require memory cross-attention freezing.
|
608
|
+
|
609
|
+
##### Joint Training
|
610
|
+
There are also `JointLMDataset` and `JointLMTrainer` classes to train encoder and decoder at once. In that case, embeddings
|
611
|
+
are updated from both encoder and decoder optimization. However, I noticed some issues with balancing training in that mode,
|
612
|
+
so it's **not recommended** now, until it will be tested and fixed
|
613
|
+
|
614
|
+
#### Memory Reinforcement Learning
|
615
|
+
**Memory Reinforcement Learning (MRL)** is the most important training stage for reactive model's **Attention-Based Memory System**.
|
616
|
+
In this stage we are training model to remember information between multiple interactions, with different curriculum stage
|
617
|
+
configs. Theoretical foundations are described in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/mrl.md).
|
618
|
+
|
619
|
+
> **MRL** algorithm is currently in tests and still a lot of things could be changed!
|
620
|
+
|
621
|
+
In practice, algorithm has over 50 hyperparams, so it require careful handling. We start from importing modules, loading
|
622
|
+
pre-trained models from SFT stage, initializing new Memory Attention, and actor and critic models:
|
623
|
+
```python
|
624
|
+
import torch
|
625
|
+
from rxnn.rxt.models import RxTAlphaDecoder, RxTAlphaEncoder, RxTAlphaMemoryAttention
|
626
|
+
from rxnn.training.tokenizer import load_tokenizer_from_hf_hub
|
627
|
+
from rxnn.training.dataset import MrlDatasets
|
628
|
+
from rxnn.training.models import MrlActorModel, MrlCriticModel
|
629
|
+
from rxnn.training.reward import MrlRewardModel
|
630
|
+
from rxnn.training.mrl import MRLTrainer, CurriculumConfig, MrlStrategy, MrlConfig
|
631
|
+
from rxnn.training.rl import PPOAlgorithm, PPOConfig
|
632
|
+
from rxnn.training.callbacks import MrlPrintCallback, MrlEarlyStoppageCallback, MrlModelSaveCallback, MrlGeneratedTokensCallback
|
633
|
+
from rxnn.utils import set_random_seed
|
634
|
+
|
635
|
+
# 1. Set random seed, batch size and embed dim
|
636
|
+
set_random_seed(42)
|
637
|
+
batch_size = 64
|
638
|
+
embed_dim = 128
|
639
|
+
|
640
|
+
# 2. Get pre-trained microscale PoC models
|
641
|
+
decoder = RxTAlphaDecoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder-SFT', token='HF_TOKEN')
|
642
|
+
encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder-SFT', token='HF_TOKEN')
|
643
|
+
# 3. Init Memory Attention Network
|
644
|
+
mem_attn = RxTAlphaMemoryAttention(
|
645
|
+
num_layers=10,
|
646
|
+
embed_dim=embed_dim,
|
647
|
+
att_heads=8,
|
648
|
+
seq_len=256,
|
649
|
+
stm_size=256,
|
650
|
+
use_flash_attention=False,
|
651
|
+
norm_type='classic-rms',
|
652
|
+
att_groups=4,
|
653
|
+
att_type='sqa',
|
654
|
+
att_query_groups=4,
|
655
|
+
)
|
656
|
+
|
657
|
+
# 4. Load shared embedding and memory from encoder to other models
|
658
|
+
decoder.load_shared_embedding(encoder.model.embedding)
|
659
|
+
encoder.model.stm.batched_memory(batch_size=batch_size, init_type='standard')
|
660
|
+
decoder.load_shared_memory(encoder.model.stm)
|
661
|
+
mem_attn.load_shared_memory(encoder.model.stm)
|
662
|
+
|
663
|
+
# 5. Init Actor model
|
664
|
+
actor = MrlActorModel(encoder, decoder, mem_attn)
|
665
|
+
|
666
|
+
# 6. Get pre-trained encoder, extend its context size, freeze memory and use as a body for Critic model
|
667
|
+
critic_encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder-SFT', token='HF_TOKEN')
|
668
|
+
|
669
|
+
critic_encoder.update_max_len(512)
|
670
|
+
critic_encoder.freeze_memory()
|
671
|
+
# 7. Init Critic model
|
672
|
+
critic = MrlCriticModel(critic_encoder, embed_dim)
|
673
|
+
```
|
674
|
+
|
675
|
+
Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_:
|
676
|
+
```python
|
677
|
+
# 1. Load tokenizer
|
678
|
+
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
|
679
|
+
|
680
|
+
# 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range
|
681
|
+
mrl_datasets = MrlDatasets.from_hf_hub(
|
682
|
+
'ReactiveAI/TinyStories-MRL',
|
683
|
+
tokenizer,
|
684
|
+
mrl_curriculum_steps=[
|
685
|
+
{ 'subset_name': 'steps-4', 'steps': 4, 'is_long_range': False },
|
686
|
+
{ 'subset_name': 'steps-6', 'steps': 6, 'is_long_range': False },
|
687
|
+
{ 'subset_name': 'steps-8', 'steps': 8, 'is_long_range': False },
|
688
|
+
{ 'subset_name': 'steps-8-lr', 'steps': 8, 'is_long_range': True },
|
689
|
+
{ 'subset_name': 'steps-12', 'steps': 12, 'is_long_range': True },
|
690
|
+
{ 'subset_name': 'steps-16', 'steps': 16, 'is_long_range': True },
|
691
|
+
],
|
692
|
+
eval_split='validation',
|
693
|
+
max_seq_len=256,
|
694
|
+
)
|
695
|
+
|
696
|
+
# 3. Create curriculum stages config
|
697
|
+
curriculum_stages = [CurriculumConfig(
|
698
|
+
steps=item['steps'],
|
699
|
+
epochs=10 if item['steps'] == 4 else 8 if item['steps'] == 8 and item['is_long_range'] else 5,
|
700
|
+
dataset=item['dataset'],
|
701
|
+
eval_dataset=item['eval_dataset'],
|
702
|
+
callbacks=[
|
703
|
+
MrlPrintCallback(),
|
704
|
+
MrlModelSaveCallback(
|
705
|
+
'./models', push_to_hub=True, hub_model_critic='ReactiveAI/RxT-Alpha-Micro-Critic-MRL',
|
706
|
+
hub_model_decoder='ReactiveAI/RxT-Alpha-Micro-Decoder-MRL', hub_model_encoder='ReactiveAI/RxT-Alpha-Micro-Encoder-MRL',
|
707
|
+
hub_model_memory_attention='ReactiveAI/RxT-Alpha-Micro-MemAtt-MRL', private_repo=True,
|
708
|
+
hf_token='HF_TOKEN', final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
|
709
|
+
push_checkpoint_weights=True,
|
710
|
+
)
|
711
|
+
],
|
712
|
+
strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY,
|
713
|
+
unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4),
|
714
|
+
random_resets=item['steps'] > 4,
|
715
|
+
random_resets_from=2,
|
716
|
+
random_resets_ratio=0.4 if item['steps'] != 4 else None,
|
717
|
+
separate_memory_lr=True,
|
718
|
+
memory_lr=6e-4 if item['steps'] == 4 else 4e-4 if item['steps'] == 8 and item['is_long_range'] else None,
|
719
|
+
lr=3e-4 if item['steps'] == 4 else 2e-4 if item['steps'] == 8 and item['is_long_range'] else None,
|
720
|
+
critic_lr=4e-4 if item['steps'] == 4 else None,
|
721
|
+
critic_encoder_lr=2e-4 if item['steps'] == 4 else None,
|
722
|
+
teacher_forcing=True if item['steps'] <= 8 else False,
|
723
|
+
) for item in mrl_datasets]
|
724
|
+
```
|
725
|
+
|
726
|
+
After that, we have to configure reward model. It's based on BLEU scores and cosine similarity between generated answers
|
727
|
+
and saved data from previous steps and reference answers from dataset. Cosine similarity is also calculated from running
|
728
|
+
mean embedding of previous steps. Reward model also includes optional length reward. It's config includes a lot of option
|
729
|
+
to set different factors for different reward parts.
|
730
|
+
```python
|
731
|
+
# 1. Init GPU device
|
732
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
733
|
+
|
734
|
+
# 2. Create reward model
|
735
|
+
reward_model = MrlRewardModel(
|
736
|
+
encoder.model.embedding,
|
737
|
+
device,
|
738
|
+
bleu_with_saved_data=True,
|
739
|
+
reward_len=True,
|
740
|
+
neg_reward_len=True,
|
741
|
+
target_len_as_ref=True,
|
742
|
+
bleu_factor=0.4,
|
743
|
+
cos_factor=0.5,
|
744
|
+
len_factor=0.1,
|
745
|
+
bleu_ref_factor=0.4,
|
746
|
+
bleu_saved_factor=0.6,
|
747
|
+
cos_ref_factor=0.35,
|
748
|
+
cos_saved_factor=0.65,
|
749
|
+
neg_bleu_factor=0.45,
|
750
|
+
neg_cos_factor=0.45,
|
751
|
+
neg_cos_ref_factor=0.3,
|
752
|
+
neg_cos_saved_factor=0.7,
|
753
|
+
neg_bleu_ref_factor=0.3,
|
754
|
+
neg_bleu_saved_factor=0.7,
|
755
|
+
multi_cos_ref_factor=0.3,
|
756
|
+
multi_cos_saved_factor= 0.5,
|
757
|
+
multi_cos_running_mean_factor = 0.2,
|
758
|
+
bleu_ref_weights=(0.2, 0.2, 0.3, 0.3),
|
759
|
+
bleu_saved_weights=(0.2, 0.2, 0.3, 0.3),
|
760
|
+
tanh_reward_scale=False,
|
761
|
+
rewards_scale=1.0,
|
762
|
+
)
|
763
|
+
```
|
764
|
+
|
765
|
+
And finally, we could create the MRL Trainer with RL algorithm (currently only PPO available) and start the training:
|
766
|
+
```python
|
767
|
+
# 1. Init PPO Algorithm
|
768
|
+
algorithm = PPOAlgorithm(
|
769
|
+
PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
|
770
|
+
)
|
771
|
+
|
772
|
+
# 2. Create config for MRLTrainer
|
773
|
+
mrl_config = MrlConfig(
|
774
|
+
lr=1e-4,
|
775
|
+
critic_lr=2e-4,
|
776
|
+
critic_encoder_lr=1e-4,
|
777
|
+
separate_memory_lr=True,
|
778
|
+
memory_lr=3e-4,
|
779
|
+
max_seq_len=256,
|
780
|
+
critic_max_len=512,
|
781
|
+
weight_decay=0.01,
|
782
|
+
critic_weight_decay=0.01,
|
783
|
+
update_epochs=10,
|
784
|
+
pad_token_id=0,
|
785
|
+
end_token_id=3,
|
786
|
+
use_moe_aux_loss=True,
|
787
|
+
embedding_lr=5e-6,
|
788
|
+
use_memory_warmup=False,
|
789
|
+
)
|
790
|
+
|
791
|
+
# 3. Initialize MRL Trainer
|
792
|
+
trainer = MRLTrainer(actor, critic, reward_model, device, mrl_config, algorithm, use_amp=True, dtype=torch.bfloat16)
|
793
|
+
|
794
|
+
# 4. Train with curriculum stages config
|
795
|
+
trainer(curriculum_stages, batch_size=batch_size)
|
796
|
+
```
|
246
797
|
|
247
798
|
Apache License
|
248
799
|
Version 2.0, January 2004
|
@@ -22,7 +22,7 @@ rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
|
22
22
|
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
|
-
rxnn/training/utils.py,sha256=
|
25
|
+
rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.67.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.67.dist-info/METADATA,sha256=LEIwAXp3Eau7DrEUCeJ5etTC6nl-rNzsQfJxiRXD7xI,49548
|
38
|
+
rxnn-0.2.67.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.67.dist-info/RECORD,,
|
File without changes
|
File without changes
|