metacontroller-pytorch 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ from metacontroller.metacontroller import MetaController
@@ -0,0 +1,362 @@
1
+ from __future__ import annotations
2
+ from contextlib import nullcontext
3
+
4
+ from functools import partial
5
+ from collections import namedtuple
6
+ from loguru import logger
7
+
8
+ import torch
9
+ from torch import nn, cat, stack, tensor
10
+ from torch.nn import Module, GRU, Linear, Identity
11
+ import torch.nn.functional as F
12
+
13
+ # einops
14
+
15
+ import einx
16
+ from einops import einsum, rearrange, repeat, reduce
17
+ from einops.layers.torch import Rearrange
18
+
19
+ # external modules
20
+
21
+ from x_transformers import Decoder
22
+ from x_mlps_pytorch import Feedforwards
23
+ from x_evolution import EvoStrategy
24
+
25
+ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
+
27
+ from assoc_scan import AssocScan
28
+
29
+ # constants
30
+
31
+ LinearNoBias = partial(Linear, bias = False)
32
+
33
+ GRU = partial(GRU, batch_first = True)
34
+
35
+ # helper functions
36
+
37
+ def exists(v):
38
+ return v is not None
39
+
40
+ def identity(t):
41
+ return t
42
+
43
+ def default(*args):
44
+ for arg in args:
45
+ if exists(arg):
46
+ return arg
47
+ return None
48
+
49
+ # tensor helpers
50
+
51
+ def straight_through(src, tgt):
52
+ return tgt + src - src.detach()
53
+
54
+ # meta controller
55
+
56
+ MetaControllerOutput = namedtuple('MetaControllerOutput', (
57
+ 'prev_hiddens',
58
+ 'action_dist',
59
+ 'actions',
60
+ 'kl_loss'
61
+ ))
62
+
63
+ class MetaController(Module):
64
+ def __init__(
65
+ self,
66
+ dim_latent,
67
+ *,
68
+ switch_per_latent_dim = True,
69
+ decoder_expansion_factor = 2.,
70
+ decoder_depth = 1,
71
+ hypernetwork_low_rank = 16,
72
+ assoc_scan_kwargs: dict = dict()
73
+ ):
74
+ super().__init__()
75
+
76
+ # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
77
+
78
+ self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
79
+
80
+ self.emitter = GRU(dim_latent * 2, dim_latent * 2)
81
+ self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
82
+
83
+ # internal rl phase substitutes the acausal + emitter with a causal ssm
84
+
85
+ self.action_proposer = GRU(dim_latent, dim_latent)
86
+ self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
87
+
88
+ # switching unit
89
+
90
+ self.switch_per_latent_dim = switch_per_latent_dim
91
+
92
+ self.switching_unit = GRU(dim_latent, dim_latent)
93
+ self.to_switching_unit_beta = nn.Linear(dim_latent, dim_latent if switch_per_latent_dim else 1, bias = False)
94
+
95
+ self.switch_gating = AssocScan(**assoc_scan_kwargs)
96
+
97
+ # decoder
98
+
99
+ assert hypernetwork_low_rank < dim_latent
100
+
101
+ dim_decoder_hidden = int(dim_latent * decoder_expansion_factor)
102
+
103
+ self.decoder = Feedforwards(
104
+ dim_in = dim_latent,
105
+ dim = dim_decoder_hidden,
106
+ depth = decoder_depth,
107
+ dim_out = 2 * hypernetwork_low_rank * dim_latent
108
+ )
109
+
110
+ self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
111
+
112
+ self.register_buffer('zero', tensor(0.), persistent = False)
113
+
114
+ def discovery_parameters(self):
115
+ return [
116
+ *self.bidirectional_temporal_compressor.parameters(),
117
+ *self.emitter.parameters(),
118
+ *self.emitter_to_action_mean_log_var.parameters(),
119
+ *self.decoder.parameters(),
120
+ *self.switch_gating.parameters()
121
+ ]
122
+
123
+ def internal_rl_parameters(self):
124
+ return [
125
+ *self.action_proposer.parameters(),
126
+ *self.action_proposer_mean_log_var.parameters()
127
+ ]
128
+
129
+ def forward(
130
+ self,
131
+ residual_stream,
132
+ cache: MetaControllerOutput | None = None,
133
+ discovery_phase = False,
134
+ hard_switch = False,
135
+ temperature = 1.
136
+ ):
137
+
138
+ # destruct prev cache
139
+
140
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
141
+
142
+ # getting proposed action for the two phases
143
+
144
+ next_action_proposer_hidden = None
145
+
146
+ if discovery_phase:
147
+ logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
148
+
149
+ temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
150
+ temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
151
+
152
+ proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
153
+ readout = self.emitter_to_action_mean_log_var
154
+
155
+ else: # else internal rl phase
156
+
157
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
158
+ readout = self.action_proposer_mean_log_var
159
+
160
+ # sample from the gaussian as the action from the meta controller
161
+
162
+ action_dist = readout(proposed_action_hidden)
163
+
164
+ sampled_action = readout.sample(action_dist, temperature = temperature)
165
+
166
+ # switching unit timer
167
+
168
+ batch, _, dim = sampled_action.shape
169
+
170
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
171
+
172
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
173
+
174
+ # need to encourage normal distribution
175
+
176
+ kl_loss = self.zero
177
+
178
+ if discovery_phase:
179
+ mean, log_var = action_dist.unbind(dim = -1)
180
+
181
+ kl_loss = (0.5 * (
182
+ log_var.exp()
183
+ + mean.square()
184
+ - log_var
185
+ - 1.
186
+ ))
187
+
188
+ kl_loss = kl_loss * switch_beta
189
+ kl_loss = kl_loss.sum(dim = -1).mean()
190
+
191
+ # maybe hard switch, then use associative scan
192
+
193
+ if hard_switch:
194
+ hard_switch_beta = (switch_beta > 0.5).float()
195
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
196
+
197
+ forget = 1. - switch_beta
198
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
199
+
200
+ next_switch_gated_action = gated_action[:, -1]
201
+
202
+ # decoder
203
+
204
+ decoder_out = self.decoder(gated_action)
205
+
206
+ w1, w2 = self.to_hyper_network_weights(decoder_out)
207
+ hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
208
+
209
+ # generating the residual stream controlling signal
210
+
211
+ control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
212
+
213
+ modified_residual_stream = residual_stream + control_signal
214
+
215
+ # returning
216
+
217
+ next_hiddens = (
218
+ next_action_proposer_hidden,
219
+ next_switching_unit_gru_hidden,
220
+ next_switch_gated_action
221
+ )
222
+
223
+ return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
224
+
225
+ # main transformer, which is subsumed into the environment after behavioral cloning
226
+
227
+ TransformerOutput = namedtuple('TransformerOutput', (
228
+ 'residual_stream_latent',
229
+ 'prev_hiddens'
230
+ ))
231
+
232
+ class Transformer(Module):
233
+ def __init__(
234
+ self,
235
+ dim,
236
+ *,
237
+ state_embed_readout: dict,
238
+ action_embed_readout: dict,
239
+ lower_body: Decoder | dict,
240
+ upper_body: Decoder | dict,
241
+ meta_controller: MetaController | None = None
242
+ ):
243
+ super().__init__()
244
+
245
+ if isinstance(lower_body, dict):
246
+ lower_body = Decoder(dim = dim, **lower_body)
247
+
248
+ if isinstance(upper_body, dict):
249
+ upper_body = Decoder(dim = dim, **upper_body)
250
+
251
+ self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
252
+ self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
253
+
254
+ self.lower_body = lower_body
255
+ self.upper_body = upper_body
256
+
257
+ # meta controller
258
+
259
+ self.meta_controller = meta_controller
260
+
261
+ self.register_buffer('zero', tensor(0.), persistent = False)
262
+
263
+ def evolve(
264
+ self,
265
+ num_generations,
266
+ environment,
267
+ **kwargs
268
+ ):
269
+ assert exists(self.meta_controller), '`meta_controller` must be passed in or defined on init for evolutionary strategies to be straightforwardly applied'
270
+
271
+ evo_strat = EvoStrategy(
272
+ self,
273
+ num_generations = num_generations,
274
+ environment = environment,
275
+ params_to_optimize = self.meta_controller.internal_rl_parameters(),
276
+ **kwargs
277
+ )
278
+
279
+ evo_strat()
280
+
281
+ def forward(
282
+ self,
283
+ state,
284
+ action_ids,
285
+ meta_controller: Module | None = None,
286
+ cache: TransformerOutput | None = None,
287
+ discovery_phase = False,
288
+ meta_controller_temperature = 1.,
289
+ return_raw_action_dist = False,
290
+ return_latents = False,
291
+ return_cache = False,
292
+ ):
293
+ meta_controller = default(meta_controller, self.meta_controller)
294
+
295
+ meta_controlling = exists(meta_controller)
296
+
297
+ behavioral_cloning = not meta_controlling and not return_raw_action_dist
298
+
299
+ # by default, if meta controller is passed in, transformer is no grad
300
+
301
+ lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
302
+ meta_controller_context = nullcontext if meta_controlling else torch.no_grad
303
+ upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
304
+
305
+ # handle cache
306
+
307
+ lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
308
+
309
+ # handle maybe behavioral cloning
310
+
311
+ if behavioral_cloning:
312
+ state, target_state = state[:, :-1], state[:, 1:]
313
+ action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
314
+
315
+ # transformer lower body
316
+
317
+ with lower_transformer_context():
318
+
319
+ state_embed = self.state_embed(state)
320
+ action_embed = self.action_embed(action_ids)
321
+
322
+ embed = state_embed + action_embed
323
+
324
+ residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
325
+
326
+ # meta controller acts on residual stream here
327
+
328
+ with meta_controller_context():
329
+
330
+ if exists(meta_controller):
331
+ modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
332
+ else:
333
+ modified_residual_stream, next_meta_hiddens = residual_stream, None
334
+
335
+ # modified residual stream sent back to transformer upper body
336
+
337
+ with upper_transformer_context():
338
+
339
+ attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
340
+
341
+ # head readout
342
+
343
+ dist_params = self.action_readout(attended)
344
+
345
+ # maybe return behavior cloning loss
346
+
347
+ if behavioral_cloning:
348
+ state_dist_params = self.state_readout(attended)
349
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
350
+
351
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
352
+
353
+ return state_clone_loss, action_clone_loss
354
+
355
+ # returning
356
+
357
+ return_one = not (return_latents or return_cache)
358
+
359
+ if return_one:
360
+ return dist_params
361
+
362
+ return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -0,0 +1,80 @@
1
+ Metadata-Version: 2.4
2
+ Name: metacontroller-pytorch
3
+ Version: 0.0.15
4
+ Summary: Transformer Metacontroller
5
+ Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
+ Project-URL: Repository, https://github.com/lucidrains/metacontroller
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,hierarchical reinforcement learning,latent steering
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: assoc-scan>=0.0.3
38
+ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
+ Requires-Dist: einops>=0.8.1
40
+ Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: loguru
42
+ Requires-Dist: torch>=2.5
43
+ Requires-Dist: x-evolution>=0.1.23
44
+ Requires-Dist: x-mlps-pytorch
45
+ Requires-Dist: x-transformers
46
+ Provides-Extra: examples
47
+ Provides-Extra: test
48
+ Requires-Dist: pytest; extra == 'test'
49
+ Description-Content-Type: text/markdown
50
+
51
+ <img src="./fig1.png" width="400px"></img>
52
+
53
+ ## metacontroller (wip)
54
+
55
+ Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
56
+
57
+ ## Citations
58
+
59
+ ```bibtex
60
+ @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
61
+ title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
62
+ author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
63
+ year={2025},
64
+ eprint = {2512.20605},
65
+ archivePrefix = {arXiv},
66
+ primaryClass = {cs.LG},
67
+ url = {https://arxiv.org/abs/2512.20605},
68
+ }
69
+ ```
70
+
71
+ ```bibtex
72
+ @article{Wagenmaker2025SteeringYD,
73
+ title = {Steering Your Diffusion Policy with Latent Space Reinforcement Learning},
74
+ author = {Andrew Wagenmaker and Mitsuhiko Nakamoto and Yunchu Zhang and Seohong Park and Waleed Yagoub and Anusha Nagabandi and Abhishek Gupta and Sergey Levine},
75
+ journal = {ArXiv},
76
+ year = {2025},
77
+ volume = {abs/2506.15799},
78
+ url = {https://api.semanticscholar.org/CorpusID:279464702}
79
+ }
80
+ ```
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=ug3xeMTZKApTF8oOPx9hWypeDjRflf1IJp8RiysXgTo,11618
3
+ metacontroller_pytorch-0.0.15.dist-info/METADATA,sha256=9d39BpcuVeOVVSD66lCVHCK1GjrkeKzRtxKOPOc-7xQ,3736
4
+ metacontroller_pytorch-0.0.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.15.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.