rxnn 0.2.58__tar.gz → 0.2.60__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.58 → rxnn-0.2.60}/PKG-INFO +1 -1
  2. {rxnn-0.2.58 → rxnn-0.2.60}/pyproject.toml +1 -1
  3. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/memory/attention.py +3 -2
  4. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/rxt/models.py +2 -2
  5. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/models.py +1 -0
  6. {rxnn-0.2.58 → rxnn-0.2.60}/LICENSE +0 -0
  7. {rxnn-0.2.58 → rxnn-0.2.60}/README.md +0 -0
  8. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/memory/norm.py +0 -0
  16. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/memory/stm.py +0 -0
  17. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/rxt/__init__.py +0 -0
  18. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.58 → rxnn-0.2.60}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.58
3
+ Version: 0.2.60
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.58"
7
+ version = "0.2.60"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -48,7 +48,8 @@ class StmMemoryAttention(nn.Module):
48
48
  return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
49
 
50
50
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
51
- mem_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask else None
51
+ if attention_mask is not None:
52
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
52
53
  new_stm = torch.zeros_like(self.stm.memory)
53
54
  for i in range(self.num_layers):
54
55
  layer_stm = self.stm(i)
@@ -57,7 +58,7 @@ class StmMemoryAttention(nn.Module):
57
58
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
58
59
  encoded_layer_data = x[i]
59
60
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
60
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mem_mask)
61
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
61
62
  if self.use_gated_residual:
62
63
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
63
64
  else:
@@ -103,13 +103,13 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
103
103
  if cross_att_type in ['mha', 'gqa', 'mqa']:
104
104
  cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
105
105
  use_flash_attention=use_flash_attention, dropout=att_dropout,
106
- max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
106
+ max_seq_len=seq_len, is_causal=False, rope_only_for_query=True)
107
107
  else:
108
108
  cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
109
109
  cross_att_groups or att_groups, rope=rope,
110
110
  use_flash_attention=use_flash_attention,
111
111
  dropout=att_dropout,
112
- max_seq_len=seq_len, is_causal=is_causal,
112
+ max_seq_len=seq_len, is_causal=False,
113
113
  num_experts=att_experts,
114
114
  num_query_experts=att_query_experts,
115
115
  num_query_groups=cross_att_query_groups or att_query_groups,
@@ -213,6 +213,7 @@ class ClassicTransformerDecoder(ClassicTransformerBase):
213
213
  if attention_mask is not None:
214
214
  mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
215
215
  elif attention_mask is not None:
216
+ print(attention_mask.size())
216
217
  mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
217
218
  else:
218
219
  mask = None
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes