rxnn 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/stm.py CHANGED
@@ -43,27 +43,29 @@ class ShortTermMemory(nn.Module):
43
43
 
44
44
  def update_all(self, new_stm: torch.Tensor):
45
45
  self.memory.copy_(new_stm)
46
+ print(self.memory.size())
46
47
 
47
48
  def make_trainable(self):
48
49
  if not self.is_trainable:
49
50
  self.is_trainable = True
50
51
  initial_stm = self.memory.clone()
51
- del self.memory
52
+ delattr(self, 'memory')
52
53
  self.memory = nn.Parameter(initial_stm)
53
54
 
54
55
  def freeze(self):
55
56
  if self.is_trainable:
56
57
  self.requires_grad_(False)
57
58
  trained_stm = self.memory.clone()
58
- del self.memory
59
+ delattr(self, 'memory')
59
60
  self.register_buffer('memory', trained_stm)
60
61
 
61
62
  def reset(self, init_type: str = None):
62
- self.memory = self._init_tensor(init_type)
63
+ self.memory.copy_(self._init_tensor(init_type))
63
64
 
64
65
  def resize(self, new_stm_size: int, init_type: str = None):
65
66
  self.stm_size = new_stm_size
66
- self.memory = self._init_tensor(init_type)
67
+ delattr(self, 'memory')
68
+ self.register_buffer('memory', self._init_tensor(init_type))
67
69
 
68
70
  def batched_memory(self, batch_size: int, init_type: str = None):
69
71
  if init_type is not None:
@@ -71,7 +73,8 @@ class ShortTermMemory(nn.Module):
71
73
  'STM init type must be one of "normal", "standard", "uniform", "ones", "zeros"'
72
74
  self.init_type = init_type
73
75
  self.batch_size = batch_size
74
- self.memory = self._init_tensor()
76
+ delattr(self, 'memory')
77
+ self.register_buffer('memory', self._init_tensor())
75
78
 
76
79
  def single_memory(self, init_type: str = None, use_mean_from_batch: bool = False):
77
80
  if init_type is not None:
@@ -81,7 +84,9 @@ class ShortTermMemory(nn.Module):
81
84
  self.batch_size = 1
82
85
  if use_mean_from_batch:
83
86
  batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
84
- self.memory = self._init_tensor()
87
+ delattr(self, 'memory')
88
+ self.register_buffer('memory', self._init_tensor())
85
89
  self.memory.copy_(batch_mean)
86
90
  else:
87
- self.memory = self._init_tensor()
91
+ delattr(self, 'memory')
92
+ self.register_buffer('memory', self._init_tensor())
rxnn/training/mrl.py CHANGED
@@ -275,7 +275,7 @@ class MRLTrainer:
275
275
  self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
276
276
  self.stage_step['collect'])
277
277
 
278
- def collect_trajectories(self, dataloader: DataLoader, epoch: int) -> list[MrlTrajectoryEpisode]:
278
+ def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
279
279
  """Collect trajectories for PPO for current curriculum step."""
280
280
  # 1. Init trajectories list
281
281
  trajectories = []
@@ -283,79 +283,82 @@ class MRLTrainer:
283
283
  with torch.no_grad():
284
284
  # 2. Collect episode trajectories for all batches in dataset
285
285
  for batch_idx, batch in enumerate(dataloader):
286
- self._increment_steps('collect')
287
- # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
288
- # state from existing one, instead of new random one)
289
- reset_done = self.reset_stm()
290
-
291
- # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
292
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
293
- interactions = interactions[:self.curriculum_steps]
294
- interactions_len = len(interactions)
295
- # 5. Encode and update STM with data to save from first interaction
296
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
297
-
298
- # 6. Save first interaction as data to save (for trajectory state)
299
- query, answer = first_query, first_answer
300
-
301
- # 7. Run training strategy for follow-up interactions
302
- episode_steps = []
303
- episode_rewards = []
304
- for i, interaction in enumerate(interactions):
305
- # 8. Generate batch of answers based on batch of follow-up queries
306
- next_query = self._move_batch(interaction['query'])
307
- generated_answer, log_probs = self.generate_answer(next_query)
308
-
309
- is_last_interaction = (i + 1) == interactions_len
310
-
311
- detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
312
-
313
- # 9. Depending on strategy compute reward
314
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
315
- # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
316
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
317
- mode=MrlRewardMode.NEGATIVE)
318
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
319
- # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
320
- reward = self.compute_reward(detached_answer, interaction['answer'],
321
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
322
- else:
323
- # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
324
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
325
- mode=MrlRewardMode.STANDARD)
326
-
327
- # 10. Update STM with generated response (except last interaction, it's not needed)
328
- if not is_last_interaction:
329
- self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
330
-
331
- # 11. Store trajectory step
332
- trajectory: MrlTrajectoryStep = {
333
- 'state': (query, answer, interaction['query']),
334
- 'action': detached_answer,
335
- 'log_probs': log_probs.detach().cpu(),
336
- 'reward': reward,
337
- 'reference': interaction['answer'],
286
+ if batch['query']['input_ids'].size(0) != batch_size:
287
+ print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
288
+ else:
289
+ self._increment_steps('collect')
290
+ # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
291
+ # state from existing one, instead of new random one)
292
+ reset_done = self.reset_stm()
293
+
294
+ # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
295
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
296
+ interactions = interactions[:self.curriculum_steps]
297
+ interactions_len = len(interactions)
298
+ # 5. Encode and update STM with data to save from first interaction
299
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
300
+
301
+ # 6. Save first interaction as data to save (for trajectory state)
302
+ query, answer = first_query, first_answer
303
+
304
+ # 7. Run training strategy for follow-up interactions
305
+ episode_steps = []
306
+ episode_rewards = []
307
+ for i, interaction in enumerate(interactions):
308
+ # 8. Generate batch of answers based on batch of follow-up queries
309
+ next_query = self._move_batch(interaction['query'])
310
+ generated_answer, log_probs = self.generate_answer(next_query)
311
+
312
+ is_last_interaction = (i + 1) == interactions_len
313
+
314
+ detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
315
+
316
+ # 9. Depending on strategy compute reward
317
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
318
+ # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
319
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
320
+ mode=MrlRewardMode.NEGATIVE)
321
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
322
+ # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
323
+ reward = self.compute_reward(detached_answer, interaction['answer'],
324
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
325
+ else:
326
+ # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
327
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
328
+ mode=MrlRewardMode.STANDARD)
329
+
330
+ # 10. Update STM with generated response (except last interaction, it's not needed)
331
+ if not is_last_interaction:
332
+ self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
333
+
334
+ # 11. Store trajectory step
335
+ trajectory: MrlTrajectoryStep = {
336
+ 'state': (query, answer, interaction['query']),
337
+ 'action': detached_answer,
338
+ 'log_probs': log_probs.detach().cpu(),
339
+ 'reward': reward,
340
+ 'reference': interaction['answer'],
341
+ }
342
+ episode_steps.append(trajectory)
343
+ episode_rewards.append(reward)
344
+
345
+ # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
346
+ query, answer = interaction['query'], detached_answer
347
+
348
+ # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
349
+ episode_trajectory: MrlTrajectoryEpisode = {
350
+ 'reset_stm': reset_done,
351
+ 'steps': episode_steps,
338
352
  }
339
- episode_steps.append(trajectory)
340
- episode_rewards.append(reward)
341
-
342
- # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
343
- query, answer = interaction['query'], detached_answer
353
+ trajectories.append(episode_trajectory)
344
354
 
345
- # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
346
- episode_trajectory: MrlTrajectoryEpisode = {
347
- 'reset_stm': reset_done,
348
- 'steps': episode_steps,
349
- }
350
- trajectories.append(episode_trajectory)
355
+ mean_episode_reward = torch.tensor(episode_rewards).mean().item()
351
356
 
352
- mean_episode_reward = torch.tensor(episode_rewards).mean().item()
357
+ self._collect_writer(mean_episode_reward, epoch)
353
358
 
354
- self._collect_writer(mean_episode_reward, epoch)
355
-
356
- # 14. Run "on episode collected" callbacks
357
- for cb in self.callbacks:
358
- cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
359
+ # 14. Run "on episode collected" callbacks
360
+ for cb in self.callbacks:
361
+ cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
359
362
 
360
363
  return trajectories
361
364
 
@@ -539,10 +542,10 @@ class MRLTrainer:
539
542
  rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
540
543
  return states, rewards
541
544
 
542
- def train_epoch(self, dataloader: DataLoader, epoch: int):
545
+ def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
543
546
  """Train for one epoch."""
544
547
  # 1. Collect trajectories for current epoch
545
- trajectories = self.collect_trajectories(dataloader, epoch)
548
+ trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
546
549
 
547
550
  # 2. Flatten trajectories and collect state and rewards for critic update
548
551
  states, rewards = self._critic_states_and_rewards(trajectories)
@@ -696,6 +699,7 @@ class MRLTrainer:
696
699
 
697
700
  # 0. Set global epoch count for all stages
698
701
  self.global_epochs_count = sum(stage['epochs'] for stage in curriculum_config)
702
+ self.global_epoch = 0
699
703
 
700
704
  # 1. Init DDP for distributed training mode
701
705
  if self.use_ddp:
@@ -765,7 +769,7 @@ class MRLTrainer:
765
769
  train_sampler.set_epoch(epoch)
766
770
 
767
771
  # 13. Run reinforcement learning algorithms for current epoch
768
- policy_loss, critic_loss = self.train_epoch(dataloader, epoch)
772
+ policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
769
773
 
770
774
  # 14. If evaluation dataset is provided, run evaluation steps
771
775
  if self.eval_dataset:
rxnn/training/utils.py CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
127
127
  # Build combined_ids using vectorized where
128
128
  combined_ids = torch.where(
129
129
  query_mask,
130
- query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)),
130
+ query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)).to(torch.int64),
131
131
  torch.where(
132
132
  answer_mask,
133
133
  answer['input_ids'].gather(1, answer_pos),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
- rxnn/memory/stm.py,sha256=eSMK5KdupWNf56FcDYprHnjA51EeYBzSKza7tiZxKSc,3618
10
+ rxnn/memory/stm.py,sha256=eIxbmOh7SSI3YDim6ki2JgiCel0sDlLnD_qSqcrf7Wc,3881
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,12 +16,12 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
17
17
  rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=WDQ8xsrHfpRmTczDZhBuOlqHX8JBaEp5SchlTdAxttY,38883
19
+ rxnn/training/mrl.py,sha256=G1sxKrnfkggHRqSInjk17PfpNhj5neqfh5Y2RUabJLk,39392
20
20
  rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
21
21
  rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
- rxnn/training/utils.py,sha256=7ED5RIC8AybCmmQrbsU6Krd7brRILxVIeTlJYtJWl_4,5702
24
+ rxnn/training/utils.py,sha256=U4fVNNomwAs_XqzqAdqtlPY1UeAl6E6oKOOSB0-Ckx0,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.4.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.4.dist-info/METADATA,sha256=8qcHy1ysyg_6GiNe5Jd0sxsix9rPBDR_RhYgvCodK28,25959
37
- rxnn-0.2.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.4.dist-info/RECORD,,
35
+ rxnn-0.2.6.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.6.dist-info/METADATA,sha256=eoxdTixI_-XuziJ64YG3QGa8TD0f06q95stHWgNzSjE,25959
37
+ rxnn-0.2.6.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.6.dist-info/RECORD,,
File without changes
File without changes