rxnn 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/mrl.py +8 -1
- rxnn/training/utils.py +1 -1
- rxnn/transformers/layers.py +3 -1
- rxnn/transformers/models.py +2 -1
- rxnn/transformers/sampler.py +5 -0
- {rxnn-0.2.6.dist-info → rxnn-0.2.8.dist-info}/METADATA +1 -1
- {rxnn-0.2.6.dist-info → rxnn-0.2.8.dist-info}/RECORD +9 -9
- {rxnn-0.2.6.dist-info → rxnn-0.2.8.dist-info}/LICENSE +0 -0
- {rxnn-0.2.6.dist-info → rxnn-0.2.8.dist-info}/WHEEL +0 -0
rxnn/training/mrl.py
CHANGED
@@ -283,7 +283,12 @@ class MRLTrainer:
|
|
283
283
|
with torch.no_grad():
|
284
284
|
# 2. Collect episode trajectories for all batches in dataset
|
285
285
|
for batch_idx, batch in enumerate(dataloader):
|
286
|
-
|
286
|
+
print('first query size: ', batch['query']['input_ids'].size())
|
287
|
+
print('first answer size: ', batch['answer']['input_ids'].size())
|
288
|
+
print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
|
289
|
+
print('Interactions len: ', len(batch['interactions']))
|
290
|
+
|
291
|
+
if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
|
287
292
|
print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
|
288
293
|
else:
|
289
294
|
self._increment_steps('collect')
|
@@ -304,9 +309,11 @@ class MRLTrainer:
|
|
304
309
|
# 7. Run training strategy for follow-up interactions
|
305
310
|
episode_steps = []
|
306
311
|
episode_rewards = []
|
312
|
+
|
307
313
|
for i, interaction in enumerate(interactions):
|
308
314
|
# 8. Generate batch of answers based on batch of follow-up queries
|
309
315
|
next_query = self._move_batch(interaction['query'])
|
316
|
+
print('first query: ', next_query['input_ids'].size())
|
310
317
|
generated_answer, log_probs = self.generate_answer(next_query)
|
311
318
|
|
312
319
|
is_last_interaction = (i + 1) == interactions_len
|
rxnn/training/utils.py
CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
127
127
|
# Build combined_ids using vectorized where
|
128
128
|
combined_ids = torch.where(
|
129
129
|
query_mask,
|
130
|
-
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)
|
130
|
+
query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1).to(torch.int64)),
|
131
131
|
torch.where(
|
132
132
|
answer_mask,
|
133
133
|
answer['input_ids'].gather(1, answer_pos),
|
rxnn/transformers/layers.py
CHANGED
@@ -87,11 +87,12 @@ class ReactiveTransformerLayer(nn.Module):
|
|
87
87
|
else:
|
88
88
|
return None
|
89
89
|
|
90
|
-
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
90
|
+
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None, logs: bool = False) -> torch.Tensor:
|
91
91
|
# First step, self-attention
|
92
92
|
residual = x
|
93
93
|
if not self.use_post_norm:
|
94
94
|
x = self.norm1(x)
|
95
|
+
if logs: print('att', x.size())
|
95
96
|
x = self.attention(x, x, x, mask=mask)
|
96
97
|
x = residual + x
|
97
98
|
if self.use_post_norm:
|
@@ -100,6 +101,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
100
101
|
residual = x
|
101
102
|
if not self.use_post_norm:
|
102
103
|
x = self.norm2(x)
|
104
|
+
if logs: print('mxatt', x.size())
|
103
105
|
x = self.memory_cross_attention(x, stm, stm)
|
104
106
|
x = residual + x
|
105
107
|
if self.use_post_norm:
|
rxnn/transformers/models.py
CHANGED
@@ -83,7 +83,8 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
83
83
|
# expand layer STM to batch size, if it's not in batch mode
|
84
84
|
if layer_stm.size(0) == 1:
|
85
85
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
86
|
-
|
86
|
+
print(f'Layer {i} before: {x.size()}')
|
87
|
+
x = self.layers[i](x, layer_stm, mask=mask, logs=True)
|
87
88
|
return self.head(x)
|
88
89
|
|
89
90
|
|
rxnn/transformers/sampler.py
CHANGED
@@ -293,6 +293,7 @@ class BatchSampler:
|
|
293
293
|
working_mask = attention_mask.clone()
|
294
294
|
|
295
295
|
for step in range(max_gen_len):
|
296
|
+
print('Sampler step', step)
|
296
297
|
active = (~finished) & (current_lens < max_seq_len)
|
297
298
|
if not active.any():
|
298
299
|
break
|
@@ -302,9 +303,13 @@ class BatchSampler:
|
|
302
303
|
max_len = active_current_lens.max().item()
|
303
304
|
|
304
305
|
with torch.set_grad_enabled(not no_grad):
|
306
|
+
print('Active size', active.size())
|
305
307
|
# Slice input and mask up to the current max length among active sequences
|
306
308
|
inputs = working_ids[active, :max_len]
|
307
309
|
masks = working_mask[active, :max_len]
|
310
|
+
print('Model input sizes')
|
311
|
+
print(inputs.size())
|
312
|
+
print(masks.size())
|
308
313
|
logits = self.model(inputs, attention_mask=masks)
|
309
314
|
|
310
315
|
# Get the last valid token index for each active sequence
|
@@ -16,23 +16,23 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=qGEz9C8O4Yq5P0_WTdvl0QWzdDtTZZVeoNTON6ICB74,39877
|
20
20
|
rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
|
-
rxnn/training/utils.py,sha256=
|
24
|
+
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
-
rxnn/transformers/layers.py,sha256=
|
28
|
+
rxnn/transformers/layers.py,sha256=u_48ocn_JtdJ905hLDIxO12qQgday_6GuAyplBKKXz0,7579
|
29
29
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
-
rxnn/transformers/models.py,sha256=
|
30
|
+
rxnn/transformers/models.py,sha256=kZIZxH4MbgQjQeeKIXGi4tuzXWAfmXjgUCfEzyyeOxA,8338
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
|
-
rxnn/transformers/sampler.py,sha256=
|
33
|
+
rxnn/transformers/sampler.py,sha256=GK5aYALWW-DegI7sIjiO531EWAfKl71k_6CbVuYXwfw,17637
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.8.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.8.dist-info/METADATA,sha256=mOklYF8wm6kn9R70EWls0IYatIxEyKrmpNYspsnMQoM,25959
|
37
|
+
rxnn-0.2.8.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|