rxnn 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +1 -0
- rxnn/training/mrl.py +77 -73
- rxnn/training/utils.py +1 -1
- {rxnn-0.2.5.dist-info → rxnn-0.2.6.dist-info}/METADATA +1 -1
- {rxnn-0.2.5.dist-info → rxnn-0.2.6.dist-info}/RECORD +7 -7
- {rxnn-0.2.5.dist-info → rxnn-0.2.6.dist-info}/LICENSE +0 -0
- {rxnn-0.2.5.dist-info → rxnn-0.2.6.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
rxnn/training/mrl.py
CHANGED
@@ -275,7 +275,7 @@ class MRLTrainer:
|
|
275
275
|
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
276
276
|
self.stage_step['collect'])
|
277
277
|
|
278
|
-
def collect_trajectories(self, dataloader: DataLoader, epoch: int) -> list[MrlTrajectoryEpisode]:
|
278
|
+
def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
|
279
279
|
"""Collect trajectories for PPO for current curriculum step."""
|
280
280
|
# 1. Init trajectories list
|
281
281
|
trajectories = []
|
@@ -283,79 +283,82 @@ class MRLTrainer:
|
|
283
283
|
with torch.no_grad():
|
284
284
|
# 2. Collect episode trajectories for all batches in dataset
|
285
285
|
for batch_idx, batch in enumerate(dataloader):
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
286
|
+
if batch['query']['input_ids'].size(0) != batch_size:
|
287
|
+
print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
|
288
|
+
else:
|
289
|
+
self._increment_steps('collect')
|
290
|
+
# 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
|
291
|
+
# state from existing one, instead of new random one)
|
292
|
+
reset_done = self.reset_stm()
|
293
|
+
|
294
|
+
# 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
295
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
296
|
+
interactions = interactions[:self.curriculum_steps]
|
297
|
+
interactions_len = len(interactions)
|
298
|
+
# 5. Encode and update STM with data to save from first interaction
|
299
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
300
|
+
|
301
|
+
# 6. Save first interaction as data to save (for trajectory state)
|
302
|
+
query, answer = first_query, first_answer
|
303
|
+
|
304
|
+
# 7. Run training strategy for follow-up interactions
|
305
|
+
episode_steps = []
|
306
|
+
episode_rewards = []
|
307
|
+
for i, interaction in enumerate(interactions):
|
308
|
+
# 8. Generate batch of answers based on batch of follow-up queries
|
309
|
+
next_query = self._move_batch(interaction['query'])
|
310
|
+
generated_answer, log_probs = self.generate_answer(next_query)
|
311
|
+
|
312
|
+
is_last_interaction = (i + 1) == interactions_len
|
313
|
+
|
314
|
+
detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
|
315
|
+
|
316
|
+
# 9. Depending on strategy compute reward
|
317
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
318
|
+
# a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
|
319
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
320
|
+
mode=MrlRewardMode.NEGATIVE)
|
321
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
322
|
+
# b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
|
323
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
324
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
|
325
|
+
else:
|
326
|
+
# c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
|
327
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
328
|
+
mode=MrlRewardMode.STANDARD)
|
329
|
+
|
330
|
+
# 10. Update STM with generated response (except last interaction, it's not needed)
|
331
|
+
if not is_last_interaction:
|
332
|
+
self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
|
333
|
+
|
334
|
+
# 11. Store trajectory step
|
335
|
+
trajectory: MrlTrajectoryStep = {
|
336
|
+
'state': (query, answer, interaction['query']),
|
337
|
+
'action': detached_answer,
|
338
|
+
'log_probs': log_probs.detach().cpu(),
|
339
|
+
'reward': reward,
|
340
|
+
'reference': interaction['answer'],
|
341
|
+
}
|
342
|
+
episode_steps.append(trajectory)
|
343
|
+
episode_rewards.append(reward)
|
344
|
+
|
345
|
+
# 12. Set current interaction query and generated answer (batches), as saved data for next interaction
|
346
|
+
query, answer = interaction['query'], detached_answer
|
347
|
+
|
348
|
+
# 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
349
|
+
episode_trajectory: MrlTrajectoryEpisode = {
|
350
|
+
'reset_stm': reset_done,
|
351
|
+
'steps': episode_steps,
|
338
352
|
}
|
339
|
-
|
340
|
-
episode_rewards.append(reward)
|
341
|
-
|
342
|
-
# 12. Set current interaction query and generated answer (batches), as saved data for next interaction
|
343
|
-
query, answer = interaction['query'], detached_answer
|
353
|
+
trajectories.append(episode_trajectory)
|
344
354
|
|
345
|
-
|
346
|
-
episode_trajectory: MrlTrajectoryEpisode = {
|
347
|
-
'reset_stm': reset_done,
|
348
|
-
'steps': episode_steps,
|
349
|
-
}
|
350
|
-
trajectories.append(episode_trajectory)
|
355
|
+
mean_episode_reward = torch.tensor(episode_rewards).mean().item()
|
351
356
|
|
352
|
-
|
357
|
+
self._collect_writer(mean_episode_reward, epoch)
|
353
358
|
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
for cb in self.callbacks:
|
358
|
-
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
359
|
+
# 14. Run "on episode collected" callbacks
|
360
|
+
for cb in self.callbacks:
|
361
|
+
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
359
362
|
|
360
363
|
return trajectories
|
361
364
|
|
@@ -539,10 +542,10 @@ class MRLTrainer:
|
|
539
542
|
rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
|
540
543
|
return states, rewards
|
541
544
|
|
542
|
-
def train_epoch(self, dataloader: DataLoader, epoch: int):
|
545
|
+
def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
|
543
546
|
"""Train for one epoch."""
|
544
547
|
# 1. Collect trajectories for current epoch
|
545
|
-
trajectories = self.collect_trajectories(dataloader, epoch)
|
548
|
+
trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
|
546
549
|
|
547
550
|
# 2. Flatten trajectories and collect state and rewards for critic update
|
548
551
|
states, rewards = self._critic_states_and_rewards(trajectories)
|
@@ -696,6 +699,7 @@ class MRLTrainer:
|
|
696
699
|
|
697
700
|
# 0. Set global epoch count for all stages
|
698
701
|
self.global_epochs_count = sum(stage['epochs'] for stage in curriculum_config)
|
702
|
+
self.global_epoch = 0
|
699
703
|
|
700
704
|
# 1. Init DDP for distributed training mode
|
701
705
|
if self.use_ddp:
|
@@ -765,7 +769,7 @@ class MRLTrainer:
|
|
765
769
|
train_sampler.set_epoch(epoch)
|
766
770
|
|
767
771
|
# 13. Run reinforcement learning algorithms for current epoch
|
768
|
-
policy_loss, critic_loss = self.train_epoch(dataloader, epoch)
|
772
|
+
policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
|
769
773
|
|
770
774
|
# 14. If evaluation dataset is provided, run evaluation steps
|
771
775
|
if self.eval_dataset:
|
rxnn/training/utils.py
CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
127
127
|
# Build combined_ids using vectorized where
|
128
128
|
combined_ids = torch.where(
|
129
129
|
query_mask,
|
130
|
-
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)),
|
130
|
+
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)).to(torch.int64),
|
131
131
|
torch.where(
|
132
132
|
answer_mask,
|
133
133
|
answer['input_ids'].gather(1, answer_pos),
|
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=eIxbmOh7SSI3YDim6ki2JgiCel0sDlLnD_qSqcrf7Wc,3881
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,12 +16,12 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=G1sxKrnfkggHRqSInjk17PfpNhj5neqfh5Y2RUabJLk,39392
|
20
20
|
rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
-
rxnn/training/utils.py,sha256=
|
24
|
+
rxnn/training/utils.py,sha256=U4fVNNomwAs_XqzqAdqtlPY1UeAl6E6oKOOSB0-Ckx0,5718
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.6.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.6.dist-info/METADATA,sha256=eoxdTixI_-XuziJ64YG3QGa8TD0f06q95stHWgNzSjE,25959
|
37
|
+
rxnn-0.2.6.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|