rxnn 0.2.43__py3-none-any.whl → 0.2.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +3 -0
- rxnn/rxt/models.py +3 -0
- rxnn/training/models.py +3 -0
- rxnn/training/mrl.py +2 -0
- {rxnn-0.2.43.dist-info → rxnn-0.2.44.dist-info}/METADATA +1 -1
- {rxnn-0.2.43.dist-info → rxnn-0.2.44.dist-info}/RECORD +8 -8
- {rxnn-0.2.43.dist-info → rxnn-0.2.44.dist-info}/LICENSE +0 -0
- {rxnn-0.2.43.dist-info → rxnn-0.2.44.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
@@ -62,6 +62,9 @@ class ShortTermMemory(nn.Module):
|
|
62
62
|
def reset(self, init_type: str = None):
|
63
63
|
self.memory = self._init_tensor(init_type).to(self.memory.device)
|
64
64
|
|
65
|
+
def clone_detach_reset(self):
|
66
|
+
self.memory = self.memory.detach().clone()
|
67
|
+
|
65
68
|
def resize(self, new_stm_size: int, init_type: str = None):
|
66
69
|
self.stm_size = new_stm_size
|
67
70
|
delattr(self, 'memory')
|
rxnn/rxt/models.py
CHANGED
@@ -301,6 +301,9 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
|
|
301
301
|
def reset_memory(self, init_type: str = None):
|
302
302
|
self.model.stm.reset(init_type)
|
303
303
|
|
304
|
+
def clone_reset_memory(self):
|
305
|
+
self.model.stm.clone_detach_reset()
|
306
|
+
|
304
307
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
305
308
|
return self.model(x, attention_mask=attention_mask)
|
306
309
|
|
rxnn/training/models.py
CHANGED
@@ -146,6 +146,9 @@ class MrlActorModel(nn.Module):
|
|
146
146
|
def reset_memory(self):
|
147
147
|
self.memory_attention.reset_memory()
|
148
148
|
|
149
|
+
def clone_reset_memory(self):
|
150
|
+
self.memory_attention.clone_reset_memory()
|
151
|
+
|
149
152
|
def memory_parameters(self) -> list[nn.Parameter]:
|
150
153
|
return list(set(
|
151
154
|
self.encoder.memory_parameters() +
|
rxnn/training/mrl.py
CHANGED
@@ -646,6 +646,8 @@ class MRLTrainer:
|
|
646
646
|
step_critic_values = episode_critic_values[step_idx]
|
647
647
|
step_advantages = episode_advantages[step_idx]
|
648
648
|
|
649
|
+
self.actor.clone_reset_memory()
|
650
|
+
|
649
651
|
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
650
652
|
if self.memory_aware_critic:
|
651
653
|
self.encode_and_update_stm(query, answer)
|
@@ -7,17 +7,17 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=POszZeW0QBKOh4VTDVekmZGKKwUr1Zj0FOAilTv8Vyg,2411
|
9
9
|
rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/rxt/models.py,sha256=
|
12
|
+
rxnn/rxt/models.py,sha256=jh7TNLu_7CL0PH_T99rMZHcLezFPiZi-xnPazNyn_dU,14563
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
|
15
15
|
rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=tqABOt_xEcWbZNEW2I2Jt-3eyaGICK011zILwuTk6Zc,9082
|
20
|
+
rxnn/training/mrl.py,sha256=BvrwqrIIyg_EmUA5p7c6UBcfFQ0ePIcl-EHEFQqyl2E,59472
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
22
|
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.44.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.44.dist-info/METADATA,sha256=tW2Ve4whRK2LfCxix10dTLS5Dl_0C6KhcK8FsoKq-x0,25960
|
38
|
+
rxnn-0.2.44.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.44.dist-info/RECORD,,
|
File without changes
|
File without changes
|