rxnn 0.2.43__tar.gz → 0.2.45__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.43 → rxnn-0.2.45}/PKG-INFO +1 -1
  2. {rxnn-0.2.43 → rxnn-0.2.45}/pyproject.toml +1 -1
  3. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/stm.py +3 -0
  4. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/rxt/models.py +3 -0
  5. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/models.py +3 -0
  6. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/mrl.py +16 -2
  7. {rxnn-0.2.43 → rxnn-0.2.45}/LICENSE +0 -0
  8. {rxnn-0.2.43 → rxnn-0.2.45}/README.md +0 -0
  9. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/attention.py +0 -0
  17. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/memory/norm.py +0 -0
  18. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/__init__.py +0 -0
  20. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/base.py +0 -0
  21. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/bml.py +0 -0
  22. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/callbacks.py +0 -0
  23. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/dataset.py +0 -0
  24. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/ddp.py +0 -0
  25. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.43 → rxnn-0.2.45}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.43
3
+ Version: 0.2.45
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.43"
7
+ version = "0.2.45"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
62
62
  def reset(self, init_type: str = None):
63
63
  self.memory = self._init_tensor(init_type).to(self.memory.device)
64
64
 
65
+ def clone_detach_reset(self):
66
+ self.memory = self.memory.detach().clone()
67
+
65
68
  def resize(self, new_stm_size: int, init_type: str = None):
66
69
  self.stm_size = new_stm_size
67
70
  delattr(self, 'memory')
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
301
301
  def reset_memory(self, init_type: str = None):
302
302
  self.model.stm.reset(init_type)
303
303
 
304
+ def clone_reset_memory(self):
305
+ self.model.stm.clone_detach_reset()
306
+
304
307
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
305
308
  return self.model(x, attention_mask=attention_mask)
306
309
 
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
146
146
  def reset_memory(self):
147
147
  self.memory_attention.reset_memory()
148
148
 
149
+ def clone_reset_memory(self):
150
+ self.memory_attention.clone_reset_memory()
151
+
149
152
  def memory_parameters(self) -> list[nn.Parameter]:
150
153
  return list(set(
151
154
  self.encoder.memory_parameters() +
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
91
91
  reset_stm: bool
92
92
  steps: list[MrlTrajectoryStep]
93
93
 
94
+ OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
94
95
 
95
96
  class MRLTrainer:
96
97
  def __init__(
@@ -646,6 +647,8 @@ class MRLTrainer:
646
647
  step_critic_values = episode_critic_values[step_idx]
647
648
  step_advantages = episode_advantages[step_idx]
648
649
 
650
+ self.actor.clone_reset_memory()
651
+
649
652
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
650
653
  if self.memory_aware_critic:
651
654
  self.encode_and_update_stm(query, answer)
@@ -979,8 +982,19 @@ class MRLTrainer:
979
982
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
980
983
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
981
984
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
982
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
983
- 'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
985
+
986
+
987
+
988
+ def has_param(field: OptimField) -> bool:
989
+ return field in config and config[field] is not None
990
+
991
+ optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
992
+
993
+ has_any_optim_param = any(
994
+ has_param(field) for field in optim_params
995
+ ) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
996
+
997
+ if has_any_optim_param:
984
998
  if config.get('separate_memory_lr', False):
985
999
  self.optim_config = {
986
1000
  'lr': config.get('lr', self.base_optim_config['lr']),
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes