rxnn 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/stm.py CHANGED
@@ -43,7 +43,6 @@ class ShortTermMemory(nn.Module):
43
43
 
44
44
  def update_all(self, new_stm: torch.Tensor):
45
45
  self.memory.copy_(new_stm)
46
- print(self.memory.size())
47
46
 
48
47
  def make_trainable(self):
49
48
  if not self.is_trainable:
rxnn/training/mrl.py CHANGED
@@ -283,14 +283,7 @@ class MRLTrainer:
283
283
  with torch.no_grad():
284
284
  # 2. Collect episode trajectories for all batches in dataset
285
285
  for batch_idx, batch in enumerate(dataloader):
286
- print('first query size: ', batch['query']['input_ids'].size())
287
- print('first answer size: ', batch['answer']['input_ids'].size())
288
- print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
289
- print('Interactions len: ', len(batch['interactions']))
290
-
291
- if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
292
- print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
293
- else:
286
+ if batch['query']['input_ids'].size(0) == batch_size:
294
287
  self._increment_steps('collect')
295
288
  # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
296
289
  # state from existing one, instead of new random one)
@@ -313,7 +306,6 @@ class MRLTrainer:
313
306
  for i, interaction in enumerate(interactions):
314
307
  # 8. Generate batch of answers based on batch of follow-up queries
315
308
  next_query = self._move_batch(interaction['query'])
316
- print('first query: ', next_query['input_ids'].size())
317
309
  generated_answer, log_probs = self.generate_answer(next_query)
318
310
 
319
311
  is_last_interaction = (i + 1) == interactions_len
rxnn/training/reward.py CHANGED
@@ -103,9 +103,9 @@ class MrlRewardModel:
103
103
  if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
104
104
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
105
105
  cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
106
- return (self.bleu_factor * torch.tensor(bleu) + self.cos_factor * cosine).tolist()
106
+ return (self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine).tolist()
107
107
  else:
108
108
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
109
109
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
110
- return (self.neg_bleu_factor * torch.tensor(bleu) + self.neg_cos_factor * cosine).tolist()
110
+ return (self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine).tolist()
111
111
 
@@ -87,12 +87,11 @@ class ReactiveTransformerLayer(nn.Module):
87
87
  else:
88
88
  return None
89
89
 
90
- def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None, logs: bool = False) -> torch.Tensor:
90
+ def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
91
91
  # First step, self-attention
92
92
  residual = x
93
93
  if not self.use_post_norm:
94
94
  x = self.norm1(x)
95
- if logs: print('att', x.size())
96
95
  x = self.attention(x, x, x, mask=mask)
97
96
  x = residual + x
98
97
  if self.use_post_norm:
@@ -101,7 +100,6 @@ class ReactiveTransformerLayer(nn.Module):
101
100
  residual = x
102
101
  if not self.use_post_norm:
103
102
  x = self.norm2(x)
104
- if logs: print('mxatt', x.size())
105
103
  x = self.memory_cross_attention(x, stm, stm)
106
104
  x = residual + x
107
105
  if self.use_post_norm:
@@ -83,8 +83,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
83
83
  # expand layer STM to batch size, if it's not in batch mode
84
84
  if layer_stm.size(0) == 1:
85
85
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
86
- print(f'Layer {i} before: {x.size()}')
87
- x = self.layers[i](x, layer_stm, mask=mask, logs=True)
86
+ x = self.layers[i](x, layer_stm, mask=mask)
88
87
  return self.head(x)
89
88
 
90
89
 
@@ -293,28 +293,21 @@ class BatchSampler:
293
293
  working_mask = attention_mask.clone()
294
294
 
295
295
  for step in range(max_gen_len):
296
- print('Sampler step', step)
297
296
  active = (~finished) & (current_lens < max_seq_len)
298
297
  if not active.any():
299
298
  break
300
299
 
301
- active_indices = active.nonzero(as_tuple=True)[0]
302
- active_current_lens = current_lens[active]
303
- max_len = active_current_lens.max().item()
300
+ max_len = current_lens.max().item()
304
301
 
305
302
  with torch.set_grad_enabled(not no_grad):
306
- print('Active size', active.size())
307
303
  # Slice input and mask up to the current max length among active sequences
308
- inputs = working_ids[active, :max_len]
309
- masks = working_mask[active, :max_len]
310
- print('Model input sizes')
311
- print(inputs.size())
312
- print(masks.size())
304
+ inputs = working_ids[:, :max_len]
305
+ masks = working_mask[:, :max_len]
313
306
  logits = self.model(inputs, attention_mask=masks)
314
307
 
315
308
  # Get the last valid token index for each active sequence
316
- indices = (active_current_lens - 1).to(self.device)
317
- last_logits = logits[torch.arange(len(active_indices), device=self.device), indices]
309
+ indices = (current_lens - 1).to(self.device)
310
+ last_logits = logits[torch.arange(batch_size, device=self.device), indices]
318
311
 
319
312
  # Sample next tokens and log probs
320
313
  next_tokens, step_log_probs = sample_batch(
@@ -322,15 +315,18 @@ class BatchSampler:
322
315
  )
323
316
 
324
317
  # Update working tensors
325
- for i, idx in enumerate(active_indices):
326
- if current_lens[idx] >= max_seq_len:
318
+ for idx in range(batch_size):
319
+ if finished[idx] or current_lens[idx] >= max_seq_len:
327
320
  continue
321
+
328
322
  pos = current_lens[idx].item()
329
- working_ids[idx, pos] = next_tokens[i]
330
- working_mask[idx, pos] = 1
331
- log_probs[idx, step] = step_log_probs[i]
323
+ token = next_tokens[idx] # Use original batch index
324
+ working_ids[idx, pos] = token
325
+ working_mask[idx, pos] = 1 if token != 0 else 0
326
+ log_probs[idx, step] = step_log_probs[idx] # Use original batch index
332
327
  current_lens[idx] += 1
333
- if next_tokens[i] == self.end_token_id:
328
+
329
+ if token == self.end_token_id:
334
330
  finished[idx] = True
335
331
 
336
332
  # Extract generated tokens
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
- rxnn/memory/stm.py,sha256=eIxbmOh7SSI3YDim6ki2JgiCel0sDlLnD_qSqcrf7Wc,3881
10
+ rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
17
17
  rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=qGEz9C8O4Yq5P0_WTdvl0QWzdDtTZZVeoNTON6ICB74,39877
20
- rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
19
+ rxnn/training/mrl.py,sha256=cftSa6zS3jfremxX0SQrxtbhbEHfZ-nvZT3Hl6GlgpI,39282
20
+ rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
21
21
  rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
28
- rxnn/transformers/layers.py,sha256=u_48ocn_JtdJ905hLDIxO12qQgday_6GuAyplBKKXz0,7579
28
+ rxnn/transformers/layers.py,sha256=MbOIX4PurbTbYxcXSavyFsNpTHCm26K_Ssk_VUCzKIE,7469
29
29
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
30
- rxnn/transformers/models.py,sha256=kZIZxH4MbgQjQeeKIXGi4tuzXWAfmXjgUCfEzyyeOxA,8338
30
+ rxnn/transformers/models.py,sha256=VvP7r7E6tj7OWsYKlJLCy2vsQ3xSSnlNez6QxR-jBAA,8276
31
31
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
- rxnn/transformers/sampler.py,sha256=GK5aYALWW-DegI7sIjiO531EWAfKl71k_6CbVuYXwfw,17637
33
+ rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.8.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.8.dist-info/METADATA,sha256=mOklYF8wm6kn9R70EWls0IYatIxEyKrmpNYspsnMQoM,25959
37
- rxnn-0.2.8.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.8.dist-info/RECORD,,
35
+ rxnn-0.2.10.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.10.dist-info/METADATA,sha256=eHGpOf-EPGR8Ljs2S7agTAaQB79o6ZYdEtuJRgZLi8E,25960
37
+ rxnn-0.2.10.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.10.dist-info/RECORD,,
File without changes
File without changes