rxnn 0.2.57__tar.gz → 0.2.58__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.57 → rxnn-0.2.58}/PKG-INFO +1 -1
  2. {rxnn-0.2.57 → rxnn-0.2.58}/pyproject.toml +1 -1
  3. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/memory/attention.py +3 -2
  4. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/rxt/models.py +2 -2
  5. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/models.py +1 -1
  6. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/layers.py +5 -1
  7. {rxnn-0.2.57 → rxnn-0.2.58}/LICENSE +0 -0
  8. {rxnn-0.2.57 → rxnn-0.2.58}/README.md +0 -0
  9. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.57 → rxnn-0.2.58}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.57
3
+ Version: 0.2.58
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.57"
7
+ version = "0.2.58"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -47,7 +47,8 @@ class StmMemoryAttention(nn.Module):
47
47
  else:
48
48
  return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
49
49
 
50
- def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
51
+ mem_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool() if attention_mask else None
51
52
  new_stm = torch.zeros_like(self.stm.memory)
52
53
  for i in range(self.num_layers):
53
54
  layer_stm = self.stm(i)
@@ -56,7 +57,7 @@ class StmMemoryAttention(nn.Module):
56
57
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
57
58
  encoded_layer_data = x[i]
58
59
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
59
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data)
60
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mem_mask)
60
61
  if self.use_gated_residual:
61
62
  new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
62
63
  else:
@@ -306,8 +306,8 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
306
306
  def clone_reset_memory(self):
307
307
  self.model.stm.clone_detach_reset()
308
308
 
309
- def forward(self, x: torch.Tensor) -> torch.Tensor:
310
- return self.model(x)
309
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
310
+ return self.model(x, attention_mask=attention_mask)
311
311
 
312
312
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
313
313
  """RxT-Alpha (Reactive Transformer) encoder model"""
@@ -204,7 +204,7 @@ class MrlActorModel(nn.Module):
204
204
  return self.decoder(x, attention_mask=attention_mask)
205
205
  else:
206
206
  _, ed = self.encoder(x, attention_mask=attention_mask)
207
- return self.memory_attention(ed)
207
+ return self.memory_attention(ed, attention_mask=attention_mask)
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
@@ -110,7 +110,11 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm, mask=mask)
113
+
114
+ if mask is not None:
115
+ mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1))
116
+
117
+ x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
114
118
  x = residual + x
115
119
  if self.use_post_norm:
116
120
  x = self.norm2(x)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes