rxnn 0.2.7__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {rxnn-0.2.7 → rxnn-0.2.8}/PKG-INFO +1 -1
  2. {rxnn-0.2.7 → rxnn-0.2.8}/pyproject.toml +1 -1
  3. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/utils.py +1 -1
  4. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/layers.py +3 -1
  5. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/models.py +2 -1
  6. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/sampler.py +5 -0
  7. {rxnn-0.2.7 → rxnn-0.2.8}/LICENSE +0 -0
  8. {rxnn-0.2.7 → rxnn-0.2.8}/README.md +0 -0
  9. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/memory/attention.py +0 -0
  17. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/memory/stm.py +0 -0
  19. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/rxt/__init__.py +0 -0
  20. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/rxt/models.py +0 -0
  21. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/__init__.py +0 -0
  22. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/base.py +0 -0
  23. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/bml.py +0 -0
  24. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/callbacks.py +0 -0
  25. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/dataset.py +0 -0
  26. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/models.py +0 -0
  27. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/mrl.py +0 -0
  28. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/reward.py +0 -0
  29. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/rl.py +0 -0
  30. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/scheduler.py +0 -0
  31. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/training/tokenizer.py +0 -0
  32. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/__init__.py +0 -0
  33. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/attention.py +0 -0
  34. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/ff.py +0 -0
  35. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.7 → rxnn-0.2.8}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.7"
7
+ version = "0.2.8"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -127,7 +127,7 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
127
127
  # Build combined_ids using vectorized where
128
128
  combined_ids = torch.where(
129
129
  query_mask,
130
- query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1)).to(torch.int64),
130
+ query['input_ids'].gather(1, torch.minimum(positions, query_lens.unsqueeze(1) - 1).to(torch.int64)),
131
131
  torch.where(
132
132
  answer_mask,
133
133
  answer['input_ids'].gather(1, answer_pos),
@@ -87,11 +87,12 @@ class ReactiveTransformerLayer(nn.Module):
87
87
  else:
88
88
  return None
89
89
 
90
- def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
90
+ def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None, logs: bool = False) -> torch.Tensor:
91
91
  # First step, self-attention
92
92
  residual = x
93
93
  if not self.use_post_norm:
94
94
  x = self.norm1(x)
95
+ if logs: print('att', x.size())
95
96
  x = self.attention(x, x, x, mask=mask)
96
97
  x = residual + x
97
98
  if self.use_post_norm:
@@ -100,6 +101,7 @@ class ReactiveTransformerLayer(nn.Module):
100
101
  residual = x
101
102
  if not self.use_post_norm:
102
103
  x = self.norm2(x)
104
+ if logs: print('mxatt', x.size())
103
105
  x = self.memory_cross_attention(x, stm, stm)
104
106
  x = residual + x
105
107
  if self.use_post_norm:
@@ -83,7 +83,8 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
83
83
  # expand layer STM to batch size, if it's not in batch mode
84
84
  if layer_stm.size(0) == 1:
85
85
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
86
- x = self.layers[i](x, layer_stm, mask=mask)
86
+ print(f'Layer {i} before: {x.size()}')
87
+ x = self.layers[i](x, layer_stm, mask=mask, logs=True)
87
88
  return self.head(x)
88
89
 
89
90
 
@@ -293,6 +293,7 @@ class BatchSampler:
293
293
  working_mask = attention_mask.clone()
294
294
 
295
295
  for step in range(max_gen_len):
296
+ print('Sampler step', step)
296
297
  active = (~finished) & (current_lens < max_seq_len)
297
298
  if not active.any():
298
299
  break
@@ -302,9 +303,13 @@ class BatchSampler:
302
303
  max_len = active_current_lens.max().item()
303
304
 
304
305
  with torch.set_grad_enabled(not no_grad):
306
+ print('Active size', active.size())
305
307
  # Slice input and mask up to the current max length among active sequences
306
308
  inputs = working_ids[active, :max_len]
307
309
  masks = working_mask[active, :max_len]
310
+ print('Model input sizes')
311
+ print(inputs.size())
312
+ print(masks.size())
308
313
  logits = self.model(inputs, attention_mask=masks)
309
314
 
310
315
  # Get the last valid token index for each active sequence
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes