rxnn 0.2.66__py3-none-any.whl → 0.2.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1102 @@
1
+ Metadata-Version: 2.3
2
+ Name: rxnn
3
+ Version: 0.2.68
4
+ Summary: RxNN: Reactive Neural Networks Platform
5
+ License: Apache-2.0
6
+ Keywords: deep-learning,ai,machine-learning
7
+ Author: Adam Filipek
8
+ Author-email: adamfilipek@rxai.dev
9
+ Requires-Python: >=3.10
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Dist: datasets (>=3.5.0,<4.0.0)
17
+ Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
18
+ Requires-Dist: nltk (>=3.9.1,<4.0.0)
19
+ Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
20
+ Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
21
+ Requires-Dist: torch (>=2.6.0,<3.0.0)
22
+ Requires-Dist: transformers (>=4.51.0,<5.0.0)
23
+ Project-URL: Homepage, https://rxai.dev/rxnn
24
+ Project-URL: Repository, https://github.com/RxAI-dev/rxnn/python
25
+ Description-Content-Type: text/markdown
26
+
27
+ <span>
28
+ <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxai_v2.png" width="400" />
29
+ <img src="https://raw.githubusercontent.com/RxAI-dev/RxNN/refs/heads/main/assets/logo/logo_rxnn_v2.png" width="400" />
30
+ </span>
31
+
32
+ # Reactive AI - RxNN
33
+ ## Reactive Neural Networks Platform
34
+
35
+ RxNN is AI/Deep Learning development platform made for Reactive Neural Networks and Event-driven AI, introduced by Reactive AI.
36
+
37
+ ## Reactive Neural Networks and Event-driven AI
38
+ Reactive neural networks (RxNN) are a new family of memory-augmented neural networks that combine classical deep learning
39
+ algorithms with reactive communication patterns. In Event-driven AI, input data (sequence) is treated as event, and memory
40
+ state has to be kept between events/interactions. Technically, it's a specific kind of RNN that's storing data between
41
+ processed sequences, instead of between sequence elements like in regular RNN. Then, their recurrence is on a higher level.
42
+ In the case of reactive communication patterns, RxRNNs are stateful reactive data sources that you have to connect before
43
+ you can send and receive messages.
44
+ While RxNNs are using some RNN concepts, they are rather made to extend Transformer language/multi-modal models. In our
45
+ opinion, the biggest downside of current LLMs is their stateless nature - conversational models have to process full chat
46
+ history on every interaction! That's not real-time processing, and it's not how human's awareness is working. In RxNN based
47
+ transformers, model is processing single messages, while all the previous interactions history should be saved and read
48
+ from memory. That features are required for **Weak** Reactive Neural Networks specification, and it will be the first major
49
+ step in transition from language models to awareness models - in Reactive AI ecosystem, it will be introduced in Reactive
50
+ Transformer architecture.
51
+
52
+ Additionally, to achieve awareness, **Strong** Reactive Neural Networks are working in reactive infinite reasoning loop,
53
+ that's generating Infinite Chain-of-Thoughts and is communicating in push-based mode (model decides if and when return output).
54
+
55
+ Reactive communication patterns in RxNN models are adapted to handle asynchronous nature of model - after it finish generating
56
+ sequence, it has to process it and save it in memory, but it could be done in background.
57
+
58
+ ## Release plan
59
+ We are working on three new reactive architectures, that progressively advance from language models to awareness models:
60
+ - **Reactive Transformer**: Reactive Language Model (RLM) with Short-Term Memory. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md)
61
+ - **Preactor**: extending Reactive Transformer with additional Long-Term Memory, providing theoretically infinite context (only
62
+ single message length is limited) and the ability to learn from interactions (Live Learning)
63
+ - **Reactor**: AGI awareness model & Strong Reactive Neural Network, that's working in infinite reasoning loop and doesn't require explicit human commands
64
+
65
+ Each new architecture is based on the previous one and adding new features/abilities. They will be progressively
66
+ released with next versions of **RxNN** framework:
67
+ - 0.1.x (Released): Reactive Transformer base models, Base Model Learning (pre-training/fine-tuning) & Transformers extensions (MoE Attention, Short-Term Memory, etc.)
68
+ - 0.2.x (Released): Memory Reinforcement Learning (MRL) for Short-Term Memory & Reactive Transformer, Attention-based Memory System details
69
+ - 0.3.x: Reinforcement Learning from Human Feedback for Reactive models (RxRLHF), basic Tensor Reactive
70
+ Extensions (TRX/Rust) for full Reactive Transformer, RxT-Alpha release (+following models - RxT-Beta, etc.)
71
+ - 0.4.x: Preactor base models, Tensor Database (TDB/Rust) for Long-Term Memory, mxRAG/revRAG subsystems
72
+ - 0.5.x: MRL for Long-Term Memory & Preactor, Live Learning for Preactor, PRx-Alpha release (+following models - PRx-Beta, etc.)
73
+ - 0.6.x: Reactor base models, TRX full implementation, Receptors & Effectors Reactive RNNs
74
+ - 0.7.x: Behavioral Reinforcement Learning (BRL) for Reactor's Infinite Chain-of-Thoughts, Continuous Live Learning for Reactor
75
+ - 0.8.x: Rx-Alpha release
76
+ - 0.9.x: Rx-Beta release
77
+ - 1.0.0: Reactor AGI official release (Expert, Assistant & Utility class models)
78
+ - 1.x.x: Multimodal reactive models (could be released earlier, depending on progress)
79
+ - 2.0.0: Real-Time Vision Reactor - Worker class models
80
+ - x.x.x: ...and more!
81
+
82
+ ## Usage
83
+ **RxNN** is made to train models based on reactive architectures, as well as transformer language models. Current version
84
+ is based on PyTorch and HuggingFace libraries (Transformers/Datasets/Tokenizer/Hub), and is integrated with [HuggingFace Hub](https://hugginface.co)
85
+ and [TensorBoard](https://github.com/tensorflow/tensorboard).
86
+
87
+ > We are also planning a version for **TensorFlow**, more info soon
88
+
89
+ ### Install library and dependencies
90
+ - RxNN and required deps: `pip install rxnn torch transformers tokenizers huggingface_hub`
91
+ - Datasets are required only for training: `pip install datasets`
92
+ - TensorBoard is optional: `pip install tensorboard`
93
+ - [Flash Attention](https://github.com/Dao-AILab/flash-attention) is recommended for faster training/inference (required for models with explicit `use_flash_attention=True`) - check its separate [installation guide](#installing-flash-attention)
94
+ - **NumPy** should be installed too: `pip install numpy`
95
+
96
+ > ### Installing Flash Attention
97
+ > Installing `flash-attn` could be very frustrating and may take hours (with standard method), only to result in some incompatibility
98
+ > error. Fortunately, the prebuilt versions could be downloaded from GitHub and installed just in seconds. However, you should choose
99
+ > the compatible version based on:
100
+ > - Python version
101
+ > - CUDA version
102
+ > - PyTorch version (2.7 is currently not supported)
103
+ > - ABI
104
+ >
105
+ > #### Steps
106
+ > 1. Choose your version from [https://github.com/Dao-AILab/flash-attention/releases](https://github.com/Dao-AILab/flash-attention/releases)
107
+ > 2. Download prebuilt release, in example: `wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl`
108
+ > 3. Install it, in example: `pip install --no-dependencies --upgrade flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl`
109
+ > 4. Verify: `flash_attn.__version__` (an incorrect version will cause the error when importing)
110
+ >
111
+ > #### Note on `use_flash_attention` option in models/layers
112
+ > Explicit `use_flash_attention` option is made to enable direct calls to `flash_attn_func` without using **PyTorch** `scaled_dot_product_attention`. Even
113
+ > if it's set to `False`, when `flash-attn` library is installed, **PyTorch** will try to use it implicitly through _SDPA backend_. It's better to set it
114
+ > to `False` and use automatically, because of better compatibility. Explicit options could be used for research
115
+
116
+ ## Modules
117
+ **RxNN** framework has multiple modules with models, layers, training and inference tools, made for complete development
118
+ of _reactive models_, and could be also used for regular **Transformers**.
119
+
120
+ ### Transformers
121
+ Transformers module includes classes for models and layers. It includes **Reactive Transformers** as well as **Classic Transformers**
122
+
123
+ Submodules:
124
+ - `rxnn.transformers.attention` - basic, most common attention layers - `MultiHeadAttention`, `GroupedQueryAttention` and `MultiQueryAttention`
125
+ - additional attention layers, especially `SparseQueryAttention` could be found in `rxnn.experimental.attention` module
126
+ - `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.3.x version
127
+ - `rxnn.transformers.positional` - positional encoding layers - `RotaryPositionalEmbedding` and legacy ones - `AbsolutePositionalEmbedding`/`RelativePositionalEmbedding`
128
+ - `rxnn.transformers.ff` - dense feed forward layers, including gated layers (_SwiGLU_, etc.) - `FeedForward` & `GatedFeedForward` (recommended)
129
+ - `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
130
+ - `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
131
+ - `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
132
+ - `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler`, `SampleDecoder`, `BatchSampler` & `BatchSampleDecoder`
133
+
134
+ In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
135
+ to be compatible with HuggingFace **JSON** config. In example:
136
+ ```python
137
+ from typing import TypedDict
138
+ import torch
139
+ import torch.nn as nn
140
+ from huggingface_hub import PyTorchModelHubMixin
141
+ from rxnn.transformers.attention import GroupedQueryAttention
142
+ from rxnn.transformers.positional import RotaryPositionalEmbedding
143
+ from rxnn.transformers.layers import ReactiveTransformerLayer
144
+ from rxnn.transformers.models import ReactiveTransformerDecoder
145
+ from rxnn.memory.stm import ShortTermMemory
146
+
147
+ class YourReactiveTransformerConfig(TypedDict):
148
+ num_layers: int
149
+ vocab_size: int
150
+ embed_dim: int
151
+ ff_dim: int
152
+ att_heads: int
153
+ seq_len: int
154
+ stm_size: int
155
+ att_groups: int
156
+ cross_att_groups: int
157
+
158
+
159
+ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
160
+ def __init__(
161
+ self,
162
+ config: YourReactiveTransformerConfig,
163
+ **kwargs
164
+ ):
165
+ super(YourReactiveTransformerDecoder, self).__init__(**kwargs)
166
+
167
+ embedding = nn.Embedding(config['vocab_size'], config['embed_dim'])
168
+ rope = RotaryPositionalEmbedding(config['embed_dim'] // config['att_heads'], config['seq_len'])
169
+ stm = ShortTermMemory(config['num_layers'], config['embed_dim'], config['stm_size'])
170
+
171
+ self.model = ReactiveTransformerDecoder(
172
+ stm=stm,
173
+ embedding=embedding,
174
+ own_layers=nn.ModuleList([
175
+ ReactiveTransformerLayer(
176
+ config['embed_dim'],
177
+ config['ff_dim'],
178
+ use_gated=True,
179
+ use_moe=False,
180
+ ff_activation=nn.GELU(),
181
+ ff_dropout=0.1,
182
+ use_rms_norm=True,
183
+ self_attention=GroupedQueryAttention(
184
+ config['embed_dim'],
185
+ config['att_heads'],
186
+ config['att_groups'],
187
+ rope=rope,
188
+ dropout=0.1,
189
+ max_seq_len=config['seq_len'],
190
+ is_causal=True,
191
+ ),
192
+ memory_cross_attention=GroupedQueryAttention(
193
+ config['embed_dim'],
194
+ config['att_heads'],
195
+ config['cross_att_groups'] if 'cross_att_groups' in config else config['att_groups'],
196
+ rope=rope,
197
+ dropout=0.1,
198
+ max_seq_len=config['seq_len'],
199
+ is_causal=False,
200
+ rope_only_for_query=True
201
+ ),
202
+ ) for _ in range(config['num_layers'])
203
+ ])
204
+ )
205
+
206
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
207
+ return self.model(x, attention_mask=attention_mask)
208
+ ```
209
+
210
+ #### RxT-Alpha
211
+ `RxTAlphaEncoder` and `RxTAlphaDecoder` are ready to use **Reactive Transformer** components, compatible with Hugging Face
212
+ Hub (the above example is based on their code), so it could be used instead of creating custom class. Example usage could
213
+ be found in [pre-training docs](#pre-training)
214
+
215
+ ### Memory
216
+ The _memory_ module includes **Short-Term Memory (STM)** and layers responsible for its update. In future versions it will also
217
+ include **Long-Term Memory (LTM)**.
218
+
219
+ #### Short Term Memory
220
+ The main `ShortTermMemory` class is located in `rxnn.memory.stm` module. As described in [Reactive Transformer research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/reactive-transformer.md),
221
+ each transformer (encoder and decoder) layer has its own **STM** layer of shape `[batch_size, stm_size, embed_dim]`. Initially,
222
+ for the first training stages (pre-training and supervised fine-tuning), **STM** is in "single/no batch" mode (`batch_size = 1`),
223
+ because it's not used. For reinforcement learning stages (**MRL/RxRLHF/BRL**), we have to switch short-term memory to batch
224
+ mode, because items in batches are independent. After training, we could switch back to "single/no batch" mode. Example:
225
+ ```python
226
+ from rxnn.memory.stm import ShortTermMemory
227
+
228
+ num_layers = 10
229
+ stm_size = 256
230
+ embed_dim = 128
231
+ batch_size = 32
232
+
233
+ # 1. Init STM
234
+ stm = ShortTermMemory(
235
+ num_layers, embed_dim, stm_size,
236
+ init_type='normal' # memory init type, 'normal' is default and means normal distribution with 0.0 mean and 0.02 std
237
+ )
238
+
239
+ # 2. Set "batch" mode for MRL
240
+ stm.batched_memory(
241
+ batch_size,
242
+ init_type='standard' # init type could be changed for batch mode, 'standard' is normal distribution with 0.0 mean and 1.0 std
243
+ )
244
+
245
+ # 3. Reset STM with optional init type change
246
+ stm.reset(init_type='uniform') # init type could be also 'ones' or 'zeros', but it's not recommended
247
+
248
+ # 4. Back to "single" mode for inference (optionally using mean value from batch)
249
+ stm.single_memory(
250
+ init_type='standard', # we could change init type again
251
+ use_mean_from_batch=True # use mean values from batch as new memory
252
+ )
253
+ ```
254
+
255
+ > ##### Other utils
256
+ > `ShortTermMemory` could be also resized with `stm.resize(new_stm_size, init_type)` method, detached and cloned
257
+ > with `stm.clone_detach_reset()` (used in MRL), or could be made trainable (experimental option):
258
+ > - could be initialized as trainable - `stm = ShortTermMemory(num_layers, embed_dim, stm_size, is_trainable=True)`
259
+ > - could be switched to trainable - `stm.make_trainable()`
260
+ > - and switched back to buffer - `stm.freeze()`
261
+
262
+ #### Memory Attention Network
263
+ **Memory Attention Network** is responsible for memory layers update. It includes memory attention layers, with normalization
264
+ and residual connection (with optional gated residual). **Memory Attention Network** should have the same number of layers
265
+ as other components (encoder & decoder). It takes the results from each encoder layer (hidden states), and combine them
266
+ with actual memory state.
267
+
268
+ You can create your own **Memory Attention Network**, integrated with **HuggingFace Hub**, same way as reactive transformers:
269
+ ```python
270
+ from typing import TypedDict
271
+ import torch
272
+ import torch.nn as nn
273
+ from huggingface_hub import PyTorchModelHubMixin
274
+ from rxnn.transformers.attention import GroupedQueryAttention
275
+ from rxnn.transformers.positional import RotaryPositionalEmbedding
276
+ from rxnn.memory.stm import ShortTermMemory
277
+ from rxnn.memory.attention import StmMemoryAttention
278
+
279
+ class YourMemoryAttentionConfig(TypedDict):
280
+ num_layers: int
281
+ vocab_size: int
282
+ embed_dim: int
283
+ ff_dim: int
284
+ att_heads: int
285
+ seq_len: int
286
+ stm_size: int
287
+ att_groups: int
288
+
289
+ class YourMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
290
+ """RxT-Alpha (Reactive Transformer) memory attention model"""
291
+
292
+ def __init__(
293
+ self,
294
+ config: YourMemoryAttentionConfig,
295
+ **kwargs,
296
+ ):
297
+ super(YourMemoryAttention, self).__init__(**kwargs)
298
+
299
+ rope = RotaryPositionalEmbedding(config['embed_dim'] // config['att_heads'], config['seq_len'])
300
+ # This separately initialized STM will be replaced by shared instance with `load_shared_memory` call
301
+ stm = ShortTermMemory(config['num_layers'], config['embed_dim'], config['stm_size'])
302
+
303
+ self.model = StmMemoryAttention(
304
+ stm,
305
+ attention_layers=nn.ModuleList([
306
+ GroupedQueryAttention(
307
+ config['embed_dim'],
308
+ config['att_heads'],
309
+ config['att_groups'],
310
+ rope=rope,
311
+ dropout=0.1,
312
+ is_causal=False,
313
+ rope_only_for_keys=True
314
+ ) for _ in range(config['num_layers'])
315
+ ]),
316
+ memory_norm_layers=nn.ModuleList([
317
+ nn.RMSNorm(config['embed_dim']) for _ in range(config['num_layers'])
318
+ ]),
319
+ use_gated_residual=False, # memory attention residual gate
320
+ per_slot_gate=False, # gate per memory slot, otherwise it's per layer
321
+ init_gate=None, # initial value for gate weights
322
+ use_dynamic_gate=False, # dynamic gate calculated from weights and memory state, otherwise it's calculated only from weights
323
+ use_tanh_gate=False, # use tanh gate, otherwise it's sigmoid
324
+ )
325
+
326
+ def load_shared_memory(self, stm: ShortTermMemory):
327
+ self.model.stm = stm
328
+
329
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
330
+ return self.model(x, attention_mask=attention_mask)
331
+ ```
332
+
333
+ > #### Gated residual
334
+ > Optional gated residual could be used to improve Memory Attention expressiveness. It's using gate (sigmoid or tanh)
335
+ > with trainable weights, to decide how much information from old and new updated memory state should be stored. Depending
336
+ > on params weights are declared per layer or per memory slot (more expressive). It could work in two modes, that could
337
+ > be switched, because they are using the same weights shape:
338
+ > - static - gate values calculated only from weights (`gate = torch.sigmoid(weights)`) - enable explicit control with `init_gate` param
339
+ > - dynamic - gate values calculated from weights and updated memory state (`gate = torch.sigmoid(weights * (new_layer_stm + layer_stm).mean(dim=-1, keepdim=True))`)
340
+ >
341
+ > Depending on `use_tanh_gate` param, final gated residual connection is calculated with different formulas:
342
+ > - sigmoid gate - `stm[i] = layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm`
343
+ > - tanh gate - `stm[i] = (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm`
344
+ > - tanh gate preserve residual connection scale, while sigmoid gate result is equivalent to `(new_layer_stm + layer_stm) / 2`
345
+ >
346
+ > **Gated residual** is currently in tests - we are not sure if it will provide better results, so **it's not recommended**
347
+
348
+ ##### RxT-Alpha Memory Attention
349
+ `RxTAlphaMemoryAttention` is ready to use Memory Attention network for **Reactive Transformer** Proof-of-Concept, that
350
+ could be used instead of creating custom class. Example is in [Memory Reinforcement Learning docs](#memory-reinforcement-learning)
351
+
352
+ ### Training
353
+ Training module includes **Trainers** for different training stages of reactive models and shared training utils.
354
+
355
+ Submodules:
356
+ - `rxnn.training.tokenizer` - custom Trainer for **HuggingFace** `tokenizers` and utils to load tokenizer from Hub
357
+ - Tokenizer could be loaded from Hub with `load_tokenizer_from_hf_hub(repo_id)`
358
+ - `rxnn.training.dataset` - datasets for different training stages:
359
+ - `MaskedLMDataset` & `AutoregressiveLMDataset` are made for base models pre-training
360
+ - `EncoderSftDataset` & `DecoderSftDataset` are made for Interaction Supervised Fine-Tuning for reactive models
361
+ - `MrlCurriculumDataset` is the dataset for single MRL Curriculum step
362
+ - `MrlDatasets` is wrapping MRL datasets for all curriculum steps
363
+ - each dataset has `from_hf_hub` class method to load dataset from Hub
364
+ - they have also `concat_from_hf_hub` class method to load multiple Hub datasets into single training dataset
365
+ - if dataset has no validation/test split, each dataset has `get_subset(subset_size, from_start=False)` method - it
366
+ returns new subset and modifying existing one - i.e. `valid_dataset = train_dataset.get_subset(0.1)`
367
+ - for concatenated datasets, validation/test split could be created with `concat_from_hf_hub_with_subset` - it cuts the
368
+ same percentage of each loaded dataset
369
+ - each dataset has `pre_tokenize` method, to tokenize all items before the training (otherwise they are tokenized on
370
+ dynamically on item access). It's recommended for smaller datasets (fine-tuning, MRL, etc.) and not recommended for
371
+ very big datasets (pre-training), because it's using a lot of RAM (CPU)
372
+ - `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
373
+ - `rxnn.training.scheduler` includes learning rate scheduler for training
374
+ - `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
375
+ - `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL
376
+ - `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
377
+ - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x)
378
+
379
+ #### Base Model Learning
380
+ **Base Model Learning (BML)** module is made for both pre-training and fine-tuning base models, that will be used as components
381
+ in reactive models. Generally the only two differences between pre-training and supervised fine-tuning are different dataset
382
+ classes and trainer/callbacks hyperparams config.
383
+
384
+ Reactive models are based on transformer decoder and encoder, with shared embeddings and memory layers - it require special
385
+ handling in first training stages:
386
+ - layers connected with memory - **Memory Cross-Attention** are frozen during pre-training and fine-tuning, and they are
387
+ skipped by residual connections
388
+ - as encoder is able to learn little better embeddings, because of bidirectional modelling, it's pre-trained first, then
389
+ decoder is trained with frozen embeddings from encoder
390
+ - in **Reactive Transformer** fine-tuning, both encoder and decoder are fit to interaction format (single query and answer), the
391
+ training order is the same as for pre-training
392
+ - in **Preactor** architecture there are 2 encoders and single decoder. Encoders are fine-tuned from single pre-trained
393
+ encoder - first one is processing only queries and second one only the answers. More info soon
394
+ - in **Reactor** architecture there are 2 encoders and 2 decoders. Both encoders and decoders are fine-tuned from single
395
+ pre-trained encoder and decoder. Each component is fine-tuned to their specific task. More info soon
396
+
397
+ ##### Pre-training
398
+ We have to start with importing required modules/libraries, initializing the models and loading the tokenized - I will
399
+ use _RxT-Alpha-Micro-Plus_ models as example:
400
+ ```python
401
+ import torch
402
+ from rxnn.rxt.models import RxTAlphaDecoder, RxTAlphaEncoder
403
+ from rxnn.training.dataset import AutoregressiveLMDataset, MaskedLMDataset
404
+ from rxnn.training.bml import AutoregressiveTrainer, MLMTrainer
405
+ from rxnn.training.models import MLMHead, MLMTrainingModel
406
+ from rxnn.training.scheduler import get_transformer_lr_scheduler, calculate_steps
407
+ from rxnn.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback, JointModelSaveCallback
408
+ from rxnn.training.tokenizer import load_tokenizer_from_hf_hub
409
+ from rxnn.utils import set_random_seed, cache_clean
410
+
411
+ embed_dim = 128
412
+ vocab_size = 7_500
413
+ seq_len = 256
414
+
415
+ set_random_seed(42)
416
+
417
+ config = {
418
+ 'num_layers': 10,
419
+ 'vocab_size': vocab_size,
420
+ 'embed_dim': embed_dim,
421
+ 'att_heads': 16, # attention heads, in SQA it's used only for dimension split
422
+ 'att_groups': 8, # key/value groups for GQA/SQA
423
+ 'seq_len': seq_len,
424
+ 'stm_size': seq_len,
425
+ 'use_flash_attention': False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend) - not recommended
426
+ 'use_gated': True, # use Gated Linear Units in feed forward, True by default
427
+ 'ff_activation': 'silu', # feed forward activation, 'silu' is default for SwiGLU layers
428
+ 'ff_dropout': 0.1,
429
+ 'self_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
430
+ 'cross_att_type': 'sqa', # self attention could be 'sqa', 'gqa', 'mqa' or 'mha'
431
+ 'att_query_groups': 8, # query groups for SQA
432
+ }
433
+
434
+ encoder_config = {
435
+ 'ff_dim': 384,
436
+ **config
437
+ }
438
+
439
+ decoder_config = {
440
+ 'ff_dim': 256,
441
+ 'use_moe': True, # use Mixture-of-Experts feed forward
442
+ 'num_experts': 20, # number of experts
443
+ 'moe_top_k': 4, # number of activated experts (per token)
444
+ **config
445
+ }
446
+
447
+ encoder = RxTAlphaEncoder(**encoder_config)
448
+ decoder = RxTAlphaDecoder(**decoder_config)
449
+ head = MLMHead(embed_dim, vocab_size)
450
+
451
+ # Tokenizer is the same for encoder and decoder
452
+ tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder', token='HF_TOKEN')
453
+ ```
454
+ Then, we have to load MLM datasets, set callbacks and run encoder training:
455
+ ```python
456
+ # 1. Load datasets
457
+ load_kwargs = {
458
+ 'trust_remote_code': True
459
+ }
460
+
461
+ train_dataset = MaskedLMDataset.from_hf_hub('roneneldan/TinyStories', load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
462
+ valid_dataset = MaskedLMDataset.from_hf_hub('roneneldan/TinyStories', split="validation", load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
463
+
464
+ # 2. Select device
465
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
466
+
467
+ # 3. Clean GPU cache (optional)
468
+ cache_clean()
469
+
470
+ # 4. Set training config variables
471
+ batch_size = 256
472
+ epochs = 8
473
+ gradient_acc_steps = 1
474
+ peak_lr = 1e-3 * gradient_acc_steps
475
+
476
+ # 5. Get number of steps for scheduler
477
+ steps_config = calculate_steps(len(train_dataset), epochs, batch_size, warmup_ratio=0.05, verbose=True)
478
+ steps_per_epoch, total_steps, warmup_steps = steps_config['epoch'], steps_config['total'], steps_config['warmup']
479
+
480
+ # 6. Freeze memory cross-attention layers
481
+ encoder.freeze_memory()
482
+
483
+ # 7. Select directory for TensorBoard logs
484
+ logs_dir = './micro/tensorboard_logs/encoder-plus-sft'
485
+
486
+ # 8. Basic callbacks - print loss, accuracy and number of processed tokens
487
+ print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
488
+ count_cb = TokenCounterCallback(3_000_000_000)
489
+ acc_cb = PrintAccuracyCallback()
490
+
491
+ # 9. Joint model save callback - used to save encoder and MLM head, and push them to HuggingFace Hub
492
+ save_cb = JointModelSaveCallback(
493
+ './micro/encoder-plus-sft',
494
+ push_to_hub=True,
495
+ hub_model_decoder=None,
496
+ hub_model_encoder='Your encoder model id',
497
+ hub_model_head='Your mlm model id',
498
+ push_checkpoint_weights=True, # push epoch checkpoints to hub
499
+ final_commit_message='Final commit message',
500
+ private_repo=False, # use HF private repository
501
+ save_checkpoint_after_n_batches=1000, # save model after N batches in epoch (batch checkpoint)
502
+ push_batch_checkpoint=True, # push batch checkpoints to HF Hub
503
+ mlm_mode=True, # use MLM mode
504
+ hf_token='HF_TOKEN',
505
+ use_ddp=False, # use distributed training mode
506
+ )
507
+
508
+ # 10. Init training model - encoder + head
509
+ model = MLMTrainingModel(encoder, head)
510
+
511
+ # 11. Init MLM Trainer
512
+ trainer = MLMTrainer(
513
+ model,
514
+ device,
515
+ dataset=train_dataset,
516
+ validation_dataset=valid_dataset,
517
+ vocab_size=vocab_size,
518
+ callbacks=[print_cb, acc_cb, count_cb, save_cb],
519
+ use_amp=True, # use autocast
520
+ dtype=torch.bfloat16, # data type for training
521
+ log_dir=logs_dir,
522
+ use_ddp=False, # use distributed training mode
523
+ )
524
+
525
+ # 12. Init optimizer and cosine annealing scheduler
526
+ optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.02)
527
+ scheduler = get_transformer_lr_scheduler(
528
+ optimizer,
529
+ warmup_steps=warmup_steps,
530
+ num_training_steps=total_steps
531
+ )
532
+
533
+ # 13. Run the training for the selected number of epochs
534
+ trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
535
+ ```
536
+ After the encoder's training, we have to train decoder:
537
+ ```python
538
+ # 1. Load datasets
539
+ load_kwargs = {
540
+ 'trust_remote_code': True
541
+ }
542
+
543
+ train_dataset = AutoregressiveLMDataset.from_hf_hub('roneneldan/TinyStories', load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
544
+ valid_dataset = AutoregressiveLMDataset.from_hf_hub('roneneldan/TinyStories', split="validation", load_kwargs=load_kwargs, tokenizer=tokenizer, max_seq_len=seq_len)
545
+
546
+ # 2. Load shared embedding and memory, then freeze embedding and memory cross-attention
547
+ decoder.load_shared_embedding(encoder.model.embedding)
548
+ decoder.load_shared_memory(encoder.model.stm)
549
+
550
+ decoder.model.embedding.requires_grad_(False)
551
+ decoder.freeze_memory()
552
+
553
+ # 3. Clean GPU cache (optional)
554
+ cache_clean()
555
+
556
+ # 4. Set training config variables
557
+ batch_size = 256
558
+ epochs = 8
559
+ gradient_acc_steps = 1
560
+ peak_lr = 1e-3 * gradient_acc_steps
561
+
562
+ # 5. Get number of steps for scheduler
563
+ steps_config = calculate_steps(len(train_dataset), epochs, batch_size, warmup_ratio=0.05, verbose=True)
564
+ steps_per_epoch, total_steps, warmup_steps = steps_config['epoch'], steps_config['total'], steps_config['warmup']
565
+
566
+ # 6. Select directory for TensorBoard logs
567
+ logs_dir = './micro/tensorboard_logs/decoder-plus-sft'
568
+
569
+ # 7. Basic callbacks - print loss, accuracy and number of processed tokens
570
+ print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch)
571
+ count_cb = TokenCounterCallback(5_000_000_000)
572
+ acc_cb = PrintAccuracyCallback()
573
+
574
+ # 8. Model save callback - used to save decoder and push it to HuggingFace Hub
575
+ save_cb = ModelSaveCallback(
576
+ './micro/decoder-plus-sft',
577
+ push_to_hub=True,
578
+ hub_model_id='Your decoder model id',
579
+ push_checkpoint_weights=True, # push epoch checkpoints to hub
580
+ final_commit_message='Final commit message',
581
+ private_repo=False, # use HF private repository
582
+ save_checkpoint_after_n_batches=1000, # save model after N batches in epoch (batch checkpoint)
583
+ push_batch_checkpoint=True, # push batch checkpoints to HF Hub
584
+ hf_token='HF_TOKEN',
585
+ use_ddp=False, # use distributed training mode
586
+ )
587
+
588
+ # 9. Init Autoregressive Trainer
589
+ trainer = AutoregressiveTrainer(
590
+ decoder,
591
+ device,
592
+ dataset=train_dataset,
593
+ validation_dataset=valid_dataset,
594
+ vocab_size=vocab_size,
595
+ callbacks=[print_cb, acc_cb, count_cb, save_cb],
596
+ use_amp=True,
597
+ dtype=torch.bfloat16,
598
+ log_dir=logs_dir,
599
+ use_moe_aux_loss=True, # Add MoE Router auxiliary loss to main loss
600
+ moe_aux_loss_scale=0.02, # MoE Router aux loss scale
601
+ use_ddp=False, # use distributed training mode
602
+ )
603
+
604
+ # 10. Init optimizer and cosine annealing scheduler
605
+ optimizer = torch.optim.AdamW(decoder.parameters(), lr=peak_lr, weight_decay=0.02)
606
+ scheduler = get_transformer_lr_scheduler(
607
+ optimizer,
608
+ warmup_steps=warmup_steps,
609
+ num_training_steps=total_steps
610
+ )
611
+
612
+ # 11. Run the training for the selected number of epochs
613
+ trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler)
614
+ ```
615
+
616
+ ##### Fine-tuning
617
+ For _**Interaction Supervised Fine-Tuning**_, the code is almost the same as for pre-training, with some small changes.
618
+
619
+ First, we have to load pre-trained models, instead of initializing them with configs:
620
+ ```python
621
+ encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder', token='HF_TOKEN')
622
+ decoder = RxTAlphaDecoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
623
+ head = MLMHead.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-MLM', token='HF_TOKEN')
624
+ ```
625
+
626
+ Then, we have to change the datasets loading part. For encoder:
627
+ ```python
628
+ # 1. Load datasets
629
+ train_dataset = EncoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', tokenizer=tokenizer, max_seq_len=seq_len)
630
+ valid_dataset = EncoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', split="validation", tokenizer=tokenizer, max_seq_len=seq_len)
631
+
632
+ # 2. Pre-tokenize dataset with verbose logging (optional)
633
+ train_dataset.pre_tokenize(verbose=True, log_interval=5000)
634
+ valid_dataset.pre_tokenize(verbose=True, log_interval=1000)
635
+ ```
636
+ And the same for decoder:
637
+ ```python
638
+ # 1. Load datasets
639
+ train_dataset = DecoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', tokenizer=tokenizer, max_seq_len=seq_len)
640
+ valid_dataset = DecoderSftDataset.from_hf_hub('ReactiveAI/TinyStories-Plus-Interaction-SFT', split="validation", tokenizer=tokenizer, max_seq_len=seq_len)
641
+
642
+ # 2. Pre-tokenize dataset with verbose logging (optional)
643
+ train_dataset.pre_tokenize(verbose=True, log_interval=5000)
644
+ valid_dataset.pre_tokenize(verbose=True, log_interval=1000)
645
+ ```
646
+
647
+ We could also add early stoppage callback:
648
+ ```python
649
+ from rxnn.training.callbacks import EarlyStoppageCallback
650
+
651
+ stop_cb = EarlyStoppageCallback(num_plateau_epochs=5)
652
+ ```
653
+
654
+ Additionally, in fine-tuning we will rather use different config for number of epochs, steps, learning rate, etc.
655
+
656
+ > #### Classic Transformer Training
657
+ > The same code could be used also to train classic decoder-only or encoder-only transformers, the only difference is
658
+ > that they don't require memory cross-attention freezing.
659
+
660
+ ##### Joint Training
661
+ There are also `JointLMDataset` and `JointLMTrainer` classes to train encoder and decoder at once. In that case, embeddings
662
+ are updated from both encoder and decoder optimization. However, I noticed some issues with balancing training in that mode,
663
+ so it's **not recommended** now, until it will be tested and fixed
664
+
665
+ #### Memory Reinforcement Learning
666
+ **Memory Reinforcement Learning (MRL)** is the most important training stage for reactive model's **Attention-Based Memory System**.
667
+ In this stage we are training model to remember information between multiple interactions, with different curriculum stage
668
+ configs. Theoretical foundations are described in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/ReactiveTransformer/mrl.md).
669
+
670
+ > **MRL** algorithm is currently in tests and still a lot of things could be changed!
671
+
672
+ In practice, algorithm has over 50 hyperparams, so it require careful handling. We start from importing modules, loading
673
+ pre-trained models from SFT stage, initializing new Memory Attention, and actor and critic models:
674
+ ```python
675
+ import torch
676
+ from rxnn.rxt.models import RxTAlphaDecoder, RxTAlphaEncoder, RxTAlphaMemoryAttention
677
+ from rxnn.training.tokenizer import load_tokenizer_from_hf_hub
678
+ from rxnn.training.dataset import MrlDatasets
679
+ from rxnn.training.models import MrlActorModel, MrlCriticModel
680
+ from rxnn.training.reward import MrlRewardModel
681
+ from rxnn.training.mrl import MRLTrainer, CurriculumConfig, MrlStrategy, MrlConfig
682
+ from rxnn.training.rl import PPOAlgorithm, PPOConfig
683
+ from rxnn.training.callbacks import MrlPrintCallback, MrlEarlyStoppageCallback, MrlModelSaveCallback, MrlGeneratedTokensCallback
684
+ from rxnn.utils import set_random_seed
685
+
686
+ # 1. Set random seed, batch size and embed dim
687
+ set_random_seed(42)
688
+ batch_size = 64
689
+ embed_dim = 128
690
+
691
+ # 2. Get pre-trained microscale PoC models
692
+ decoder = RxTAlphaDecoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder-SFT', token='HF_TOKEN')
693
+ encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder-SFT', token='HF_TOKEN')
694
+ # 3. Init Memory Attention Network
695
+ mem_attn = RxTAlphaMemoryAttention(
696
+ num_layers=10,
697
+ embed_dim=embed_dim,
698
+ att_heads=8,
699
+ seq_len=256,
700
+ stm_size=256,
701
+ use_flash_attention=False, # explicitly use flash-attn function (otherwise it's used through PyTorch backend)
702
+ norm_type='classic-rms', # memory norm type
703
+ att_groups=4, # key/value groups for SQA/GQA
704
+ att_type='sqa', # attention type, could be 'sqa', 'gqa', 'mqa' or 'mha'
705
+ att_query_groups=4, # query groups for SQA
706
+ )
707
+
708
+ # 4. Load shared embedding and memory from encoder to other models
709
+ decoder.load_shared_embedding(encoder.model.embedding)
710
+ encoder.model.stm.batched_memory(batch_size=batch_size, init_type='standard')
711
+ decoder.load_shared_memory(encoder.model.stm)
712
+ mem_attn.load_shared_memory(encoder.model.stm)
713
+
714
+ # 5. Init Actor model
715
+ actor = MrlActorModel(encoder, decoder, mem_attn)
716
+
717
+ # 6. Get pre-trained encoder, extend its context size, freeze memory and use as a body for Critic model
718
+ critic_encoder = RxTAlphaEncoder.from_pretrained('ReactiveAI/RxT-Alpha-Micro-Plus-Encoder-SFT', token='HF_TOKEN')
719
+
720
+ critic_encoder.update_max_len(512)
721
+ critic_encoder.freeze_memory()
722
+ # 7. Init Critic model
723
+ critic = MrlCriticModel(critic_encoder, embed_dim)
724
+ ```
725
+
726
+ Then, we have to load tokenizer and MRL Datasets, and create _curriculum config_:
727
+ ```python
728
+ # 1. Load tokenizer
729
+ tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/RxT-Alpha-Micro-Plus-Decoder', token='HF_TOKEN')
730
+
731
+ # 2. Load PoC TinyStories based MRL Dataset, starting from 4 steps to 16 in long range, and pre-tokenize it
732
+ mrl_datasets = MrlDatasets.from_hf_hub(
733
+ 'ReactiveAI/TinyStories-MRL',
734
+ tokenizer,
735
+ mrl_curriculum_steps=[
736
+ { 'subset_name': 'steps-4', 'steps': 4, 'is_long_range': False },
737
+ { 'subset_name': 'steps-6', 'steps': 6, 'is_long_range': False },
738
+ { 'subset_name': 'steps-8', 'steps': 8, 'is_long_range': False },
739
+ { 'subset_name': 'steps-8-lr', 'steps': 8, 'is_long_range': True },
740
+ { 'subset_name': 'steps-12', 'steps': 12, 'is_long_range': True },
741
+ { 'subset_name': 'steps-16', 'steps': 16, 'is_long_range': True },
742
+ ],
743
+ eval_split='validation',
744
+ max_seq_len=256,
745
+ )
746
+
747
+ mrl_datasets.pre_tokenize(verbose=True, log_interval=100)
748
+
749
+ # 3. Create curriculum stages config
750
+ curriculum_stages = [CurriculumConfig(
751
+ steps=item['steps'], # number of steps in curriculum stage
752
+ epochs=10 if item['steps'] == 4 else 5, # number of epochs in curriculum stage
753
+ dataset=item['dataset'],
754
+ eval_dataset=item['eval_dataset'],
755
+ callbacks=[
756
+ MrlPrintCallback(), # Print loss/reward callback
757
+ MrlModelSaveCallback(
758
+ './models',
759
+ push_to_hub=True,
760
+ hub_model_critic='Your critic model hub id',
761
+ hub_model_decoder='Your MRL decoder model hub id',
762
+ hub_model_encoder='Your MRL encoder model hub id',
763
+ hub_model_memory_attention='Your memory-attention model hub id',
764
+ private_repo=True,
765
+ hf_token='HF_TOKEN',
766
+ final_commit_message=f"MRL steps: {item['steps']} {'lr' if item['is_long_range'] else ''}",
767
+ push_checkpoint_weights=True,
768
+ ) # MRL Model save callback - save and push to hub critic model and actor components
769
+ ],
770
+ strategy=MrlStrategy.LONG_RANGE_STRATEGY if item['is_long_range'] else MrlStrategy.MULTI_STEP_STRATEGY, # strategy for curriculum stage
771
+ unfreeze_epoch=((2, 2e-5), (4, 8e-5), (6, 1e-5), 8) if item['steps'] == 4 else (0, 1, (2, 1e-6), 4), # unfreeze strategy config
772
+ random_resets=item['steps'] > 4, # enable random memory resets
773
+ random_resets_from=2, # epoch when random resets starts
774
+ random_resets_ratio=0.4 if item['steps'] != 4 else None, # probability of STM reset before episode
775
+ separate_memory_lr=True, # use separate memory LR in current curriculum stage
776
+ memory_lr=6e-4 if item['steps'] == 4 else None, # memory LR for curriculum stage, if None, use global config
777
+ lr=3e-4 if item['steps'] == 4 else None, # model LR for curriculum stage, if None, use global config
778
+ critic_lr=4e-4 if item['steps'] == 4 else None, # critic (head) LR for curriculum stage, if None, use global config
779
+ critic_encoder_lr=2e-4 if item['steps'] == 4 else None, # critic (encoder) LR for curriculum stage, if None, use global config
780
+ teacher_forcing=item['steps'] <= 8, # use teacher forcing - save reference answers from dataset in memory instead of generated ones
781
+ ) for item in mrl_datasets]
782
+ ```
783
+
784
+ After that, we have to configure reward model. It's based on BLEU scores and cosine similarity between generated answers
785
+ and saved data from previous steps and reference answers from dataset. Cosine similarity is also calculated from running
786
+ mean embedding of previous steps. Reward model also includes optional length reward. It's config includes a lot of option
787
+ to set different factors for different reward parts.
788
+ ```python
789
+ # 1. Init GPU device
790
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
791
+
792
+ # 2. Create reward model
793
+ reward_model = MrlRewardModel(
794
+ encoder.model.embedding,
795
+ device,
796
+ bleu_with_saved_data=True, # use saved data (previous or first interaction) in BLEU calculation
797
+ reward_len=True, # use length reward in calculation (answer_len / target_len)
798
+ max_rewarded_len=None, # target length awarded as 1.0
799
+ neg_reward_len=True, # negative length reward - lower reward when answer is too long (target_len / answer_len)
800
+ target_len_as_ref=True, # use reference answer len as target
801
+ use_running_mean=True, # use running mean embedding of all previous answers in cosine similarity calculation
802
+ allow_not_summing_factors=False, # if True sum of reward factors could be different from 1.0, it's False by default
803
+ bleu_factor=0.4, # factor for BLEU score in standard reward
804
+ cos_factor=0.5, # factor for cosine similarity score in standard reward
805
+ len_factor=0.1, # factor for length reward score in standard reward
806
+ bleu_ref_factor=0.4, # factor for reference answer score in BLEU calculation (standard mode)
807
+ bleu_saved_factor=0.6, # factor for saved data score in BLEU calculation (standard mode)
808
+ cos_ref_factor=0.35, # factor for reference answer score in cosine sim calculation (standard mode)
809
+ cos_saved_factor=0.65, # factor for saved data score in cosine sim calculation (standard mode)
810
+ multi_cos_ref_factor=0.3, # factor for reference answer in multi-step cosine sim calculation
811
+ multi_cos_saved_factor= 0.5, # factor for saved data in multi-step cosine sim calculation
812
+ multi_cos_running_mean_factor = 0.2, # factor for previous answers running mean in multi-step cosine sim calculation
813
+ neg_bleu_factor=0.45, # factor for BLEU score in negative reward
814
+ neg_cos_factor=0.45, # factor for cosine similarity score in negative reward
815
+ neg_bleu_ref_factor=0.3, # factor for reference answer score in BLEU calculation (negative mode)
816
+ neg_bleu_saved_factor=0.7, # factor for saved data score in BLEU calculation (negative mode)
817
+ neg_cos_ref_factor=0.3, # factor for reference answer score in cosine sim calculation (negative mode)
818
+ neg_cos_saved_factor=0.7, # factor for saved data score in cosine sim calculation (negative mode)
819
+ bleu_ref_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for reference answers
820
+ bleu_saved_weights=(0.2, 0.2, 0.3, 0.3), # weights for n-grams in NLTK BLEU calculation for saved data
821
+ tanh_reward_scale=False, # scale rewards to -1.0 to 1.0 range, instead of standard 0.0-1.0
822
+ rewards_scale=1.0, # rewards scaling factor (reward * rewards_scale)
823
+ )
824
+ ```
825
+
826
+ And finally, we could create the MRL Trainer with RL algorithm (currently only PPO available) and start the training:
827
+ ```python
828
+ # 1. Init PPO Algorithm
829
+ algorithm = PPOAlgorithm(
830
+ PPOConfig(clip_eps=0.2, gae_lambda=0.95, gae_gamma=0.99, entropy_coef=0.01, critic_value_clip=50.0)
831
+ )
832
+
833
+ # 2. Create config for MRLTrainer (most of MrlConfig fields could be overwritten in each curriculum stage)
834
+ mrl_config = MrlConfig(
835
+ lr=1e-4, # main LR, used for decoder layers
836
+ encoder_lr=2e-4, # encoder LR, used for encoder layers (if None, lr is used)
837
+ critic_lr=2e-4, # critic LR, used for critic value head
838
+ critic_encoder_lr=1e-4, # critic encoder LR (if not set, critic_lr is used)
839
+ separate_memory_lr=True, # use separate LR for memory attention and memory cross-attention
840
+ encoder_memory_lr=5e-4, # LR for encoder memory cross-attention (if None, memory_lr is used)
841
+ memory_lr=3e-4, # memory LR, used for decoder memory cross-attention
842
+ memory_attn_lr=5e-4, # memory attention LR (if None, memory_lr is used)
843
+ max_seq_len=256, # maximum length of single interaction
844
+ critic_max_len=512, # maximum length of critic sequence (have to be longer than actor's context)
845
+ weight_decay=0.01, # weight decay for actor AdamW optimizer
846
+ critic_weight_decay=0.01, # weight decay for critic AdamW optimizer
847
+ update_epochs=10, # inner PPO update epochs
848
+ pad_token_id=0, # tokenizer padding token id
849
+ end_token_id=3, # tokenizer EOS token id
850
+ use_moe_aux_loss=True, # add Mixture-of-Experts Router auxiliary loss to policy loss
851
+ freeze_embeddings=False, # freeze pre-trained embeddings for MRL training
852
+ embedding_lr=5e-6, # LR for embeddings, if not frozen (if None, lr is used)
853
+ use_memory_warmup=False, # memory warmup - update memory with first interaction in no grad mode, before episode, for better initialization
854
+ )
855
+
856
+ # 3. Initialize MRL Trainer
857
+ trainer = MRLTrainer(
858
+ actor, critic, reward_model, device, mrl_config, algorithm,
859
+ use_amp=True, # use autocast in MRL Training
860
+ dtype=torch.bfloat16, # data type for MRL
861
+ use_ddp=False, # use distributed training with DDP
862
+ )
863
+
864
+ # 4. Train with curriculum stages config
865
+ trainer(curriculum_stages, batch_size=batch_size)
866
+ ```
867
+
868
+ ## Experimental attention layers
869
+ While working on reactive architectures, we also developed several new types of attention layers, some of which achieve
870
+ very promising results. Even considering that reactive models, processing single interactions, have much lower computational
871
+ requirements, we need the most efficient attention mechanisms, consistent with memory requirements. Since memory is not a
872
+ sequence but a set, spatial sparsity is probably not a good solution here, so we were looking for an efficient alternative
873
+ to Flex Attention with full access to all memory positions. New attention layers are implemented in `rxnn.experimental.attention`
874
+ module:
875
+ - **Grouped Mixture-of-Experts Attention (GMA)** - use MoE routing to dynamically select K active key/value heads for each token, instead
876
+ of using static selection in **GQA**. While it's theoretically interesting, in practice, it achieved worse results than **GQA**,
877
+ and even **MQA**, in all test, and is a lot slower because of routing overhead, so we abandoned further research. More details
878
+ in [research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
879
+ - **Deep Mixture-of-Experts Attention (DMA)** - extends **GMA** with the same MoE routing for query heads. Like **GMA**,
880
+ it gives even worse results, and all the computational performance benefits from the sparse query heads (like in
881
+ **SQA**) are lost by routing overhead (lack of specialized kernels for heads selection), so the further research is also
882
+ abandoned. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/moe_attention.md)
883
+ - **Hierarchical MoE Attention (HMA)** - extends **DMA/GMA**, using different number of query/key/value heads for tokens with
884
+ different priority. It's only the idea and is not implemented, because of poor results of GMA/DMA. [More info](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/hierarchical_moe_attention.md)
885
+ - **Sparse Query Attention (SQA)** - the most trivial extension to GQA, reducing not only the number of key/value heads, but
886
+ also the number of query heads. It results in even 2-3x faster model (for 32k/131k tokens). **SQA** is the fastest attention
887
+ mechanism for 0-131k sequence length, for longer sequences **Flex Attention** becomes faster. That's ideal for reactive models,
888
+ that doesn't need a million token context for single interaction processing. In tested cases **SQA** models results (loss/accuracy)
889
+ were close to GQA, differences were almost unnoticeable, but it still requires more tests. [Research docs](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md)
890
+ - **Flex Sparse Query Attention (Flex-SQA)** - **Flex Attention** combined with **SQA** - enable handling 4-8x longer sliding
891
+ windows, in shorter time, than base **Flex**, so it should result in better results. **Flex-SQA** should be the fastest
892
+ attention mechanism for sequences longer than 131k tokens and is made for classic transformers, or potentially self-attention
893
+ in bigger reactive models. Currently, it's viable only with symmetric variants of **SQA** (same number of used query
894
+ and key/value heads), because kernels aren't compatible with GQA in sliding windows and not symmetric variants is 2x slower,
895
+ than it should be. Docs and tests in progress
896
+
897
+ ### Test usage
898
+ Experimental attention layers could be tested with `ExperimentalAttentionTransformer` model from `rxnn.experimental.models`,
899
+ Usage example could be found in our notebooks repository - [RxNN Notebooks](https://github.com/RxAI-dev/rxnn-notebooks)
900
+
901
+ Apache License
902
+ Version 2.0, January 2004
903
+ http://www.apache.org/licenses/
904
+
905
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
906
+
907
+ 1. Definitions.
908
+
909
+ "License" shall mean the terms and conditions for use, reproduction,
910
+ and distribution as defined by Sections 1 through 9 of this document.
911
+
912
+ "Licensor" shall mean the copyright owner or entity authorized by
913
+ the copyright owner that is granting the License.
914
+
915
+ "Legal Entity" shall mean the union of the acting entity and all
916
+ other entities that control, are controlled by, or are under common
917
+ control with that entity. For the purposes of this definition,
918
+ "control" means (i) the power, direct or indirect, to cause the
919
+ direction or management of such entity, whether by contract or
920
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
921
+ outstanding shares, or (iii) beneficial ownership of such entity.
922
+
923
+ "You" (or "Your") shall mean an individual or Legal Entity
924
+ exercising permissions granted by this License.
925
+
926
+ "Source" form shall mean the preferred form for making modifications,
927
+ including but not limited to software source code, documentation
928
+ source, and configuration files.
929
+
930
+ "Object" form shall mean any form resulting from mechanical
931
+ transformation or translation of a Source form, including but
932
+ not limited to compiled object code, generated documentation,
933
+ and conversions to other media types.
934
+
935
+ "Work" shall mean the work of authorship, whether in Source or
936
+ Object form, made available under the License, as indicated by a
937
+ notice that is included in or attached to the work
938
+ (an example is provided in the Appendix below).
939
+
940
+ "Derivative Works" shall mean any work, whether in Source or Object
941
+ form, that is based on (or derived from) the Work and for which the
942
+ editorial revisions, annotations, elaborations, or other modifications
943
+ represent, as a whole, an original work of authorship. For the purposes
944
+ of this License, Derivative Works shall not include works that remain
945
+ separable from, or merely link (or bind by name) to the interfaces of,
946
+ the Work and Derivative Works thereof.
947
+
948
+ "Contribution" shall mean any work of authorship, including
949
+ the original version of the Work and any modifications or additions
950
+ to that Work or Derivative Works thereof, that is intentionally
951
+ submitted to Licensor for inclusion in the Work by the owner
952
+ or by an individual or Legal Entity authorized to submit on behalf of
953
+ the owner. For the purposes of this definition, "submitted"
954
+ means any form of electronic, verbal, or written communication sent
955
+ to the Licensor or its representatives, including but not limited to
956
+ communication on electronic mailing lists, source code control systems,
957
+ and issue tracking systems that are managed by, or on behalf of, the
958
+ Licensor for the purpose of discussing and improving the Work, but
959
+ excluding communication that is conspicuously marked or otherwise
960
+ designated in writing by the owner as "Not a Contribution."
961
+
962
+ "Contributor" shall mean Licensor and any individual or Legal Entity
963
+ on behalf of whom a Contribution has been received by Licensor and
964
+ subsequently incorporated within the Work.
965
+
966
+ 2. Grant of License. Subject to the terms and conditions of
967
+ this License, each Contributor hereby grants to You a perpetual,
968
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
969
+ license to reproduce, prepare Derivative Works of,
970
+ publicly display, publicly perform, sublicense, and distribute the
971
+ Work and such Derivative Works in Source or Object form.
972
+
973
+ 3. Grant of Patent License. Subject to the terms and conditions of
974
+ this License, each Contributor hereby grants to You a perpetual,
975
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
976
+ (except as stated in this section) patent license to make, have made,
977
+ use, offer to sell, sell, import, and otherwise transfer the Work,
978
+ where such license applies only to those patent claims licensable
979
+ by such Contributor that are necessarily infringed by their
980
+ Contribution(s) alone or by combination of their Contribution(s)
981
+ with the Work to which such Contribution(s) was submitted. If You
982
+ institute patent litigation against any entity (including a
983
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
984
+ or a Contribution incorporated within the Work constitutes direct
985
+ or contributory patent infringement, then any patent licenses
986
+ granted to You under this License for that Work shall terminate
987
+ as of the date such litigation is filed.
988
+
989
+ 4. Redistribution. You may reproduce and distribute copies of the
990
+ Work or Derivative Works thereof in any medium, with or without
991
+ modifications, and in Source or Object form, provided that You
992
+ meet the following conditions:
993
+
994
+ (a) You must give any other recipients of the Work or
995
+ Derivative Works a copy of this License; and
996
+
997
+ (b) You must cause any modified files to carry prominent notices
998
+ stating that You changed the files; and
999
+
1000
+ (c) You must retain, in the Source form of any Derivative Works
1001
+ that You distribute, all , patent, trademark, and
1002
+ attribution notices from the Source form of the Work,
1003
+ excluding those notices that do not pertain to any part of
1004
+ the Derivative Works; and
1005
+
1006
+ (d) If the Work includes a "NOTICE" text file as part of its
1007
+ distribution, then any Derivative Works that You distribute must
1008
+ include a readable copy of the attribution notices contained
1009
+ within such NOTICE file, excluding those notices that do not
1010
+ pertain to any part of the Derivative Works, in at least one
1011
+ of the following places: within a NOTICE text file distributed
1012
+ as part of the Derivative Works; within the Source form or
1013
+ documentation, if provided along with the Derivative Works; or,
1014
+ within a display generated by the Derivative Works, if and
1015
+ wherever such third-party notices normally appear. The contents
1016
+ of the NOTICE file are for informational purposes only and
1017
+ do not modify the License. You may add Your own attribution
1018
+ notices within Derivative Works that You distribute, alongside
1019
+ or as an addendum to the NOTICE text from the Work, provided
1020
+ that such additional attribution notices cannot be construed
1021
+ as modifying the License.
1022
+
1023
+ You may add Your own statement to Your modifications and
1024
+ may provide additional or different license terms and conditions
1025
+ for use, reproduction, or distribution of Your modifications, or
1026
+ for any such Derivative Works as a whole, provided Your use,
1027
+ reproduction, and distribution of the Work otherwise complies with
1028
+ the conditions stated in this License.
1029
+
1030
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
1031
+ any Contribution intentionally submitted for inclusion in the Work
1032
+ by You to the Licensor shall be under the terms and conditions of
1033
+ this License, without any additional terms or conditions.
1034
+ Notwithstanding the above, nothing herein shall supersede or modify
1035
+ the terms of any separate license agreement you may have executed
1036
+ with Licensor regarding such Contributions.
1037
+
1038
+ 6. Trademarks. This License does not grant permission to use the trade
1039
+ names, trademarks, service marks, or product names of the Licensor,
1040
+ except as required for reasonable and customary use in describing the
1041
+ origin of the Work and reproducing the content of the NOTICE file.
1042
+
1043
+ 7. Disclaimer of Warranty. Unless required by applicable law or
1044
+ agreed to in writing, Licensor provides the Work (and each
1045
+ Contributor provides its Contributions) on an "AS IS" BASIS,
1046
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
1047
+ implied, including, without limitation, any warranties or conditions
1048
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
1049
+ PARTICULAR PURPOSE. You are solely responsible for determining the
1050
+ appropriateness of using or redistributing the Work and assume any
1051
+ risks associated with Your exercise of permissions under this License.
1052
+
1053
+ 8. Limitation of Liability. In no event and under no legal theory,
1054
+ whether in tort (including negligence), contract, or otherwise,
1055
+ unless required by applicable law (such as deliberate and grossly
1056
+ negligent acts) or agreed to in writing, shall any Contributor be
1057
+ liable to You for damages, including any direct, indirect, special,
1058
+ incidental, or consequential damages of any character arising as a
1059
+ result of this License or out of the use or inability to use the
1060
+ Work (including but not limited to damages for loss of goodwill,
1061
+ work stoppage, computer failure or malfunction, or any and all
1062
+ other commercial damages or losses), even if such Contributor
1063
+ has been advised of the possibility of such damages.
1064
+
1065
+ 9. Accepting Warranty or Additional Liability. While redistributing
1066
+ the Work or Derivative Works thereof, You may choose to offer,
1067
+ and charge a fee for, acceptance of support, warranty, indemnity,
1068
+ or other liability obligations and/or rights consistent with this
1069
+ License. However, in accepting such obligations, You may act only
1070
+ on Your own behalf and on Your sole responsibility, not on behalf
1071
+ of any other Contributor, and only if You agree to indemnify,
1072
+ defend, and hold each Contributor harmless for any liability
1073
+ incurred by, or claims asserted against, such Contributor by reason
1074
+ of your accepting any such warranty or additional liability.
1075
+
1076
+ END OF TERMS AND CONDITIONS
1077
+
1078
+ APPENDIX: How to apply the Apache License to your work.
1079
+
1080
+ To apply the Apache License to your work, attach the following
1081
+ boilerplate notice, with the fields enclosed by brackets "[]"
1082
+ replaced with your own identifying information. (Don't include
1083
+ the brackets!) The text should be enclosed in the appropriate
1084
+ comment syntax for the file format. We also recommend that a
1085
+ file or class name and description of purpose be included on the
1086
+ same "printed page" as the copyright notice for easier
1087
+ identification within third-party archives.
1088
+
1089
+ Copyright 2024-2025 Adam Filipek
1090
+
1091
+ Licensed under the Apache License, Version 2.0 (the "License");
1092
+ you may not use this file except in compliance with the License.
1093
+ You may obtain a copy of the License at
1094
+
1095
+ http://www.apache.org/licenses/LICENSE-2.0
1096
+
1097
+ Unless required by applicable law or agreed to in writing, software
1098
+ distributed under the License is distributed on an "AS IS" BASIS,
1099
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1100
+ See the License for the specific language governing permissions and
1101
+ limitations under the License.
1102
+