metacontroller-pytorch 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +11 -8
- {metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.20.dist-info}/METADATA +13 -1
- metacontroller_pytorch-0.0.20.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.19.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.20.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.20.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -18,7 +18,7 @@ from einops.layers.torch import Rearrange
|
|
|
18
18
|
|
|
19
19
|
# external modules
|
|
20
20
|
|
|
21
|
-
from x_transformers import Decoder
|
|
21
|
+
from x_transformers import Encoder, Decoder
|
|
22
22
|
from x_mlps_pytorch import Feedforwards
|
|
23
23
|
from x_evolution import EvoStrategy
|
|
24
24
|
|
|
@@ -72,7 +72,11 @@ class MetaController(Module):
|
|
|
72
72
|
decoder_expansion_factor = 2.,
|
|
73
73
|
decoder_depth = 1,
|
|
74
74
|
hypernetwork_low_rank = 16,
|
|
75
|
-
assoc_scan_kwargs: dict = dict()
|
|
75
|
+
assoc_scan_kwargs: dict = dict(),
|
|
76
|
+
bidirectional_temporal_encoder_kwargs: dict = dict(
|
|
77
|
+
attn_dim_head = 32,
|
|
78
|
+
heads = 8
|
|
79
|
+
)
|
|
76
80
|
):
|
|
77
81
|
super().__init__()
|
|
78
82
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
@@ -81,9 +85,9 @@ class MetaController(Module):
|
|
|
81
85
|
|
|
82
86
|
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
83
87
|
|
|
84
|
-
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use
|
|
88
|
+
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use bidirectional attention as placeholder
|
|
85
89
|
|
|
86
|
-
self.
|
|
90
|
+
self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
|
|
87
91
|
|
|
88
92
|
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
89
93
|
self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
|
|
@@ -122,7 +126,7 @@ class MetaController(Module):
|
|
|
122
126
|
def discovery_parameters(self):
|
|
123
127
|
return [
|
|
124
128
|
*self.model_to_meta.parameters(),
|
|
125
|
-
*self.
|
|
129
|
+
*self.bidirectional_temporal_encoder.parameters(),
|
|
126
130
|
*self.emitter.parameters(),
|
|
127
131
|
*self.emitter_to_action_mean_log_var.parameters(),
|
|
128
132
|
*self.decoder.parameters(),
|
|
@@ -157,10 +161,9 @@ class MetaController(Module):
|
|
|
157
161
|
if discovery_phase:
|
|
158
162
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
159
163
|
|
|
160
|
-
|
|
161
|
-
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
164
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed)
|
|
162
165
|
|
|
163
|
-
proposed_action_hidden, _ = self.emitter(cat((
|
|
166
|
+
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
164
167
|
readout = self.emitter_to_action_mean_log_var
|
|
165
168
|
|
|
166
169
|
else: # else internal rl phase
|
{metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.20.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.20
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -78,3 +78,15 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
|
|
|
78
78
|
url = {https://api.semanticscholar.org/CorpusID:279464702}
|
|
79
79
|
}
|
|
80
80
|
```
|
|
81
|
+
|
|
82
|
+
```bibtex
|
|
83
|
+
@misc{fleuret2025freetransformer,
|
|
84
|
+
title = {The Free Transformer},
|
|
85
|
+
author = {François Fleuret},
|
|
86
|
+
year = {2025},
|
|
87
|
+
eprint = {2510.17558},
|
|
88
|
+
archivePrefix = {arXiv},
|
|
89
|
+
primaryClass = {cs.LG},
|
|
90
|
+
url = {https://arxiv.org/abs/2510.17558},
|
|
91
|
+
}
|
|
92
|
+
```
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=3QZrId9z8I6MMQ3GhEQ6Xb5LFRTFJq4EAU4JCvRmm-4,12368
|
|
3
|
+
metacontroller_pytorch-0.0.20.dist-info/METADATA,sha256=5t4rDJiJzbx7m9BNsTTgO5JOnavaX-3jv31HTGuLP6A,4034
|
|
4
|
+
metacontroller_pytorch-0.0.20.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.20.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=GTErzikqVd8XDY8pmDnY8t4uIjbGCUd1GZBJX13peo8,12339
|
|
3
|
-
metacontroller_pytorch-0.0.19.dist-info/METADATA,sha256=lX3L7J3CKoSyxvJniLdSJsCu0UMEbJTxQLEw6zzT7dY,3741
|
|
4
|
-
metacontroller_pytorch-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.19.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.20.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|