metacontroller-pytorch 0.0.15__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.15
3
+ Version: 0.0.16
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -60,7 +60,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
60
60
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
61
61
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
62
62
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
63
- year={2025},
63
+ year = {2025},
64
64
  eprint = {2512.20605},
65
65
  archivePrefix = {arXiv},
66
66
  primaryClass = {cs.LG},
@@ -10,7 +10,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
10
10
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
11
11
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
12
12
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
13
- year={2025},
13
+ year = {2025},
14
14
  eprint = {2512.20605},
15
15
  archivePrefix = {arXiv},
16
16
  primaryClass = {cs.LG},
@@ -308,7 +308,8 @@ class Transformer(Module):
308
308
 
309
309
  # handle maybe behavioral cloning
310
310
 
311
- if behavioral_cloning:
311
+ if behavioral_cloning or (meta_controlling and discovery_phase):
312
+
312
313
  state, target_state = state[:, :-1], state[:, 1:]
313
314
  action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
314
315
 
@@ -352,6 +353,12 @@ class Transformer(Module):
352
353
 
353
354
  return state_clone_loss, action_clone_loss
354
355
 
356
+ elif meta_controlling and discovery_phase:
357
+
358
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
359
+
360
+ return action_recon_loss, next_meta_hiddens.kl_loss
361
+
355
362
  # returning
356
363
 
357
364
  return_one = not (return_latents or return_cache)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.15"
3
+ version = "0.0.16"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -5,11 +5,9 @@ import torch
5
5
  from metacontroller.metacontroller import Transformer, MetaController
6
6
 
7
7
  @param('action_discrete', (False, True))
8
- @param('discovery_phase', (False, True))
9
8
  @param('switch_per_latent_dim', (False, True))
10
9
  def test_metacontroller(
11
10
  action_discrete,
12
- discovery_phase,
13
11
  switch_per_latent_dim
14
12
  ):
15
13
 
@@ -24,7 +22,7 @@ def test_metacontroller(
24
22
  action_embed_readout = dict(num_continuous = 8)
25
23
  assert_shape = (8, 2)
26
24
 
27
- # behavioral cloning pahse
25
+ # behavioral cloning phase
28
26
 
29
27
  model = Transformer(
30
28
  dim = 512,
@@ -44,14 +42,23 @@ def test_metacontroller(
44
42
  switch_per_latent_dim = switch_per_latent_dim
45
43
  )
46
44
 
47
- logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
45
+ # discovery phase
46
+
47
+ (action_recon_loss, kl_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
48
+ (action_recon_loss + kl_loss * 0.1).backward()
49
+
50
+ # internal rl
51
+
52
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True)
48
53
 
49
54
  assert logits.shape == (1, 1024, *assert_shape)
50
55
 
51
- logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
52
- logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
56
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
57
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
53
58
 
54
59
  assert logits.shape == (1, 1, *assert_shape)
55
60
 
61
+ # evolutionary strategies over grpo
62
+
56
63
  model.meta_controller = meta_controller
57
64
  model.evolve(1, lambda _: 1., noise_population_size = 2)