rxnn 0.2.50__tar.gz → 0.2.51__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.50 → rxnn-0.2.51}/PKG-INFO +1 -1
- {rxnn-0.2.50 → rxnn-0.2.51}/pyproject.toml +1 -1
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/models.py +7 -6
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/rl.py +1 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/LICENSE +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/README.md +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/mrl.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/utils.py +0 -0
@@ -208,17 +208,18 @@ class MrlActorModel(nn.Module):
|
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
211
|
-
def __init__(self, encoder: nn.Module, embed_dim: int,
|
212
|
-
out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
|
211
|
+
def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
|
213
212
|
super(MrlCriticModel, self).__init__(**kwargs)
|
214
213
|
self.encoder = encoder
|
215
214
|
self.value_head = nn.Sequential(
|
216
215
|
GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
|
217
216
|
nn.LayerNorm(embed_dim),
|
218
|
-
nn.Linear(embed_dim, 1)
|
219
|
-
get_activation_layer(out_activation)
|
217
|
+
nn.Linear(embed_dim, 1)
|
220
218
|
)
|
221
|
-
|
219
|
+
# Learnable scaling parameters
|
220
|
+
self.scale = nn.Parameter(torch.tensor(1.0))
|
221
|
+
self.shift = nn.Parameter(torch.tensor(0.0))
|
222
|
+
|
222
223
|
|
223
224
|
def head_parameters(self) -> Iterator[nn.Parameter]:
|
224
225
|
return self.value_head.parameters()
|
@@ -235,4 +236,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
235
236
|
else:
|
236
237
|
x = x.mean(dim=1)
|
237
238
|
|
238
|
-
return self.value_head(x) * self.
|
239
|
+
return self.value_head(x) * self.scale + self.shift
|
@@ -55,6 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
|
|
55
55
|
# Critic loss with clipped values
|
56
56
|
if self.clip_critic_values:
|
57
57
|
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
+
ref_values = torch.clamp(ref_values, -self.critic_value_clip, self.critic_value_clip)
|
58
59
|
return self.critic_loss_fn(values, ref_values)
|
59
60
|
|
60
61
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|