rxnn 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/dataset.py CHANGED
@@ -977,6 +977,7 @@ class MrlCurriculumDataset(Dataset):
977
977
  answer_field: str = 'answer',
978
978
  interactions_field: str = 'interactions',
979
979
  load_kwargs: dict = None,
980
+ max_seq_len: int = 1024,
980
981
  **kwargs
981
982
  ):
982
983
  """
@@ -993,6 +994,7 @@ class MrlCurriculumDataset(Dataset):
993
994
  answer_field (str): Answer field (default: "answer")
994
995
  interactions_field (str): Interactions field (default: "interactions")
995
996
  load_kwargs (dict): Additional args for HuggingFace API load_dataset function
997
+ max_seq_len (int): Maximum sequence length (default: 1024)
996
998
  **kwargs: Additional args for RxNN Dataset class
997
999
  """
998
1000
  if load_kwargs is None:
@@ -1000,7 +1002,7 @@ class MrlCurriculumDataset(Dataset):
1000
1002
 
1001
1003
  hf_dataset = load_dataset(dataset_id, mrl_subset, split=split, **load_kwargs)
1002
1004
 
1003
- return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, **kwargs)
1005
+ return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, max_seq_len=max_seq_len, **kwargs)
1004
1006
 
1005
1007
  @staticmethod
1006
1008
  def collate_mrl_batch(batch: list[MrlDataItem]) -> MrlDataItem:
@@ -1098,6 +1100,7 @@ class MrlDatasets:
1098
1100
  load_kwargs: dict = None,
1099
1101
  mrl_ds_kwargs: dict = None,
1100
1102
  eval_split: str = None,
1103
+ max_seq_len: int = 256,
1101
1104
  ):
1102
1105
  """
1103
1106
  Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
@@ -1115,6 +1118,7 @@ class MrlDatasets:
1115
1118
  load_kwargs (dict): Additional args for HuggingFace API load_dataset function
1116
1119
  mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
1117
1120
  eval_split (str): Load also evaluation/validation split (default: None)
1121
+ max_seq_len (int): Maximum sequence length (default: 256)
1118
1122
  """
1119
1123
  if load_kwargs is None:
1120
1124
  load_kwargs = {}
@@ -1131,6 +1135,7 @@ class MrlDatasets:
1131
1135
  interactions_field=interactions_field,
1132
1136
  split=load_split,
1133
1137
  load_kwargs=load_kwargs,
1138
+ max_seq_len=max_seq_len,
1134
1139
  **mrl_ds_kwargs,
1135
1140
  )
1136
1141
 
rxnn/training/mrl.py CHANGED
@@ -92,8 +92,8 @@ class MRLTrainer:
92
92
  self.critic = critic
93
93
  self.reward = reward
94
94
  self.device = device
95
- self.max_seq_len = config.get('max_seq_len', 1024)
96
- self.critic_max_len = config.get('critic_max_len', 2048)
95
+ self.max_seq_len = config.get('max_seq_len', 256)
96
+ self.critic_max_len = config.get('critic_max_len', 512)
97
97
 
98
98
  # Move models to device
99
99
  if use_amp:
@@ -306,6 +306,7 @@ class MRLTrainer:
306
306
  for i, interaction in enumerate(interactions):
307
307
  # 8. Generate batch of answers based on batch of follow-up queries
308
308
  next_query = self._move_batch(interaction['query'])
309
+ print(next_query['input_ids'].size())
309
310
  generated_answer, log_probs = self.generate_answer(next_query)
310
311
 
311
312
  is_last_interaction = (i + 1) == interactions_len
@@ -365,7 +366,7 @@ class MRLTrainer:
365
366
  def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
366
367
  # 1. Calculate values with critic encoder
367
368
  values = self.critic(
368
- input_ids=inputs['input_ids'],
369
+ inputs['input_ids'],
369
370
  attention_mask=inputs['attention_mask'],
370
371
  ).squeeze()
371
372
  # 2. Calculate critic loss
@@ -461,6 +462,7 @@ class MRLTrainer:
461
462
  # 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
462
463
  # memory, based on collected episode data
463
464
  all_losses = []
465
+ trajectories_len = len(trajectories)
464
466
  for episode_idx, episode in enumerate(trajectories):
465
467
  episode_steps = episode['steps']
466
468
  should_reset_stm = episode['reset_stm']
@@ -506,7 +508,7 @@ class MRLTrainer:
506
508
  action=MrlActorAction.DECODE)
507
509
 
508
510
  # 7. Calculate RL Algorithm (PPO etc.) loss
509
- policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
511
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, inputs['input_ids'], logits, log_probs, advantages)
510
512
 
511
513
  # 8. Reset gradients
512
514
  self.optimizer.zero_grad()
@@ -593,63 +595,64 @@ class MRLTrainer:
593
595
  # 2. Run evaluation on all batch episodes
594
596
  for batch in dataloader:
595
597
  with torch.no_grad():
596
- self._increment_steps('eval')
597
- # 3. Reset STM with random resets ratio
598
- self.reset_stm()
599
-
600
- # 4. Get batches for first queries, answers and all follow-up interactions
601
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
602
- # 5. Encode and update STM with initial interactions (batch)
603
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
604
-
605
- # 6. Save follow-up interactions len and first query and answer as previous one for iteration
606
- interactions_len = len(interactions)
607
- query, answer = first_query, first_answer
608
- episode_reward = torch.tensor(0.0).to(self.device)
609
- episode_interactions = torch.tensor(0).to(self.device)
610
- # 7. Run all follow-up interactions
611
- for i, interaction in enumerate(interactions):
612
- # 8. Generate batch of answers
613
- next_query = self._move_batch(interaction['query'])
614
- generated_answer, _ = self.generate_answer(next_query)
615
-
616
- is_last_interaction = (i + 1) == interactions_len
617
-
618
- detached_answer = self._cpu_detach(generated_answer)
619
-
620
- # 9. Depending on current strategy and step, compute reward
621
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
622
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
623
- mode=MrlRewardMode.NEGATIVE, eval_mode=True)
624
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
625
- reward = self.compute_reward(detached_answer, interaction['answer'],
626
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
627
- eval_mode=True)
628
- else:
629
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
630
- mode=MrlRewardMode.STANDARD, eval_mode=True)
631
-
632
- # 10. Encode and update memory for the next interaction
633
- if not is_last_interaction:
634
- self.encode_and_update_stm(next_query, generated_answer)
635
-
636
- # 11. Accumulate rewards
637
- step_reward = torch.tensor(reward).mean().to(self.device)
638
- # total
639
- total_reward += step_reward
640
- count += 1
641
- # episode
642
- episode_reward += step_reward
643
- episode_interactions += 1
644
- # 12. Save previous interaction
645
- query, answer = interaction['query'], detached_answer
646
- avg_episode_reward = (episode_reward / episode_interactions).item()
647
- # 13. Run eval TensorBoard writer with average episode reward
648
- self._eval_writer(avg_episode_reward, epoch)
649
-
650
- # 14. Run "on eval episode end" callbacks
651
- for cb in self.callbacks:
652
- cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
598
+ if batch['query']['input_ids'].size(0) == batch_size:
599
+ self._increment_steps('eval')
600
+ # 3. Reset STM with random resets ratio
601
+ self.reset_stm()
602
+
603
+ # 4. Get batches for first queries, answers and all follow-up interactions
604
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
605
+ # 5. Encode and update STM with initial interactions (batch)
606
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
607
+
608
+ # 6. Save follow-up interactions len and first query and answer as previous one for iteration
609
+ interactions_len = len(interactions)
610
+ query, answer = first_query, first_answer
611
+ episode_reward = torch.tensor(0.0).to(self.device)
612
+ episode_interactions = torch.tensor(0).to(self.device)
613
+ # 7. Run all follow-up interactions
614
+ for i, interaction in enumerate(interactions):
615
+ # 8. Generate batch of answers
616
+ next_query = self._move_batch(interaction['query'])
617
+ generated_answer, _ = self.generate_answer(next_query)
618
+
619
+ is_last_interaction = (i + 1) == interactions_len
620
+
621
+ detached_answer = self._cpu_detach(generated_answer)
622
+
623
+ # 9. Depending on current strategy and step, compute reward
624
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
625
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
626
+ mode=MrlRewardMode.NEGATIVE, eval_mode=True)
627
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
628
+ reward = self.compute_reward(detached_answer, interaction['answer'],
629
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
630
+ eval_mode=True)
631
+ else:
632
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
633
+ mode=MrlRewardMode.STANDARD, eval_mode=True)
634
+
635
+ # 10. Encode and update memory for the next interaction
636
+ if not is_last_interaction:
637
+ self.encode_and_update_stm(next_query, generated_answer)
638
+
639
+ # 11. Accumulate rewards
640
+ step_reward = torch.tensor(reward).mean().to(self.device)
641
+ # total
642
+ total_reward += step_reward
643
+ count += 1
644
+ # episode
645
+ episode_reward += step_reward
646
+ episode_interactions += 1
647
+ # 12. Save previous interaction
648
+ query, answer = interaction['query'], detached_answer
649
+ avg_episode_reward = (episode_reward / episode_interactions).item()
650
+ # 13. Run eval TensorBoard writer with average episode reward
651
+ self._eval_writer(avg_episode_reward, epoch)
652
+
653
+ # 14. Run "on eval episode end" callbacks
654
+ for cb in self.callbacks:
655
+ cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
653
656
 
654
657
  # 15. Calculate average reward
655
658
  if self.use_ddp:
@@ -804,3 +807,4 @@ class MRLTrainer:
804
807
  # 21. Close writer
805
808
  if self.writer:
806
809
  self.writer.close()
810
+
rxnn/training/rl.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import TypedDict
6
+ from .utils import TokenizedDict
6
7
 
7
8
 
8
9
  class RlAlgorithm(ABC):
@@ -11,7 +12,8 @@ class RlAlgorithm(ABC):
11
12
  self.critic_loss = nn.MSELoss()
12
13
 
13
14
  @abstractmethod
14
- def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
+ def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
16
+ old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
17
  pass
16
18
 
17
19
  @abstractmethod
@@ -35,10 +37,38 @@ class PPOAlgorithm(RlAlgorithm):
35
37
  self.gae_lambda = config.get('gae_lambda', 0.95)
36
38
  self.clip_eps = config.get('clip_eps', 0.2)
37
39
 
38
- def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
40
+ def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
41
+ old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
42
+ query_lens = query['attention_mask'].sum(dim=1).long() # Query lengths per sample
43
+ answer_mask = answer['attention_mask']
44
+ answer_lens = answer_mask.sum(dim=1).long() # Answer lengths per sample (before padding)
45
+
46
+ max_length = query['input_ids'].size(1)
47
+
48
+ combined_lens = torch.minimum(
49
+ query_lens + answer_lens,
50
+ torch.full_like(query_lens, max_length)
51
+ )
52
+
53
+ def extract_answer_tokens(tensor: torch.Tensor) -> torch.Tensor:
54
+ B, L, *rest = tensor.size()
55
+ result = torch.zeros((B, max_length, *rest), dtype=tensor.dtype, device=tensor.device)
56
+
57
+ for i in range(B):
58
+ s = query_lens[i].item()
59
+ e = combined_lens[i].item()
60
+ valid_len = e - s
61
+ if valid_len > 0:
62
+ result[i, :valid_len] = tensor[i, s:e]
63
+ return result
64
+
65
+ new_logits = extract_answer_tokens(logits)
66
+
39
67
  # a) Get new log probs
40
- new_probs = F.log_softmax(logits, dim=-1)
41
- new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
68
+ new_probs = F.log_softmax(new_logits, dim=-1)
69
+ new_log_probs = new_probs.gather(-1, answer['input_ids'].unsqueeze(-1)).squeeze(-1)
70
+
71
+ new_log_probs = extract_answer_tokens(new_log_probs.unsqueeze(-1)).squeeze(-1) # Ensure 3D for extraction (add singleton dim)
42
72
 
43
73
  # b) Calculate ratio
44
74
  ratio = (new_log_probs - old_log_probs).exp()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.13
3
+ Version: 0.2.15
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -14,11 +14,11 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
- rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
17
+ rxnn/training/dataset.py,sha256=mXUZa6ypTt73sE-G-s9jQ4_Vhp8zw43bjhsLEEPPnDo,50611
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=53uOwotmgwKeceMYA6qXQbQMZmggXt_5hq08X-YwrEY,39327
19
+ rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
20
20
  rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
21
- rxnn/training/rl.py,sha256=s6wPbg0X6y-RX9-5ctZIDpdJPfExI9DzWUy-TvAiiow,2710
21
+ rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
24
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.13.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.13.dist-info/METADATA,sha256=BOn4qig3IKpYiG0NEWHiF_5NWsWboBqVNeGb2-mYesU,25960
37
- rxnn-0.2.13.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.13.dist-info/RECORD,,
35
+ rxnn-0.2.15.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.15.dist-info/METADATA,sha256=ZMMOf5u9DqEQPyYszLG-o51FcVN_3vzE4aDIwsdk-lg,25960
37
+ rxnn-0.2.15.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.15.dist-info/RECORD,,
File without changes
File without changes