rxnn 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/stm.py CHANGED
@@ -43,7 +43,6 @@ class ShortTermMemory(nn.Module):
43
43
 
44
44
  def update_all(self, new_stm: torch.Tensor):
45
45
  self.memory.copy_(new_stm)
46
- print(self.memory.size())
47
46
 
48
47
  def make_trainable(self):
49
48
  if not self.is_trainable:
rxnn/training/mrl.py CHANGED
@@ -283,14 +283,7 @@ class MRLTrainer:
283
283
  with torch.no_grad():
284
284
  # 2. Collect episode trajectories for all batches in dataset
285
285
  for batch_idx, batch in enumerate(dataloader):
286
- print('first query size: ', batch['query']['input_ids'].size())
287
- print('first answer size: ', batch['answer']['input_ids'].size())
288
- print('next query size: ', batch['interactions'][0]['query']['input_ids'].size())
289
- print('Interactions len: ', len(batch['interactions']))
290
-
291
- if batch['query']['input_ids'].size(0) != batch_size or batch['interactions'][0]['query']['input_ids'].size(0) != batch_size:
292
- print('Incorrect batch size: ', batch['query']['input_ids'].size(0))
293
- else:
286
+ if batch['query']['input_ids'].size(0) == batch_size:
294
287
  self._increment_steps('collect')
295
288
  # 3. Reset Short-Term Memory state (with random reset ratio - sometimes it will be good to build memory
296
289
  # state from existing one, instead of new random one)
@@ -313,7 +306,6 @@ class MRLTrainer:
313
306
  for i, interaction in enumerate(interactions):
314
307
  # 8. Generate batch of answers based on batch of follow-up queries
315
308
  next_query = self._move_batch(interaction['query'])
316
- print('first query: ', next_query['input_ids'].size())
317
309
  generated_answer, log_probs = self.generate_answer(next_query)
318
310
 
319
311
  is_last_interaction = (i + 1) == interactions_len
rxnn/training/utils.py CHANGED
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
127
127
  # Build combined_ids using vectorized where
128
128
  combined_ids = torch.where(
129
129
  query_mask,
130
- query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)).to(torch.int64),
130
+ query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1).to(torch.int64)),
131
131
  torch.where(
132
132
  answer_mask,
133
133
  answer['input_ids'].gather(1, answer_pos),
@@ -297,19 +297,17 @@ class BatchSampler:
297
297
  if not active.any():
298
298
  break
299
299
 
300
- active_indices = active.nonzero(as_tuple=True)[0]
301
- active_current_lens = current_lens[active]
302
- max_len = active_current_lens.max().item()
300
+ max_len = current_lens.max().item()
303
301
 
304
302
  with torch.set_grad_enabled(not no_grad):
305
303
  # Slice input and mask up to the current max length among active sequences
306
- inputs = working_ids[active, :max_len]
307
- masks = working_mask[active, :max_len]
304
+ inputs = working_ids[:, :max_len]
305
+ masks = working_mask[:, :max_len]
308
306
  logits = self.model(inputs, attention_mask=masks)
309
307
 
310
308
  # Get the last valid token index for each active sequence
311
- indices = (active_current_lens - 1).to(self.device)
312
- last_logits = logits[torch.arange(len(active_indices), device=self.device), indices]
309
+ indices = (current_lens - 1).to(self.device)
310
+ last_logits = logits[torch.arange(batch_size, device=self.device), indices]
313
311
 
314
312
  # Sample next tokens and log probs
315
313
  next_tokens, step_log_probs = sample_batch(
@@ -317,15 +315,18 @@ class BatchSampler:
317
315
  )
318
316
 
319
317
  # Update working tensors
320
- for i, idx in enumerate(active_indices):
321
- if current_lens[idx] >= max_seq_len:
318
+ for idx in range(batch_size):
319
+ if finished[idx] or current_lens[idx] >= max_seq_len:
322
320
  continue
321
+
323
322
  pos = current_lens[idx].item()
324
- working_ids[idx, pos] = next_tokens[i]
325
- working_mask[idx, pos] = 1
326
- log_probs[idx, step] = step_log_probs[i]
323
+ token = next_tokens[idx] # Use original batch index
324
+ working_ids[idx, pos] = token
325
+ working_mask[idx, pos] = 1 if token != 0 else 0
326
+ log_probs[idx, step] = step_log_probs[idx] # Use original batch index
327
327
  current_lens[idx] += 1
328
- if next_tokens[i] == self.end_token_id:
328
+
329
+ if token == self.end_token_id:
329
330
  finished[idx] = True
330
331
 
331
332
  # Extract generated tokens
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,7 +7,7 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
9
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
10
- rxnn/memory/stm.py,sha256=eIxbmOh7SSI3YDim6ki2JgiCel0sDlLnD_qSqcrf7Wc,3881
10
+ rxnn/memory/stm.py,sha256=DPkK1q1SLRw3HWM0dcvkn4XvIrfwUK47h4KmvFVWljc,3847
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,12 +16,12 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=o68IPFJyWM1CGooPRDNU9DfNcy4H_o0PcKDTn_ZLnKA,35053
17
17
  rxnn/training/dataset.py,sha256=XeRzo0KUYyQ43XjZ3o6Jban9ePIRtpHsqUmeKAQPRQk,50305
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=qGEz9C8O4Yq5P0_WTdvl0QWzdDtTZZVeoNTON6ICB74,39877
19
+ rxnn/training/mrl.py,sha256=cftSa6zS3jfremxX0SQrxtbhbEHfZ-nvZT3Hl6GlgpI,39282
20
20
  rxnn/training/reward.py,sha256=i0nhrPCDgy1di89HWylRBS6cQ7rSSxJUiS3TX8fiiHE,5614
21
21
  rxnn/training/rl.py,sha256=FKrBOBAfNub_qzkceFQR-WUtCBffC6oGHE8wlPsz2YA,2682
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
24
- rxnn/training/utils.py,sha256=U4fVNNomwAs_XqzqAdqtlPY1UeAl6E6oKOOSB0-Ckx0,5718
24
+ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
@@ -30,9 +30,9 @@ rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
30
30
  rxnn/transformers/models.py,sha256=VvP7r7E6tj7OWsYKlJLCy2vsQ3xSSnlNez6QxR-jBAA,8276
31
31
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
- rxnn/transformers/sampler.py,sha256=2dpUQv88ekZa_CMSPLrXvB6X684wxUE2bDVznsi5ACs,17429
33
+ rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.7.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.7.dist-info/METADATA,sha256=xyc69x0B-y2j_ww61VhAeKoQL3pvMMXe2hUV2QNWKZI,25959
37
- rxnn-0.2.7.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.7.dist-info/RECORD,,
35
+ rxnn-0.2.9.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.9.dist-info/METADATA,sha256=ERPLS1G1D0zDUR3OUpTAPUHQ47Xjd3U4b4bSMnqA4p4,25959
37
+ rxnn-0.2.9.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.9.dist-info/RECORD,,
File without changes
File without changes