rxnn 0.2.50__tar.gz → 0.2.51__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.50 → rxnn-0.2.51}/PKG-INFO +1 -1
  2. {rxnn-0.2.50 → rxnn-0.2.51}/pyproject.toml +1 -1
  3. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/models.py +7 -6
  4. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/rl.py +1 -0
  5. {rxnn-0.2.50 → rxnn-0.2.51}/LICENSE +0 -0
  6. {rxnn-0.2.50 → rxnn-0.2.51}/README.md +0 -0
  7. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/.DS_Store +0 -0
  8. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/__init__.py +0 -0
  9. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/__init__.py +0 -0
  10. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/attention.py +0 -0
  11. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/models.py +0 -0
  12. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/experimental/moe.py +0 -0
  13. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/__init__.py +0 -0
  14. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/attention.py +0 -0
  15. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/norm.py +0 -0
  16. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/memory/stm.py +0 -0
  17. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/rxt/__init__.py +0 -0
  18. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/rxt/models.py +0 -0
  19. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/mrl.py +0 -0
  26. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/reward.py +0 -0
  27. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.50 → rxnn-0.2.51}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.50
3
+ Version: 0.2.51
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.50"
7
+ version = "0.2.51"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -208,17 +208,18 @@ class MrlActorModel(nn.Module):
208
208
 
209
209
 
210
210
  class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
211
- def __init__(self, encoder: nn.Module, embed_dim: int,
212
- out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
211
+ def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
213
212
  super(MrlCriticModel, self).__init__(**kwargs)
214
213
  self.encoder = encoder
215
214
  self.value_head = nn.Sequential(
216
215
  GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
217
216
  nn.LayerNorm(embed_dim),
218
- nn.Linear(embed_dim, 1),
219
- get_activation_layer(out_activation)
217
+ nn.Linear(embed_dim, 1)
220
218
  )
221
- self.output_scale = output_scale
219
+ # Learnable scaling parameters
220
+ self.scale = nn.Parameter(torch.tensor(1.0))
221
+ self.shift = nn.Parameter(torch.tensor(0.0))
222
+
222
223
 
223
224
  def head_parameters(self) -> Iterator[nn.Parameter]:
224
225
  return self.value_head.parameters()
@@ -235,4 +236,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
235
236
  else:
236
237
  x = x.mean(dim=1)
237
238
 
238
- return self.value_head(x) * self.output_scale
239
+ return self.value_head(x) * self.scale + self.shift
@@ -55,6 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
55
55
  # Critic loss with clipped values
56
56
  if self.clip_critic_values:
57
57
  values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
+ ref_values = torch.clamp(ref_values, -self.critic_value_clip, self.critic_value_clip)
58
59
  return self.critic_loss_fn(values, ref_values)
59
60
 
60
61
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes