rxnn 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py CHANGED
@@ -39,6 +39,8 @@ class BaseTrainer(ABC):
39
39
  self.optimizer = optimizer
40
40
  self.dataset = dataset
41
41
  self.callbacks = callbacks or []
42
+ if log_dir and not os.path.exists(log_dir):
43
+ os.makedirs(log_dir)
42
44
  self.writer = SummaryWriter(log_dir) if log_dir else None
43
45
  self.use_ddp = use_ddp
44
46
  self.use_amp = use_amp
@@ -499,3 +499,24 @@ class JointModelSaveCallback(TrainerCallback):
499
499
  if not self.mlm_mode:
500
500
  self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
501
501
  self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
502
+
503
+ class EarlyStoppageCallback(TrainerCallback):
504
+ def __init__(self, num_plateau_epochs: int = 3) -> None:
505
+ super().__init__()
506
+ self.num_plateau_epochs = num_plateau_epochs
507
+ self.best_loss = 9999.0
508
+ self.best_loss_epoch = 0
509
+
510
+ def on_validation_end(
511
+ self,
512
+ model: torch.nn.Module,
513
+ epoch: int,
514
+ val_loss: float,
515
+ val_metrics: dict
516
+ ):
517
+ if val_loss < self.best_loss:
518
+ self.best_loss = val_loss
519
+ self.best_loss_epoch = epoch
520
+ elif epoch - self.best_loss_epoch > self.num_plateau_epochs:
521
+ return True
522
+ return None
rxnn/training/dataset.py CHANGED
@@ -314,6 +314,8 @@ class JointLMDataset(BaseDataset):
314
314
  def __getitem__(self, idx: int) -> dict[str, dict[str, torch.Tensor]]:
315
315
  inputs = self.get_tokenized_text(idx)
316
316
  encoder_input_ids = inputs['input_ids'][0]
317
+ if self.is_pre_tokenized:
318
+ encoder_input_ids = encoder_input_ids.clone()
317
319
  attention_mask = inputs['attention_mask'][0]
318
320
 
319
321
  decoder_input_ids = encoder_input_ids.clone()
@@ -361,6 +363,8 @@ class MaskedLMDataset(BaseDataset):
361
363
  inputs = self.get_tokenized_text(idx)
362
364
 
363
365
  input_ids = inputs['input_ids'][0]
366
+ if self.is_pre_tokenized:
367
+ input_ids = input_ids.clone()
364
368
  attention_mask = inputs['attention_mask'][0]
365
369
  labels = input_ids.clone()
366
370
 
@@ -17,3 +17,21 @@ def get_transformer_lr_scheduler(
17
17
  return LambdaLR(optimizer, lr_lambda)
18
18
  else:
19
19
  return CosineAnnealingLR(optimizer, T_max=num_training_steps)
20
+
21
+ def calculate_steps(
22
+ dataset_size: int,
23
+ epochs: int,
24
+ batch_size: int,
25
+ warmup_ratio: float = 0.0,
26
+ num_workers: int = 1,
27
+ gradient_accumulation_steps: int = 1,
28
+ verbose: bool = True,
29
+ ):
30
+ steps_per_epoch = int((dataset_size / batch_size - 1) // num_workers)
31
+ total_steps = int((epochs * steps_per_epoch) / gradient_accumulation_steps)
32
+ warmup_steps = int(warmup_ratio * total_steps)
33
+ if verbose:
34
+ print(f'Total steps: {total_steps}')
35
+ print(f'Warmup steps: {warmup_steps}')
36
+ print(f'Total steps per epoch: {steps_per_epoch}')
37
+ return { 'total': total_steps, 'warmup': warmup_steps, 'epoch': steps_per_epoch}
rxnn/utils.py CHANGED
@@ -1,6 +1,9 @@
1
+ import random, gc
1
2
  import torch
3
+ import numpy as np
2
4
 
3
5
  def human_format(num: int):
6
+ """Format numbers to human-readable format."""
4
7
  num = float('{:.3g}'.format(num))
5
8
  magnitude = 0
6
9
  while abs(num) >= 1000:
@@ -10,5 +13,23 @@ def human_format(num: int):
10
13
 
11
14
 
12
15
  def get_model_size(model: torch.nn.Module):
16
+ """Calculate all models parameters with requires_grad param set as True"""
13
17
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
14
- return f'Model params {human_format(trainable_params)}'
18
+ return f'Model params {human_format(trainable_params)}'
19
+
20
+ def set_random_seed(seed: int):
21
+ """
22
+ Set random seed for reproducibility.
23
+
24
+ Applied on 3 libs: PyTorch, Numpy and random
25
+
26
+ seed (int): Random seed value
27
+ """
28
+ torch.random.manual_seed(seed)
29
+ np.random.seed(seed)
30
+ random.seed(seed)
31
+
32
+ def cache_clean():
33
+ gc.collect()
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.82
3
+ Version: 0.1.83
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -29,7 +29,7 @@ Description-Content-Type: text/markdown
29
29
  # Reactive AI - RxNN
30
30
  ## Reactive Neural Networks Platform
31
31
 
32
- RxNN is AI/DeepLearning development platform made for Reactive Neural Networks and Event-driven AI, introduced by Reactive AI.
32
+ RxNN is AI/Deep Learning development platform made for Reactive Neural Networks and Event-driven AI, introduced by Reactive AI.
33
33
 
34
34
  ## Reactive Neural Networks and Event-driven AI
35
35
  Reactive neural networks (RxNN) are a new family of memory-augmented neural networks that combine classical deep learning
@@ -75,6 +75,172 @@ released with next versions of **RxNN** framework:
75
75
  - 1.x.x: Multimodal reactive models (could be released earlier, depending on progress)
76
76
  - 2.0.0: Real-Time Vision Reactor - Worker class models
77
77
  - x.x.x: ...and more!
78
+
79
+ ## Usage
80
+ **RxNN** is made to train models based on reactive architectures, as well as transformer language models. Current version
81
+ is based on PyTorch and HuggingFace libraries (Transformers/Datasets/Tokenizer/Hub), and is integrated with [HuggingFace Hub](https://hugginface.co)
82
+ and [TensorBoard](https://github.com/tensorflow/tensorboard).
83
+
84
+ > We are also planning a version for **TensorFlow**, more info soon
85
+
86
+ ### Install library and dependencies
87
+ - RxNN and required deps: `pip install rxnn torch transformers tokenizers huggingface_hub`
88
+ - Datasets are required only for training: `pip install datasets`
89
+ - TensorBoard is optional: `pip install tensorboard`
90
+ - [Flash Attention](https://github.com/Dao-AILab/flash-attention) is recommended for faster training/inference (required for models with explicit `use_flash_attention=True`) - check its separate [installation guide](#installing-flash-attention)
91
+ - **NumPy** should be installed too: `pip install numpy`
92
+
93
+ > ### Installing Flash Attention
94
+ > Installing `flash-attn` could be very frustrating and may take hours (with standard method), only to result in some incompatibility
95
+ > error. Fortunately, the prebuilt versions could be downloaded from GitHub and installed just in seconds. However, you should choose
96
+ > the compatible version based on:
97
+ > - Python version
98
+ > - CUDA version
99
+ > - PyTorch version (2.7 is currently not supported)
100
+ > - ABI
101
+ >
102
+ > #### Steps
103
+ > 1. Choose your version from [https://github.com/Dao-AILab/flash-attention/releases](https://github.com/Dao-AILab/flash-attention/releases)
104
+ > 2. Download prebuilt release, in example: `wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl`
105
+ > 3. Install it, in example: `pip install --no-dependencies --upgrade flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl`
106
+ > 4. Verify: `flash_attn.__version__` (an incorrect version will cause the error when importing)
107
+ >
108
+ > #### Note on `use_flash_attention` option in models/layers
109
+ > Explicit `use_flash_attention` option is made to enable direct calls to `flash_attn_func` without using **PyTorch** `scaled_dot_product_attention`. Even
110
+ > if it's set to `False`, when `flash-attn` library is installed, **PyTorch** will try to use it implicitly through _SDPA backend_. It's better to set it
111
+ > to `False` and use automatically, because of better compatibility. Explicit options could be used for research
112
+
113
+ ### Modules
114
+ **RxNN** framework has multiple modules with models, layers, training and inference tools, made for complete development
115
+ of _reactive models_, and could be also used for regular **Transformers**.
116
+
117
+ #### Transformers
118
+ Transformers module includes classes for models and layers. It includes **Reactive Transformers** as well as **Classic Transformers**
119
+
120
+ Submodules:
121
+ - `rxnn.transformers.attention` - basic, most common attention layers - `MultiHeadAttention`, `GroupedQueryAttention` and `MultiQueryAttention`
122
+ - additional attention layers, especially `SparseQueryAttention` could be found in `rxnn.experimental.attention` module
123
+ - `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.2.x version
124
+ - `rxnn.transformers.positional` - positional encoding layers - `RotaryPositionalEmbedding` and legacy ones - `AbsolutePositionalEmbedding`/`RelativePositionalEmbedding`
125
+ - `rxnn.transformers.ff` - dense feed forward layers, including gated layers (_SwiGLU_, etc.) - `FeedForward` & `GatedFeedForward` (recommended)
126
+ - `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
127
+ - `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
128
+ - `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
129
+ - `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `SampleDecoder`
130
+
131
+ In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
132
+ to be compatible with HuggingFace **JSON** config. In example:
133
+
134
+ ```python
135
+ from typing import TypedDict
136
+ import torch
137
+ import torch.nn as nn
138
+ from huggingface_hub import PyTorchModelHubMixin
139
+ from rxnn.transformers.attention import GroupedQueryAttention
140
+ from rxnn.transformers.positional import RotaryPositionalEmbedding
141
+ from rxnn.transformers.layers import ReactiveTransformerLayer
142
+ from rxnn.transformers.models import ReactiveTransformerDecoder
143
+ from rxnn.memory.stm import ShortTermMemory
144
+
145
+ class YourReactiveTransformerConfig(TypedDict):
146
+ num_layers: int
147
+ vocab_size: int
148
+ embed_dim: int
149
+ ff_dim: int
150
+ att_heads: int
151
+ seq_len: int
152
+ stm_size: int
153
+ att_groups: int
154
+ cross_att_groups: int
155
+
156
+
157
+ class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
158
+ def __init__(
159
+ self,
160
+ config: YourReactiveTransformerConfig,
161
+ **kwargs
162
+ ):
163
+ super(YourReactiveTransformerDecoder, self).__init__(**kwargs)
164
+
165
+ embedding = nn.Embedding(config['vocab_size'], config['embed_dim'])
166
+ rope = RotaryPositionalEmbedding(config['embed_dim'] // config['att_heads'], config['seq_len'])
167
+ stm = ShortTermMemory(config['num_layers'], config['embed_dim'], config['stm_size'])
168
+
169
+ self.model = ReactiveTransformerDecoder(
170
+ stm=stm,
171
+ embedding=embedding,
172
+ own_layers=nn.ModuleList([
173
+ ReactiveTransformerLayer(
174
+ config['embed_dim'],
175
+ config['ff_dim'],
176
+ use_gated=True,
177
+ use_moe=False,
178
+ ff_activation=nn.GELU(),
179
+ ff_dropout=0.1,
180
+ use_rms_norm=True,
181
+ self_attention=GroupedQueryAttention(
182
+ config['embed_dim'],
183
+ config['att_heads'],
184
+ config['att_groups'],
185
+ rope=rope,
186
+ dropout=0.1,
187
+ max_seq_len=config['seq_len'],
188
+ is_causal=True,
189
+ ),
190
+ memory_cross_attention=GroupedQueryAttention(
191
+ config['embed_dim'],
192
+ config['att_heads'],
193
+ config['att_groups'],
194
+ rope=rope,
195
+ dropout=0.1,
196
+ max_seq_len=config['seq_len'],
197
+ is_causal=True,
198
+ rope_only_for_query=True
199
+ ),
200
+ ) for _ in range(config['num_layers'])
201
+ ])
202
+ )
203
+
204
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
205
+ return self.model(x, attention_mask=attention_mask)
206
+ ```
207
+
208
+ #### Memory
209
+ The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
210
+ include **Long-Term Memory**.
211
+
212
+ The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
213
+
214
+ Other submodules are connected to **Memory Attention** and will be described in 0.2.x version, after MRL
215
+
216
+ #### Training
217
+ Training module includes **Trainers** for different training stages of reactive models and shared training utils.
218
+
219
+ Submodules:
220
+ - `rxnn.training.tokenizer` - custom Trainer for **HuggingFace** `tokenizers` and utils to load tokenizer from Hub
221
+ - Tokenizer could be loaded from Hub with `load_tokenizer_from_hf_hub(repo_id)`
222
+ - `rxnn.training.dataset` - datasets for different training stages:
223
+ - `MaskedLMDataset` & `AutoregressiveLMDataset` are made for base models pre-training
224
+ - `EncoderSftDataset` & `DecoderSftDataset` are made for Interaction Supervised Fine-Tuning for reactive models
225
+ - `MrlCurriculumDataset` is the dataset for single MRL Curriculum step
226
+ - `MrlDatasets` is wrapping MRL datasets for all curriculum steps
227
+ - each dataset has `from_hf_hub` class method to load dataset from Hub
228
+ - they have also `concat_from_hf_hub` class method to load multiple Hub datasets into single training dataset
229
+ - if dataset has no validation/test split, each dataset has `get_subset(subset_size, from_start=False)` method - it
230
+ returns new subset and modifying existing one - i.e. `valid_dataset = train_dataset.get_subset(0.1)`
231
+ - for concatenated datasets, validation/test split could be created with `concat_from_hf_hub_with_subset` - it cuts the
232
+ same percentage of each loaded dataset
233
+ - `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
234
+ - `rxnn.training.scheduler` includes learning rate scheduler for training
235
+ - `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
236
+ - `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL (from 0.2.x)
237
+ - `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
238
+ - `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
239
+
240
+ ##### Base Model Learning
241
+ Docs in progress
242
+
243
+
78
244
  Apache License
79
245
  Version 2.0, January 2004
80
246
  http://www.apache.org/licenses/
@@ -9,11 +9,11 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
9
9
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  rxnn/rxt/models.py,sha256=iUlSvdXrD1NVzZFmdC55qp4_3xoJj31FC40BGgYlf4Q,8763
11
11
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/training/base.py,sha256=xPMA2Bg9-oUZvSZg67ls2p7Gk9pZ9IHUiIJwUzSe2K8,11766
12
+ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
13
13
  rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
14
- rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
15
- rxnn/training/dataset.py,sha256=xI7bbARRWifunVX6HakCroSFqkM401BQmxfsf9pDeY4,35621
16
- rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
14
+ rxnn/training/callbacks.py,sha256=xcU3W6_OsIEDTFTbN7S3uIWyGqLulbUWZMpW0aIXmF4,22699
15
+ rxnn/training/dataset.py,sha256=XEDmOwD8v0c9u0QCk7I3xZShKaMtBDwYlfK1ofu6A1E,35789
16
+ rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
17
17
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
@@ -24,8 +24,8 @@ rxnn/transformers/models.py,sha256=xbnn3FTNZFhaqq9A0XEM12ie_WL_58pPeq0qFXIgve0,7
24
24
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
25
25
  rxnn/transformers/positional.py,sha256=ge-kaS6WnWnPGnWVp25ZK5bVkmhBUNCaELaN2rN_fSY,4097
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
- rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.82.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.82.dist-info/METADATA,sha256=xhip3_H9uGKIHKfyTnR0vk_a9zr0TzTIr8buNIiDUQY,16589
30
- rxnn-0.1.82.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
31
- rxnn-0.1.82.dist-info/RECORD,,
27
+ rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
28
+ rxnn-0.1.83.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.83.dist-info/METADATA,sha256=AhGTqWM9mvBzDRWliKeTRySDAL2cXXTYefRL_HGJN_Q,25930
30
+ rxnn-0.1.83.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
31
+ rxnn-0.1.83.dist-info/RECORD,,
File without changes
File without changes