rxnn 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py
CHANGED
@@ -283,7 +283,12 @@ class MRLTrainer:
|
|
283
283
|
with torch.no_grad():
|
284
284
|
# 2. Collect episode trajectories for all batches in dataset
|
285
285
|
for batch_idx, batch in enumerate(dataloader):
|
286
|
-
|
286
|
+
print('first query size: ', batch['query']['input_ids'].size())
|
287
|
+
print('first answer size: ', batch['answer']['input_ids'].size())
|
288
|
+
print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
|
289
|
+
print('Interactions len: ', len(batch['interactions']))
|
290
|
+
|
291
|
+
if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
|
287
292
|
print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
|
288
293
|
else:
|
289
294
|
self._increment_steps('collect')
|
@@ -304,9 +309,11 @@ class MRLTrainer:
|
|
304
309
|
# 7. Run training strategy for follow-up interactions
|
305
310
|
episode_steps = []
|
306
311
|
episode_rewards = []
|
312
|
+
|
307
313
|
for i, interaction in enumerate(interactions):
|
308
314
|
# 8. Generate batch of answers based on batch of follow-up queries
|
309
315
|
next_query = self._move_batch(interaction['query'])
|
316
|
+
print('first query: ', next_query['input_ids'].size())
|
310
317
|
generated_answer, log_probs = self.generate_answer(next_query)
|
311
318
|
|
312
319
|
is_last_interaction = (i + 1) == interactions_len
|
@@ -16,7 +16,7 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=qGEz9C8O4Yq5P0_WTdvl0QWzdDtTZZVeoNTON6ICB74,39877
|
20
20
|
rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.7.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.7.dist-info/METADATA,sha256=xyc69x0B-y2j_ww61VhAeKoQL3pvMMXe2hUV2QNWKZI,25959
|
37
|
+
rxnn-0.2.7.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|