rxnn 0.2.31__py3-none-any.whl → 0.2.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +0 -1
- rxnn/rxt/models.py +27 -1
- rxnn/training/models.py +11 -3
- rxnn/training/mrl.py +2 -2
- rxnn/training/reward.py +9 -1
- rxnn/training/rl.py +28 -17
- rxnn/transformers/ff.py +2 -0
- rxnn/transformers/models.py +33 -0
- {rxnn-0.2.31.dist-info → rxnn-0.2.33.dist-info}/METADATA +1 -1
- {rxnn-0.2.31.dist-info → rxnn-0.2.33.dist-info}/RECORD +12 -12
- {rxnn-0.2.31.dist-info → rxnn-0.2.33.dist-info}/LICENSE +0 -0
- {rxnn-0.2.31.dist-info → rxnn-0.2.33.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
rxnn/rxt/models.py
CHANGED
@@ -5,7 +5,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
5
5
|
from ..transformers.positional import RotaryPositionalEmbedding
|
6
6
|
from ..transformers.attention import init_attention
|
7
7
|
from ..transformers.layers import ReactiveTransformerLayer
|
8
|
-
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
8
|
+
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder, ReactiveTransformerEncoderDetachStm
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
11
|
from ..memory.norm import init_memory_norm
|
@@ -293,3 +293,29 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
293
293
|
|
294
294
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
295
295
|
return self.model(x, attention_mask=attention_mask)
|
296
|
+
|
297
|
+
class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
|
298
|
+
"""RxT-Alpha (Reactive Transformer) encoder model"""
|
299
|
+
|
300
|
+
def __init__(self, **kwargs: RxTAlphaComponentConfig):
|
301
|
+
super(RxTAlphaCriticEncoder, self).__init__(False, **kwargs)
|
302
|
+
|
303
|
+
def _init_model(
|
304
|
+
self,
|
305
|
+
stm: ShortTermMemory,
|
306
|
+
layers: nn.ModuleList,
|
307
|
+
embedding: nn.Embedding,
|
308
|
+
use_flash_attention: bool,
|
309
|
+
embed_dim: int,
|
310
|
+
vocab_size: int
|
311
|
+
) -> ReactiveTransformerEncoderDetachStm:
|
312
|
+
return ReactiveTransformerEncoderDetachStm(
|
313
|
+
stm=stm,
|
314
|
+
embedding=embedding,
|
315
|
+
own_layers=layers,
|
316
|
+
use_flash_attention=use_flash_attention,
|
317
|
+
)
|
318
|
+
|
319
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
320
|
+
return self.model(x, attention_mask=attention_mask)
|
321
|
+
|
rxnn/training/models.py
CHANGED
@@ -4,6 +4,7 @@ from enum import Enum
|
|
4
4
|
from typing import Literal, Iterator
|
5
5
|
from huggingface_hub import PyTorchModelHubMixin
|
6
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
7
|
+
from ..transformers.ff import GatedLinearUnit, get_activation_layer
|
7
8
|
|
8
9
|
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
9
10
|
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
@@ -161,10 +162,16 @@ class MrlActorModel(nn.Module):
|
|
161
162
|
return self.memory_attention(ed, attention_mask=attention_mask)
|
162
163
|
|
163
164
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
164
|
-
def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
|
165
|
+
def __init__(self, encoder: nn.Module, embed_dim: int, out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
|
165
166
|
super(MrlCriticModel, self).__init__(**kwargs)
|
166
167
|
self.encoder = encoder
|
167
|
-
self.value_head = nn.
|
168
|
+
self.value_head = nn.Sequential(
|
169
|
+
GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
|
170
|
+
nn.LayerNorm(embed_dim),
|
171
|
+
nn.Linear(embed_dim, 1),
|
172
|
+
get_activation_layer(out_activation)
|
173
|
+
)
|
174
|
+
self.output_scale = output_scale
|
168
175
|
|
169
176
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
170
177
|
x, _ = self.encoder(x, attention_mask=attention_mask)
|
@@ -175,4 +182,5 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
175
182
|
else:
|
176
183
|
x = x.mean(dim=1)
|
177
184
|
|
178
|
-
return self.value_head(x)
|
185
|
+
return self.value_head(x) * self.output_scale
|
186
|
+
|
rxnn/training/mrl.py
CHANGED
@@ -481,7 +481,7 @@ class MRLTrainer:
|
|
481
481
|
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
482
482
|
pad_token_id=self.pad_token_id)
|
483
483
|
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
484
|
-
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
484
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
|
485
485
|
# 2.2 Run backpropagation with scaler
|
486
486
|
self.critic_scaler.scale(critic_loss).backward()
|
487
487
|
# 2.3 Unscale and clip gradients
|
@@ -495,7 +495,7 @@ class MRLTrainer:
|
|
495
495
|
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
496
496
|
pad_token_id=self.pad_token_id)
|
497
497
|
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
498
|
-
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
498
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
|
499
499
|
# 2.2 Run backpropagation
|
500
500
|
critic_loss.backward()
|
501
501
|
# 2.3 Clip gradients
|
rxnn/training/reward.py
CHANGED
@@ -42,6 +42,7 @@ class MrlRewardModel:
|
|
42
42
|
running_mean_decay: float = 0.2,
|
43
43
|
bleu_saved_weights: tuple = (0.5, 0.5),
|
44
44
|
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
|
+
tanh_reward_scale: bool = False,
|
45
46
|
rewards_scale: float = 1.0,
|
46
47
|
):
|
47
48
|
self.shared_embedding = shared_embedding.to(device)
|
@@ -71,6 +72,7 @@ class MrlRewardModel:
|
|
71
72
|
self.running_mean_decay = running_mean_decay
|
72
73
|
self.bleu_ref_weights = bleu_ref_weights
|
73
74
|
self.bleu_saved_weights = bleu_saved_weights
|
75
|
+
self.tanh_reward_scale = tanh_reward_scale
|
74
76
|
self.rewards_scale = rewards_scale
|
75
77
|
|
76
78
|
self.prev_data_running_mean = None
|
@@ -175,6 +177,12 @@ class MrlRewardModel:
|
|
175
177
|
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
176
178
|
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
177
179
|
|
180
|
+
def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
|
181
|
+
if self.tanh_reward_scale:
|
182
|
+
return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
|
183
|
+
else:
|
184
|
+
return rewards
|
185
|
+
|
178
186
|
def __call__(
|
179
187
|
self,
|
180
188
|
generated: TokenizedDict,
|
@@ -204,5 +212,5 @@ class MrlRewardModel:
|
|
204
212
|
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
205
213
|
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
206
214
|
|
207
|
-
rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
215
|
+
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
208
216
|
return rewards.tolist()
|
rxnn/training/rl.py
CHANGED
@@ -21,8 +21,8 @@ class RlAlgorithm(ABC):
|
|
21
21
|
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
22
22
|
pass
|
23
23
|
|
24
|
-
def critic_loss(self,
|
25
|
-
return self.critic_loss(
|
24
|
+
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
25
|
+
return self.critic_loss(values, ref_values)
|
26
26
|
|
27
27
|
|
28
28
|
class PPOConfig(TypedDict):
|
@@ -31,6 +31,8 @@ class PPOConfig(TypedDict):
|
|
31
31
|
gae_gamma: Optional[float]
|
32
32
|
entropy_coef: Optional[float]
|
33
33
|
use_distributed_advantage_norm: Optional[bool]
|
34
|
+
clip_critic_values: Optional[bool]
|
35
|
+
critic_value_clip: Optional[float]
|
34
36
|
|
35
37
|
|
36
38
|
class PPOAlgorithm(RlAlgorithm):
|
@@ -46,6 +48,14 @@ class PPOAlgorithm(RlAlgorithm):
|
|
46
48
|
self.gae_gamma = config.get('gae_gamma', 0.99)
|
47
49
|
self.entropy_coef = config.get('entropy_coef', 0.01)
|
48
50
|
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
51
|
+
self.clip_critic_values = config.get('clip_critic_values', True)
|
52
|
+
self.critic_value_clip = config.get('critic_value_clip', 10.0)
|
53
|
+
|
54
|
+
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
55
|
+
# Critic loss with clipped values
|
56
|
+
if self.clip_critic_values:
|
57
|
+
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
+
return self.critic_loss(values, ref_values)
|
49
59
|
|
50
60
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
51
61
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -96,23 +106,24 @@ class PPOAlgorithm(RlAlgorithm):
|
|
96
106
|
|
97
107
|
return policy_loss
|
98
108
|
|
99
|
-
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
|
100
|
-
|
101
|
-
|
109
|
+
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
|
110
|
+
last_value: torch.Tensor, dones: torch.Tensor):
|
111
|
+
trajectory_len, batch_size = rewards.shape
|
112
|
+
advantages = torch.zeros_like(rewards, device=rewards.device)
|
102
113
|
last_advantage = 0
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
delta = rewards[t] + self.gae_gamma * next_values - values[t]
|
114
|
-
advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
114
|
+
next_value = last_value
|
115
|
+
next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
|
116
|
+
dones = dones.float()
|
117
|
+
for t in reversed(range(trajectory_len)):
|
118
|
+
# Check if next state is terminal
|
119
|
+
non_terminal = 1.0 - next_done
|
120
|
+
|
121
|
+
# Delta should not include next_value if next is terminal
|
122
|
+
delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
|
123
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
|
115
124
|
last_advantage = advantages[t]
|
125
|
+
next_value = values[t]
|
126
|
+
next_done = dones[t]
|
116
127
|
|
117
128
|
returns = advantages + values
|
118
129
|
return advantages, returns
|
rxnn/transformers/ff.py
CHANGED
rxnn/transformers/models.py
CHANGED
@@ -126,6 +126,39 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
|
|
126
126
|
return x, torch.stack(hidden_states)
|
127
127
|
|
128
128
|
|
129
|
+
class ReactiveTransformerEncoderDetachStm(ReactiveTransformerBase):
|
130
|
+
"""
|
131
|
+
Reactive Transformer encoder DetachStm version - reactive transformer encoder that's detaching Short-Term Memory tensors,
|
132
|
+
before processing them in layers (memory cross-attention). Made for Memory-Aware Critic models, to not include memory
|
133
|
+
update gradients in Critic optimization.
|
134
|
+
"""
|
135
|
+
|
136
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
137
|
+
x = super().forward(x) # apply embeddings
|
138
|
+
if attention_mask is not None:
|
139
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
|
140
|
+
|
141
|
+
hidden_states = []
|
142
|
+
# Process shared layers
|
143
|
+
if self.shared_layers is not None:
|
144
|
+
for i in range(self.num_shared_layers):
|
145
|
+
layer_stm = self.stm(i).detach() # <- Detach STM layer
|
146
|
+
# expand layer STM to batch size, if it's not in batch mode
|
147
|
+
if layer_stm.size(0) == 1:
|
148
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
149
|
+
x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
|
150
|
+
hidden_states.append(x)
|
151
|
+
# Process own layers
|
152
|
+
for i in range(self.num_own_layers):
|
153
|
+
layer_stm = self.stm(i).detach() # <- Detach STM layer
|
154
|
+
# expand layer STM to batch size, if it's not in batch mode
|
155
|
+
if layer_stm.size(0) == 1:
|
156
|
+
layer_stm = layer_stm.expand(x.size(0), -1, -1)
|
157
|
+
x = self.layers[i](x, layer_stm, mask=attention_mask)
|
158
|
+
hidden_states.append(x)
|
159
|
+
return x, torch.stack(hidden_states)
|
160
|
+
|
161
|
+
|
129
162
|
class ClassicTransformerBase(nn.Module):
|
130
163
|
"""Base class for Classic Transformer models - common logic for both decoders and encoders."""
|
131
164
|
|
@@ -7,33 +7,33 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
|
9
9
|
rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=CzFELVv5-ybAwl1s1ptpmwM7wdJ07M4jaT1-I8PYrR0,13999
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=TGz_37RfI1qLI31GNRV5rLowW1kAHnJwqPm7DNfLfe4,11730
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
21
|
-
rxnn/training/reward.py,sha256=
|
22
|
-
rxnn/training/rl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=8FV5eZx1HxtqRSgikwfKoB_bNhPuMYyNi0uSXB65-M4,7223
|
20
|
+
rxnn/training/mrl.py,sha256=1pYzjXI17FDZGPTVpmbaBvMYpB-a6SLv-84RHXA4JEA,55142
|
21
|
+
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
|
+
rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
28
|
-
rxnn/transformers/ff.py,sha256=
|
28
|
+
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
|
-
rxnn/transformers/models.py,sha256=
|
31
|
+
rxnn/transformers/models.py,sha256=hey6tFN9gmLfWCZLjtl_9OcvIjGpWLI1IDeVnr5y8YM,10583
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.33.dist-info/METADATA,sha256=im17irb58IYMXOzMXE6QaSPF31Akx0iYS4ay-aRqA9Q,25960
|
38
|
+
rxnn-0.2.33.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.33.dist-info/RECORD,,
|
File without changes
|
File without changes
|