metacontroller-pytorch 0.0.15__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -60,7 +60,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
60
60
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
61
61
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
62
62
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
63
- year={2025},
63
+ year = {2025},
64
64
  eprint = {2512.20605},
65
65
  archivePrefix = {arXiv},
66
66
  primaryClass = {cs.LG},
@@ -10,7 +10,7 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
10
10
  @misc{kobayashi2025emergenttemporalabstractionsautoregressive,
11
11
  title = {Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning},
12
12
  author = {Seijin Kobayashi and Yanick Schimpf and Maximilian Schlegel and Angelika Steger and Maciej Wolczyk and Johannes von Oswald and Nino Scherrer and Kaitlin Maile and Guillaume Lajoie and Blake A. Richards and Rif A. Saurous and James Manyika and Blaise Agüera y Arcas and Alexander Meulemans and João Sacramento},
13
- year={2025},
13
+ year = {2025},
14
14
  eprint = {2512.20605},
15
15
  archivePrefix = {arXiv},
16
16
  primaryClass = {cs.LG},
@@ -57,7 +57,8 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
57
57
  'prev_hiddens',
58
58
  'action_dist',
59
59
  'actions',
60
- 'kl_loss'
60
+ 'kl_loss',
61
+ 'switch_loss'
61
62
  ))
62
63
 
63
64
  class MetaController(Module):
@@ -173,7 +174,7 @@ class MetaController(Module):
173
174
 
174
175
  # need to encourage normal distribution
175
176
 
176
- kl_loss = self.zero
177
+ kl_loss = switch_loss = self.zero
177
178
 
178
179
  if discovery_phase:
179
180
  mean, log_var = action_dist.unbind(dim = -1)
@@ -188,6 +189,10 @@ class MetaController(Module):
188
189
  kl_loss = kl_loss * switch_beta
189
190
  kl_loss = kl_loss.sum(dim = -1).mean()
190
191
 
192
+ # encourage less switching
193
+
194
+ switch_loss = switch_beta.mean()
195
+
191
196
  # maybe hard switch, then use associative scan
192
197
 
193
198
  if hard_switch:
@@ -220,7 +225,7 @@ class MetaController(Module):
220
225
  next_switch_gated_action
221
226
  )
222
227
 
223
- return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
228
+ return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
224
229
 
225
230
  # main transformer, which is subsumed into the environment after behavioral cloning
226
231
 
@@ -308,7 +313,8 @@ class Transformer(Module):
308
313
 
309
314
  # handle maybe behavioral cloning
310
315
 
311
- if behavioral_cloning:
316
+ if behavioral_cloning or (meta_controlling and discovery_phase):
317
+
312
318
  state, target_state = state[:, :-1], state[:, 1:]
313
319
  action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
314
320
 
@@ -352,6 +358,12 @@ class Transformer(Module):
352
358
 
353
359
  return state_clone_loss, action_clone_loss
354
360
 
361
+ elif meta_controlling and discovery_phase:
362
+
363
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
364
+
365
+ return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
366
+
355
367
  # returning
356
368
 
357
369
  return_one = not (return_latents or return_cache)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.15"
3
+ version = "0.0.17"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -5,11 +5,9 @@ import torch
5
5
  from metacontroller.metacontroller import Transformer, MetaController
6
6
 
7
7
  @param('action_discrete', (False, True))
8
- @param('discovery_phase', (False, True))
9
8
  @param('switch_per_latent_dim', (False, True))
10
9
  def test_metacontroller(
11
10
  action_discrete,
12
- discovery_phase,
13
11
  switch_per_latent_dim
14
12
  ):
15
13
 
@@ -24,7 +22,7 @@ def test_metacontroller(
24
22
  action_embed_readout = dict(num_continuous = 8)
25
23
  assert_shape = (8, 2)
26
24
 
27
- # behavioral cloning pahse
25
+ # behavioral cloning phase
28
26
 
29
27
  model = Transformer(
30
28
  dim = 512,
@@ -44,14 +42,23 @@ def test_metacontroller(
44
42
  switch_per_latent_dim = switch_per_latent_dim
45
43
  )
46
44
 
47
- logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
45
+ # discovery phase
46
+
47
+ (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
48
+ (action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
49
+
50
+ # internal rl
51
+
52
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True)
48
53
 
49
54
  assert logits.shape == (1, 1024, *assert_shape)
50
55
 
51
- logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
52
- logits, cache = model(state, actions, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
56
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
57
+ logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
53
58
 
54
59
  assert logits.shape == (1, 1, *assert_shape)
55
60
 
61
+ # evolutionary strategies over grpo
62
+
56
63
  model.meta_controller = meta_controller
57
64
  model.evolve(1, lambda _: 1., noise_population_size = 2)