metacontroller-pytorch 0.0.14__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -60,7 +60,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
60
60
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
61
61
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
62
62
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
63
- year={2025},
63
+ year = {2025},
64
64
  eprint = {2512.20605},
65
65
  archivePrefix = {arXiv},
66
66
  primaryClass = {cs.LG},
@@ -10,7 +10,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
10
10
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
11
11
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
12
12
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
13
- year={2025},
13
+ year = {2025},
14
14
  eprint = {2512.20605},
15
15
  archivePrefix = {arXiv},
16
16
  primaryClass = {cs.LG},
@@ -22,7 +22,7 @@ from x_transformers import Decoder
22
22
  from x_mlps_pytorch import Feedforwards
23
23
  from x_evolution import EvoStrategy
24
24
 
25
- from discrete_continuous_embed_readout import Embed, Readout
25
+ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
@@ -234,30 +234,25 @@ class Transformer(Module):
234
234
  self,
235
235
  dim,
236
236
  *,
237
- embed: Embed | dict,
237
+ state_embed_readout: dict,
238
+ action_embed_readout: dict,
238
239
  lower_body: Decoder | dict,
239
240
  upper_body: Decoder | dict,
240
- readout: Readout | dict,
241
241
  meta_controller: MetaController | None = None
242
242
  ):
243
243
  super().__init__()
244
244
 
245
- if isinstance(embed, dict):
246
- embed = Embed(dim = dim, **embed)
247
-
248
245
  if isinstance(lower_body, dict):
249
246
  lower_body = Decoder(dim = dim, **lower_body)
250
247
 
251
248
  if isinstance(upper_body, dict):
252
249
  upper_body = Decoder(dim = dim, **upper_body)
253
250
 
254
- if isinstance(readout, dict):
255
- readout = Readout(dim = dim, **readout)
251
+ self.state_embed, self.state_readout = EmbedAndReadout(dim, **state_embed_readout)
252
+ self.action_embed, self.action_readout = EmbedAndReadout(dim, **action_embed_readout)
256
253
 
257
- self.embed = embed
258
254
  self.lower_body = lower_body
259
255
  self.upper_body = upper_body
260
- self.readout = readout
261
256
 
262
257
  # meta controller
263
258
 
@@ -285,11 +280,13 @@ class Transformer(Module):
285
280
 
286
281
  def forward(
287
282
  self,
288
- ids,
283
+ state,
284
+ action_ids,
289
285
  meta_controller: Module | None = None,
290
286
  cache: TransformerOutput | None = None,
291
287
  discovery_phase = False,
292
288
  meta_controller_temperature = 1.,
289
+ return_raw_action_dist = False,
293
290
  return_latents = False,
294
291
  return_cache = False,
295
292
  ):
@@ -297,21 +294,33 @@ class Transformer(Module):
297
294
 
298
295
  meta_controlling = exists(meta_controller)
299
296
 
297
+ behavioral_cloning = not meta_controlling and not return_raw_action_dist
298
+
300
299
  # by default, if meta controller is passed in, transformer is no grad
301
300
 
302
301
  lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
303
302
  meta_controller_context = nullcontext if meta_controlling else torch.no_grad
304
- upper_transformer_context = nullcontext if meta_controlling and discovery_phase else torch.no_grad
303
+ upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
305
304
 
306
305
  # handle cache
307
306
 
308
307
  lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
309
308
 
309
+ # handle maybe behavioral cloning
310
+
311
+ if behavioral_cloning or (meta_controlling and discovery_phase):
312
+
313
+ state, target_state = state[:, :-1], state[:, 1:]
314
+ action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
315
+
310
316
  # transformer lower body
311
317
 
312
318
  with lower_transformer_context():
313
319
 
314
- embed = self.embed(ids)
320
+ state_embed = self.state_embed(state)
321
+ action_embed = self.action_embed(action_ids)
322
+
323
+ embed = state_embed + action_embed
315
324
 
316
325
  residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
317
326
 
@@ -332,7 +341,23 @@ class Transformer(Module):
332
341
 
333
342
  # head readout
334
343
 
335
- dist_params = self.readout(attended)
344
+ dist_params = self.action_readout(attended)
345
+
346
+ # maybe return behavior cloning loss
347
+
348
+ if behavioral_cloning:
349
+ state_dist_params = self.state_readout(attended)
350
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
351
+
352
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
353
+
354
+ return state_clone_loss, action_clone_loss
355
+
356
+ elif meta_controlling and discovery_phase:
357
+
358
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
359
+
360
+ return action_recon_loss, next_meta_hiddens.kl_loss
336
361
 
337
362
  # returning
338
363
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.14"
3
+ version = "0.0.16"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,64 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from metacontroller.metacontroller import Transformer, MetaController
6
+
7
+ @param('action_discrete', (False, True))
8
+ @param('switch_per_latent_dim', (False, True))
9
+ def test_metacontroller(
10
+ action_discrete,
11
+ switch_per_latent_dim
12
+ ):
13
+
14
+ state = torch.randn(1, 1024, 384)
15
+
16
+ if action_discrete:
17
+ actions = torch.randint(0, 4, (1, 1024))
18
+ action_embed_readout = dict(num_discrete = 4)
19
+ assert_shape = (4,)
20
+ else:
21
+ actions = torch.randn(1, 1024, 8)
22
+ action_embed_readout = dict(num_continuous = 8)
23
+ assert_shape = (8, 2)
24
+
25
+ # behavioral cloning phase
26
+
27
+ model = Transformer(
28
+ dim = 512,
29
+ action_embed_readout = action_embed_readout,
30
+ state_embed_readout = dict(num_continuous = 384),
31
+ lower_body = dict(depth = 2,),
32
+ upper_body = dict(depth = 2,),
33
+ )
34
+
35
+ state_clone_loss, action_clone_loss = model(state, actions)
36
+ (state_clone_loss + 0.5 * action_clone_loss).backward()
37
+
38
+ # discovery and internal rl phase with meta controller
39
+
40
+ meta_controller = MetaController(
41
+ dim_latent = 512,
42
+ switch_per_latent_dim = switch_per_latent_dim
43
+ )
44
+
45
+ # discovery phase
46
+
47
+ (action_recon_loss, kl_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
48
+ (action_recon_loss + kl_loss * 0.1).backward()
49
+
50
+ # internal rl
51
+
52
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True)
53
+
54
+ assert logits.shape == (1, 1024, *assert_shape)
55
+
56
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
57
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
58
+
59
+ assert logits.shape == (1, 1, *assert_shape)
60
+
61
+ # evolutionary strategies over grpo
62
+
63
+ model.meta_controller = meta_controller
64
+ model.evolve(1, lambda _: 1., noise_population_size = 2)
@@ -1,39 +0,0 @@
1
- import pytest
2
- param = pytest.mark.parametrize
3
-
4
- import torch
5
- from metacontroller.metacontroller import Transformer, MetaController
6
-
7
- @param('discovery_phase', (False, True))
8
- @param('switch_per_latent_dim', (False, True))
9
- def test_metacontroller(
10
- discovery_phase,
11
- switch_per_latent_dim
12
- ):
13
-
14
- ids = torch.randint(0, 256, (1, 1024))
15
-
16
- model = Transformer(
17
- 512,
18
- embed = dict(num_discrete = 256),
19
- lower_body = dict(depth = 2,),
20
- upper_body = dict(depth = 2,),
21
- readout = dict(num_discrete = 256)
22
- )
23
-
24
- meta_controller = MetaController(
25
- 512,
26
- switch_per_latent_dim = switch_per_latent_dim
27
- )
28
-
29
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
30
-
31
- assert logits.shape == (1, 1024, 256)
32
-
33
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
34
- logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
35
-
36
- assert logits.shape == (1, 1, 256)
37
-
38
- model.meta_controller = meta_controller
39
- model.evolve(1, lambda _: 1., noise_population_size = 2)