rxnn 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +0 -1
- rxnn/training/mrl.py +1 -9
- rxnn/training/reward.py +2 -2
- rxnn/transformers/layers.py +1 -3
- rxnn/transformers/models.py +1 -2
- rxnn/transformers/sampler.py +14 -18
- {rxnn-0.2.8.dist-info → rxnn-0.2.10.dist-info}/METADATA +1 -1
- {rxnn-0.2.8.dist-info → rxnn-0.2.10.dist-info}/RECORD +10 -10
- {rxnn-0.2.8.dist-info → rxnn-0.2.10.dist-info}/LICENSE +0 -0
- {rxnn-0.2.8.dist-info → rxnn-0.2.10.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
rxnn/training/mrl.py
CHANGED
@@ -283,14 +283,7 @@ class MRLTrainer:
|
|
283
283
|
with torch.no_grad():
|
284
284
|
# 2. Collect episode trajectories for all batches in dataset
|
285
285
|
for batch_idx, batch in enumerate(dataloader):
|
286
|
-
|
287
|
-
print('first answer size: ', batch['answer']['input_ids'].size())
|
288
|
-
print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
|
289
|
-
print('Interactions len: ', len(batch['interactions']))
|
290
|
-
|
291
|
-
if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
|
292
|
-
print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
|
293
|
-
else:
|
286
|
+
if batch['query']['input_ids'].size(0) == batch_size:
|
294
287
|
self._increment_steps('collect')
|
295
288
|
# 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
|
296
289
|
# state from existing one, instead of new random one)
|
@@ -313,7 +306,6 @@ class MRLTrainer:
|
|
313
306
|
for i, interaction in enumerate(interactions):
|
314
307
|
# 8. Generate batch of answers based on batch of follow-up queries
|
315
308
|
next_query = self._move_batch(interaction['query'])
|
316
|
-
print('first query: ', next_query['input_ids'].size())
|
317
309
|
generated_answer, log_probs = self.generate_answer(next_query)
|
318
310
|
|
319
311
|
is_last_interaction = (i + 1) == interactions_len
|
rxnn/training/reward.py
CHANGED
@@ -103,9 +103,9 @@ class MrlRewardModel:
|
|
103
103
|
if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
|
104
104
|
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
105
105
|
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
106
|
-
return (self.bleu_factor * torch.tensor(bleu) + self.cos_factor * cosine).tolist()
|
106
|
+
return (self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine).tolist()
|
107
107
|
else:
|
108
108
|
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
109
109
|
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
110
|
-
return (self.neg_bleu_factor * torch.tensor(bleu) + self.neg_cos_factor * cosine).tolist()
|
110
|
+
return (self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine).tolist()
|
111
111
|
|
rxnn/transformers/layers.py
CHANGED
@@ -87,12 +87,11 @@ class ReactiveTransformerLayer(nn.Module):
|
|
87
87
|
else:
|
88
88
|
return None
|
89
89
|
|
90
|
-
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None
|
90
|
+
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
91
91
|
# First step, self-attention
|
92
92
|
residual = x
|
93
93
|
if not self.use_post_norm:
|
94
94
|
x = self.norm1(x)
|
95
|
-
if logs: print('att', x.size())
|
96
95
|
x = self.attention(x, x, x, mask=mask)
|
97
96
|
x = residual + x
|
98
97
|
if self.use_post_norm:
|
@@ -101,7 +100,6 @@ class ReactiveTransformerLayer(nn.Module):
|
|
101
100
|
residual = x
|
102
101
|
if not self.use_post_norm:
|
103
102
|
x = self.norm2(x)
|
104
|
-
if logs: print('mxatt', x.size())
|
105
103
|
x = self.memory_cross_attention(x, stm, stm)
|
106
104
|
x = residual + x
|
107
105
|
if self.use_post_norm:
|
rxnn/transformers/models.py
CHANGED
@@ -83,8 +83,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
|
|
83
83
|
# expand layer STM to batch size, if it's not in batch mode
|
84
84
|
if layer_stm.size(0) == 1:
|
85
85
|
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
86
|
-
|
87
|
-
x = self.layers[i](x, layer_stm, mask=mask, logs=True)
|
86
|
+
x = self.layers[i](x, layer_stm, mask=mask)
|
88
87
|
return self.head(x)
|
89
88
|
|
90
89
|
|
rxnn/transformers/sampler.py
CHANGED
@@ -293,28 +293,21 @@ class BatchSampler:
|
|
293
293
|
working_mask = attention_mask.clone()
|
294
294
|
|
295
295
|
for step in range(max_gen_len):
|
296
|
-
print('Sampler step', step)
|
297
296
|
active = (~finished) & (current_lens < max_seq_len)
|
298
297
|
if not active.any():
|
299
298
|
break
|
300
299
|
|
301
|
-
|
302
|
-
active_current_lens = current_lens[active]
|
303
|
-
max_len = active_current_lens.max().item()
|
300
|
+
max_len = current_lens.max().item()
|
304
301
|
|
305
302
|
with torch.set_grad_enabled(not no_grad):
|
306
|
-
print('Active size', active.size())
|
307
303
|
# Slice input and mask up to the current max length among active sequences
|
308
|
-
inputs = working_ids[
|
309
|
-
masks = working_mask[
|
310
|
-
print('Model input sizes')
|
311
|
-
print(inputs.size())
|
312
|
-
print(masks.size())
|
304
|
+
inputs = working_ids[:, :max_len]
|
305
|
+
masks = working_mask[:, :max_len]
|
313
306
|
logits = self.model(inputs, attention_mask=masks)
|
314
307
|
|
315
308
|
# Get the last valid token index for each active sequence
|
316
|
-
indices = (
|
317
|
-
last_logits = logits[torch.arange(
|
309
|
+
indices = (current_lens - 1).to(self.device)
|
310
|
+
last_logits = logits[torch.arange(batch_size, device=self.device), indices]
|
318
311
|
|
319
312
|
# Sample next tokens and log probs
|
320
313
|
next_tokens, step_log_probs = sample_batch(
|
@@ -322,15 +315,18 @@ class BatchSampler:
|
|
322
315
|
)
|
323
316
|
|
324
317
|
# Update working tensors
|
325
|
-
for
|
326
|
-
if current_lens[idx] >= max_seq_len:
|
318
|
+
for idx in range(batch_size):
|
319
|
+
if finished[idx] or current_lens[idx] >= max_seq_len:
|
327
320
|
continue
|
321
|
+
|
328
322
|
pos = current_lens[idx].item()
|
329
|
-
|
330
|
-
|
331
|
-
|
323
|
+
token = next_tokens[idx] # Use original batch index
|
324
|
+
working_ids[idx, pos] = token
|
325
|
+
working_mask[idx, pos] = 1 if token != 0 else 0
|
326
|
+
log_probs[idx, step] = step_log_probs[idx] # Use original batch index
|
332
327
|
current_lens[idx] += 1
|
333
|
-
|
328
|
+
|
329
|
+
if token == self.end_token_id:
|
334
330
|
finished[idx] = True
|
335
331
|
|
336
332
|
# Extract generated tokens
|
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
|
17
17
|
rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
20
|
-
rxnn/training/reward.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=cftSa6zS3jfremxX0SQrxtbhbEHfZ-nvZT3Hl6GlgpI,39282
|
20
|
+
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
21
21
|
rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
|
25
25
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
27
27
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
28
|
-
rxnn/transformers/layers.py,sha256=
|
28
|
+
rxnn/transformers/layers.py,sha256=MbOIX4PurbTbYxcXSavyFsNpTHCm26K_Ssk_VUCzKIE,7469
|
29
29
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
30
|
-
rxnn/transformers/models.py,sha256=
|
30
|
+
rxnn/transformers/models.py,sha256=VvP7r7E6tj7OWsYKlJLCy2vsQ3xSSnlNez6QxR-jBAA,8276
|
31
31
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
|
-
rxnn/transformers/sampler.py,sha256=
|
33
|
+
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.10.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.10.dist-info/METADATA,sha256=eHGpOf-EPGR8Ljs2S7agTAaQB79o6ZYdEtuJRgZLi8E,25960
|
37
|
+
rxnn-0.2.10.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|