rxnn 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/dataset.py CHANGED
@@ -1098,6 +1098,7 @@ class MrlDatasets:
1098
1098
  load_kwargs: dict = None,
1099
1099
  mrl_ds_kwargs: dict = None,
1100
1100
  eval_split: str = None,
1101
+ max_seq_len: int = 256,
1101
1102
  ):
1102
1103
  """
1103
1104
  Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
@@ -1115,6 +1116,7 @@ class MrlDatasets:
1115
1116
  load_kwargs (dict): Additional args for HuggingFace API load_dataset function
1116
1117
  mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
1117
1118
  eval_split (str): Load also evaluation/validation split (default: None)
1119
+ max_seq_len (int): Maximum sequence length (default: 256)
1118
1120
  """
1119
1121
  if load_kwargs is None:
1120
1122
  load_kwargs = {}
@@ -1131,6 +1133,7 @@ class MrlDatasets:
1131
1133
  interactions_field=interactions_field,
1132
1134
  split=load_split,
1133
1135
  load_kwargs=load_kwargs,
1136
+ max_seq_len=max_seq_len,
1134
1137
  **mrl_ds_kwargs,
1135
1138
  )
1136
1139
 
rxnn/training/mrl.py CHANGED
@@ -92,8 +92,8 @@ class MRLTrainer:
92
92
  self.critic = critic
93
93
  self.reward = reward
94
94
  self.device = device
95
- self.max_seq_len = config.get('max_seq_len', 1024)
96
- self.critic_max_len = config.get('critic_max_len', 2048)
95
+ self.max_seq_len = config.get('max_seq_len', 256)
96
+ self.critic_max_len = config.get('critic_max_len', 512)
97
97
 
98
98
  # Move models to device
99
99
  if use_amp:
@@ -306,6 +306,7 @@ class MRLTrainer:
306
306
  for i, interaction in enumerate(interactions):
307
307
  # 8. Generate batch of answers based on batch of follow-up queries
308
308
  next_query = self._move_batch(interaction['query'])
309
+ print(next_query['input_ids'].size())
309
310
  generated_answer, log_probs = self.generate_answer(next_query)
310
311
 
311
312
  is_last_interaction = (i + 1) == interactions_len
@@ -365,7 +366,7 @@ class MRLTrainer:
365
366
  def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
366
367
  # 1. Calculate values with critic encoder
367
368
  values = self.critic(
368
- input_ids=inputs['input_ids'],
369
+ inputs['input_ids'],
369
370
  attention_mask=inputs['attention_mask'],
370
371
  ).squeeze()
371
372
  # 2. Calculate critic loss
@@ -461,6 +462,7 @@ class MRLTrainer:
461
462
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
462
463
  # memory, based on collected episode data
463
464
  all_losses = []
465
+ trajectories_len = len(trajectories)
464
466
  for episode_idx, episode in enumerate(trajectories):
465
467
  episode_steps = episode['steps']
466
468
  should_reset_stm = episode['reset_stm']
@@ -506,7 +508,7 @@ class MRLTrainer:
506
508
  action=MrlActorAction.DECODE)
507
509
 
508
510
  # 7. Calculate RL Algorithm (PPO etc.) loss
509
- policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
511
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, inputs['input_ids'], logits, log_probs, advantages)
510
512
 
511
513
  # 8. Reset gradients
512
514
  self.optimizer.zero_grad()
@@ -593,63 +595,64 @@ class MRLTrainer:
593
595
  # 2. Run evaluation on all batch episodes
594
596
  for batch in dataloader:
595
597
  with torch.no_grad():
596
- self._increment_steps('eval')
597
- # 3. Reset STM with random resets ratio
598
- self.reset_stm()
599
-
600
- # 4. Get batches for first queries, answers and all follow-up interactions
601
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
602
- # 5. Encode and update STM with initial interactions (batch)
603
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
604
-
605
- # 6. Save follow-up interactions len and first query and answer as previous one for iteration
606
- interactions_len = len(interactions)
607
- query, answer = first_query, first_answer
608
- episode_reward = torch.tensor(0.0).to(self.device)
609
- episode_interactions = torch.tensor(0).to(self.device)
610
- # 7. Run all follow-up interactions
611
- for i, interaction in enumerate(interactions):
612
- # 8. Generate batch of answers
613
- next_query = self._move_batch(interaction['query'])
614
- generated_answer, _ = self.generate_answer(next_query)
615
-
616
- is_last_interaction = (i + 1) == interactions_len
617
-
618
- detached_answer = self._cpu_detach(generated_answer)
619
-
620
- # 9. Depending on current strategy and step, compute reward
621
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
622
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
623
- mode=MrlRewardMode.NEGATIVE, eval_mode=True)
624
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
625
- reward = self.compute_reward(detached_answer, interaction['answer'],
626
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
627
- eval_mode=True)
628
- else:
629
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
630
- mode=MrlRewardMode.STANDARD, eval_mode=True)
631
-
632
- # 10. Encode and update memory for the next interaction
633
- if not is_last_interaction:
634
- self.encode_and_update_stm(next_query, generated_answer)
635
-
636
- # 11. Accumulate rewards
637
- step_reward = torch.tensor(reward).mean().to(self.device)
638
- # total
639
- total_reward += step_reward
640
- count += 1
641
- # episode
642
- episode_reward += step_reward
643
- episode_interactions += 1
644
- # 12. Save previous interaction
645
- query, answer = interaction['query'], detached_answer
646
- avg_episode_reward = (episode_reward / episode_interactions).item()
647
- # 13. Run eval TensorBoard writer with average episode reward
648
- self._eval_writer(avg_episode_reward, epoch)
649
-
650
- # 14. Run "on eval episode end" callbacks
651
- for cb in self.callbacks:
652
- cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
598
+ if batch['query']['input_ids'].size(0) == batch_size:
599
+ self._increment_steps('eval')
600
+ # 3. Reset STM with random resets ratio
601
+ self.reset_stm()
602
+
603
+ # 4. Get batches for first queries, answers and all follow-up interactions
604
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
605
+ # 5. Encode and update STM with initial interactions (batch)
606
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
607
+
608
+ # 6. Save follow-up interactions len and first query and answer as previous one for iteration
609
+ interactions_len = len(interactions)
610
+ query, answer = first_query, first_answer
611
+ episode_reward = torch.tensor(0.0).to(self.device)
612
+ episode_interactions = torch.tensor(0).to(self.device)
613
+ # 7. Run all follow-up interactions
614
+ for i, interaction in enumerate(interactions):
615
+ # 8. Generate batch of answers
616
+ next_query = self._move_batch(interaction['query'])
617
+ generated_answer, _ = self.generate_answer(next_query)
618
+
619
+ is_last_interaction = (i + 1) == interactions_len
620
+
621
+ detached_answer = self._cpu_detach(generated_answer)
622
+
623
+ # 9. Depending on current strategy and step, compute reward
624
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
625
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
626
+ mode=MrlRewardMode.NEGATIVE, eval_mode=True)
627
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
628
+ reward = self.compute_reward(detached_answer, interaction['answer'],
629
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
630
+ eval_mode=True)
631
+ else:
632
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
633
+ mode=MrlRewardMode.STANDARD, eval_mode=True)
634
+
635
+ # 10. Encode and update memory for the next interaction
636
+ if not is_last_interaction:
637
+ self.encode_and_update_stm(next_query, generated_answer)
638
+
639
+ # 11. Accumulate rewards
640
+ step_reward = torch.tensor(reward).mean().to(self.device)
641
+ # total
642
+ total_reward += step_reward
643
+ count += 1
644
+ # episode
645
+ episode_reward += step_reward
646
+ episode_interactions += 1
647
+ # 12. Save previous interaction
648
+ query, answer = interaction['query'], detached_answer
649
+ avg_episode_reward = (episode_reward / episode_interactions).item()
650
+ # 13. Run eval TensorBoard writer with average episode reward
651
+ self._eval_writer(avg_episode_reward, epoch)
652
+
653
+ # 14. Run "on eval episode end" callbacks
654
+ for cb in self.callbacks:
655
+ cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
653
656
 
654
657
  # 15. Calculate average reward
655
658
  if self.use_ddp:
@@ -804,3 +807,4 @@ class MRLTrainer:
804
807
  # 21. Close writer
805
808
  if self.writer:
806
809
  self.writer.close()
810
+
rxnn/training/rl.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import TypedDict
6
+ from .utils import TokenizedDict
6
7
 
7
8
 
8
9
  class RlAlgorithm(ABC):
@@ -11,7 +12,8 @@ class RlAlgorithm(ABC):
11
12
  self.critic_loss = nn.MSELoss()
12
13
 
13
14
  @abstractmethod
14
- def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
+ def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
16
+ old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
17
  pass
16
18
 
17
19
  @abstractmethod
@@ -35,10 +37,38 @@ class PPOAlgorithm(RlAlgorithm):
35
37
  self.gae_lambda = config.get('gae_lambda', 0.95)
36
38
  self.clip_eps = config.get('clip_eps', 0.2)
37
39
 
38
- def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
40
+ def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
41
+ old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
42
+ query_lens = query['attention_mask'].sum(dim=1).long() # Query lengths per sample
43
+ answer_mask = answer['attention_mask']
44
+ answer_lens = answer_mask.sum(dim=1).long() # Answer lengths per sample (before padding)
45
+
46
+ max_length = query['input_ids'].size(1)
47
+
48
+ combined_lens = torch.minimum(
49
+ query_lens + answer_lens,
50
+ torch.full_like(query_lens, max_length)
51
+ )
52
+
53
+ def extract_answer_tokens(tensor: torch.Tensor) -> torch.Tensor:
54
+ B, L, *rest = tensor.size()
55
+ result = torch.zeros((B, max_length, *rest), dtype=tensor.dtype, device=tensor.device)
56
+
57
+ for i in range(B):
58
+ s = query_lens[i].item()
59
+ e = combined_lens[i].item()
60
+ valid_len = e - s
61
+ if valid_len > 0:
62
+ result[i, :valid_len] = tensor[i, s:e]
63
+ return result
64
+
65
+ new_logits = extract_answer_tokens(logits)
66
+
39
67
  # a) Get new log probs
40
- new_probs = F.log_softmax(logits, dim=-1)
41
- new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
68
+ new_probs = F.log_softmax(new_logits, dim=-1)
69
+ new_log_probs = new_probs.gather(-1, answer['input_ids'].unsqueeze(-1)).squeeze(-1)
70
+
71
+ new_log_probs = extract_answer_tokens(new_log_probs.unsqueeze(-1)).squeeze(-1) # Ensure 3D for extraction (add singleton dim)
42
72
 
43
73
  # b) Calculate ratio
44
74
  ratio = (new_log_probs - old_log_probs).exp()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.13
3
+ Version: 0.2.14
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -14,11 +14,11 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
- rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
17
+ rxnn/training/dataset.py,sha256=m1opjNA7XHl6Ys-NtERM00c0BLN2xuu84lsfXp-3GQA,50478
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=53uOwotmgwKeceMYA6qXQbQMZmggXt_5hq08X-YwrEY,39327
19
+ rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
20
20
  rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
21
- rxnn/training/rl.py,sha256=s6wPbg0X6y-RX9-5ctZIDpdJPfExI9DzWUy-TvAiiow,2710
21
+ rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
24
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.13.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.13.dist-info/METADATA,sha256=BOn4qig3IKpYiG0NEWHiF_5NWsWboBqVNeGb2-mYesU,25960
37
- rxnn-0.2.13.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.13.dist-info/RECORD,,
35
+ rxnn-0.2.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.14.dist-info/METADATA,sha256=dutamudjxMj9IzykuCONpMyqnU4emEEwvseD4nmKkfs,25960
37
+ rxnn-0.2.14.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.14.dist-info/RECORD,,
File without changes
File without changes