rxnn 0.2.44__tar.gz → 0.2.45__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.44 → rxnn-0.2.45}/PKG-INFO +1 -1
  2. {rxnn-0.2.44 → rxnn-0.2.45}/pyproject.toml +1 -1
  3. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/mrl.py +14 -2
  4. {rxnn-0.2.44 → rxnn-0.2.45}/LICENSE +0 -0
  5. {rxnn-0.2.44 → rxnn-0.2.45}/README.md +0 -0
  6. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/.DS_Store +0 -0
  7. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/attention.py +0 -0
  10. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/models.py +0 -0
  11. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/moe.py +0 -0
  12. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/__init__.py +0 -0
  13. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/attention.py +0 -0
  14. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/norm.py +0 -0
  15. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/stm.py +0 -0
  16. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/rxt/__init__.py +0 -0
  17. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/rxt/models.py +0 -0
  18. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.44
3
+ Version: 0.2.45
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.44"
7
+ version = "0.2.45"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
91
91
  reset_stm: bool
92
92
  steps: list[MrlTrajectoryStep]
93
93
 
94
+ OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
94
95
 
95
96
  class MRLTrainer:
96
97
  def __init__(
@@ -981,8 +982,19 @@ class MRLTrainer:
981
982
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
982
983
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
983
984
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
984
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
985
- 'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
985
+
986
+
987
+
988
+ def has_param(field: OptimField) -> bool:
989
+ return field in config and config[field] is not None
990
+
991
+ optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
992
+
993
+ has_any_optim_param = any(
994
+ has_param(field) for field in optim_params
995
+ ) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
996
+
997
+ if has_any_optim_param:
986
998
  if config.get('separate_memory_lr', False):
987
999
  self.optim_config = {
988
1000
  'lr': config.get('lr', self.base_optim_config['lr']),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes