rxnn 0.2.44__tar.gz → 0.2.45__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.44 → rxnn-0.2.45}/PKG-INFO +1 -1
- {rxnn-0.2.44 → rxnn-0.2.45}/pyproject.toml +1 -1
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/mrl.py +14 -2
- {rxnn-0.2.44 → rxnn-0.2.45}/LICENSE +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/README.md +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.44 → rxnn-0.2.45}/src/rxnn/utils.py +0 -0
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
|
|
91
91
|
reset_stm: bool
|
92
92
|
steps: list[MrlTrajectoryStep]
|
93
93
|
|
94
|
+
OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
|
94
95
|
|
95
96
|
class MRLTrainer:
|
96
97
|
def __init__(
|
@@ -981,8 +982,19 @@ class MRLTrainer:
|
|
981
982
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
982
983
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
983
984
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
984
|
-
|
985
|
-
|
985
|
+
|
986
|
+
|
987
|
+
|
988
|
+
def has_param(field: OptimField) -> bool:
|
989
|
+
return field in config and config[field] is not None
|
990
|
+
|
991
|
+
optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
|
992
|
+
|
993
|
+
has_any_optim_param = any(
|
994
|
+
has_param(field) for field in optim_params
|
995
|
+
) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
|
996
|
+
|
997
|
+
if has_any_optim_param:
|
986
998
|
if config.get('separate_memory_lr', False):
|
987
999
|
self.optim_config = {
|
988
1000
|
'lr': config.get('lr', self.base_optim_config['lr']),
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|