rxnn 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +1 -0
- rxnn/training/mrl.py +83 -72
- rxnn/training/utils.py +1 -1
- {rxnn-0.2.5.dist-info → rxnn-0.2.7.dist-info}/METADATA +1 -1
- {rxnn-0.2.5.dist-info → rxnn-0.2.7.dist-info}/RECORD +7 -7
- {rxnn-0.2.5.dist-info → rxnn-0.2.7.dist-info}/LICENSE +0 -0
- {rxnn-0.2.5.dist-info → rxnn-0.2.7.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
rxnn/training/mrl.py
CHANGED
@@ -275,7 +275,7 @@ class MRLTrainer:
|
|
275
275
|
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
276
276
|
self.stage_step['collect'])
|
277
277
|
|
278
|
-
def collect_trajectories(self, dataloader: DataLoader, epoch: int) -> list[MrlTrajectoryEpisode]:
|
278
|
+
def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
|
279
279
|
"""Collect trajectories for PPO for current curriculum step."""
|
280
280
|
# 1. Init trajectories list
|
281
281
|
trajectories = []
|
@@ -283,79 +283,89 @@ class MRLTrainer:
|
|
283
283
|
with torch.no_grad():
|
284
284
|
# 2. Collect episode trajectories for all batches in dataset
|
285
285
|
for batch_idx, batch in enumerate(dataloader):
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
286
|
+
print('first query size: ', batch['query']['input_ids'].size())
|
287
|
+
print('first answer size: ', batch['answer']['input_ids'].size())
|
288
|
+
print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
|
289
|
+
print('Interactions len: ', len(batch['interactions']))
|
290
290
|
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
#
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
291
|
+
if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
|
292
|
+
print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
|
293
|
+
else:
|
294
|
+
self._increment_steps('collect')
|
295
|
+
# 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
|
296
|
+
# state from existing one, instead of new random one)
|
297
|
+
reset_done = self.reset_stm()
|
298
|
+
|
299
|
+
# 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
300
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
301
|
+
interactions = interactions[:self.curriculum_steps]
|
302
|
+
interactions_len = len(interactions)
|
303
|
+
# 5. Encode and update STM with data to save from first interaction
|
304
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
305
|
+
|
306
|
+
# 6. Save first interaction as data to save (for trajectory state)
|
307
|
+
query, answer = first_query, first_answer
|
308
|
+
|
309
|
+
# 7. Run training strategy for follow-up interactions
|
310
|
+
episode_steps = []
|
311
|
+
episode_rewards = []
|
312
|
+
|
313
|
+
for i, interaction in enumerate(interactions):
|
314
|
+
# 8. Generate batch of answers based on batch of follow-up queries
|
315
|
+
next_query = self._move_batch(interaction['query'])
|
316
|
+
print('first query: ', next_query['input_ids'].size())
|
317
|
+
generated_answer, log_probs = self.generate_answer(next_query)
|
318
|
+
|
319
|
+
is_last_interaction = (i + 1) == interactions_len
|
320
|
+
|
321
|
+
detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
|
322
|
+
|
323
|
+
# 9. Depending on strategy compute reward
|
324
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
325
|
+
# a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
|
326
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
327
|
+
mode=MrlRewardMode.NEGATIVE)
|
328
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
329
|
+
# b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
|
330
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
331
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
|
332
|
+
else:
|
333
|
+
# c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
|
334
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
335
|
+
mode=MrlRewardMode.STANDARD)
|
336
|
+
|
337
|
+
# 10. Update STM with generated response (except last interaction, it's not needed)
|
338
|
+
if not is_last_interaction:
|
339
|
+
self.encode_and_update_stm(next_query, generated_answer) # update with generated_answer on GPU
|
340
|
+
|
341
|
+
# 11. Store trajectory step
|
342
|
+
trajectory: MrlTrajectoryStep = {
|
343
|
+
'state': (query, answer, interaction['query']),
|
344
|
+
'action': detached_answer,
|
345
|
+
'log_probs': log_probs.detach().cpu(),
|
346
|
+
'reward': reward,
|
347
|
+
'reference': interaction['answer'],
|
348
|
+
}
|
349
|
+
episode_steps.append(trajectory)
|
350
|
+
episode_rewards.append(reward)
|
351
|
+
|
352
|
+
# 12. Set current interaction query and generated answer (batches), as saved data for next interaction
|
353
|
+
query, answer = interaction['query'], detached_answer
|
354
|
+
|
355
|
+
# 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
356
|
+
episode_trajectory: MrlTrajectoryEpisode = {
|
357
|
+
'reset_stm': reset_done,
|
358
|
+
'steps': episode_steps,
|
338
359
|
}
|
339
|
-
|
340
|
-
episode_rewards.append(reward)
|
341
|
-
|
342
|
-
# 12. Set current interaction query and generated answer (batches), as saved data for next interaction
|
343
|
-
query, answer = interaction['query'], detached_answer
|
360
|
+
trajectories.append(episode_trajectory)
|
344
361
|
|
345
|
-
|
346
|
-
episode_trajectory: MrlTrajectoryEpisode = {
|
347
|
-
'reset_stm': reset_done,
|
348
|
-
'steps': episode_steps,
|
349
|
-
}
|
350
|
-
trajectories.append(episode_trajectory)
|
362
|
+
mean_episode_reward = torch.tensor(episode_rewards).mean().item()
|
351
363
|
|
352
|
-
|
364
|
+
self._collect_writer(mean_episode_reward, epoch)
|
353
365
|
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
for cb in self.callbacks:
|
358
|
-
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
366
|
+
# 14. Run "on episode collected" callbacks
|
367
|
+
for cb in self.callbacks:
|
368
|
+
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
359
369
|
|
360
370
|
return trajectories
|
361
371
|
|
@@ -539,10 +549,10 @@ class MRLTrainer:
|
|
539
549
|
rewards = [torch.tensor(t['reward']) for t in flat_trajectories]
|
540
550
|
return states, rewards
|
541
551
|
|
542
|
-
def train_epoch(self, dataloader: DataLoader, epoch: int):
|
552
|
+
def train_epoch(self, dataloader: DataLoader, epoch: int, batch_size: int):
|
543
553
|
"""Train for one epoch."""
|
544
554
|
# 1. Collect trajectories for current epoch
|
545
|
-
trajectories = self.collect_trajectories(dataloader, epoch)
|
555
|
+
trajectories = self.collect_trajectories(dataloader, epoch, batch_size)
|
546
556
|
|
547
557
|
# 2. Flatten trajectories and collect state and rewards for critic update
|
548
558
|
states, rewards = self._critic_states_and_rewards(trajectories)
|
@@ -696,6 +706,7 @@ class MRLTrainer:
|
|
696
706
|
|
697
707
|
# 0. Set global epoch count for all stages
|
698
708
|
self.global_epochs_count = sum(stage['epochs'] for stage in curriculum_config)
|
709
|
+
self.global_epoch = 0
|
699
710
|
|
700
711
|
# 1. Init DDP for distributed training mode
|
701
712
|
if self.use_ddp:
|
@@ -765,7 +776,7 @@ class MRLTrainer:
|
|
765
776
|
train_sampler.set_epoch(epoch)
|
766
777
|
|
767
778
|
# 13. Run reinforcement learning algorithms for current epoch
|
768
|
-
policy_loss, critic_loss = self.train_epoch(dataloader, epoch)
|
779
|
+
policy_loss, critic_loss = self.train_epoch(dataloader, epoch, batch_size)
|
769
780
|
|
770
781
|
# 14. If evaluation dataset is provided, run evaluation steps
|
771
782
|
if self.eval_dataset:
|
rxnn/training/utils.py
CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
127
127
|
# Build combined_ids using vectorized where
|
128
128
|
combined_ids = torch.where(
|
129
129
|
query_mask,
|
130
|
-
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)),
|
130
|
+
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)).to(torch.int64),
|
131
131
|
torch.where(
|
132
132
|
answer_mask,
|
133
133
|
answer['input_ids'].gather(1, answer_pos),
|
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=eIxbmOh7SSI3YDim6ki2JgiCel0sDlLnD_qSqcrf7Wc,3881
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,12 +16,12 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=qGEz9C8O4Yq5P0_WTdvl0QWzdDtTZZVeoNTON6ICB74,39877
|
20
20
|
rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
-
rxnn/training/utils.py,sha256=
|
24
|
+
rxnn/training/utils.py,sha256=U4fVNNomwAs_XqzqAdqtlPY1UeAl6E6oKOOSB0-Ckx0,5718
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.7.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.7.dist-info/METADATA,sha256=xyc69x0B-y2j_ww61VhAeKoQL3pvMMXe2hUV2QNWKZI,25959
|
37
|
+
rxnn-0.2.7.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|