rxnn 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/rxt/models.py +3 -3
- rxnn/training/callbacks.py +2 -2
- rxnn/training/mrl.py +12 -17
- rxnn/training/reward.py +7 -7
- rxnn/training/utils.py +0 -4
- {rxnn-0.2.2.dist-info → rxnn-0.2.4.dist-info}/METADATA +1 -1
- {rxnn-0.2.2.dist-info → rxnn-0.2.4.dist-info}/RECORD +9 -9
- {rxnn-0.2.2.dist-info → rxnn-0.2.4.dist-info}/LICENSE +0 -0
- {rxnn-0.2.2.dist-info → rxnn-0.2.4.dist-info}/WHEEL +0 -0
rxnn/rxt/models.py
CHANGED
@@ -53,7 +53,7 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
53
53
|
att_heads: int = 16,
|
54
54
|
seq_len: int = 1024,
|
55
55
|
stm_size: int = 1024,
|
56
|
-
use_flash_attention: bool =
|
56
|
+
use_flash_attention: bool = False,
|
57
57
|
use_gated: bool = True,
|
58
58
|
ff_activation: str = "swish",
|
59
59
|
ff_dropout: float = 0.0,
|
@@ -232,7 +232,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
232
232
|
att_heads: int = 16,
|
233
233
|
seq_len: int = 1024,
|
234
234
|
stm_size: int = 1024,
|
235
|
-
use_flash_attention: bool =
|
235
|
+
use_flash_attention: bool = False,
|
236
236
|
att_dropout: float = 0.0,
|
237
237
|
norm_type: str = 'rms',
|
238
238
|
att_groups: int = 1,
|
@@ -271,7 +271,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
271
271
|
self.model.update_max_len(max_seq_len)
|
272
272
|
|
273
273
|
def reset_memory(self, init_type: str = None):
|
274
|
-
self.model.stm.
|
274
|
+
self.model.stm.reset(init_type)
|
275
275
|
|
276
276
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
277
277
|
return self.model(x, attention_mask=attention_mask)
|
rxnn/training/callbacks.py
CHANGED
@@ -557,7 +557,7 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
557
557
|
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
|
558
558
|
global_epoch: int, global_epochs: int) -> None:
|
559
559
|
print(
|
560
|
-
f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config[
|
560
|
+
f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
561
561
|
|
562
562
|
def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
|
563
563
|
critic_loss: float, global_epoch: int, global_epochs: int) -> None:
|
@@ -580,7 +580,7 @@ class MrlPrintCallback(MrlTrainerCallback):
|
|
580
580
|
print(f'Epoch {epoch} | Step {step} - updated policy loss {critic_loss}')
|
581
581
|
|
582
582
|
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
583
|
-
print(f'Finished training for {curriculum_config[
|
583
|
+
print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
|
584
584
|
|
585
585
|
def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> None:
|
586
586
|
print(f'Eval epoch {epoch} - mean reward {eval_mean_reward}')
|
rxnn/training/mrl.py
CHANGED
@@ -9,7 +9,7 @@ import random, os
|
|
9
9
|
from ..transformers.sampler import BatchSampler
|
10
10
|
from .callbacks import MrlTrainerCallback
|
11
11
|
from .dataset import MrlCurriculumDataset
|
12
|
-
from .utils import smart_concat, smart_concat_critic_states,
|
12
|
+
from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
|
13
13
|
from .rl import RlAlgorithm
|
14
14
|
from .reward import MrlRewardMode, MrlRewardModel
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
@@ -74,7 +74,6 @@ class MRLTrainer:
|
|
74
74
|
sampler_config: Optional[SamplerConfig] = None,
|
75
75
|
log_dir: str = None,
|
76
76
|
pad_token_id: int = 0,
|
77
|
-
start_token_id: int = 2,
|
78
77
|
end_token_id: int = 3,
|
79
78
|
use_ddp: bool = False,
|
80
79
|
use_amp: bool = False,
|
@@ -112,11 +111,7 @@ class MRLTrainer:
|
|
112
111
|
top_p=None,
|
113
112
|
) if sampler_config is None else sampler_config
|
114
113
|
|
115
|
-
self.
|
116
|
-
'pad': pad_token_id,
|
117
|
-
'bos': start_token_id,
|
118
|
-
'eos': end_token_id,
|
119
|
-
}
|
114
|
+
self.pad_token_id = pad_token_id
|
120
115
|
|
121
116
|
self.use_ddp = use_ddp
|
122
117
|
self.use_amp = use_amp
|
@@ -191,12 +186,12 @@ class MRLTrainer:
|
|
191
186
|
if self.use_amp:
|
192
187
|
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
193
188
|
# 2. Concatenate batch of queries and answers (they are already on training device)
|
194
|
-
inputs = smart_concat(query, answer, self.max_seq_len, self.
|
189
|
+
inputs = smart_concat(query, answer, self.max_seq_len, self.pad_token_id)
|
195
190
|
# 3. Encode data and update STM
|
196
191
|
self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
|
197
192
|
else:
|
198
193
|
# 2. Concatenate batch of queries and answers (they are already on training device)
|
199
|
-
inputs = smart_concat(query, answer, self.max_seq_len, self.
|
194
|
+
inputs = smart_concat(query, answer, self.max_seq_len, self.pad_token_id)
|
200
195
|
# 3. Encode data and update STM
|
201
196
|
self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
|
202
197
|
|
@@ -235,11 +230,11 @@ class MRLTrainer:
|
|
235
230
|
if self.use_amp:
|
236
231
|
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
237
232
|
saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
|
238
|
-
pad_token_id=self.
|
233
|
+
pad_token_id=self.pad_token_id)
|
239
234
|
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
240
235
|
else:
|
241
236
|
saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
|
242
|
-
pad_token_id=self.
|
237
|
+
pad_token_id=self.pad_token_id)
|
243
238
|
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
244
239
|
|
245
240
|
# 2. Run 'on reward' callbacks
|
@@ -404,7 +399,7 @@ class MRLTrainer:
|
|
404
399
|
inputs = smart_concat_critic_states(
|
405
400
|
prev_query, prev_answer, next_query,
|
406
401
|
max_length=self.critic_max_len,
|
407
|
-
pad_token_id=self.
|
402
|
+
pad_token_id=self.pad_token_id,
|
408
403
|
)
|
409
404
|
loss = self._critic_loss(inputs, batch_rewards)
|
410
405
|
# Run backpropagation with scaler
|
@@ -420,7 +415,7 @@ class MRLTrainer:
|
|
420
415
|
inputs = smart_concat_critic_states(
|
421
416
|
prev_query, prev_answer, next_query,
|
422
417
|
max_length=self.critic_max_len,
|
423
|
-
pad_token_id=self.
|
418
|
+
pad_token_id=self.pad_token_id,
|
424
419
|
)
|
425
420
|
# Calculate loss
|
426
421
|
loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
|
@@ -486,11 +481,11 @@ class MRLTrainer:
|
|
486
481
|
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
487
482
|
critic_state = smart_concat_critic_states(query, answer, next_query,
|
488
483
|
max_length=self.critic_max_len,
|
489
|
-
pad_token_id=self.
|
484
|
+
pad_token_id=self.pad_token_id)
|
490
485
|
advantages = self._critic_advantages(critic_state, rewards)
|
491
486
|
else:
|
492
487
|
critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
|
493
|
-
pad_token_id=self.
|
488
|
+
pad_token_id=self.pad_token_id)
|
494
489
|
advantages = self._critic_advantages(critic_state, rewards)
|
495
490
|
|
496
491
|
# 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
|
@@ -499,12 +494,12 @@ class MRLTrainer:
|
|
499
494
|
if self.use_amp:
|
500
495
|
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
501
496
|
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
502
|
-
pad_token_id=self.
|
497
|
+
pad_token_id=self.pad_token_id)
|
503
498
|
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
504
499
|
action=MrlActorAction.DECODE)
|
505
500
|
else:
|
506
501
|
inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
|
507
|
-
pad_token_id=self.
|
502
|
+
pad_token_id=self.pad_token_id)
|
508
503
|
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
509
504
|
action=MrlActorAction.DECODE)
|
510
505
|
|
rxnn/training/reward.py
CHANGED
@@ -33,13 +33,6 @@ class MrlRewardModel:
|
|
33
33
|
self.device = device
|
34
34
|
self.bleu_with_saved_data = bleu_with_saved_data
|
35
35
|
|
36
|
-
if not allow_not_summing_factors:
|
37
|
-
assert bleu_factor + cos_factor == 1.0
|
38
|
-
assert cos_ref_factor + cos_saved_factor == 1.0
|
39
|
-
assert neg_bleu_factor + neg_cos_factor == 1.0
|
40
|
-
assert neg_cos_ref_factor + neg_cos_saved_factor == 1.0
|
41
|
-
assert neg_bleu_ref_factor + neg_bleu_saved_factor == 1.0
|
42
|
-
|
43
36
|
self.bleu_factor = bleu_factor
|
44
37
|
self.cos_factor = cos_factor
|
45
38
|
self.cos_ref_factor = cos_ref_factor
|
@@ -51,6 +44,13 @@ class MrlRewardModel:
|
|
51
44
|
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
52
45
|
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
53
46
|
|
47
|
+
if not allow_not_summing_factors:
|
48
|
+
assert self.bleu_factor + self.cos_factor == 1.0
|
49
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
50
|
+
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
51
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
52
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
53
|
+
|
54
54
|
def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
55
55
|
from nltk.translate.bleu_score import sentence_bleu
|
56
56
|
refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
|
rxnn/training/utils.py
CHANGED
@@ -9,19 +9,19 @@ rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
10
|
rxnn/memory/stm.py,sha256=eSMK5KdupWNf56FcDYprHnjA51EeYBzSKza7tiZxKSc,3618
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
|
-
rxnn/training/callbacks.py,sha256=
|
16
|
+
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
20
|
-
rxnn/training/reward.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=WDQ8xsrHfpRmTczDZhBuOlqHX8JBaEp5SchlTdAxttY,38883
|
20
|
+
rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
-
rxnn/training/utils.py,sha256=
|
24
|
+
rxnn/training/utils.py,sha256=7ED5RIC8AybCmmQrbsU6Krd7brRILxVIeTlJYtJWl_4,5702
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.4.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.4.dist-info/METADATA,sha256=8qcHy1ysyg_6GiNe5Jd0sxsix9rPBDR_RhYgvCodK28,25959
|
37
|
+
rxnn-0.2.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|