rxnn 0.2.43__py3-none-any.whl → 0.2.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/stm.py CHANGED
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
62
62
  def reset(self, init_type: str = None):
63
63
  self.memory = self._init_tensor(init_type).to(self.memory.device)
64
64
 
65
+ def clone_detach_reset(self):
66
+ self.memory = self.memory.detach().clone()
67
+
65
68
  def resize(self, new_stm_size: int, init_type: str = None):
66
69
  self.stm_size = new_stm_size
67
70
  delattr(self, 'memory')
rxnn/rxt/models.py CHANGED
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
301
301
  def reset_memory(self, init_type: str = None):
302
302
  self.model.stm.reset(init_type)
303
303
 
304
+ def clone_reset_memory(self):
305
+ self.model.stm.clone_detach_reset()
306
+
304
307
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
305
308
  return self.model(x, attention_mask=attention_mask)
306
309
 
rxnn/training/models.py CHANGED
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
146
146
  def reset_memory(self):
147
147
  self.memory_attention.reset_memory()
148
148
 
149
+ def clone_reset_memory(self):
150
+ self.memory_attention.clone_reset_memory()
151
+
149
152
  def memory_parameters(self) -> list[nn.Parameter]:
150
153
  return list(set(
151
154
  self.encoder.memory_parameters() +
rxnn/training/mrl.py CHANGED
@@ -91,6 +91,7 @@ class MrlTrajectoryEpisode(TypedDict):
91
91
  reset_stm: bool
92
92
  steps: list[MrlTrajectoryStep]
93
93
 
94
+ OptimField: TypeAlias = Literal['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay', 'separate_memory_lr', 'memory_lr']
94
95
 
95
96
  class MRLTrainer:
96
97
  def __init__(
@@ -646,6 +647,8 @@ class MRLTrainer:
646
647
  step_critic_values = episode_critic_values[step_idx]
647
648
  step_advantages = episode_advantages[step_idx]
648
649
 
650
+ self.actor.clone_reset_memory()
651
+
649
652
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
650
653
  if self.memory_aware_critic:
651
654
  self.encode_and_update_stm(query, answer)
@@ -979,8 +982,19 @@ class MRLTrainer:
979
982
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
980
983
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
981
984
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
982
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config[
983
- 'critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
985
+
986
+
987
+
988
+ def has_param(field: OptimField) -> bool:
989
+ return field in config and config[field] is not None
990
+
991
+ optim_params: list[OptimField] = ['lr', 'critic_lr', 'weight_decay', 'critic_weight_decay']
992
+
993
+ has_any_optim_param = any(
994
+ has_param(field) for field in optim_params
995
+ ) or (has_param('separate_memory_lr') and config['separate_memory_lr'] and has_param('memory_lr'))
996
+
997
+ if has_any_optim_param:
984
998
  if config.get('separate_memory_lr', False):
985
999
  self.optim_config = {
986
1000
  'lr': config.get('lr', self.base_optim_config['lr']),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.43
3
+ Version: 0.2.45
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -7,17 +7,17 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=POszZeW0QBKOh4VTDVekmZGKKwUr1Zj0FOAilTv8Vyg,2411
9
9
  rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
10
- rxnn/memory/stm.py,sha256=SSfc-RL9FE-RLkmOEkLB-9Rb00ZXbMLbsAEPdpIW89o,3851
10
+ rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=lRn7NRIAAeCxr8hoIXanhaD-cGwVwA23hBdIQpBK6kc,14484
12
+ rxnn/rxt/models.py,sha256=jh7TNLu_7CL0PH_T99rMZHcLezFPiZi-xnPazNyn_dU,14563
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
15
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=7IcUsqMdQ3NldXAxHf9mlZw0rmCfSFlRK51nGe4qLAg,8996
20
- rxnn/training/mrl.py,sha256=4_FYwfI71adrbmDLq9TaBTgpiU8lLAjuR5QYMd57WN4,59423
19
+ rxnn/training/models.py,sha256=tqABOt_xEcWbZNEW2I2Jt-3eyaGICK011zILwuTk6Zc,9082
20
+ rxnn/training/mrl.py,sha256=L4G7xSPlxsymvNhvsSloCpaqYjOXxEm7GmKilM_Ojvc,59809
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.43.dist-info/METADATA,sha256=5k76J0DHE9dpJAteFEPaJO8TYp-D8DrQpD26UeUerJw,25960
38
- rxnn-0.2.43.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.43.dist-info/RECORD,,
36
+ rxnn-0.2.45.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.45.dist-info/METADATA,sha256=g8IqXAR2tXEyaNQOs--IPFFtSOnrWe4oouPK1PQBITw,25960
38
+ rxnn-0.2.45.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.45.dist-info/RECORD,,
File without changes
File without changes