rxnn 0.2.31__tar.gz → 0.2.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.31 → rxnn-0.2.32}/PKG-INFO +1 -1
- {rxnn-0.2.31 → rxnn-0.2.32}/pyproject.toml +1 -1
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/stm.py +0 -1
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/rxt/models.py +27 -1
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/models.py +33 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/LICENSE +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/README.md +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/mrl.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.31 → rxnn-0.2.32}/src/rxnn/utils.py +0 -0
@@ -5,7 +5,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
5
5
|
from ..transformers.positional import RotaryPositionalEmbedding
|
6
6
|
from ..transformers.attention import init_attention
|
7
7
|
from ..transformers.layers import ReactiveTransformerLayer
|
8
|
-
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
8
|
+
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder, ReactiveTransformerEncoderDetachStm
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
11
|
from ..memory.norm import init_memory_norm
|
@@ -293,3 +293,29 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
293
293
|
|
294
294
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
295
295
|
return self.model(x, attention_mask=attention_mask)
|
296
|
+
|
297
|
+
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
298
|
+
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
299
|
+
|
300
|
+
def __init__(self, **kwargs: RxTAlphaComponentConfig):
|
301
|
+
super(RxTAlphaCriticEncoder, self).__init__(False, **kwargs)
|
302
|
+
|
303
|
+
def _init_model(
|
304
|
+
self,
|
305
|
+
stm: ShortTermMemory,
|
306
|
+
layers: nn.ModuleList,
|
307
|
+
embedding: nn.Embedding,
|
308
|
+
use_flash_attention: bool,
|
309
|
+
embed_dim: int,
|
310
|
+
vocab_size: int
|
311
|
+
) -> ReactiveTransformerEncoderDetachStm:
|
312
|
+
return ReactiveTransformerEncoderDetachStm(
|
313
|
+
stm=stm,
|
314
|
+
embedding=embedding,
|
315
|
+
own_layers=layers,
|
316
|
+
use_flash_attention=use_flash_attention,
|
317
|
+
)
|
318
|
+
|
319
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
320
|
+
return self.model(x, attention_mask=attention_mask)
|
321
|
+
|
@@ -126,6 +126,39 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
126
126
|
return x, torch.stack(hidden_states)
|
127
127
|
|
128
128
|
|
129
|
+
class ReactiveTransformerEncoderDetachStm(ReactiveTransformerBase):
|
130
|
+
"""
|
131
|
+
Reactive Transformer encoder DetachStm version - reactive transformer encoder that's detaching Short-Term Memory tensors,
|
132
|
+
before processing them in layers (memory cross-attention). Made for Memory-Aware Critic models, to not include memory
|
133
|
+
update gradients in Critic optimization.
|
134
|
+
"""
|
135
|
+
|
136
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
137
|
+
x = super().forward(x) # apply embeddings
|
138
|
+
if attention_mask is not None:
|
139
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
140
|
+
|
141
|
+
hidden_states = []
|
142
|
+
# Process shared layers
|
143
|
+
if self.shared_layers is not None:
|
144
|
+
for i in range(self.num_shared_layers):
|
145
|
+
layer_stm = self.stm(i).detach() # <- Detach STM layer
|
146
|
+
# expand layer STM to batch size, if it's not in batch mode
|
147
|
+
if layer_stm.size(0) == 1:
|
148
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
149
|
+
x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
|
150
|
+
hidden_states.append(x)
|
151
|
+
# Process own layers
|
152
|
+
for i in range(self.num_own_layers):
|
153
|
+
layer_stm = self.stm(i).detach() # <- Detach STM layer
|
154
|
+
# expand layer STM to batch size, if it's not in batch mode
|
155
|
+
if layer_stm.size(0) == 1:
|
156
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
157
|
+
x = self.layers[i](x, layer_stm, mask=attention_mask)
|
158
|
+
hidden_states.append(x)
|
159
|
+
return x, torch.stack(hidden_states)
|
160
|
+
|
161
|
+
|
129
162
|
class ClassicTransformerBase(nn.Module):
|
130
163
|
"""Base class for Classic Transformer models - common logic for both decoders and encoders."""
|
131
164
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|