rxnn 0.2.42__tar.gz → 0.2.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.42 → rxnn-0.2.43}/PKG-INFO +1 -1
- {rxnn-0.2.42 → rxnn-0.2.43}/pyproject.toml +1 -1
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/models.py +3 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/mrl.py +6 -6
- {rxnn-0.2.42 → rxnn-0.2.43}/LICENSE +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/README.md +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.43}/src/rxnn/utils.py +0 -0
@@ -168,6 +168,9 @@ class MrlActorModel(nn.Module):
|
|
168
168
|
self.decoder.not_memory_parameters()
|
169
169
|
))
|
170
170
|
|
171
|
+
def embedding_parameters(self) -> Iterator[nn.Parameter]:
|
172
|
+
return self.encoder.model.embedding.parameters()
|
173
|
+
|
171
174
|
def unique_parameters(self, with_embedding: bool = True):
|
172
175
|
if with_embedding:
|
173
176
|
return list(set(
|
@@ -225,7 +225,7 @@ class MRLTrainer:
|
|
225
225
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
226
226
|
if memory_lr is not None:
|
227
227
|
optimizer = torch.optim.AdamW([
|
228
|
-
{'params': self.actor.
|
228
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
229
229
|
{'params': self.actor.not_memory_parameters(), 'lr': lr},
|
230
230
|
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
231
231
|
],
|
@@ -233,7 +233,7 @@ class MRLTrainer:
|
|
233
233
|
)
|
234
234
|
else:
|
235
235
|
optimizer = torch.optim.AdamW([
|
236
|
-
{'params': self.actor.
|
236
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
237
237
|
{'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
|
238
238
|
],
|
239
239
|
weight_decay=weight_decay,
|
@@ -929,7 +929,7 @@ class MRLTrainer:
|
|
929
929
|
|
930
930
|
if mode == 'update':
|
931
931
|
params = [
|
932
|
-
{'params': self.actor.
|
932
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
933
933
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
934
934
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
935
935
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -938,7 +938,7 @@ class MRLTrainer:
|
|
938
938
|
]
|
939
939
|
elif mode == 'fetch':
|
940
940
|
params = [
|
941
|
-
{'params': self.actor.
|
941
|
+
{'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
|
942
942
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
943
943
|
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
944
944
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
@@ -947,7 +947,7 @@ class MRLTrainer:
|
|
947
947
|
]
|
948
948
|
elif mode == 'joint':
|
949
949
|
params = [
|
950
|
-
{'params': self.actor.
|
950
|
+
{'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
|
951
951
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
952
952
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
953
953
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -956,7 +956,7 @@ class MRLTrainer:
|
|
956
956
|
]
|
957
957
|
else:
|
958
958
|
params = [
|
959
|
-
{'params': self.actor.
|
959
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
960
960
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
961
961
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
962
962
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|