metacontroller-pytorch 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +39 -14
- {metacontroller_pytorch-0.0.14.dist-info → metacontroller_pytorch-0.0.16.dist-info}/METADATA +2 -2
- metacontroller_pytorch-0.0.16.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.14.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.14.dist-info → metacontroller_pytorch-0.0.16.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.14.dist-info → metacontroller_pytorch-0.0.16.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -22,7 +22,7 @@ from x_transformers import Decoder
|
|
|
22
22
|
from x_mlps_pytorch import Feedforwards
|
|
23
23
|
from x_evolution import EvoStrategy
|
|
24
24
|
|
|
25
|
-
from discrete_continuous_embed_readout import Embed, Readout
|
|
25
|
+
from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
@@ -234,30 +234,25 @@ class Transformer(Module):
|
|
|
234
234
|
self,
|
|
235
235
|
dim,
|
|
236
236
|
*,
|
|
237
|
-
|
|
237
|
+
state_embed_readout: dict,
|
|
238
|
+
action_embed_readout: dict,
|
|
238
239
|
lower_body: Decoder | dict,
|
|
239
240
|
upper_body: Decoder | dict,
|
|
240
|
-
readout: Readout | dict,
|
|
241
241
|
meta_controller: MetaController | None = None
|
|
242
242
|
):
|
|
243
243
|
super().__init__()
|
|
244
244
|
|
|
245
|
-
if isinstance(embed, dict):
|
|
246
|
-
embed = Embed(dim = dim, **embed)
|
|
247
|
-
|
|
248
245
|
if isinstance(lower_body, dict):
|
|
249
246
|
lower_body = Decoder(dim = dim, **lower_body)
|
|
250
247
|
|
|
251
248
|
if isinstance(upper_body, dict):
|
|
252
249
|
upper_body = Decoder(dim = dim, **upper_body)
|
|
253
250
|
|
|
254
|
-
|
|
255
|
-
|
|
251
|
+
self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
|
|
252
|
+
self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
|
|
256
253
|
|
|
257
|
-
self.embed = embed
|
|
258
254
|
self.lower_body = lower_body
|
|
259
255
|
self.upper_body = upper_body
|
|
260
|
-
self.readout = readout
|
|
261
256
|
|
|
262
257
|
# meta controller
|
|
263
258
|
|
|
@@ -285,11 +280,13 @@ class Transformer(Module):
|
|
|
285
280
|
|
|
286
281
|
def forward(
|
|
287
282
|
self,
|
|
288
|
-
|
|
283
|
+
state,
|
|
284
|
+
action_ids,
|
|
289
285
|
meta_controller: Module | None = None,
|
|
290
286
|
cache: TransformerOutput | None = None,
|
|
291
287
|
discovery_phase = False,
|
|
292
288
|
meta_controller_temperature = 1.,
|
|
289
|
+
return_raw_action_dist = False,
|
|
293
290
|
return_latents = False,
|
|
294
291
|
return_cache = False,
|
|
295
292
|
):
|
|
@@ -297,21 +294,33 @@ class Transformer(Module):
|
|
|
297
294
|
|
|
298
295
|
meta_controlling = exists(meta_controller)
|
|
299
296
|
|
|
297
|
+
behavioral_cloning = not meta_controlling and not return_raw_action_dist
|
|
298
|
+
|
|
300
299
|
# by default, if meta controller is passed in, transformer is no grad
|
|
301
300
|
|
|
302
301
|
lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
|
|
303
302
|
meta_controller_context = nullcontext if meta_controlling else torch.no_grad
|
|
304
|
-
upper_transformer_context = nullcontext if meta_controlling
|
|
303
|
+
upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
|
|
305
304
|
|
|
306
305
|
# handle cache
|
|
307
306
|
|
|
308
307
|
lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
309
308
|
|
|
309
|
+
# handle maybe behavioral cloning
|
|
310
|
+
|
|
311
|
+
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
312
|
+
|
|
313
|
+
state, target_state = state[:, :-1], state[:, 1:]
|
|
314
|
+
action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
|
|
315
|
+
|
|
310
316
|
# transformer lower body
|
|
311
317
|
|
|
312
318
|
with lower_transformer_context():
|
|
313
319
|
|
|
314
|
-
|
|
320
|
+
state_embed = self.state_embed(state)
|
|
321
|
+
action_embed = self.action_embed(action_ids)
|
|
322
|
+
|
|
323
|
+
embed = state_embed + action_embed
|
|
315
324
|
|
|
316
325
|
residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
|
|
317
326
|
|
|
@@ -332,7 +341,23 @@ class Transformer(Module):
|
|
|
332
341
|
|
|
333
342
|
# head readout
|
|
334
343
|
|
|
335
|
-
dist_params = self.
|
|
344
|
+
dist_params = self.action_readout(attended)
|
|
345
|
+
|
|
346
|
+
# maybe return behavior cloning loss
|
|
347
|
+
|
|
348
|
+
if behavioral_cloning:
|
|
349
|
+
state_dist_params = self.state_readout(attended)
|
|
350
|
+
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
351
|
+
|
|
352
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
353
|
+
|
|
354
|
+
return state_clone_loss, action_clone_loss
|
|
355
|
+
|
|
356
|
+
elif meta_controlling and discovery_phase:
|
|
357
|
+
|
|
358
|
+
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
359
|
+
|
|
360
|
+
return action_recon_loss, next_meta_hiddens.kl_loss
|
|
336
361
|
|
|
337
362
|
# returning
|
|
338
363
|
|
{metacontroller_pytorch-0.0.14.dist-info → metacontroller_pytorch-0.0.16.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.16
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -60,7 +60,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
|
|
|
60
60
|
@misc{kobayashi2025emergenttemporalabstractionsautoregressive,
|
|
61
61
|
title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
|
|
62
62
|
author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
|
|
63
|
-
year={2025},
|
|
63
|
+
year = {2025},
|
|
64
64
|
eprint = {2512.20605},
|
|
65
65
|
archivePrefix = {arXiv},
|
|
66
66
|
primaryClass = {cs.LG},
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=BT7GH8F9NkEIYLEueBkkZ8glQ3Oht1FRoV84SIaTWdQ,11878
|
|
3
|
+
metacontroller_pytorch-0.0.16.dist-info/METADATA,sha256=eyECb3994X58zyExLnnffMl3pOoMlIb-WAUhepIt0r8,3741
|
|
4
|
+
metacontroller_pytorch-0.0.16.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.16.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=rQJLyXMHHNCZm0iohWqAkcMpYSi8b1Z0dgB-8AJVqqo,10751
|
|
3
|
-
metacontroller_pytorch-0.0.14.dist-info/METADATA,sha256=-CP3Ak1NPaqTpyF4tTgwn-T47Pv2OiPzPFxecwGe3Ng,3736
|
|
4
|
-
metacontroller_pytorch-0.0.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.14.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.14.dist-info → metacontroller_pytorch-0.0.16.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|