rxnn 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +2 -0
- rxnn/training/callbacks.py +21 -0
- rxnn/training/dataset.py +4 -0
- rxnn/training/scheduler.py +18 -0
- rxnn/utils.py +22 -1
- {rxnn-0.1.82.dist-info → rxnn-0.1.83.dist-info}/METADATA +168 -2
- {rxnn-0.1.82.dist-info → rxnn-0.1.83.dist-info}/RECORD +9 -9
- {rxnn-0.1.82.dist-info → rxnn-0.1.83.dist-info}/LICENSE +0 -0
- {rxnn-0.1.82.dist-info → rxnn-0.1.83.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -39,6 +39,8 @@ class BaseTrainer(ABC):
|
|
39
39
|
self.optimizer = optimizer
|
40
40
|
self.dataset = dataset
|
41
41
|
self.callbacks = callbacks or []
|
42
|
+
if log_dir and not os.path.exists(log_dir):
|
43
|
+
os.makedirs(log_dir)
|
42
44
|
self.writer = SummaryWriter(log_dir) if log_dir else None
|
43
45
|
self.use_ddp = use_ddp
|
44
46
|
self.use_amp = use_amp
|
rxnn/training/callbacks.py
CHANGED
@@ -499,3 +499,24 @@ class JointModelSaveCallback(TrainerCallback):
|
|
499
499
|
if not self.mlm_mode:
|
500
500
|
self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
|
501
501
|
self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
|
502
|
+
|
503
|
+
class EarlyStoppageCallback(TrainerCallback):
|
504
|
+
def __init__(self, num_plateau_epochs: int = 3) -> None:
|
505
|
+
super().__init__()
|
506
|
+
self.num_plateau_epochs = num_plateau_epochs
|
507
|
+
self.best_loss = 9999.0
|
508
|
+
self.best_loss_epoch = 0
|
509
|
+
|
510
|
+
def on_validation_end(
|
511
|
+
self,
|
512
|
+
model: torch.nn.Module,
|
513
|
+
epoch: int,
|
514
|
+
val_loss: float,
|
515
|
+
val_metrics: dict
|
516
|
+
):
|
517
|
+
if val_loss < self.best_loss:
|
518
|
+
self.best_loss = val_loss
|
519
|
+
self.best_loss_epoch = epoch
|
520
|
+
elif epoch - self.best_loss_epoch > self.num_plateau_epochs:
|
521
|
+
return True
|
522
|
+
return None
|
rxnn/training/dataset.py
CHANGED
@@ -314,6 +314,8 @@ class JointLMDataset(BaseDataset):
|
|
314
314
|
def __getitem__(self, idx: int) -> dict[str, dict[str, torch.Tensor]]:
|
315
315
|
inputs = self.get_tokenized_text(idx)
|
316
316
|
encoder_input_ids = inputs['input_ids'][0]
|
317
|
+
if self.is_pre_tokenized:
|
318
|
+
encoder_input_ids = encoder_input_ids.clone()
|
317
319
|
attention_mask = inputs['attention_mask'][0]
|
318
320
|
|
319
321
|
decoder_input_ids = encoder_input_ids.clone()
|
@@ -361,6 +363,8 @@ class MaskedLMDataset(BaseDataset):
|
|
361
363
|
inputs = self.get_tokenized_text(idx)
|
362
364
|
|
363
365
|
input_ids = inputs['input_ids'][0]
|
366
|
+
if self.is_pre_tokenized:
|
367
|
+
input_ids = input_ids.clone()
|
364
368
|
attention_mask = inputs['attention_mask'][0]
|
365
369
|
labels = input_ids.clone()
|
366
370
|
|
rxnn/training/scheduler.py
CHANGED
@@ -17,3 +17,21 @@ def get_transformer_lr_scheduler(
|
|
17
17
|
return LambdaLR(optimizer, lr_lambda)
|
18
18
|
else:
|
19
19
|
return CosineAnnealingLR(optimizer, T_max=num_training_steps)
|
20
|
+
|
21
|
+
def calculate_steps(
|
22
|
+
dataset_size: int,
|
23
|
+
epochs: int,
|
24
|
+
batch_size: int,
|
25
|
+
warmup_ratio: float = 0.0,
|
26
|
+
num_workers: int = 1,
|
27
|
+
gradient_accumulation_steps: int = 1,
|
28
|
+
verbose: bool = True,
|
29
|
+
):
|
30
|
+
steps_per_epoch = int((dataset_size / batch_size - 1) // num_workers)
|
31
|
+
total_steps = int((epochs * steps_per_epoch) / gradient_accumulation_steps)
|
32
|
+
warmup_steps = int(warmup_ratio * total_steps)
|
33
|
+
if verbose:
|
34
|
+
print(f'Total steps: {total_steps}')
|
35
|
+
print(f'Warmup steps: {warmup_steps}')
|
36
|
+
print(f'Total steps per epoch: {steps_per_epoch}')
|
37
|
+
return { 'total': total_steps, 'warmup': warmup_steps, 'epoch': steps_per_epoch}
|
rxnn/utils.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
import random, gc
|
1
2
|
import torch
|
3
|
+
import numpy as np
|
2
4
|
|
3
5
|
def human_format(num: int):
|
6
|
+
"""Format numbers to human-readable format."""
|
4
7
|
num = float('{:.3g}'.format(num))
|
5
8
|
magnitude = 0
|
6
9
|
while abs(num) >= 1000:
|
@@ -10,5 +13,23 @@ def human_format(num: int):
|
|
10
13
|
|
11
14
|
|
12
15
|
def get_model_size(model: torch.nn.Module):
|
16
|
+
"""Calculate all models parameters with requires_grad param set as True"""
|
13
17
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
14
|
-
return f'Model params {human_format(trainable_params)}'
|
18
|
+
return f'Model params {human_format(trainable_params)}'
|
19
|
+
|
20
|
+
def set_random_seed(seed: int):
|
21
|
+
"""
|
22
|
+
Set random seed for reproducibility.
|
23
|
+
|
24
|
+
Applied on 3 libs: PyTorch, Numpy and random
|
25
|
+
|
26
|
+
seed (int): Random seed value
|
27
|
+
"""
|
28
|
+
torch.random.manual_seed(seed)
|
29
|
+
np.random.seed(seed)
|
30
|
+
random.seed(seed)
|
31
|
+
|
32
|
+
def cache_clean():
|
33
|
+
gc.collect()
|
34
|
+
if torch.cuda.is_available():
|
35
|
+
torch.cuda.empty_cache()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.83
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -29,7 +29,7 @@ Description-Content-Type: text/markdown
|
|
29
29
|
# Reactive AI - RxNN
|
30
30
|
## Reactive Neural Networks Platform
|
31
31
|
|
32
|
-
RxNN is AI/
|
32
|
+
RxNN is AI/Deep Learning development platform made for Reactive Neural Networks and Event-driven AI, introduced by Reactive AI.
|
33
33
|
|
34
34
|
## Reactive Neural Networks and Event-driven AI
|
35
35
|
Reactive neural networks (RxNN) are a new family of memory-augmented neural networks that combine classical deep learning
|
@@ -75,6 +75,172 @@ released with next versions of **RxNN** framework:
|
|
75
75
|
- 1.x.x: Multimodal reactive models (could be released earlier, depending on progress)
|
76
76
|
- 2.0.0: Real-Time Vision Reactor - Worker class models
|
77
77
|
- x.x.x: ...and more!
|
78
|
+
|
79
|
+
## Usage
|
80
|
+
**RxNN** is made to train models based on reactive architectures, as well as transformer language models. Current version
|
81
|
+
is based on PyTorch and HuggingFace libraries (Transformers/Datasets/Tokenizer/Hub), and is integrated with [HuggingFace Hub](https://hugginface.co)
|
82
|
+
and [TensorBoard](https://github.com/tensorflow/tensorboard).
|
83
|
+
|
84
|
+
> We are also planning a version for **TensorFlow**, more info soon
|
85
|
+
|
86
|
+
### Install library and dependencies
|
87
|
+
- RxNN and required deps: `pip install rxnn torch transformers tokenizers huggingface_hub`
|
88
|
+
- Datasets are required only for training: `pip install datasets`
|
89
|
+
- TensorBoard is optional: `pip install tensorboard`
|
90
|
+
- [Flash Attention](https://github.com/Dao-AILab/flash-attention) is recommended for faster training/inference (required for models with explicit `use_flash_attention=True`) - check its separate [installation guide](#installing-flash-attention)
|
91
|
+
- **NumPy** should be installed too: `pip install numpy`
|
92
|
+
|
93
|
+
> ### Installing Flash Attention
|
94
|
+
> Installing `flash-attn` could be very frustrating and may take hours (with standard method), only to result in some incompatibility
|
95
|
+
> error. Fortunately, the prebuilt versions could be downloaded from GitHub and installed just in seconds. However, you should choose
|
96
|
+
> the compatible version based on:
|
97
|
+
> - Python version
|
98
|
+
> - CUDA version
|
99
|
+
> - PyTorch version (2.7 is currently not supported)
|
100
|
+
> - ABI
|
101
|
+
>
|
102
|
+
> #### Steps
|
103
|
+
> 1. Choose your version from [https://github.com/Dao-AILab/flash-attention/releases](https://github.com/Dao-AILab/flash-attention/releases)
|
104
|
+
> 2. Download prebuilt release, in example: `wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl`
|
105
|
+
> 3. Install it, in example: `pip install --no-dependencies --upgrade flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl`
|
106
|
+
> 4. Verify: `flash_attn.__version__` (an incorrect version will cause the error when importing)
|
107
|
+
>
|
108
|
+
> #### Note on `use_flash_attention` option in models/layers
|
109
|
+
> Explicit `use_flash_attention` option is made to enable direct calls to `flash_attn_func` without using **PyTorch** `scaled_dot_product_attention`. Even
|
110
|
+
> if it's set to `False`, when `flash-attn` library is installed, **PyTorch** will try to use it implicitly through _SDPA backend_. It's better to set it
|
111
|
+
> to `False` and use automatically, because of better compatibility. Explicit options could be used for research
|
112
|
+
|
113
|
+
### Modules
|
114
|
+
**RxNN** framework has multiple modules with models, layers, training and inference tools, made for complete development
|
115
|
+
of _reactive models_, and could be also used for regular **Transformers**.
|
116
|
+
|
117
|
+
#### Transformers
|
118
|
+
Transformers module includes classes for models and layers. It includes **Reactive Transformers** as well as **Classic Transformers**
|
119
|
+
|
120
|
+
Submodules:
|
121
|
+
- `rxnn.transformers.attention` - basic, most common attention layers - `MultiHeadAttention`, `GroupedQueryAttention` and `MultiQueryAttention`
|
122
|
+
- additional attention layers, especially `SparseQueryAttention` could be found in `rxnn.experimental.attention` module
|
123
|
+
- `SparseQueryAttention` will be moved to `rxnn.transformers.attention` in 0.2.x version
|
124
|
+
- `rxnn.transformers.positional` - positional encoding layers - `RotaryPositionalEmbedding` and legacy ones - `AbsolutePositionalEmbedding`/`RelativePositionalEmbedding`
|
125
|
+
- `rxnn.transformers.ff` - dense feed forward layers, including gated layers (_SwiGLU_, etc.) - `FeedForward` & `GatedFeedForward` (recommended)
|
126
|
+
- `rxnn.transformers.moe` - Mixture-of-Experts feed forward layers - `MoeFeedForward` & `GatedMoeFeedForward` (recommended)
|
127
|
+
- `rxnn.transformer.layers` - complete reactive/classic transformer layers - `ReactiveTransformerLayer` & `ClassicTransformerLayer`
|
128
|
+
- `rxnn.transformer.models` - reactive/classic transformer models - `ReactiveTransformerEncoder`, `ReactiveTransformerDecoder` & `ClassicTransformerEncoder`, `ClassicTransformerDecoder`
|
129
|
+
- `rxnn.transformer.sampler` - samplers for reactive models (Sampler is the integral part of reactive architectures) - `Sampler` & `SampleDecoder`
|
130
|
+
|
131
|
+
In **RxNN** models are initialized in declarative style by class composition, but then they are wrapped in imperative classes,
|
132
|
+
to be compatible with HuggingFace **JSON** config. In example:
|
133
|
+
|
134
|
+
```python
|
135
|
+
from typing import TypedDict
|
136
|
+
import torch
|
137
|
+
import torch.nn as nn
|
138
|
+
from huggingface_hub import PyTorchModelHubMixin
|
139
|
+
from rxnn.transformers.attention import GroupedQueryAttention
|
140
|
+
from rxnn.transformers.positional import RotaryPositionalEmbedding
|
141
|
+
from rxnn.transformers.layers import ReactiveTransformerLayer
|
142
|
+
from rxnn.transformers.models import ReactiveTransformerDecoder
|
143
|
+
from rxnn.memory.stm import ShortTermMemory
|
144
|
+
|
145
|
+
class YourReactiveTransformerConfig(TypedDict):
|
146
|
+
num_layers: int
|
147
|
+
vocab_size: int
|
148
|
+
embed_dim: int
|
149
|
+
ff_dim: int
|
150
|
+
att_heads: int
|
151
|
+
seq_len: int
|
152
|
+
stm_size: int
|
153
|
+
att_groups: int
|
154
|
+
cross_att_groups: int
|
155
|
+
|
156
|
+
|
157
|
+
class YourReactiveTransformerDecoder(nn.Module, PyTorchModelHubMixin):
|
158
|
+
def __init__(
|
159
|
+
self,
|
160
|
+
config: YourReactiveTransformerConfig,
|
161
|
+
**kwargs
|
162
|
+
):
|
163
|
+
super(YourReactiveTransformerDecoder, self).__init__(**kwargs)
|
164
|
+
|
165
|
+
embedding = nn.Embedding(config['vocab_size'], config['embed_dim'])
|
166
|
+
rope = RotaryPositionalEmbedding(config['embed_dim'] // config['att_heads'], config['seq_len'])
|
167
|
+
stm = ShortTermMemory(config['num_layers'], config['embed_dim'], config['stm_size'])
|
168
|
+
|
169
|
+
self.model = ReactiveTransformerDecoder(
|
170
|
+
stm=stm,
|
171
|
+
embedding=embedding,
|
172
|
+
own_layers=nn.ModuleList([
|
173
|
+
ReactiveTransformerLayer(
|
174
|
+
config['embed_dim'],
|
175
|
+
config['ff_dim'],
|
176
|
+
use_gated=True,
|
177
|
+
use_moe=False,
|
178
|
+
ff_activation=nn.GELU(),
|
179
|
+
ff_dropout=0.1,
|
180
|
+
use_rms_norm=True,
|
181
|
+
self_attention=GroupedQueryAttention(
|
182
|
+
config['embed_dim'],
|
183
|
+
config['att_heads'],
|
184
|
+
config['att_groups'],
|
185
|
+
rope=rope,
|
186
|
+
dropout=0.1,
|
187
|
+
max_seq_len=config['seq_len'],
|
188
|
+
is_causal=True,
|
189
|
+
),
|
190
|
+
memory_cross_attention=GroupedQueryAttention(
|
191
|
+
config['embed_dim'],
|
192
|
+
config['att_heads'],
|
193
|
+
config['att_groups'],
|
194
|
+
rope=rope,
|
195
|
+
dropout=0.1,
|
196
|
+
max_seq_len=config['seq_len'],
|
197
|
+
is_causal=True,
|
198
|
+
rope_only_for_query=True
|
199
|
+
),
|
200
|
+
) for _ in range(config['num_layers'])
|
201
|
+
])
|
202
|
+
)
|
203
|
+
|
204
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None):
|
205
|
+
return self.model(x, attention_mask=attention_mask)
|
206
|
+
```
|
207
|
+
|
208
|
+
#### Memory
|
209
|
+
The _memory_ module includes **Short-Term Memory** and layers responsible for its update. In future versions it will also
|
210
|
+
include **Long-Term Memory**.
|
211
|
+
|
212
|
+
The main `ShortTermMemory` class is located in `rxnn.memory.stm` module - the usage example is in Transformers module description.
|
213
|
+
|
214
|
+
Other submodules are connected to **Memory Attention** and will be described in 0.2.x version, after MRL
|
215
|
+
|
216
|
+
#### Training
|
217
|
+
Training module includes **Trainers** for different training stages of reactive models and shared training utils.
|
218
|
+
|
219
|
+
Submodules:
|
220
|
+
- `rxnn.training.tokenizer` - custom Trainer for **HuggingFace** `tokenizers` and utils to load tokenizer from Hub
|
221
|
+
- Tokenizer could be loaded from Hub with `load_tokenizer_from_hf_hub(repo_id)`
|
222
|
+
- `rxnn.training.dataset` - datasets for different training stages:
|
223
|
+
- `MaskedLMDataset` & `AutoregressiveLMDataset` are made for base models pre-training
|
224
|
+
- `EncoderSftDataset` & `DecoderSftDataset` are made for Interaction Supervised Fine-Tuning for reactive models
|
225
|
+
- `MrlCurriculumDataset` is the dataset for single MRL Curriculum step
|
226
|
+
- `MrlDatasets` is wrapping MRL datasets for all curriculum steps
|
227
|
+
- each dataset has `from_hf_hub` class method to load dataset from Hub
|
228
|
+
- they have also `concat_from_hf_hub` class method to load multiple Hub datasets into single training dataset
|
229
|
+
- if dataset has no validation/test split, each dataset has `get_subset(subset_size, from_start=False)` method - it
|
230
|
+
returns new subset and modifying existing one - i.e. `valid_dataset = train_dataset.get_subset(0.1)`
|
231
|
+
- for concatenated datasets, validation/test split could be created with `concat_from_hf_hub_with_subset` - it cuts the
|
232
|
+
same percentage of each loaded dataset
|
233
|
+
- `rxnn.training.callbacks` contain Trainer callbacks, for different kind of utils (more info below)
|
234
|
+
- `rxnn.training.scheduler` includes learning rate scheduler for training
|
235
|
+
- `rxnn.training.bml` - Base Model Learning module with Trainers for pre-training and fine-tuning
|
236
|
+
- `rxnn.training.mrl` - Memory Reinforcement Learning module with Trainers for MRL (from 0.2.x)
|
237
|
+
- `rxnn.training.rxrlhf` - Reinforcement Learning from Human Feedback for Reactive Models module (from 0.3.x)
|
238
|
+
- `rxnn.training.brl` - Behavioral Reinforcement Learning module (Reactor / from 0.7.x
|
239
|
+
|
240
|
+
##### Base Model Learning
|
241
|
+
Docs in progress
|
242
|
+
|
243
|
+
|
78
244
|
Apache License
|
79
245
|
Version 2.0, January 2004
|
80
246
|
http://www.apache.org/licenses/
|
@@ -9,11 +9,11 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
|
9
9
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
rxnn/rxt/models.py,sha256=iUlSvdXrD1NVzZFmdC55qp4_3xoJj31FC40BGgYlf4Q,8763
|
11
11
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/training/base.py,sha256=
|
12
|
+
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
13
13
|
rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
|
14
|
-
rxnn/training/callbacks.py,sha256=
|
15
|
-
rxnn/training/dataset.py,sha256=
|
16
|
-
rxnn/training/scheduler.py,sha256=
|
14
|
+
rxnn/training/callbacks.py,sha256=xcU3W6_OsIEDTFTbN7S3uIWyGqLulbUWZMpW0aIXmF4,22699
|
15
|
+
rxnn/training/dataset.py,sha256=XEDmOwD8v0c9u0QCk7I3xZShKaMtBDwYlfK1ofu6A1E,35789
|
16
|
+
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
17
17
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
@@ -24,8 +24,8 @@ rxnn/transformers/models.py,sha256=xbnn3FTNZFhaqq9A0XEM12ie_WL_58pPeq0qFXIgve0,7
|
|
24
24
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
25
25
|
rxnn/transformers/positional.py,sha256=ge-kaS6WnWnPGnWVp25ZK5bVkmhBUNCaELaN2rN_fSY,4097
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
|
-
rxnn/utils.py,sha256=
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
27
|
+
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
28
|
+
rxnn-0.1.83.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.83.dist-info/METADATA,sha256=AhGTqWM9mvBzDRWliKeTRySDAL2cXXTYefRL_HGJN_Q,25930
|
30
|
+
rxnn-0.1.83.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
31
|
+
rxnn-0.1.83.dist-info/RECORD,,
|
File without changes
|
File without changes
|