rxnn 0.2.42__tar.gz → 0.2.44__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.42 → rxnn-0.2.44}/PKG-INFO +1 -1
- {rxnn-0.2.42 → rxnn-0.2.44}/pyproject.toml +1 -1
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/stm.py +3 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/rxt/models.py +3 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/models.py +6 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/mrl.py +8 -6
- {rxnn-0.2.42 → rxnn-0.2.44}/LICENSE +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/README.md +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.42 → rxnn-0.2.44}/src/rxnn/utils.py +0 -0
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
|
|
62
62
|
def reset(self, init_type: str = None):
|
63
63
|
self.memory = self._init_tensor(init_type).to(self.memory.device)
|
64
64
|
|
65
|
+
def clone_detach_reset(self):
|
66
|
+
self.memory = self.memory.detach().clone()
|
67
|
+
|
65
68
|
def resize(self, new_stm_size: int, init_type: str = None):
|
66
69
|
self.stm_size = new_stm_size
|
67
70
|
delattr(self, 'memory')
|
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
301
301
|
def reset_memory(self, init_type: str = None):
|
302
302
|
self.model.stm.reset(init_type)
|
303
303
|
|
304
|
+
def clone_reset_memory(self):
|
305
|
+
self.model.stm.clone_detach_reset()
|
306
|
+
|
304
307
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
305
308
|
return self.model(x, attention_mask=attention_mask)
|
306
309
|
|
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
|
|
146
146
|
def reset_memory(self):
|
147
147
|
self.memory_attention.reset_memory()
|
148
148
|
|
149
|
+
def clone_reset_memory(self):
|
150
|
+
self.memory_attention.clone_reset_memory()
|
151
|
+
|
149
152
|
def memory_parameters(self) -> list[nn.Parameter]:
|
150
153
|
return list(set(
|
151
154
|
self.encoder.memory_parameters() +
|
@@ -168,6 +171,9 @@ class MrlActorModel(nn.Module):
|
|
168
171
|
self.decoder.not_memory_parameters()
|
169
172
|
))
|
170
173
|
|
174
|
+
def embedding_parameters(self) -> Iterator[nn.Parameter]:
|
175
|
+
return self.encoder.model.embedding.parameters()
|
176
|
+
|
171
177
|
def unique_parameters(self, with_embedding: bool = True):
|
172
178
|
if with_embedding:
|
173
179
|
return list(set(
|
@@ -225,7 +225,7 @@ class MRLTrainer:
|
|
225
225
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
226
226
|
if memory_lr is not None:
|
227
227
|
optimizer = torch.optim.AdamW([
|
228
|
-
{'params': self.actor.
|
228
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
229
229
|
{'params': self.actor.not_memory_parameters(), 'lr': lr},
|
230
230
|
{'params': self.actor.memory_parameters(), 'lr': memory_lr},
|
231
231
|
],
|
@@ -233,7 +233,7 @@ class MRLTrainer:
|
|
233
233
|
)
|
234
234
|
else:
|
235
235
|
optimizer = torch.optim.AdamW([
|
236
|
-
{'params': self.actor.
|
236
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
237
237
|
{'params': self.actor.unique_parameters(with_embedding=False), 'lr': lr},
|
238
238
|
],
|
239
239
|
weight_decay=weight_decay,
|
@@ -646,6 +646,8 @@ class MRLTrainer:
|
|
646
646
|
step_critic_values = episode_critic_values[step_idx]
|
647
647
|
step_advantages = episode_advantages[step_idx]
|
648
648
|
|
649
|
+
self.actor.clone_reset_memory()
|
650
|
+
|
649
651
|
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
650
652
|
if self.memory_aware_critic:
|
651
653
|
self.encode_and_update_stm(query, answer)
|
@@ -929,7 +931,7 @@ class MRLTrainer:
|
|
929
931
|
|
930
932
|
if mode == 'update':
|
931
933
|
params = [
|
932
|
-
{'params': self.actor.
|
934
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
933
935
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
934
936
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
935
937
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -938,7 +940,7 @@ class MRLTrainer:
|
|
938
940
|
]
|
939
941
|
elif mode == 'fetch':
|
940
942
|
params = [
|
941
|
-
{'params': self.actor.
|
943
|
+
{'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
|
942
944
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
943
945
|
{'params': self.actor.encoder.memory_parameters(), 'lr': unfreeze_lr},
|
944
946
|
{'params': self.actor.memory_attention_parameters(), 'lr': unfreeze_lr},
|
@@ -947,7 +949,7 @@ class MRLTrainer:
|
|
947
949
|
]
|
948
950
|
elif mode == 'joint':
|
949
951
|
params = [
|
950
|
-
{'params': self.actor.
|
952
|
+
{'params': self.actor.embedding_parameters(), 'lr': unfreeze_lr},
|
951
953
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': unfreeze_lr},
|
952
954
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
953
955
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
@@ -956,7 +958,7 @@ class MRLTrainer:
|
|
956
958
|
]
|
957
959
|
else:
|
958
960
|
params = [
|
959
|
-
{'params': self.actor.
|
961
|
+
{'params': self.actor.embedding_parameters(), 'lr': embedding_lr},
|
960
962
|
{'params': self.actor.encoder.not_memory_parameters(), 'lr': model_lr},
|
961
963
|
{'params': self.actor.encoder.memory_parameters(), 'lr': memory_lr},
|
962
964
|
{'params': self.actor.memory_attention_parameters(), 'lr': memory_lr},
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|