rxnn 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/stm.py CHANGED
@@ -39,6 +39,7 @@ class ShortTermMemory(nn.Module):
39
39
  return self.memory[layer]
40
40
 
41
41
  def update_layer(self, layer: int, new_stm: torch.Tensor):
42
+ self.memory = self.memory.clone()
42
43
  self.memory[layer] = new_stm
43
44
 
44
45
  def update_all(self, new_stm: torch.Tensor):
@@ -60,7 +61,7 @@ class ShortTermMemory(nn.Module):
60
61
  self.register_buffer('memory', trained_stm)
61
62
 
62
63
  def reset(self, init_type: str = None):
63
- self.memory.copy_(self._init_tensor(init_type))
64
+ self.memory = self._init_tensor(init_type).to(self.memory.device)
64
65
 
65
66
  def resize(self, new_stm_size: int, init_type: str = None):
66
67
  self.stm_size = new_stm_size
@@ -85,8 +86,7 @@ class ShortTermMemory(nn.Module):
85
86
  if use_mean_from_batch:
86
87
  batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
87
88
  delattr(self, 'memory')
88
- self.register_buffer('memory', self._init_tensor())
89
- self.memory.copy_(batch_mean)
89
+ self.register_buffer('memory', batch_mean)
90
90
  else:
91
91
  delattr(self, 'memory')
92
92
  self.register_buffer('memory', self._init_tensor())
rxnn/training/dataset.py CHANGED
@@ -1098,6 +1098,7 @@ class MrlDatasets:
1098
1098
  load_kwargs: dict = None,
1099
1099
  mrl_ds_kwargs: dict = None,
1100
1100
  eval_split: str = None,
1101
+ max_seq_len: int = 256,
1101
1102
  ):
1102
1103
  """
1103
1104
  Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
@@ -1115,6 +1116,7 @@ class MrlDatasets:
1115
1116
  load_kwargs (dict): Additional args for HuggingFace API load_dataset function
1116
1117
  mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
1117
1118
  eval_split (str): Load also evaluation/validation split (default: None)
1119
+ max_seq_len (int): Maximum sequence length (default: 256)
1118
1120
  """
1119
1121
  if load_kwargs is None:
1120
1122
  load_kwargs = {}
@@ -1131,6 +1133,7 @@ class MrlDatasets:
1131
1133
  interactions_field=interactions_field,
1132
1134
  split=load_split,
1133
1135
  load_kwargs=load_kwargs,
1136
+ max_seq_len=max_seq_len,
1134
1137
  **mrl_ds_kwargs,
1135
1138
  )
1136
1139
 
rxnn/training/mrl.py CHANGED
@@ -92,8 +92,8 @@ class MRLTrainer:
92
92
  self.critic = critic
93
93
  self.reward = reward
94
94
  self.device = device
95
- self.max_seq_len = config.get('max_seq_len', 1024)
96
- self.critic_max_len = config.get('critic_max_len', 2048)
95
+ self.max_seq_len = config.get('max_seq_len', 256)
96
+ self.critic_max_len = config.get('critic_max_len', 512)
97
97
 
98
98
  # Move models to device
99
99
  if use_amp:
@@ -306,6 +306,7 @@ class MRLTrainer:
306
306
  for i, interaction in enumerate(interactions):
307
307
  # 8. Generate batch of answers based on batch of follow-up queries
308
308
  next_query = self._move_batch(interaction['query'])
309
+ print(next_query['input_ids'].size())
309
310
  generated_answer, log_probs = self.generate_answer(next_query)
310
311
 
311
312
  is_last_interaction = (i + 1) == interactions_len
@@ -365,7 +366,7 @@ class MRLTrainer:
365
366
  def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
366
367
  # 1. Calculate values with critic encoder
367
368
  values = self.critic(
368
- input_ids=inputs['input_ids'],
369
+ inputs['input_ids'],
369
370
  attention_mask=inputs['attention_mask'],
370
371
  ).squeeze()
371
372
  # 2. Calculate critic loss
@@ -507,21 +508,21 @@ class MRLTrainer:
507
508
  action=MrlActorAction.DECODE)
508
509
 
509
510
  # 7. Calculate RL Algorithm (PPO etc.) loss
510
- policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
511
+ policy_loss = self.rl_algorithm.policy_loss(next_query, action, inputs['input_ids'], logits, log_probs, advantages)
511
512
 
512
513
  # 8. Reset gradients
513
514
  self.optimizer.zero_grad()
514
515
 
515
516
  # 9. Update the model in AMP or regular mode
516
517
  if self.use_amp:
517
- self.scaler.scale(policy_loss).backward()
518
+ self.scaler.scale(policy_loss).backward(retain_graph=True)
518
519
  self.scaler.unscale_(self.optimizer)
519
520
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
520
521
  error_if_nonfinite=False)
521
522
  self.scaler.step(self.optimizer)
522
523
  self.scaler.update()
523
524
  else:
524
- policy_loss.backward()
525
+ policy_loss.backward(retain_graph=True)
525
526
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
526
527
  error_if_nonfinite=False)
527
528
  self.optimizer.step()
@@ -594,63 +595,64 @@ class MRLTrainer:
594
595
  # 2. Run evaluation on all batch episodes
595
596
  for batch in dataloader:
596
597
  with torch.no_grad():
597
- self._increment_steps('eval')
598
- # 3. Reset STM with random resets ratio
599
- self.reset_stm()
600
-
601
- # 4. Get batches for first queries, answers and all follow-up interactions
602
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
603
- # 5. Encode and update STM with initial interactions (batch)
604
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
605
-
606
- # 6. Save follow-up interactions len and first query and answer as previous one for iteration
607
- interactions_len = len(interactions)
608
- query, answer = first_query, first_answer
609
- episode_reward = torch.tensor(0.0).to(self.device)
610
- episode_interactions = torch.tensor(0).to(self.device)
611
- # 7. Run all follow-up interactions
612
- for i, interaction in enumerate(interactions):
613
- # 8. Generate batch of answers
614
- next_query = self._move_batch(interaction['query'])
615
- generated_answer, _ = self.generate_answer(next_query)
616
-
617
- is_last_interaction = (i + 1) == interactions_len
618
-
619
- detached_answer = self._cpu_detach(generated_answer)
620
-
621
- # 9. Depending on current strategy and step, compute reward
622
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
623
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
624
- mode=MrlRewardMode.NEGATIVE, eval_mode=True)
625
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
626
- reward = self.compute_reward(detached_answer, interaction['answer'],
627
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
628
- eval_mode=True)
629
- else:
630
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
631
- mode=MrlRewardMode.STANDARD, eval_mode=True)
632
-
633
- # 10. Encode and update memory for the next interaction
634
- if not is_last_interaction:
635
- self.encode_and_update_stm(next_query, generated_answer)
636
-
637
- # 11. Accumulate rewards
638
- step_reward = torch.tensor(reward).mean().to(self.device)
639
- # total
640
- total_reward += step_reward
641
- count += 1
642
- # episode
643
- episode_reward += step_reward
644
- episode_interactions += 1
645
- # 12. Save previous interaction
646
- query, answer = interaction['query'], detached_answer
647
- avg_episode_reward = (episode_reward / episode_interactions).item()
648
- # 13. Run eval TensorBoard writer with average episode reward
649
- self._eval_writer(avg_episode_reward, epoch)
650
-
651
- # 14. Run "on eval episode end" callbacks
652
- for cb in self.callbacks:
653
- cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
598
+ if batch['query']['input_ids'].size(0) == batch_size:
599
+ self._increment_steps('eval')
600
+ # 3. Reset STM with random resets ratio
601
+ self.reset_stm()
602
+
603
+ # 4. Get batches for first queries, answers and all follow-up interactions
604
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
605
+ # 5. Encode and update STM with initial interactions (batch)
606
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
607
+
608
+ # 6. Save follow-up interactions len and first query and answer as previous one for iteration
609
+ interactions_len = len(interactions)
610
+ query, answer = first_query, first_answer
611
+ episode_reward = torch.tensor(0.0).to(self.device)
612
+ episode_interactions = torch.tensor(0).to(self.device)
613
+ # 7. Run all follow-up interactions
614
+ for i, interaction in enumerate(interactions):
615
+ # 8. Generate batch of answers
616
+ next_query = self._move_batch(interaction['query'])
617
+ generated_answer, _ = self.generate_answer(next_query)
618
+
619
+ is_last_interaction = (i + 1) == interactions_len
620
+
621
+ detached_answer = self._cpu_detach(generated_answer)
622
+
623
+ # 9. Depending on current strategy and step, compute reward
624
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
625
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
626
+ mode=MrlRewardMode.NEGATIVE, eval_mode=True)
627
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
628
+ reward = self.compute_reward(detached_answer, interaction['answer'],
629
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
630
+ eval_mode=True)
631
+ else:
632
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
633
+ mode=MrlRewardMode.STANDARD, eval_mode=True)
634
+
635
+ # 10. Encode and update memory for the next interaction
636
+ if not is_last_interaction:
637
+ self.encode_and_update_stm(next_query, generated_answer)
638
+
639
+ # 11. Accumulate rewards
640
+ step_reward = torch.tensor(reward).mean().to(self.device)
641
+ # total
642
+ total_reward += step_reward
643
+ count += 1
644
+ # episode
645
+ episode_reward += step_reward
646
+ episode_interactions += 1
647
+ # 12. Save previous interaction
648
+ query, answer = interaction['query'], detached_answer
649
+ avg_episode_reward = (episode_reward / episode_interactions).item()
650
+ # 13. Run eval TensorBoard writer with average episode reward
651
+ self._eval_writer(avg_episode_reward, epoch)
652
+
653
+ # 14. Run "on eval episode end" callbacks
654
+ for cb in self.callbacks:
655
+ cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
654
656
 
655
657
  # 15. Calculate average reward
656
658
  if self.use_ddp:
@@ -805,3 +807,4 @@ class MRLTrainer:
805
807
  # 21. Close writer
806
808
  if self.writer:
807
809
  self.writer.close()
810
+
rxnn/training/rl.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import TypedDict
6
+ from .utils import TokenizedDict
6
7
 
7
8
 
8
9
  class RlAlgorithm(ABC):
@@ -11,7 +12,8 @@ class RlAlgorithm(ABC):
11
12
  self.critic_loss = nn.MSELoss()
12
13
 
13
14
  @abstractmethod
14
- def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
+ def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
16
+ old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
15
17
  pass
16
18
 
17
19
  @abstractmethod
@@ -35,14 +37,44 @@ class PPOAlgorithm(RlAlgorithm):
35
37
  self.gae_lambda = config.get('gae_lambda', 0.95)
36
38
  self.clip_eps = config.get('clip_eps', 0.2)
37
39
 
38
- def policy_loss(self, input_ids: torch.Tensor, logits: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
40
+ def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
41
+ old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
42
+ query_lens = query['attention_mask'].sum(dim=1).long() # Query lengths per sample
43
+ answer_mask = answer['attention_mask']
44
+ answer_lens = answer_mask.sum(dim=1).long() # Answer lengths per sample (before padding)
45
+
46
+ max_length = query['input_ids'].size(1)
47
+
48
+ combined_lens = torch.minimum(
49
+ query_lens + answer_lens,
50
+ torch.full_like(query_lens, max_length)
51
+ )
52
+
53
+ def extract_answer_tokens(tensor: torch.Tensor) -> torch.Tensor:
54
+ B, L, *rest = tensor.size()
55
+ result = torch.zeros((B, max_length, *rest), dtype=tensor.dtype, device=tensor.device)
56
+
57
+ for i in range(B):
58
+ s = query_lens[i].item()
59
+ e = combined_lens[i].item()
60
+ valid_len = e - s
61
+ if valid_len > 0:
62
+ result[i, :valid_len] = tensor[i, s:e]
63
+ return result
64
+
65
+ new_logits = extract_answer_tokens(logits)
66
+
39
67
  # a) Get new log probs
40
- new_probs = F.log_softmax(logits, dim=-1)
41
- new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
68
+ new_probs = F.log_softmax(new_logits, dim=-1)
69
+ new_log_probs = new_probs.gather(-1, answer['input_ids'].unsqueeze(-1)).squeeze(-1)
70
+
71
+ new_log_probs = extract_answer_tokens(new_log_probs.unsqueeze(-1)).squeeze(-1) # Ensure 3D for extraction (add singleton dim)
42
72
 
43
73
  # b) Calculate ratio
44
74
  ratio = (new_log_probs - old_log_probs).exp()
45
75
 
76
+ advantages = advantages.unsqueeze(-1)
77
+
46
78
  # c) Clipped surrogate loss
47
79
  surr1 = ratio * advantages
48
80
  surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.12
3
+ Version: 0.2.14
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,18 +7,18 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
- rxnn/memory/stm.py,sha256=S5CtPI2KXxjs_vvMtb-w57ZPN3TmvVvU3TBHG2au2VE,3879
10
+ rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
- rxnn/training/dataset.py,sha256=i8EdSJnoPbkuDSdqoYDj-Aig5Se_uPY4lulkD2bdOrs,50331
17
+ rxnn/training/dataset.py,sha256=m1opjNA7XHl6Ys-NtERM00c0BLN2xuu84lsfXp-3GQA,50478
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=mCsg50bX0iqPozvvQB6CeZ0FYEfuj9ln1p-4IaZBryo,39338
19
+ rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
20
20
  rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
21
- rxnn/training/rl.py,sha256=T69gLwDlvMMyLuRaJSRmwzO0Mcu0uLXwhAiBB58VK-Y,2663
21
+ rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
24
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.12.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.12.dist-info/METADATA,sha256=HvEJSZUelxjiAKWAQ3wwbNtNmMsJjxlstZZModU9UMw,25960
37
- rxnn-0.2.12.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.12.dist-info/RECORD,,
35
+ rxnn-0.2.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.14.dist-info/METADATA,sha256=dutamudjxMj9IzykuCONpMyqnU4emEEwvseD4nmKkfs,25960
37
+ rxnn-0.2.14.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.14.dist-info/RECORD,,
File without changes
File without changes