rxnn 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/rxt/models.py CHANGED
@@ -53,7 +53,7 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
53
53
  att_heads: int = 16,
54
54
  seq_len: int = 1024,
55
55
  stm_size: int = 1024,
56
- use_flash_attention: bool = True,
56
+ use_flash_attention: bool = False,
57
57
  use_gated: bool = True,
58
58
  ff_activation: str = "swish",
59
59
  ff_dropout: float = 0.0,
@@ -232,7 +232,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
232
232
  att_heads: int = 16,
233
233
  seq_len: int = 1024,
234
234
  stm_size: int = 1024,
235
- use_flash_attention: bool = True,
235
+ use_flash_attention: bool = False,
236
236
  att_dropout: float = 0.0,
237
237
  norm_type: str = 'rms',
238
238
  att_groups: int = 1,
@@ -271,7 +271,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
271
271
  self.model.update_max_len(max_seq_len)
272
272
 
273
273
  def reset_memory(self, init_type: str = None):
274
- self.model.stm.reset_memory(init_type)
274
+ self.model.stm.reset(init_type)
275
275
 
276
276
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
277
277
  return self.model(x, attention_mask=attention_mask)
@@ -557,7 +557,7 @@ class MrlPrintCallback(MrlTrainerCallback):
557
557
  def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
558
558
  global_epoch: int, global_epochs: int) -> None:
559
559
  print(
560
- f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config['steps']} steps in {curriculum_config['strategy']} strategy.')
560
+ f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
561
561
 
562
562
  def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
563
563
  critic_loss: float, global_epoch: int, global_epochs: int) -> None:
@@ -580,7 +580,7 @@ class MrlPrintCallback(MrlTrainerCallback):
580
580
  print(f'Epoch {epoch} | Step {step} - updated policy loss {critic_loss}')
581
581
 
582
582
  def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
583
- print(f'Finished training for {curriculum_config['steps']} steps in {curriculum_config['strategy']} strategy.')
583
+ print(f'Finished training for {curriculum_config["steps"]} steps in {curriculum_config["strategy"]} strategy.')
584
584
 
585
585
  def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> None:
586
586
  print(f'Eval epoch {epoch} - mean reward {eval_mean_reward}')
rxnn/training/mrl.py CHANGED
@@ -9,7 +9,7 @@ import random, os
9
9
  from ..transformers.sampler import BatchSampler
10
10
  from .callbacks import MrlTrainerCallback
11
11
  from .dataset import MrlCurriculumDataset
12
- from .utils import smart_concat, smart_concat_critic_states, SpecialTokenIds, TokenizedDict
12
+ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
13
13
  from .rl import RlAlgorithm
14
14
  from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
@@ -74,7 +74,6 @@ class MRLTrainer:
74
74
  sampler_config: Optional[SamplerConfig] = None,
75
75
  log_dir: str = None,
76
76
  pad_token_id: int = 0,
77
- start_token_id: int = 2,
78
77
  end_token_id: int = 3,
79
78
  use_ddp: bool = False,
80
79
  use_amp: bool = False,
@@ -112,11 +111,7 @@ class MRLTrainer:
112
111
  top_p=None,
113
112
  ) if sampler_config is None else sampler_config
114
113
 
115
- self.special_token_ids: SpecialTokenIds = {
116
- 'pad': pad_token_id,
117
- 'bos': start_token_id,
118
- 'eos': end_token_id,
119
- }
114
+ self.pad_token_id = pad_token_id
120
115
 
121
116
  self.use_ddp = use_ddp
122
117
  self.use_amp = use_amp
@@ -191,12 +186,12 @@ class MRLTrainer:
191
186
  if self.use_amp:
192
187
  with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
193
188
  # 2. Concatenate batch of queries and answers (they are already on training device)
194
- inputs = smart_concat(query, answer, self.max_seq_len, self.special_token_ids['pad'])
189
+ inputs = smart_concat(query, answer, self.max_seq_len, self.pad_token_id)
195
190
  # 3. Encode data and update STM
196
191
  self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
197
192
  else:
198
193
  # 2. Concatenate batch of queries and answers (they are already on training device)
199
- inputs = smart_concat(query, answer, self.max_seq_len, self.special_token_ids['pad'])
194
+ inputs = smart_concat(query, answer, self.max_seq_len, self.pad_token_id)
200
195
  # 3. Encode data and update STM
201
196
  self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'], action=MrlActorAction.UPDATE)
202
197
 
@@ -235,11 +230,11 @@ class MRLTrainer:
235
230
  if self.use_amp:
236
231
  with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
237
232
  saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
238
- pad_token_id=self.special_token_ids['pad'])
233
+ pad_token_id=self.pad_token_id)
239
234
  reward = self.reward(generated, reference, saved_interaction, mode=mode)
240
235
  else:
241
236
  saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
242
- pad_token_id=self.special_token_ids['pad'])
237
+ pad_token_id=self.pad_token_id)
243
238
  reward = self.reward(generated, reference, saved_interaction, mode=mode)
244
239
 
245
240
  # 2. Run 'on reward' callbacks
@@ -404,7 +399,7 @@ class MRLTrainer:
404
399
  inputs = smart_concat_critic_states(
405
400
  prev_query, prev_answer, next_query,
406
401
  max_length=self.critic_max_len,
407
- pad_token_id=self.special_token_ids['pad'],
402
+ pad_token_id=self.pad_token_id,
408
403
  )
409
404
  loss = self._critic_loss(inputs, batch_rewards)
410
405
  # Run backpropagation with scaler
@@ -420,7 +415,7 @@ class MRLTrainer:
420
415
  inputs = smart_concat_critic_states(
421
416
  prev_query, prev_answer, next_query,
422
417
  max_length=self.critic_max_len,
423
- pad_token_id=self.special_token_ids['pad'],
418
+ pad_token_id=self.pad_token_id,
424
419
  )
425
420
  # Calculate loss
426
421
  loss = self._critic_loss(inputs, reward.to(self.device, dtype=self.dtype))
@@ -486,11 +481,11 @@ class MRLTrainer:
486
481
  with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
487
482
  critic_state = smart_concat_critic_states(query, answer, next_query,
488
483
  max_length=self.critic_max_len,
489
- pad_token_id=self.special_token_ids['pad'])
484
+ pad_token_id=self.pad_token_id)
490
485
  advantages = self._critic_advantages(critic_state, rewards)
491
486
  else:
492
487
  critic_state = smart_concat_critic_states(query, answer, next_query, max_length=self.critic_max_len,
493
- pad_token_id=self.special_token_ids['pad'])
488
+ pad_token_id=self.pad_token_id)
494
489
  advantages = self._critic_advantages(critic_state, rewards)
495
490
 
496
491
  # 5. Encode and update STM on each step, to include encoder and memory attention gradients in loss
@@ -499,12 +494,12 @@ class MRLTrainer:
499
494
  if self.use_amp:
500
495
  with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
501
496
  inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
502
- pad_token_id=self.special_token_ids['pad'])
497
+ pad_token_id=self.pad_token_id)
503
498
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
504
499
  action=MrlActorAction.DECODE)
505
500
  else:
506
501
  inputs = smart_concat(next_query, action, max_length=self.max_seq_len,
507
- pad_token_id=self.special_token_ids['pad'])
502
+ pad_token_id=self.pad_token_id)
508
503
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
509
504
  action=MrlActorAction.DECODE)
510
505
 
rxnn/training/reward.py CHANGED
@@ -33,13 +33,6 @@ class MrlRewardModel:
33
33
  self.device = device
34
34
  self.bleu_with_saved_data = bleu_with_saved_data
35
35
 
36
- if not allow_not_summing_factors:
37
- assert bleu_factor + cos_factor == 1.0
38
- assert cos_ref_factor + cos_saved_factor == 1.0
39
- assert neg_bleu_factor + neg_cos_factor == 1.0
40
- assert neg_cos_ref_factor + neg_cos_saved_factor == 1.0
41
- assert neg_bleu_ref_factor + neg_bleu_saved_factor == 1.0
42
-
43
36
  self.bleu_factor = bleu_factor
44
37
  self.cos_factor = cos_factor
45
38
  self.cos_ref_factor = cos_ref_factor
@@ -51,6 +44,13 @@ class MrlRewardModel:
51
44
  self.neg_bleu_ref_factor = neg_bleu_ref_factor
52
45
  self.neg_bleu_saved_factor = neg_bleu_saved_factor
53
46
 
47
+ if not allow_not_summing_factors:
48
+ assert self.bleu_factor + self.cos_factor == 1.0
49
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
50
+ assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
51
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
52
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
53
+
54
54
  def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
55
55
  from nltk.translate.bleu_score import sentence_bleu
56
56
  refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
rxnn/training/utils.py CHANGED
@@ -1,10 +1,6 @@
1
1
  import torch
2
2
  from typing import TypedDict
3
3
 
4
- class SpecialTokenIds(TypedDict):
5
- bos: int
6
- eos: int
7
- pad: int
8
4
 
9
5
  class TokenizedDict(TypedDict):
10
6
  input_ids: torch.Tensor
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -9,19 +9,19 @@ rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
10
  rxnn/memory/stm.py,sha256=eSMK5KdupWNf56FcDYprHnjA51EeYBzSKza7tiZxKSc,3618
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=zNrf6mn-s2vJyauHwNgYm_e-gFI1clmXp_JyCKGQD3E,12083
12
+ rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
- rxnn/training/callbacks.py,sha256=aqi8CfXUWnjMDbELYC5BPBbYyq0YiMicyVaTIr778DY,35053
16
+ rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
17
17
  rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=KcGvBWlBcFJ5GSwd4lx3pUXKlcyeNgJYPZAk3DRMH48,39179
20
- rxnn/training/reward.py,sha256=bjm8ya-HFIRA56JvQgnhtotKEpt8yw6yacVTV_SDpm4,5564
19
+ rxnn/training/mrl.py,sha256=WDQ8xsrHfpRmTczDZhBuOlqHX8JBaEp5SchlTdAxttY,38883
20
+ rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
21
21
  rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
- rxnn/training/utils.py,sha256=c-6aBaLnKeGfMW6Sp29z3FPLj5hdV3pyGJ2rZMcKs2s,5775
24
+ rxnn/training/utils.py,sha256=7ED5RIC8AybCmmQrbsU6Krd7brRILxVIeTlJYtJWl_4,5702
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.2.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.2.dist-info/METADATA,sha256=GlH7tyaDt27dzlp7G3CafWLAic8S5dTd-eiYKzDNQlA,25959
37
- rxnn-0.2.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.2.dist-info/RECORD,,
35
+ rxnn-0.2.4.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.4.dist-info/METADATA,sha256=8qcHy1ysyg_6GiNe5Jd0sxsix9rPBDR_RhYgvCodK28,25959
37
+ rxnn-0.2.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.4.dist-info/RECORD,,
File without changes
File without changes