rxnn 0.2.42__py3-none-any.whl → 0.2.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -168,6 +168,9 @@ class MrlActorModel(nn.Module):
168
168
  self.decoder.not_memory_parameters()
169
169
  ))
170
170
 
171
+ def embedding_parameters(self) -> Iterator[nn.Parameter]:
172
+ return self.encoder.model.embedding.parameters()
173
+
171
174
  def unique_parameters(self, with_embedding: bool = True):
172
175
  if with_embedding:
173
176
  return list(set(
rxnn/training/mrl.py CHANGED
@@ -225,7 +225,7 @@ class MRLTrainer:
225
225
  ) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
226
226
  if memory_lr is not None:
227
227
  optimizer = torch.optim.AdamW([
228
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
228
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
229
229
  {'params': self.actor.not_memory_parameters(), 'lr': lr},
230
230
  {'params': self.actor.memory_parameters(), 'lr': memory_lr},
231
231
  ],
@@ -233,7 +233,7 @@ class MRLTrainer:
233
233
  )
234
234
  else:
235
235
  optimizer = torch.optim.AdamW([
236
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
236
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
237
237
  {'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
238
238
  ],
239
239
  weight_decay=weight_decay,
@@ -929,7 +929,7 @@ class MRLTrainer:
929
929
 
930
930
  if mode == 'update':
931
931
  params = [
932
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
932
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
933
933
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
934
934
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
935
935
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -938,7 +938,7 @@ class MRLTrainer:
938
938
  ]
939
939
  elif mode == 'fetch':
940
940
  params = [
941
- {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
941
+ {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
942
942
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
943
943
  {'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
944
944
  {'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
@@ -947,7 +947,7 @@ class MRLTrainer:
947
947
  ]
948
948
  elif mode == 'joint':
949
949
  params = [
950
- {'params': self.actor.encoder.embedding.parameters(), 'lr': unfreeze_lr},
950
+ {'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
951
951
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
952
952
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
953
953
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -956,7 +956,7 @@ class MRLTrainer:
956
956
  ]
957
957
  else:
958
958
  params = [
959
- {'params': self.actor.encoder.embedding.parameters(), 'lr': embedding_lr},
959
+ {'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
960
960
  {'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
961
961
  {'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
962
962
  {'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.42
3
+ Version: 0.2.43
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
16
16
  rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
- rxnn/training/models.py,sha256=4hDH-R9l1lNvBMW_CGG_QgmCVrkyG7Lyo40PPzvkovQ,8876
20
- rxnn/training/mrl.py,sha256=xHH-tcmvwmwV5wwiAa3DaXLuF5OipmVDDYxLL5wOYVM,59471
19
+ rxnn/training/models.py,sha256=7IcUsqMdQ3NldXAxHf9mlZw0rmCfSFlRK51nGe4qLAg,8996
20
+ rxnn/training/mrl.py,sha256=4_FYwfI71adrbmDLq9TaBTgpiU8lLAjuR5QYMd57WN4,59423
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
22
  rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.42.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.42.dist-info/METADATA,sha256=5ZND9je7xzC5qCXQmyFB0XKedtqe5gicSqZnRui1K0Q,25960
38
- rxnn-0.2.42.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.42.dist-info/RECORD,,
36
+ rxnn-0.2.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.43.dist-info/METADATA,sha256=5k76J0DHE9dpJAteFEPaJO8TYp-D8DrQpD26UeUerJw,25960
38
+ rxnn-0.2.43.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.43.dist-info/RECORD,,
File without changes
File without changes