rxnn 0.2.31__tar.gz → 0.2.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.31 → rxnn-0.2.32}/PKG-INFO +1 -1
  2. {rxnn-0.2.31 → rxnn-0.2.32}/pyproject.toml +1 -1
  3. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/stm.py +0 -1
  4. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/rxt/models.py +27 -1
  5. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/models.py +33 -0
  6. {rxnn-0.2.31 → rxnn-0.2.32}/LICENSE +0 -0
  7. {rxnn-0.2.31 → rxnn-0.2.32}/README.md +0 -0
  8. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/attention.py +0 -0
  16. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/rxt/__init__.py +0 -0
  18. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.31
3
+ Version: 0.2.32
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.31"
7
+ version = "0.2.32"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -44,7 +44,6 @@ class ShortTermMemory(nn.Module):
44
44
 
45
45
  def update_all(self, new_stm: torch.Tensor):
46
46
  self.memory = new_stm
47
- # self.memory.copy_(new_stm)
48
47
 
49
48
  def make_trainable(self):
50
49
  if not self.is_trainable:
@@ -5,7 +5,7 @@ from huggingface_hub import PyTorchModelHubMixin
5
5
  from ..transformers.positional import RotaryPositionalEmbedding
6
6
  from ..transformers.attention import init_attention
7
7
  from ..transformers.layers import ReactiveTransformerLayer
8
- from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
8
+ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder, ReactiveTransformerEncoderDetachStm
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..memory.norm import init_memory_norm
@@ -293,3 +293,29 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
293
293
 
294
294
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
295
295
  return self.model(x, attention_mask=attention_mask)
296
+
297
+ class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
298
+ """RxT-Alpha (Reactive Transformer) encoder model"""
299
+
300
+ def __init__(self, **kwargs: RxTAlphaComponentConfig):
301
+ super(RxTAlphaCriticEncoder, self).__init__(False, **kwargs)
302
+
303
+ def _init_model(
304
+ self,
305
+ stm: ShortTermMemory,
306
+ layers: nn.ModuleList,
307
+ embedding: nn.Embedding,
308
+ use_flash_attention: bool,
309
+ embed_dim: int,
310
+ vocab_size: int
311
+ ) -> ReactiveTransformerEncoderDetachStm:
312
+ return ReactiveTransformerEncoderDetachStm(
313
+ stm=stm,
314
+ embedding=embedding,
315
+ own_layers=layers,
316
+ use_flash_attention=use_flash_attention,
317
+ )
318
+
319
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
320
+ return self.model(x, attention_mask=attention_mask)
321
+
@@ -126,6 +126,39 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
126
126
  return x, torch.stack(hidden_states)
127
127
 
128
128
 
129
+ class ReactiveTransformerEncoderDetachStm(ReactiveTransformerBase):
130
+ """
131
+ Reactive Transformer encoder DetachStm version - reactive transformer encoder that's detaching Short-Term Memory tensors,
132
+ before processing them in layers (memory cross-attention). Made for Memory-Aware Critic models, to not include memory
133
+ update gradients in Critic optimization.
134
+ """
135
+
136
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
137
+ x = super().forward(x) # apply embeddings
138
+ if attention_mask is not None:
139
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
140
+
141
+ hidden_states = []
142
+ # Process shared layers
143
+ if self.shared_layers is not None:
144
+ for i in range(self.num_shared_layers):
145
+ layer_stm = self.stm(i).detach() # <- Detach STM layer
146
+ # expand layer STM to batch size, if it's not in batch mode
147
+ if layer_stm.size(0) == 1:
148
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
149
+ x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
150
+ hidden_states.append(x)
151
+ # Process own layers
152
+ for i in range(self.num_own_layers):
153
+ layer_stm = self.stm(i).detach() # <- Detach STM layer
154
+ # expand layer STM to batch size, if it's not in batch mode
155
+ if layer_stm.size(0) == 1:
156
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
157
+ x = self.layers[i](x, layer_stm, mask=attention_mask)
158
+ hidden_states.append(x)
159
+ return x, torch.stack(hidden_states)
160
+
161
+
129
162
  class ClassicTransformerBase(nn.Module):
130
163
  """Base class for Classic Transformer models - common logic for both decoders and encoders."""
131
164
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes