rxnn 0.2.41__py3-none-any.whl → 0.2.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/models.py +3 -0
- rxnn/training/mrl.py +8 -8
- {rxnn-0.2.41.dist-info → rxnn-0.2.43.dist-info}/METADATA +1 -1
- {rxnn-0.2.41.dist-info → rxnn-0.2.43.dist-info}/RECORD +6 -6
- {rxnn-0.2.41.dist-info → rxnn-0.2.43.dist-info}/LICENSE +0 -0
- {rxnn-0.2.41.dist-info → rxnn-0.2.43.dist-info}/WHEEL +0 -0
rxnn/training/models.py
CHANGED
@@ -168,6 +168,9 @@ class MrlActorModel(nn.Module):
|
|
168
168
|
self.decoder.not_memory_parameters()
|
169
169
|
))
|
170
170
|
|
171
|
+
def embedding_parameters(self) -> Iterator[nn.Parameter]:
|
172
|
+
return self.encoder.model.embedding.parameters()
|
173
|
+
|
171
174
|
def unique_parameters(self, with_embedding: bool = True):
|
172
175
|
if with_embedding:
|
173
176
|
return list(set(
|
rxnn/training/mrl.py
CHANGED
@@ -225,7 +225,7 @@ class MRLTrainer:
|
|
225
225
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
226
226
|
if memory_lr is not None:
|
227
227
|
optimizer = torch.optim.AdamW([
|
228
|
-
{'params': self.actor.
|
228
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
229
229
|
{'params': self.actor.not_memory_parameters(), 'lr': lr},
|
230
230
|
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
231
231
|
],
|
@@ -233,7 +233,7 @@ class MRLTrainer:
|
|
233
233
|
)
|
234
234
|
else:
|
235
235
|
optimizer = torch.optim.AdamW([
|
236
|
-
{'params': self.actor.
|
236
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
237
237
|
{'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
|
238
238
|
],
|
239
239
|
weight_decay=weight_decay,
|
@@ -929,7 +929,7 @@ class MRLTrainer:
|
|
929
929
|
|
930
930
|
if mode == 'update':
|
931
931
|
params = [
|
932
|
-
{'params': self.actor.
|
932
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
933
933
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
934
934
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
935
935
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -938,7 +938,7 @@ class MRLTrainer:
|
|
938
938
|
]
|
939
939
|
elif mode == 'fetch':
|
940
940
|
params = [
|
941
|
-
{'params': self.actor.
|
941
|
+
{'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
|
942
942
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
943
943
|
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
944
944
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
@@ -947,7 +947,7 @@ class MRLTrainer:
|
|
947
947
|
]
|
948
948
|
elif mode == 'joint':
|
949
949
|
params = [
|
950
|
-
{'params': self.actor.
|
950
|
+
{'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
|
951
951
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
952
952
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
953
953
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -956,7 +956,7 @@ class MRLTrainer:
|
|
956
956
|
]
|
957
957
|
else:
|
958
958
|
params = [
|
959
|
-
{'params': self.actor.
|
959
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
960
960
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
961
961
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
962
962
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -1022,7 +1022,7 @@ class MRLTrainer:
|
|
1022
1022
|
|
1023
1023
|
return (epochs, unfreeze_epoch), (random_resets, random_resets_from, random_resets_ratio)
|
1024
1024
|
|
1025
|
-
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int):
|
1025
|
+
def __call__(self, curriculum_config: list[CurriculumConfig], batch_size: int, ddp_find_unused_parameters: bool = False):
|
1026
1026
|
"""Start Memory Reinforcement Learning Curriculum."""
|
1027
1027
|
|
1028
1028
|
# 0. Set global epoch count for all stages
|
@@ -1033,7 +1033,7 @@ class MRLTrainer:
|
|
1033
1033
|
if self.use_ddp:
|
1034
1034
|
rank, world_size = get_os_ddp_config()
|
1035
1035
|
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
|
1036
|
-
self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index])
|
1036
|
+
self.actor = DistributedDataParallel(self.actor, device_ids=[self.device.index], find_unused_parameters=ddp_find_unused_parameters)
|
1037
1037
|
self.critic = DistributedDataParallel(self.critic, device_ids=[self.device.index])
|
1038
1038
|
|
1039
1039
|
# 2. Init BatchSampler with actor model (we have to run it after DDP init)
|
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=7IcUsqMdQ3NldXAxHf9mlZw0rmCfSFlRK51nGe4qLAg,8996
|
20
|
+
rxnn/training/mrl.py,sha256=4_FYwfI71adrbmDLq9TaBTgpiU8lLAjuR5QYMd57WN4,59423
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
22
|
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.43.dist-info/METADATA,sha256=5k76J0DHE9dpJAteFEPaJO8TYp-D8DrQpD26UeUerJw,25960
|
38
|
+
rxnn-0.2.43.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.43.dist-info/RECORD,,
|
File without changes
|
File without changes
|