rxnn 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/stm.py CHANGED
@@ -43,6 +43,7 @@ class ShortTermMemory(nn.Module):
43
43
 
44
44
  def update_all(self, new_stm: torch.Tensor):
45
45
  self.memory.copy_(new_stm)
46
+ print(self.memory.size())
46
47
 
47
48
  def make_trainable(self):
48
49
  if not self.is_trainable:
rxnn/training/mrl.py CHANGED
@@ -275,7 +275,7 @@ class MRLTrainer:
275
275
  self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
276
276
  self.stage_step['collect'])
277
277
 
278
- def collect_trajectories(self, dataloader: DataLoader, epoch: int) -> list[MrlTrajectoryEpisode]:
278
+ def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
279
279
  """Collect trajectories for PPO for current curriculum step."""
280
280
  # 1. Init trajectories list
281
281
  trajectories = []
@@ -283,79 +283,89 @@ class MRLTrainer:
283
283
  with torch.no_grad():
284
284
  # 2. Collect episode trajectories for all batches in dataset
285
285
  for batch_idx, batch in enumerate(dataloader):
286
- self._increment_steps('collect')
287
- # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
288
- # state from existing one, instead of new random one)
289
- reset_done = self.reset_stm()
286
+ print('first query size: ', batch['query']['input_ids'].size())
287
+ print('first answer size: ', batch['answer']['input_ids'].size())
288
+ print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
289
+ print('Interactions len: ', len(batch['interactions']))
290
290
 
291
- # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
292
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
293
- interactions = interactions[:self.curriculum_steps]
294
- interactions_len = len(interactions)
295
- # 5. Encode and update STM with data to save from first interaction
296
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
297
-
298
- # 6. Save first interaction as data to save (for trajectory state)
299
- query, answer = first_query, first_answer
300
-
301
- # 7. Run training strategy for follow-up interactions
302
- episode_steps = []
303
- episode_rewards = []
304
- for i, interaction in enumerate(interactions):
305
- # 8. Generate batch of answers based on batch of follow-up queries
306
- next_query = self._move_batch(interaction['query'])
307
- generated_answer, log_probs = self.generate_answer(next_query)
308
-
309
- is_last_interaction = (i + 1) == interactions_len
310
-
311
- detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
312
-
313
- # 9. Depending on strategy compute reward
314
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
315
- # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
316
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
317
- mode=MrlRewardMode.NEGATIVE)
318
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
319
- # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
320
- reward = self.compute_reward(detached_answer, interaction['answer'],
321
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
322
- else:
323
- # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
324
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
325
- mode=MrlRewardMode.STANDARD)
326
-
327
- # 10. Update STM with generated response (except last interaction, it's not needed)
328
- if not is_last_interaction:
329
- self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
330
-
331
- # 11. Store trajectory step
332
- trajectory: MrlTrajectoryStep = {
333
- 'state': (query, answer, interaction['query']),
334
- 'action': detached_answer,
335
- 'log_probs': log_probs.detach().cpu(),
336
- 'reward': reward,
337
- 'reference': interaction['answer'],
291
+ if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
292
+ print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
293
+ else:
294
+ self._increment_steps('collect')
295
+ # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
296
+ # state from existing one, instead of new random one)
297
+ reset_done = self.reset_stm()
298
+
299
+ # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
300
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
301
+ interactions = interactions[:self.curriculum_steps]
302
+ interactions_len = len(interactions)
303
+ # 5. Encode and update STM with data to save from first interaction
304
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
305
+
306
+ # 6. Save first interaction as data to save (for trajectory state)
307
+ query, answer = first_query, first_answer
308
+
309
+ # 7. Run training strategy for follow-up interactions
310
+ episode_steps = []
311
+ episode_rewards = []
312
+
313
+ for i, interaction in enumerate(interactions):
314
+ # 8. Generate batch of answers based on batch of follow-up queries
315
+ next_query = self._move_batch(interaction['query'])
316
+ print('first query: ', next_query['input_ids'].size())
317
+ generated_answer, log_probs = self.generate_answer(next_query)
318
+
319
+ is_last_interaction = (i + 1) == interactions_len
320
+
321
+ detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
322
+
323
+ # 9. Depending on strategy compute reward
324
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
325
+ # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
326
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
327
+ mode=MrlRewardMode.NEGATIVE)
328
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
329
+ # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
330
+ reward = self.compute_reward(detached_answer, interaction['answer'],
331
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
332
+ else:
333
+ # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
334
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
335
+ mode=MrlRewardMode.STANDARD)
336
+
337
+ # 10. Update STM with generated response (except last interaction, it's not needed)
338
+ if not is_last_interaction:
339
+ self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
340
+
341
+ # 11. Store trajectory step
342
+ trajectory: MrlTrajectoryStep = {
343
+ 'state': (query, answer, interaction['query']),
344
+ 'action': detached_answer,
345
+ 'log_probs': log_probs.detach().cpu(),
346
+ 'reward': reward,
347
+ 'reference': interaction['answer'],
348
+ }
349
+ episode_steps.append(trajectory)
350
+ episode_rewards.append(reward)
351
+
352
+ # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
353
+ query, answer = interaction['query'], detached_answer
354
+
355
+ # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
356
+ episode_trajectory: MrlTrajectoryEpisode = {
357
+ 'reset_stm': reset_done,
358
+ 'steps': episode_steps,
338
359
  }
339
- episode_steps.append(trajectory)
340
- episode_rewards.append(reward)
341
-
342
- # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
343
- query, answer = interaction['query'], detached_answer
360
+ trajectories.append(episode_trajectory)
344
361
 
345
- # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
346
- episode_trajectory: MrlTrajectoryEpisode = {
347
- 'reset_stm': reset_done,
348
- 'steps': episode_steps,
349
- }
350
- trajectories.append(episode_trajectory)
362
+ mean_episode_reward = torch.tensor(episode_rewards).mean().item()
351
363
 
352
- mean_episode_reward = torch.tensor(episode_rewards).mean().item()
364
+ self._collect_writer(mean_episode_reward, epoch)
353
365
 
354
- self._collect_writer(mean_episode_reward, epoch)
355
-
356
- # 14. Run "on episode collected" callbacks
357
- for cb in self.callbacks:
358
- cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
366
+ # 14. Run "on episode collected" callbacks
367
+ for cb in self.callbacks:
368
+ cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
359
369
 
360
370
  return trajectories
361
371
 
@@ -539,10 +549,10 @@ class MRLTrainer:
539
549
  rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
540
550
  return states, rewards
541
551
 
542
- def train_epoch(self, dataloader: DataLoader, epoch: int):
552
+ def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
543
553
  """Train for one epoch."""
544
554
  # 1. Collect trajectories for current epoch
545
- trajectories = self.collect_trajectories(dataloader, epoch)
555
+ trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
546
556
 
547
557
  # 2. Flatten trajectories and collect state and rewards for critic update
548
558
  states, rewards = self._critic_states_and_rewards(trajectories)
@@ -696,6 +706,7 @@ class MRLTrainer:
696
706
 
697
707
  # 0. Set global epoch count for all stages
698
708
  self.global_epochs_count = sum(stage['epochs'] for stage in curriculum_config)
709
+ self.global_epoch = 0
699
710
 
700
711
  # 1. Init DDP for distributed training mode
701
712
  if self.use_ddp:
@@ -765,7 +776,7 @@ class MRLTrainer:
765
776
  train_sampler.set_epoch(epoch)
766
777
 
767
778
  # 13. Run reinforcement learning algorithms for current epoch
768
- policy_loss, critic_loss = self.train_epoch(dataloader, epoch)
779
+ policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
769
780
 
770
781
  # 14. If evaluation dataset is provided, run evaluation steps
771
782
  if self.eval_dataset:
rxnn/training/utils.py CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
127
127
  # Build combined_ids using vectorized where
128
128
  combined_ids = torch.where(
129
129
  query_mask,
130
- query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)),
130
+ query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)).to(torch.int64),
131
131
  torch.where(
132
132
  answer_mask,
133
133
  answer['input_ids'].gather(1, answer_pos),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
- rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
10
+ rxnn/memory/stm.py,sha256=eIxbmOh7SSI3YDim6ki2JgiCel0sDlLnD_qSqcrf7Wc,3881
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,12 +16,12 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
17
17
  rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=WDQ8xsrHfpRmTczDZhBuOlqHX8JBaEp5SchlTdAxttY,38883
19
+ rxnn/training/mrl.py,sha256=qGEz9C8O4Yq5P0_WTdvl0QWzdDtTZZVeoNTON6ICB74,39877
20
20
  rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
21
21
  rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
- rxnn/training/utils.py,sha256=7ED5RIC8AybCmmQrbsU6Krd7brRILxVIeTlJYtJWl_4,5702
24
+ rxnn/training/utils.py,sha256=U4fVNNomwAs_XqzqAdqtlPY1UeAl6E6oKOOSB0-Ckx0,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.5.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.5.dist-info/METADATA,sha256=rfAJmz-On8W_e9tg8PVJ79-isZeBcv_3ejUkP2EcvA8,25959
37
- rxnn-0.2.5.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.5.dist-info/RECORD,,
35
+ rxnn-0.2.7.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.7.dist-info/METADATA,sha256=xyc69x0B-y2j_ww61VhAeKoQL3pvMMXe2hUV2QNWKZI,25959
37
+ rxnn-0.2.7.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.7.dist-info/RECORD,,
File without changes
File without changes