metacontroller-pytorch 0.0.41__tar.gz → 0.0.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/PKG-INFO +2 -2
  2. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/README.md +1 -1
  3. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller.py +19 -1
  4. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller_with_binary_mapper.py +1 -1
  5. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/pyproject.toml +1 -1
  6. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/tests/test_metacontroller.py +16 -16
  7. metacontroller_pytorch-0.0.43/train_baby_evo_strat.py +213 -0
  8. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/train_babyai.py +21 -21
  9. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.github/workflows/python-publish.yml +0 -0
  10. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.github/workflows/test.yml +0 -0
  11. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/.gitignore +0 -0
  12. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/LICENSE +0 -0
  13. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/babyai_env.py +0 -0
  14. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/fig1.png +0 -0
  15. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/gather_babyai_trajs.py +0 -0
  16. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/__init__.py +0 -0
  17. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/metacontroller/transformer_with_resnet.py +0 -0
  18. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/test_babyai_e2e.sh +0 -0
  19. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.43}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.41
3
+ Version: 0.0.43
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./fig1.png" width="400px"></img>
55
55
 
56
- ## metacontroller (wip)
56
+ ## metacontroller
57
57
 
58
58
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
59
59
 
@@ -1,6 +1,6 @@
1
1
  <img src="./fig1.png" width="400px"></img>
2
2
 
3
- ## metacontroller (wip)
3
+ ## metacontroller
4
4
 
5
5
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
6
6
 
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ GRPOOutput = namedtuple('GRPOOutput', (
70
+ 'state',
71
+ 'action',
72
+ 'log_prob',
73
+ 'switch_beta'
74
+ ))
75
+
69
76
  def z_score(t, eps = 1e-8):
70
77
  return (t - t.mean()) / (t.std() + eps)
71
78
 
@@ -107,6 +114,17 @@ def policy_loss(
107
114
 
108
115
  return masked_mean(losses, mask)
109
116
 
117
+ def extract_grpo_data(meta_controller, transformer_output):
118
+ meta_output = transformer_output.prev_hiddens.meta_controller
119
+
120
+ state = meta_output.input_residual_stream
121
+ action = meta_output.actions
122
+ switch_beta = meta_output.switch_beta
123
+
124
+ log_prob = meta_controller.log_prob(meta_output.action_dist, action)
125
+
126
+ return GRPOOutput(state, action, log_prob, switch_beta)
127
+
110
128
  @save_load()
111
129
  class MetaController(Module):
112
130
  def __init__(
@@ -273,7 +291,7 @@ class MetaController(Module):
273
291
  else:
274
292
  # else during inference, use the previous sampled latent action
275
293
 
276
- assert seq_len == 1, f'inference RL phase must be done one token at a time'
294
+ assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
277
295
  z_prev = prev_sampled_latent_action
278
296
 
279
297
  # switch input is previous latent action and the embedding
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
241
241
  if discovery_phase:
242
242
  z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
243
243
  else:
244
- assert seq_len == 1, f'inference RL phase must be done one token at a time'
244
+ assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
245
245
  z_prev = prev_sampled_code
246
246
 
247
247
  switch_input = torch.cat((meta_embed, z_prev), dim=-1)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.41"
3
+ version = "0.0.43"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from functools import partial
7
7
 
8
8
  import torch
9
9
  from torch import cat
10
- from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
10
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
11
11
  from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
12
12
 
13
13
  from memmap_replay_buffer import ReplayBuffer
@@ -109,36 +109,29 @@ def test_metacontroller(
109
109
  cache = None
110
110
  past_action_id = None
111
111
 
112
- states = []
113
- log_probs = []
114
- switch_betas = []
115
- latent_actions = []
112
+ grpo_data_list = []
116
113
 
117
114
  for one_state in subset_state.unbind(dim = 1):
118
115
  one_state = rearrange(one_state, 'b d -> b 1 d')
119
116
 
120
- logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
117
+ logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, cache = cache, return_cache = True)
121
118
 
122
119
  past_action_id = model.action_readout.sample(logits)
123
120
 
124
- # get log prob from meta controller latent actions
121
+ # extract grpo data and store
125
122
 
126
- meta_output = cache.prev_hiddens.meta_controller
127
-
128
- old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
129
-
130
- states.append(meta_output.input_residual_stream)
131
- log_probs.append(old_log_probs)
132
- switch_betas.append(meta_output.switch_beta)
133
- latent_actions.append(meta_output.actions)
123
+ grpo_data = extract_grpo_data(meta_controller, cache)
124
+ grpo_data_list.append(grpo_data)
134
125
 
135
126
  # accumulate across time for the episode data
136
127
 
128
+ states, actions, log_probs, switch_betas = zip(*grpo_data_list)
129
+
137
130
  all_episodes.append((
138
131
  cat(states, dim = 1),
139
132
  cat(log_probs, dim = 1),
140
133
  cat(switch_betas, dim = 1),
141
- cat(latent_actions, dim = 1)
134
+ cat(actions, dim = 1)
142
135
  ))
143
136
 
144
137
  all_rewards.append(torch.randn(1))
@@ -153,6 +146,13 @@ def test_metacontroller(
153
146
  # simulate a policy loss update over the entire group
154
147
 
155
148
  group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
149
+
150
+ # parallel verification
151
+
152
+ parallel_action_dist = meta_controller.get_action_dist_for_internal_rl(group_states)
153
+ parallel_log_probs = meta_controller.log_prob(parallel_action_dist, group_latent_actions)
154
+
155
+ assert torch.allclose(parallel_log_probs, group_log_probs, atol = 1e-5), 'parallel log probs do not match stored log probs'
156
156
 
157
157
  for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
158
158
  replay_buffer.store_episode(
@@ -0,0 +1,213 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "fire",
4
+ # "gymnasium",
5
+ # "gymnasium[other]",
6
+ # "metacontroller-pytorch",
7
+ # "minigrid",
8
+ # "tqdm",
9
+ # "x-evolution",
10
+ # "einops"
11
+ # ]
12
+ # ///
13
+
14
+ from __future__ import annotations
15
+ import fire
16
+ from pathlib import Path
17
+ from shutil import rmtree
18
+ import numpy as np
19
+
20
+ import torch
21
+ from torch import nn, Tensor, tensor
22
+ from torch.nn import Module
23
+ from einops import rearrange
24
+
25
+ from babyai_env import create_env
26
+ from metacontroller.metacontroller import Transformer, MetaController
27
+
28
+ # functions
29
+
30
+ def exists(v):
31
+ return v is not None
32
+
33
+ def default(v, d):
34
+ return v if exists(v) else d
35
+
36
+ # default fitness function
37
+
38
+ def default_fitness_fn(
39
+ rewards: list[float],
40
+ states: list[any],
41
+ actions: list[any],
42
+ next_states: list[any],
43
+ infos: list[any]
44
+ ) -> float:
45
+ """
46
+ researchers can modify this function to engineer their own rewards and fitness scores
47
+ processing the entire episode at once for every noise vector of the population separately
48
+ """
49
+ return sum(rewards)
50
+
51
+ # babyai environment for ES
52
+
53
+ class BabyAIEnvironment(Module):
54
+ def __init__(
55
+ self,
56
+ env_id = 'BabyAI-BossLevel-v0',
57
+ video_folder = './recordings_babyai_es',
58
+ render_every_eps = 100,
59
+ max_steps = 500,
60
+ use_resnet = False,
61
+ fitness_fn = default_fitness_fn
62
+ ):
63
+ super().__init__()
64
+
65
+ self.env_id = env_id
66
+ self.video_folder = video_folder
67
+ self.render_every_eps = render_every_eps
68
+ self.max_steps = max_steps
69
+ self.use_resnet = use_resnet
70
+ self.fitness_fn = fitness_fn
71
+
72
+ # initial env creation for observation space etc. if needed
73
+ # but create_env is called inside pre_main_callback or reset
74
+ self.env = None
75
+
76
+ def pre_main_callback(self):
77
+ # clean up and initialize environment
78
+ rmtree(self.video_folder, ignore_errors = True)
79
+
80
+ self.env = create_env(
81
+ self.env_id,
82
+ render_mode = 'rgb_array',
83
+ video_folder = self.video_folder,
84
+ render_every_eps = self.render_every_eps
85
+ )
86
+
87
+ def forward(self, model):
88
+ device = next(model.parameters()).device
89
+
90
+ seed = torch.randint(0, int(1e6), ()).item()
91
+ state, _ = self.env.reset(seed = seed)
92
+
93
+ step = 0
94
+ cache = None
95
+ past_action_id = None
96
+
97
+ unwrapped_model = getattr(model, 'model', model)
98
+
99
+ episode_rewards = []
100
+ episode_states = []
101
+ episode_actions = []
102
+ episode_next_states = []
103
+ episode_infos = []
104
+
105
+ while step < self.max_steps:
106
+ image = state['image']
107
+ image_tensor = torch.from_numpy(image).float().to(device)
108
+
109
+ if self.use_resnet:
110
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
111
+ image_tensor = unwrapped_model.visual_encode(image_tensor)
112
+ else:
113
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
114
+
115
+ if torch.is_tensor(past_action_id):
116
+ past_action_id = past_action_id.long()
117
+
118
+ with torch.no_grad():
119
+ logits, cache = model(
120
+ image_tensor,
121
+ past_action_id,
122
+ return_cache = True,
123
+ return_raw_action_dist = True,
124
+ cache = cache
125
+ )
126
+
127
+ action = unwrapped_model.action_readout.sample(logits)
128
+ past_action_id = action
129
+ action_id = action.squeeze()
130
+
131
+ next_state, reward, terminated, truncated, info = self.env.step(action_id.cpu().numpy().item())
132
+
133
+ episode_rewards.append(reward)
134
+ episode_states.append(state)
135
+ episode_actions.append(action_id)
136
+ episode_next_states.append(next_state)
137
+ episode_infos.append(info)
138
+
139
+ done = terminated or truncated
140
+ if done:
141
+ break
142
+
143
+ state = next_state
144
+ step += 1
145
+
146
+ return self.fitness_fn(
147
+ episode_rewards,
148
+ episode_states,
149
+ episode_actions,
150
+ episode_next_states,
151
+ episode_infos
152
+ )
153
+
154
+ def main(
155
+ env_id = 'BabyAI-BossLevel-v0',
156
+ num_generations = 100,
157
+ max_steps = 500,
158
+ render_every_eps = 100,
159
+ video_folder = './recordings_babyai_es',
160
+ transformer_weights_path: str | None = None,
161
+ meta_controller_weights_path: str | None = None,
162
+ output_meta_controller_path = 'metacontroller_es_trained.pt',
163
+ use_resnet = False,
164
+ noise_population_size = 50,
165
+ noise_scale = 1e-2,
166
+ learning_rate = 1e-3,
167
+ fitness_fn = default_fitness_fn
168
+ ):
169
+ # load model
170
+
171
+ assert exists(transformer_weights_path), "Transformer weights must be provided"
172
+
173
+ # lazy import to avoid unnecessary dependencies if not used
174
+ from metacontroller.transformer_with_resnet import TransformerWithResnet as TransformerResnet
175
+ transformer_klass = TransformerResnet if use_resnet else Transformer
176
+
177
+ model = transformer_klass.init_and_load(transformer_weights_path, strict = False)
178
+ model.eval()
179
+
180
+ if exists(meta_controller_weights_path):
181
+ meta_controller = MetaController.init_and_load(meta_controller_weights_path, strict = False)
182
+ model.meta_controller = meta_controller
183
+
184
+ assert exists(model.meta_controller), "MetaController must be present for evolution"
185
+
186
+ # setup environment
187
+
188
+ babyai_env = BabyAIEnvironment(
189
+ env_id = env_id,
190
+ video_folder = video_folder,
191
+ render_every_eps = render_every_eps,
192
+ max_steps = max_steps,
193
+ use_resnet = use_resnet,
194
+ fitness_fn = fitness_fn
195
+ )
196
+
197
+ # evolve
198
+
199
+ model.evolve(
200
+ num_generations = num_generations,
201
+ environment = babyai_env,
202
+ noise_population_size = noise_population_size,
203
+ noise_scale = noise_scale,
204
+ learning_rate = learning_rate
205
+ )
206
+
207
+ # save
208
+
209
+ model.meta_controller.save(output_meta_controller_path)
210
+ print(f'MetaController weights saved to {output_meta_controller_path}')
211
+
212
+ if __name__ == '__main__':
213
+ fire.Fire(main)
@@ -26,7 +26,7 @@ from accelerate import Accelerator
26
26
 
27
27
  from babyai_env import create_env
28
28
  from memmap_replay_buffer import ReplayBuffer
29
- from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
29
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
30
30
  from metacontroller.transformer_with_resnet import TransformerWithResnet
31
31
 
32
32
  # research entry point
@@ -57,22 +57,23 @@ def default(v, d):
57
57
  # main
58
58
 
59
59
  def main(
60
- env_name: str = 'BabyAI-BossLevel-v0',
61
- num_episodes: int = int(10e6),
62
- max_timesteps: int = 500,
63
- buffer_size: int = 5_000,
64
- render_every_eps: int = 1_000,
65
- video_folder: str = './recordings',
60
+ env_name = 'BabyAI-BossLevel-v0',
61
+ num_episodes = int(10e6),
62
+ max_timesteps = 500,
63
+ buffer_size = 5_000,
64
+ render_every_eps = 1_000,
65
+ video_folder = './recordings',
66
66
  seed: int | None = None,
67
67
  transformer_weights_path: str | None = None,
68
68
  meta_controller_weights_path: str | None = None,
69
- output_meta_controller_path: str = 'metacontroller_rl_trained.pt',
70
- use_resnet: bool = False,
71
- lr: float = 1e-4,
72
- num_groups: int = 16,
73
- max_grad_norm: float = 1.0,
74
- use_wandb: bool = False,
75
- wandb_project: str = 'metacontroller-babyai-rl'
69
+ output_meta_controller_path = 'metacontroller_rl_trained.pt',
70
+ use_resnet = False,
71
+ lr = 1e-4,
72
+ batch_size = 16,
73
+ num_groups = 16,
74
+ max_grad_norm = 1.0,
75
+ use_wandb = False,
76
+ wandb_project = 'metacontroller-babyai-rl'
76
77
  ):
77
78
  # accelerator
78
79
 
@@ -195,13 +196,12 @@ def main(
195
196
 
196
197
  # GRPO collection
197
198
 
198
- meta_output = cache.prev_hiddens.meta_controller
199
- old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
199
+ grpo_data = extract_grpo_data(unwrapped_meta_controller, cache)
200
200
 
201
- states.append(meta_output.input_residual_stream)
202
- log_probs.append(old_log_probs)
203
- switch_betas.append(meta_output.switch_beta)
204
- latent_actions.append(meta_output.actions)
201
+ states.append(grpo_data.state)
202
+ log_probs.append(grpo_data.log_prob)
203
+ switch_betas.append(grpo_data.switch_beta)
204
+ latent_actions.append(grpo_data.action)
205
205
 
206
206
  next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
207
207
 
@@ -264,7 +264,7 @@ def main(
264
264
  # learn
265
265
 
266
266
  if len(replay_buffer) >= buffer_size:
267
- dl = replay_buffer.dataloader(batch_size = num_groups)
267
+ dl = replay_buffer.dataloader(batch_size = batch_size)
268
268
  dl = accelerator.prepare(dl)
269
269
 
270
270
  meta_controller.train()