rxnn 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +0 -1
- rxnn/training/mrl.py +1 -9
- rxnn/training/utils.py +1 -1
- rxnn/transformers/sampler.py +14 -13
- {rxnn-0.2.7.dist-info → rxnn-0.2.9.dist-info}/METADATA +1 -1
- {rxnn-0.2.7.dist-info → rxnn-0.2.9.dist-info}/RECORD +8 -8
- {rxnn-0.2.7.dist-info → rxnn-0.2.9.dist-info}/LICENSE +0 -0
- {rxnn-0.2.7.dist-info → rxnn-0.2.9.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
rxnn/training/mrl.py
CHANGED
@@ -283,14 +283,7 @@ class MRLTrainer:
|
|
283
283
|
with torch.no_grad():
|
284
284
|
# 2. Collect episode trajectories for all batches in dataset
|
285
285
|
for batch_idx, batch in enumerate(dataloader):
|
286
|
-
|
287
|
-
print('first answer size: ', batch['answer']['input_ids'].size())
|
288
|
-
print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
|
289
|
-
print('Interactions len: ', len(batch['interactions']))
|
290
|
-
|
291
|
-
if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
|
292
|
-
print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
|
293
|
-
else:
|
286
|
+
if batch['query']['input_ids'].size(0) == batch_size:
|
294
287
|
self._increment_steps('collect')
|
295
288
|
# 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
|
296
289
|
# state from existing one, instead of new random one)
|
@@ -313,7 +306,6 @@ class MRLTrainer:
|
|
313
306
|
for i, interaction in enumerate(interactions):
|
314
307
|
# 8. Generate batch of answers based on batch of follow-up queries
|
315
308
|
next_query = self._move_batch(interaction['query'])
|
316
|
-
print('first query: ', next_query['input_ids'].size())
|
317
309
|
generated_answer, log_probs = self.generate_answer(next_query)
|
318
310
|
|
319
311
|
is_last_interaction = (i + 1) == interactions_len
|
rxnn/training/utils.py
CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
127
127
|
# Build combined_ids using vectorized where
|
128
128
|
combined_ids = torch.where(
|
129
129
|
query_mask,
|
130
|
-
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)
|
130
|
+
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1).to(torch.int64)),
|
131
131
|
torch.where(
|
132
132
|
answer_mask,
|
133
133
|
answer['input_ids'].gather(1, answer_pos),
|
rxnn/transformers/sampler.py
CHANGED
@@ -297,19 +297,17 @@ class BatchSampler:
|
|
297
297
|
if not active.any():
|
298
298
|
break
|
299
299
|
|
300
|
-
|
301
|
-
active_current_lens = current_lens[active]
|
302
|
-
max_len = active_current_lens.max().item()
|
300
|
+
max_len = current_lens.max().item()
|
303
301
|
|
304
302
|
with torch.set_grad_enabled(not no_grad):
|
305
303
|
# Slice input and mask up to the current max length among active sequences
|
306
|
-
inputs = working_ids[
|
307
|
-
masks = working_mask[
|
304
|
+
inputs = working_ids[:, :max_len]
|
305
|
+
masks = working_mask[:, :max_len]
|
308
306
|
logits = self.model(inputs, attention_mask=masks)
|
309
307
|
|
310
308
|
# Get the last valid token index for each active sequence
|
311
|
-
indices = (
|
312
|
-
last_logits = logits[torch.arange(
|
309
|
+
indices = (current_lens - 1).to(self.device)
|
310
|
+
last_logits = logits[torch.arange(batch_size, device=self.device), indices]
|
313
311
|
|
314
312
|
# Sample next tokens and log probs
|
315
313
|
next_tokens, step_log_probs = sample_batch(
|
@@ -317,15 +315,18 @@ class BatchSampler:
|
|
317
315
|
)
|
318
316
|
|
319
317
|
# Update working tensors
|
320
|
-
for
|
321
|
-
if current_lens[idx] >= max_seq_len:
|
318
|
+
for idx in range(batch_size):
|
319
|
+
if finished[idx] or current_lens[idx] >= max_seq_len:
|
322
320
|
continue
|
321
|
+
|
323
322
|
pos = current_lens[idx].item()
|
324
|
-
|
325
|
-
|
326
|
-
|
323
|
+
token = next_tokens[idx] # Use original batch index
|
324
|
+
working_ids[idx, pos] = token
|
325
|
+
working_mask[idx, pos] = 1 if token != 0 else 0
|
326
|
+
log_probs[idx, step] = step_log_probs[idx] # Use original batch index
|
327
327
|
current_lens[idx] += 1
|
328
|
-
|
328
|
+
|
329
|
+
if token == self.end_token_id:
|
329
330
|
finished[idx] = True
|
330
331
|
|
331
332
|
# Extract generated tokens
|
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,12 +16,12 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=cftSa6zS3jfremxX0SQrxtbhbEHfZ-nvZT3Hl6GlgpI,39282
|
20
20
|
rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
-
rxnn/training/utils.py,sha256=
|
24
|
+
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
@@ -30,9 +30,9 @@ rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
|
30
30
|
rxnn/transformers/models.py,sha256=VvP7r7E6tj7OWsYKlJLCy2vsQ3xSSnlNez6QxR-jBAA,8276
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
|
-
rxnn/transformers/sampler.py,sha256=
|
33
|
+
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.9.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.9.dist-info/METADATA,sha256=ERPLS1G1D0zDUR3OUpTAPUHQ47Xjd3U4b4bSMnqA4p4,25959
|
37
|
+
rxnn-0.2.9.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|