metacontroller-pytorch 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from metacontroller.metacontroller import MetaController
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, cat, stack, tensor
|
|
6
|
+
from torch.nn import Module, GRU, Linear, Identity
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
# einops
|
|
10
|
+
|
|
11
|
+
import einx
|
|
12
|
+
from einops import einsum, rearrange, repeat, reduce
|
|
13
|
+
from einops.layers.torch import Rearrange
|
|
14
|
+
|
|
15
|
+
# external modules
|
|
16
|
+
|
|
17
|
+
from x_transformers import Decoder
|
|
18
|
+
from x_mlps_pytorch import Feedforwards
|
|
19
|
+
from x_evolution import EvoStrategy
|
|
20
|
+
|
|
21
|
+
from discrete_continuous_embed_readout import Embed, Readout
|
|
22
|
+
|
|
23
|
+
from assoc_scan import AssocScan
|
|
24
|
+
|
|
25
|
+
# constants
|
|
26
|
+
|
|
27
|
+
LinearNoBias = partial(Linear, bias = False)
|
|
28
|
+
|
|
29
|
+
GRU = partial(GRU, batch_first = True)
|
|
30
|
+
|
|
31
|
+
# helper functions
|
|
32
|
+
|
|
33
|
+
def exists(v):
|
|
34
|
+
return v is not None
|
|
35
|
+
|
|
36
|
+
def identity(t):
|
|
37
|
+
return t
|
|
38
|
+
|
|
39
|
+
def default(*args):
|
|
40
|
+
for arg in args:
|
|
41
|
+
if exists(arg):
|
|
42
|
+
return arg
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
# meta controller
|
|
46
|
+
|
|
47
|
+
class MetaController(Module):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
dim_latent,
|
|
51
|
+
*,
|
|
52
|
+
decoder_expansion_factor = 2.,
|
|
53
|
+
decoder_depth = 1,
|
|
54
|
+
hypernetwork_low_rank = 16,
|
|
55
|
+
assoc_scan_kwargs: dict = dict()
|
|
56
|
+
):
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
|
|
60
|
+
|
|
61
|
+
self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
|
|
62
|
+
|
|
63
|
+
self.emitter = GRU(dim_latent * 2, dim_latent * 2)
|
|
64
|
+
self.emitter_to_action_mean_log_var = LinearNoBias(dim_latent * 2, dim_latent * 2)
|
|
65
|
+
|
|
66
|
+
# internal rl phase substitutes the acausal + emitter with a causal ssm
|
|
67
|
+
|
|
68
|
+
self.action_proposer = GRU(dim_latent, dim_latent)
|
|
69
|
+
self.action_proposer_mean_log_var = LinearNoBias(dim_latent, dim_latent * 2)
|
|
70
|
+
|
|
71
|
+
# switching unit
|
|
72
|
+
|
|
73
|
+
self.switching_unit = GRU(dim_latent, dim_latent)
|
|
74
|
+
self.to_switching_unit_beta = nn.Linear(dim_latent, 1, bias = False)
|
|
75
|
+
|
|
76
|
+
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
77
|
+
|
|
78
|
+
# decoder
|
|
79
|
+
|
|
80
|
+
assert hypernetwork_low_rank < dim_latent
|
|
81
|
+
|
|
82
|
+
dim_decoder_hidden = int(dim_latent * decoder_expansion_factor)
|
|
83
|
+
|
|
84
|
+
self.decoder = Feedforwards(
|
|
85
|
+
dim_in = dim_latent,
|
|
86
|
+
dim = dim_decoder_hidden,
|
|
87
|
+
depth = decoder_depth,
|
|
88
|
+
dim_out = 2 * hypernetwork_low_rank * dim_latent
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
92
|
+
|
|
93
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
94
|
+
|
|
95
|
+
def discovery_parameters(self):
|
|
96
|
+
return [
|
|
97
|
+
*self.bidirectional_temporal_compressor.parameters(),
|
|
98
|
+
*self.emitter.parameters(),
|
|
99
|
+
*self.emitter_to_action_mean_log_var.parameters()
|
|
100
|
+
*self.decoder.parameters(),
|
|
101
|
+
*self.switch_gating
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
def internal_rl_parameters(self):
|
|
105
|
+
return [
|
|
106
|
+
*self.action_proposer.parameters(),
|
|
107
|
+
*self.action_proposer_mean_log_var.parameters(),
|
|
108
|
+
*self.decoder.parameters(),
|
|
109
|
+
*self.switch_gating
|
|
110
|
+
]
|
|
111
|
+
|
|
112
|
+
def forward(
|
|
113
|
+
self,
|
|
114
|
+
residual_stream,
|
|
115
|
+
discovery_phase = False
|
|
116
|
+
):
|
|
117
|
+
|
|
118
|
+
if discovery_phase:
|
|
119
|
+
temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
|
|
120
|
+
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
121
|
+
|
|
122
|
+
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
|
|
123
|
+
proposed_action = self.emitter_to_action_mean_log_var(proposed_action_hidden)
|
|
124
|
+
|
|
125
|
+
else: # else internal rl phase
|
|
126
|
+
proposed_action_hidden, _ = self.action_proposer(residual_stream)
|
|
127
|
+
proposed_action = self.action_proposer_mean_log_var(proposed_action_hidden)
|
|
128
|
+
|
|
129
|
+
# sample from the gaussian as the action from the meta controller
|
|
130
|
+
|
|
131
|
+
mean, log_var = proposed_action.chunk(2, dim = -1)
|
|
132
|
+
|
|
133
|
+
std = (0.5 * log_var).exp()
|
|
134
|
+
sampled_action_intents = mean + torch.randn_like(mean) * std
|
|
135
|
+
|
|
136
|
+
# need to encourage normal distribution
|
|
137
|
+
|
|
138
|
+
vae_kl_loss = self.zero
|
|
139
|
+
|
|
140
|
+
if discovery_phase:
|
|
141
|
+
vae_kl_loss = (0.5 * (
|
|
142
|
+
log_var.exp()
|
|
143
|
+
+ mean.square()
|
|
144
|
+
- log_var
|
|
145
|
+
- 1.
|
|
146
|
+
)).sum(dim = -1).mean()
|
|
147
|
+
|
|
148
|
+
# switching unit timer
|
|
149
|
+
|
|
150
|
+
batch, _, dim = sampled_action_intents.shape
|
|
151
|
+
|
|
152
|
+
switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
|
|
153
|
+
|
|
154
|
+
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
155
|
+
|
|
156
|
+
action_intent_for_gating = rearrange(sampled_action_intents, 'b n d -> (b d) n')
|
|
157
|
+
switch_beta = repeat(switch_beta, 'b n 1 -> (b d) n', d = dim)
|
|
158
|
+
|
|
159
|
+
forget = 1. - switch_beta
|
|
160
|
+
gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
|
|
161
|
+
|
|
162
|
+
gated_action_intent = rearrange(gated_action_intent, '(b d) n -> b n d', b = batch)
|
|
163
|
+
|
|
164
|
+
# decoder
|
|
165
|
+
|
|
166
|
+
decoder_out = self.decoder(gated_action_intent)
|
|
167
|
+
|
|
168
|
+
w1, w2 = self.to_hyper_network_weights(decoder_out)
|
|
169
|
+
hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
|
|
170
|
+
|
|
171
|
+
# generating the residual stream controlling signal
|
|
172
|
+
|
|
173
|
+
control_signal = einsum(gated_action_intent, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
174
|
+
|
|
175
|
+
modified_residual_stream = residual_stream + control_signal
|
|
176
|
+
|
|
177
|
+
return modified_residual_stream, vae_kl_loss
|
|
178
|
+
|
|
179
|
+
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
180
|
+
|
|
181
|
+
class Transformer(Module):
|
|
182
|
+
def __init__(
|
|
183
|
+
self,
|
|
184
|
+
dim,
|
|
185
|
+
*,
|
|
186
|
+
embed: Embed | dict,
|
|
187
|
+
lower_body: Decoder | dict,
|
|
188
|
+
upper_body: Decoder | dict,
|
|
189
|
+
readout: Readout | dict,
|
|
190
|
+
meta_controller: MetaController | None = None
|
|
191
|
+
):
|
|
192
|
+
super().__init__()
|
|
193
|
+
|
|
194
|
+
if isinstance(embed, dict):
|
|
195
|
+
embed = Embed(dim = dim, **embed)
|
|
196
|
+
|
|
197
|
+
if isinstance(lower_body, dict):
|
|
198
|
+
lower_body = Decoder(dim = dim, **lower_body)
|
|
199
|
+
|
|
200
|
+
if isinstance(upper_body, dict):
|
|
201
|
+
upper_body = Decoder(dim = dim, **upper_body)
|
|
202
|
+
|
|
203
|
+
if isinstance(readout, dict):
|
|
204
|
+
readout = Readout(dim = dim, **readout)
|
|
205
|
+
|
|
206
|
+
self.embed = embed
|
|
207
|
+
self.lower_body = lower_body
|
|
208
|
+
self.upper_body = upper_body
|
|
209
|
+
self.readout = readout
|
|
210
|
+
|
|
211
|
+
# meta controller
|
|
212
|
+
|
|
213
|
+
self.meta_controller = meta_controller
|
|
214
|
+
|
|
215
|
+
def evolve(
|
|
216
|
+
self,
|
|
217
|
+
environment,
|
|
218
|
+
**kwargs
|
|
219
|
+
):
|
|
220
|
+
assert exists(self.meta_controller), '`meta_controller` must be defined on init for evolutionary strategies to be straightforwardly applied'
|
|
221
|
+
|
|
222
|
+
evo_strat = EvoStrategy(
|
|
223
|
+
self,
|
|
224
|
+
environment = environment,
|
|
225
|
+
params_to_optimize = self.meta_controller.internal_rl_parameters(),
|
|
226
|
+
**kwargs
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
evo_strat()
|
|
230
|
+
|
|
231
|
+
def forward(
|
|
232
|
+
self,
|
|
233
|
+
ids,
|
|
234
|
+
meta_controller: Module | None = None,
|
|
235
|
+
discovery_phase = False,
|
|
236
|
+
return_latents = False
|
|
237
|
+
):
|
|
238
|
+
meta_controller = default(meta_controller, self.meta_controller, Identity())
|
|
239
|
+
|
|
240
|
+
embed = self.embed(ids)
|
|
241
|
+
|
|
242
|
+
residual_stream = self.lower_body(embed)
|
|
243
|
+
|
|
244
|
+
# meta controller acts on residual stream here
|
|
245
|
+
|
|
246
|
+
modified_residual_stream, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
247
|
+
|
|
248
|
+
# modified residual stream sent back
|
|
249
|
+
|
|
250
|
+
attended = self.upper_body(modified_residual_stream)
|
|
251
|
+
|
|
252
|
+
dist_params = self.readout(attended)
|
|
253
|
+
|
|
254
|
+
if not return_latents:
|
|
255
|
+
return dist_params
|
|
256
|
+
|
|
257
|
+
return dist_params, latents
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: metacontroller-pytorch
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Transformer Metacontroller
|
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
|
+
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
|
8
|
+
License: MIT License
|
|
9
|
+
|
|
10
|
+
Copyright (c) 2025 Phil Wang
|
|
11
|
+
|
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
14
|
+
in the Software without restriction, including without limitation the rights
|
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
17
|
+
furnished to do so, subject to the following conditions:
|
|
18
|
+
|
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
|
20
|
+
copies or substantial portions of the Software.
|
|
21
|
+
|
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
SOFTWARE.
|
|
29
|
+
License-File: LICENSE
|
|
30
|
+
Keywords: artificial intelligence,deep learning,hierarchical reinforcement learning,latent steering
|
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
|
32
|
+
Classifier: Intended Audience :: Developers
|
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
34
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
|
+
Requires-Python: >=3.9
|
|
37
|
+
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: discrete-continuous-embed-readout>=0.1.11
|
|
39
|
+
Requires-Dist: einops>=0.8.1
|
|
40
|
+
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: torch>=2.5
|
|
42
|
+
Requires-Dist: x-evolution>=0.1.23
|
|
43
|
+
Requires-Dist: x-mlps-pytorch
|
|
44
|
+
Requires-Dist: x-transformers
|
|
45
|
+
Provides-Extra: examples
|
|
46
|
+
Provides-Extra: test
|
|
47
|
+
Requires-Dist: pytest; extra == 'test'
|
|
48
|
+
Description-Content-Type: text/markdown
|
|
49
|
+
|
|
50
|
+
<img src="./fig1.png" width="400px"></img>
|
|
51
|
+
|
|
52
|
+
## metacontroller (wip)
|
|
53
|
+
|
|
54
|
+
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
55
|
+
|
|
56
|
+
## Citations
|
|
57
|
+
|
|
58
|
+
```bibtex
|
|
59
|
+
@misc{kobayashi2025emergenttemporalabstractionsautoregressive,
|
|
60
|
+
title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
|
|
61
|
+
author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
|
|
62
|
+
year={2025},
|
|
63
|
+
eprint = {2512.20605},
|
|
64
|
+
archivePrefix = {arXiv},
|
|
65
|
+
primaryClass = {cs.LG},
|
|
66
|
+
url = {https://arxiv.org/abs/2512.20605},
|
|
67
|
+
}
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
```bibtex
|
|
71
|
+
@article{Wagenmaker2025SteeringYD,
|
|
72
|
+
title = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning},
|
|
73
|
+
author = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine},
|
|
74
|
+
journal = {ArXiv},
|
|
75
|
+
year = {2025},
|
|
76
|
+
volume = {abs/2506.15799},
|
|
77
|
+
url = {https://api.semanticscholar.org/CorpusID:279464702}
|
|
78
|
+
}
|
|
79
|
+
```
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=Pyv6iNGj8qyNKTF8AVHo9MK-Jg8g7A0xo5jZTeSU4Ys,7717
|
|
3
|
+
metacontroller_pytorch-0.0.1.dist-info/METADATA,sha256=PuTRLQAP7vDagLO4f_bSQJJDS1tbhltwcczoQg_QZzk,3706
|
|
4
|
+
metacontroller_pytorch-0.0.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Phil Wang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|