metacontroller-pytorch 0.0.40__tar.gz → 0.0.42__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (19) hide show
  1. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/PKG-INFO +87 -2
  2. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/README.md +86 -1
  3. metacontroller_pytorch-0.0.42/metacontroller/__init__.py +1 -0
  4. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller.py +18 -0
  5. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/pyproject.toml +1 -1
  6. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/tests/test_metacontroller.py +7 -15
  7. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/train_babyai.py +6 -7
  8. metacontroller_pytorch-0.0.40/metacontroller/__init__.py +0 -1
  9. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/.github/workflows/python-publish.yml +0 -0
  10. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/.github/workflows/test.yml +0 -0
  11. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/.gitignore +0 -0
  12. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/LICENSE +0 -0
  13. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/babyai_env.py +0 -0
  14. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/fig1.png +0 -0
  15. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/gather_babyai_trajs.py +0 -0
  16. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  17. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/metacontroller/transformer_with_resnet.py +0 -0
  18. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/test_babyai_e2e.sh +0 -0
  19. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.42}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.40
3
+ Version: 0.0.42
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./fig1.png" width="400px"></img>
55
55
 
56
- ## metacontroller (wip)
56
+ ## metacontroller
57
57
 
58
58
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
59
59
 
@@ -69,6 +69,91 @@ $ pip install metacontroller-pytorch
69
69
 
70
70
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
71
 
72
+ ## Usage
73
+
74
+ ```python
75
+ import torch
76
+ from metacontroller import Transformer, MetaController
77
+
78
+ # 1. initialize model
79
+
80
+ model = Transformer(
81
+ dim = 512,
82
+ action_embed_readout = dict(num_discrete = 4),
83
+ state_embed_readout = dict(num_continuous = 384),
84
+ lower_body = dict(depth = 2),
85
+ upper_body = dict(depth = 2)
86
+ )
87
+
88
+ state = torch.randn(2, 128, 384)
89
+ actions = torch.randint(0, 4, (2, 128))
90
+
91
+ # 2. behavioral cloning (BC)
92
+
93
+ state_loss, action_loss = model(state, actions)
94
+ (state_loss + action_loss).backward()
95
+
96
+ # 3. discovery phase
97
+
98
+ meta_controller = MetaController(
99
+ dim_model = 512,
100
+ dim_meta_controller = 256,
101
+ dim_latent = 128
102
+ )
103
+
104
+ action_recon_loss, kl_loss, switch_loss = model(
105
+ state,
106
+ actions,
107
+ meta_controller = meta_controller,
108
+ discovery_phase = True
109
+ )
110
+
111
+ (action_recon_loss + kl_loss + switch_loss).backward()
112
+
113
+ # 4. internal rl phase (GRPO)
114
+
115
+ # ... collect trajectories ...
116
+
117
+ logits, cache = model(
118
+ one_state,
119
+ past_action_id,
120
+ meta_controller = meta_controller,
121
+ return_cache = True
122
+ )
123
+
124
+ meta_output = cache.prev_hiddens.meta_controller
125
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
126
+
127
+ # ... calculate advantages ...
128
+
129
+ loss = meta_controller.policy_loss(
130
+ group_states,
131
+ group_old_log_probs,
132
+ group_latent_actions,
133
+ group_advantages,
134
+ group_switch_betas
135
+ )
136
+
137
+ loss.backward()
138
+ ```
139
+
140
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
141
+
142
+ ```python
143
+ # 5. evolve (ES over GRPO)
144
+
145
+ model.meta_controller = meta_controller
146
+
147
+ def environment_callable(model):
148
+ # return a fitness score
149
+ return 1.0
150
+
151
+ model.evolve(
152
+ num_generations = 10,
153
+ environment = environment_callable
154
+ )
155
+ ```
156
+
72
157
  ## Citations
73
158
 
74
159
  ```bibtex
@@ -1,6 +1,6 @@
1
1
  <img src="./fig1.png" width="400px"></img>
2
2
 
3
- ## metacontroller (wip)
3
+ ## metacontroller
4
4
 
5
5
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
6
6
 
@@ -16,6 +16,91 @@ $ pip install metacontroller-pytorch
16
16
 
17
17
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
18
18
 
19
+ ## Usage
20
+
21
+ ```python
22
+ import torch
23
+ from metacontroller import Transformer, MetaController
24
+
25
+ # 1. initialize model
26
+
27
+ model = Transformer(
28
+ dim = 512,
29
+ action_embed_readout = dict(num_discrete = 4),
30
+ state_embed_readout = dict(num_continuous = 384),
31
+ lower_body = dict(depth = 2),
32
+ upper_body = dict(depth = 2)
33
+ )
34
+
35
+ state = torch.randn(2, 128, 384)
36
+ actions = torch.randint(0, 4, (2, 128))
37
+
38
+ # 2. behavioral cloning (BC)
39
+
40
+ state_loss, action_loss = model(state, actions)
41
+ (state_loss + action_loss).backward()
42
+
43
+ # 3. discovery phase
44
+
45
+ meta_controller = MetaController(
46
+ dim_model = 512,
47
+ dim_meta_controller = 256,
48
+ dim_latent = 128
49
+ )
50
+
51
+ action_recon_loss, kl_loss, switch_loss = model(
52
+ state,
53
+ actions,
54
+ meta_controller = meta_controller,
55
+ discovery_phase = True
56
+ )
57
+
58
+ (action_recon_loss + kl_loss + switch_loss).backward()
59
+
60
+ # 4. internal rl phase (GRPO)
61
+
62
+ # ... collect trajectories ...
63
+
64
+ logits, cache = model(
65
+ one_state,
66
+ past_action_id,
67
+ meta_controller = meta_controller,
68
+ return_cache = True
69
+ )
70
+
71
+ meta_output = cache.prev_hiddens.meta_controller
72
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
73
+
74
+ # ... calculate advantages ...
75
+
76
+ loss = meta_controller.policy_loss(
77
+ group_states,
78
+ group_old_log_probs,
79
+ group_latent_actions,
80
+ group_advantages,
81
+ group_switch_betas
82
+ )
83
+
84
+ loss.backward()
85
+ ```
86
+
87
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
88
+
89
+ ```python
90
+ # 5. evolve (ES over GRPO)
91
+
92
+ model.meta_controller = meta_controller
93
+
94
+ def environment_callable(model):
95
+ # return a fitness score
96
+ return 1.0
97
+
98
+ model.evolve(
99
+ num_generations = 10,
100
+ environment = environment_callable
101
+ )
102
+ ```
103
+
19
104
  ## Citations
20
105
 
21
106
  ```bibtex
@@ -0,0 +1 @@
1
+ from metacontroller.metacontroller import MetaController, Transformer
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ GRPOOutput = namedtuple('GRPOOutput', (
70
+ 'state',
71
+ 'action',
72
+ 'log_prob',
73
+ 'switch_beta'
74
+ ))
75
+
69
76
  def z_score(t, eps = 1e-8):
70
77
  return (t - t.mean()) / (t.std() + eps)
71
78
 
@@ -107,6 +114,17 @@ def policy_loss(
107
114
 
108
115
  return masked_mean(losses, mask)
109
116
 
117
+ def extract_grpo_data(meta_controller, transformer_output):
118
+ meta_output = transformer_output.prev_hiddens.meta_controller
119
+
120
+ state = meta_output.input_residual_stream
121
+ action = meta_output.actions
122
+ switch_beta = meta_output.switch_beta
123
+
124
+ log_prob = meta_controller.log_prob(meta_output.action_dist, action)
125
+
126
+ return GRPOOutput(state, action, log_prob, switch_beta)
127
+
110
128
  @save_load()
111
129
  class MetaController(Module):
112
130
  def __init__(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.40"
3
+ version = "0.0.42"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from functools import partial
7
7
 
8
8
  import torch
9
9
  from torch import cat
10
- from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
10
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
11
11
  from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
12
12
 
13
13
  from memmap_replay_buffer import ReplayBuffer
@@ -109,10 +109,7 @@ def test_metacontroller(
109
109
  cache = None
110
110
  past_action_id = None
111
111
 
112
- states = []
113
- log_probs = []
114
- switch_betas = []
115
- latent_actions = []
112
+ grpo_data_list = []
116
113
 
117
114
  for one_state in subset_state.unbind(dim = 1):
118
115
  one_state = rearrange(one_state, 'b d -> b 1 d')
@@ -121,24 +118,19 @@ def test_metacontroller(
121
118
 
122
119
  past_action_id = model.action_readout.sample(logits)
123
120
 
124
- # get log prob from meta controller latent actions
121
+ # extract grpo data and store
125
122
 
126
- meta_output = cache.prev_hiddens.meta_controller
127
-
128
- old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
129
-
130
- states.append(meta_output.input_residual_stream)
131
- log_probs.append(old_log_probs)
132
- switch_betas.append(meta_output.switch_beta)
133
- latent_actions.append(meta_output.actions)
123
+ grpo_data_list.append(extract_grpo_data(meta_controller, cache))
134
124
 
135
125
  # accumulate across time for the episode data
136
126
 
127
+ states, actions, log_probs, switch_betas = zip(*grpo_data_list)
128
+
137
129
  all_episodes.append((
138
130
  cat(states, dim = 1),
139
131
  cat(log_probs, dim = 1),
140
132
  cat(switch_betas, dim = 1),
141
- cat(latent_actions, dim = 1)
133
+ cat(actions, dim = 1)
142
134
  ))
143
135
 
144
136
  all_rewards.append(torch.randn(1))
@@ -26,7 +26,7 @@ from accelerate import Accelerator
26
26
 
27
27
  from babyai_env import create_env
28
28
  from memmap_replay_buffer import ReplayBuffer
29
- from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
29
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
30
30
  from metacontroller.transformer_with_resnet import TransformerWithResnet
31
31
 
32
32
  # research entry point
@@ -195,13 +195,12 @@ def main(
195
195
 
196
196
  # GRPO collection
197
197
 
198
- meta_output = cache.prev_hiddens.meta_controller
199
- old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
198
+ grpo_data = extract_grpo_data(unwrapped_meta_controller, cache)
200
199
 
201
- states.append(meta_output.input_residual_stream)
202
- log_probs.append(old_log_probs)
203
- switch_betas.append(meta_output.switch_beta)
204
- latent_actions.append(meta_output.actions)
200
+ states.append(grpo_data.state)
201
+ log_probs.append(grpo_data.log_prob)
202
+ switch_betas.append(grpo_data.switch_beta)
203
+ latent_actions.append(grpo_data.action)
205
204
 
206
205
  next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
207
206
 
@@ -1 +0,0 @@
1
- from metacontroller.metacontroller import MetaController