rxnn 0.2.43__tar.gz → 0.2.45__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.43 → rxnn-0.2.45}/PKG-INFO +1 -1
- {rxnn-0.2.43 → rxnn-0.2.45}/pyproject.toml +1 -1
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/stm.py +3 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/rxt/models.py +3 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/models.py +3 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/mrl.py +16 -2
- {rxnn-0.2.43 → rxnn-0.2.45}/LICENSE +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/README.md +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/utils.py +0 -0
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
|
|
62
62
|
def reset(self, init_type: str = None):
|
63
63
|
self.memory = self._init_tensor(init_type).to(self.memory.device)
|
64
64
|
|
65
|
+
def clone_detach_reset(self):
|
66
|
+
self.memory = self.memory.detach().clone()
|
67
|
+
|
65
68
|
def resize(self, new_stm_size: int, init_type: str = None):
|
66
69
|
self.stm_size = new_stm_size
|
67
70
|
delattr(self, 'memory')
|
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
301
301
|
def reset_memory(self, init_type: str = None):
|
302
302
|
self.model.stm.reset(init_type)
|
303
303
|
|
304
|
+
def clone_reset_memory(self):
|
305
|
+
self.model.stm.clone_detach_reset()
|
306
|
+
|
304
307
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
305
308
|
return self.model(x, attention_mask=attention_mask)
|
306
309
|
|
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
|
|
146
146
|
def reset_memory(self):
|
147
147
|
self.memory_attention.reset_memory()
|
148
148
|
|
149
|
+
def clone_reset_memory(self):
|
150
|
+
self.memory_attention.clone_reset_memory()
|
151
|
+
|
149
152
|
def memory_parameters(self) -> list[nn.Parameter]:
|
150
153
|
return list(set(
|
151
154
|
self.encoder.memory_parameters() +
|
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
|
|
91
91
|
reset_stm: bool
|
92
92
|
steps: list[MrlTrajectoryStep]
|
93
93
|
|
94
|
+
OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
|
94
95
|
|
95
96
|
class MRLTrainer:
|
96
97
|
def __init__(
|
@@ -646,6 +647,8 @@ class MRLTrainer:
|
|
646
647
|
step_critic_values = episode_critic_values[step_idx]
|
647
648
|
step_advantages = episode_advantages[step_idx]
|
648
649
|
|
650
|
+
self.actor.clone_reset_memory()
|
651
|
+
|
649
652
|
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
650
653
|
if self.memory_aware_critic:
|
651
654
|
self.encode_and_update_stm(query, answer)
|
@@ -979,8 +982,19 @@ class MRLTrainer:
|
|
979
982
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
980
983
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
981
984
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
982
|
-
|
983
|
-
|
985
|
+
|
986
|
+
|
987
|
+
|
988
|
+
def has_param(field: OptimField) -> bool:
|
989
|
+
return field in config and config[field] is not None
|
990
|
+
|
991
|
+
optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
|
992
|
+
|
993
|
+
has_any_optim_param = any(
|
994
|
+
has_param(field) for field in optim_params
|
995
|
+
) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
|
996
|
+
|
997
|
+
if has_any_optim_param:
|
984
998
|
if config.get('separate_memory_lr', False):
|
985
999
|
self.optim_config = {
|
986
1000
|
'lr': config.get('lr', self.base_optim_config['lr']),
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|