rxnn 0.2.33__tar.gz → 0.2.35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.33 → rxnn-0.2.35}/PKG-INFO +1 -1
- {rxnn-0.2.33 → rxnn-0.2.35}/pyproject.toml +1 -1
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/models.py +15 -3
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/mrl.py +16 -4
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/rl.py +12 -10
- {rxnn-0.2.33 → rxnn-0.2.35}/LICENSE +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/README.md +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.33 → rxnn-0.2.35}/src/rxnn/utils.py +0 -0
@@ -6,6 +6,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
6
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
7
7
|
from ..transformers.ff import GatedLinearUnit, get_activation_layer
|
8
8
|
|
9
|
+
|
9
10
|
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
10
11
|
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
11
12
|
super(MLMHead, self).__init__(*args, **kwargs)
|
@@ -38,6 +39,7 @@ class MLMTrainingModel(nn.Module):
|
|
38
39
|
y = self.mlm_head(h)
|
39
40
|
return y
|
40
41
|
|
42
|
+
|
41
43
|
class JointTrainingModel(nn.Module):
|
42
44
|
def __init__(
|
43
45
|
self,
|
@@ -59,10 +61,12 @@ class JointTrainingModel(nn.Module):
|
|
59
61
|
y_d = self.decoder(x_d, attention_mask=attention_mask)
|
60
62
|
return y_e, y_d
|
61
63
|
|
64
|
+
|
62
65
|
class MrlActorAction(Enum):
|
63
66
|
DECODE = 1
|
64
67
|
UPDATE = 2
|
65
68
|
|
69
|
+
|
66
70
|
class MrlActorModel(nn.Module):
|
67
71
|
def __init__(
|
68
72
|
self,
|
@@ -154,15 +158,18 @@ class MrlActorModel(nn.Module):
|
|
154
158
|
list(self.memory_attention.parameters())
|
155
159
|
))
|
156
160
|
|
157
|
-
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
|
161
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
|
162
|
+
action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
|
158
163
|
if action == MrlActorAction.DECODE:
|
159
164
|
return self.decoder(x, attention_mask=attention_mask)
|
160
165
|
else:
|
161
166
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
162
167
|
return self.memory_attention(ed, attention_mask=attention_mask)
|
163
168
|
|
169
|
+
|
164
170
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
165
|
-
def __init__(self, encoder: nn.Module, embed_dim: int,
|
171
|
+
def __init__(self, encoder: nn.Module, embed_dim: int,
|
172
|
+
out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
|
166
173
|
super(MrlCriticModel, self).__init__(**kwargs)
|
167
174
|
self.encoder = encoder
|
168
175
|
self.value_head = nn.Sequential(
|
@@ -173,6 +180,12 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
173
180
|
)
|
174
181
|
self.output_scale = output_scale
|
175
182
|
|
183
|
+
def head_parameters(self) -> Iterator[nn.Parameter]:
|
184
|
+
return self.value_head.parameters()
|
185
|
+
|
186
|
+
def encoder_parameters(self) -> Iterator[nn.Parameter]:
|
187
|
+
return self.encoder.parameters()
|
188
|
+
|
176
189
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
177
190
|
x, _ = self.encoder(x, attention_mask=attention_mask)
|
178
191
|
|
@@ -183,4 +196,3 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
183
196
|
x = x.mean(dim=1)
|
184
197
|
|
185
198
|
return self.value_head(x) * self.output_scale
|
186
|
-
|
@@ -15,11 +15,13 @@ from .reward import MrlRewardMode, MrlRewardModel
|
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
16
16
|
from .ddp import get_os_ddp_config, distributed_mean
|
17
17
|
|
18
|
+
|
18
19
|
class MrlConfig(TypedDict):
|
19
20
|
lr: float
|
20
21
|
separate_memory_lr: Optional[bool]
|
21
22
|
memory_lr: Optional[float]
|
22
23
|
critic_lr: float
|
24
|
+
critic_encoder_lr: float
|
23
25
|
max_seq_len: int
|
24
26
|
critic_max_len: int
|
25
27
|
weight_decay: float
|
@@ -58,6 +60,7 @@ class CurriculumConfig(TypedDict):
|
|
58
60
|
lr: Optional[float]
|
59
61
|
memory_lr: Optional[float]
|
60
62
|
critic_lr: Optional[float]
|
63
|
+
critic_encoder_lr: Optional[float]
|
61
64
|
weight_decay: Optional[float]
|
62
65
|
critic_weight_decay: Optional[float]
|
63
66
|
update_epochs: Optional[int]
|
@@ -158,6 +161,7 @@ class MRLTrainer:
|
|
158
161
|
'critic_lr': config.get('critic_lr', 1e-4),
|
159
162
|
'weight_decay': config.get('weight_decay', 0.01),
|
160
163
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
164
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
161
165
|
}
|
162
166
|
else:
|
163
167
|
self.base_optim_config = {
|
@@ -165,6 +169,7 @@ class MRLTrainer:
|
|
165
169
|
'critic_lr': config.get('critic_lr', 1e-4),
|
166
170
|
'weight_decay': config.get('weight_decay', 0.01),
|
167
171
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
172
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
168
173
|
}
|
169
174
|
|
170
175
|
self.optim_config = self.base_optim_config
|
@@ -202,6 +207,7 @@ class MRLTrainer:
|
|
202
207
|
critic_lr: float,
|
203
208
|
weight_decay: float,
|
204
209
|
critic_weight_decay: float,
|
210
|
+
critic_encoder_lr: float,
|
205
211
|
memory_lr: Optional[float] = None,
|
206
212
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
207
213
|
if memory_lr is not None:
|
@@ -219,8 +225,10 @@ class MRLTrainer:
|
|
219
225
|
)
|
220
226
|
|
221
227
|
critic_optimizer = torch.optim.AdamW(
|
222
|
-
|
223
|
-
|
228
|
+
[
|
229
|
+
{'params': self.critic.head_parameters(), 'lr': critic_lr},
|
230
|
+
{'params': self.critic.encoder_parameters(), 'lr': critic_encoder_lr},
|
231
|
+
],
|
224
232
|
weight_decay=critic_weight_decay,
|
225
233
|
)
|
226
234
|
|
@@ -633,7 +641,8 @@ class MRLTrainer:
|
|
633
641
|
for i, t in enumerate(episode['steps'])
|
634
642
|
]
|
635
643
|
values = torch.stack([
|
636
|
-
self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
|
644
|
+
self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
|
645
|
+
flat_trajectories
|
637
646
|
]).to(self.device)
|
638
647
|
rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
|
639
648
|
dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
|
@@ -646,7 +655,8 @@ class MRLTrainer:
|
|
646
655
|
dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
|
647
656
|
return values, rewards, dones
|
648
657
|
|
649
|
-
def _critic_values_with_memory(self, reset_stm: bool,
|
658
|
+
def _critic_values_with_memory(self, reset_stm: bool,
|
659
|
+
*moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
|
650
660
|
# 1. Calculate critic values in memory aware version - reset/update STM before calculating values
|
651
661
|
with torch.no_grad():
|
652
662
|
# 2. Reset STM if it was reset in trajectory collection
|
@@ -933,6 +943,7 @@ class MRLTrainer:
|
|
933
943
|
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
934
944
|
'critic_weight_decay': config.get('critic_weight_decay',
|
935
945
|
self.base_optim_config['critic_weight_decay']),
|
946
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
936
947
|
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
937
948
|
}
|
938
949
|
else:
|
@@ -942,6 +953,7 @@ class MRLTrainer:
|
|
942
953
|
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
943
954
|
'critic_weight_decay': config.get('critic_weight_decay',
|
944
955
|
self.base_optim_config['critic_weight_decay']),
|
956
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
945
957
|
}
|
946
958
|
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
947
959
|
elif self.optim_config != self.base_optim_config:
|
@@ -10,7 +10,7 @@ from .ddp import distributed_mean
|
|
10
10
|
class RlAlgorithm(ABC):
|
11
11
|
def __init__(self):
|
12
12
|
super(RlAlgorithm, self).__init__()
|
13
|
-
self.
|
13
|
+
self.critic_loss_fn = nn.MSELoss()
|
14
14
|
|
15
15
|
@abstractmethod
|
16
16
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
@@ -22,7 +22,7 @@ class RlAlgorithm(ABC):
|
|
22
22
|
pass
|
23
23
|
|
24
24
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
25
|
-
return self.
|
25
|
+
return self.critic_loss_fn(values, ref_values)
|
26
26
|
|
27
27
|
|
28
28
|
class PPOConfig(TypedDict):
|
@@ -55,7 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
|
|
55
55
|
# Critic loss with clipped values
|
56
56
|
if self.clip_critic_values:
|
57
57
|
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
-
return self.
|
58
|
+
return self.critic_loss_fn(values, ref_values)
|
59
59
|
|
60
60
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
61
61
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -114,15 +114,17 @@ class PPOAlgorithm(RlAlgorithm):
|
|
114
114
|
next_value = last_value
|
115
115
|
next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
|
116
116
|
dones = dones.float()
|
117
|
-
for t in reversed(range(trajectory_len)):
|
118
|
-
# Check if next state is terminal
|
119
|
-
non_terminal = 1.0 - next_done
|
120
117
|
|
121
|
-
|
122
|
-
|
123
|
-
|
118
|
+
for t in reversed(range(trajectory_len)):
|
119
|
+
if t == trajectory_len - 1:
|
120
|
+
# For the last step, use the provided last_value
|
121
|
+
delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
|
122
|
+
else:
|
123
|
+
# For other steps, use the next value in the trajectory
|
124
|
+
delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
|
125
|
+
|
126
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
|
124
127
|
last_advantage = advantages[t]
|
125
|
-
next_value = values[t]
|
126
128
|
next_done = dones[t]
|
127
129
|
|
128
130
|
returns = advantages + values
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|