metacontroller-pytorch 0.0.42__tar.gz → 0.0.44__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.gitignore +6 -0
  2. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/PKG-INFO +1 -1
  3. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/gather_babyai_trajs.py +81 -9
  4. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/metacontroller.py +13 -2
  5. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/metacontroller_with_binary_mapper.py +1 -1
  6. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/pyproject.toml +1 -1
  7. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/tests/test_metacontroller.py +10 -2
  8. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/train_babyai.py +20 -17
  9. metacontroller_pytorch-0.0.44/train_babyai_evo_strat.py +213 -0
  10. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/train_behavior_clone_babyai.py +51 -22
  11. metacontroller_pytorch-0.0.42/test_babyai_e2e.sh +0 -35
  12. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.github/workflows/python-publish.yml +0 -0
  13. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.github/workflows/test.yml +0 -0
  14. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/LICENSE +0 -0
  15. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/README.md +0 -0
  16. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/babyai_env.py +0 -0
  17. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/fig1.png +0 -0
  18. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/__init__.py +0 -0
  19. {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/transformer_with_resnet.py +0 -0
@@ -1,5 +1,11 @@
1
1
  replay-data/
2
2
  recordings/
3
+ trajectories/
4
+ wandb/
5
+ checkpoints/
6
+ *.sh
7
+ *.out
8
+ *.slurm
3
9
 
4
10
  # Byte-compiled / optimized / DLL files
5
11
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.42
3
+ Version: 0.0.44
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -37,6 +37,9 @@ from minigrid.core.constants import OBJECT_TO_IDX
37
37
 
38
38
  from memmap_replay_buffer import ReplayBuffer
39
39
 
40
+ # Difficulty thresholds based on mission length
41
+ EASY_MAX_LENGTH = 30 # easy: 0 to 30
42
+ MEDIUM_MAX_LENGTH = 75 # medium: 30 to 75, hard: > 75
40
43
 
41
44
  # helpers
42
45
 
@@ -46,6 +49,67 @@ def exists(val):
46
49
  def sample(prob):
47
50
  return random.random() < prob
48
51
 
52
+ def get_mission_length(env_id, seed):
53
+ """
54
+ Get the mission length for a given seed.
55
+ Returns the length of the mission string.
56
+ """
57
+ env = gym.make(env_id, render_mode="rgb_array")
58
+ env.reset(seed=seed)
59
+ length = len(env.unwrapped.mission)
60
+ env.close()
61
+ return length
62
+
63
+ def categorize_seeds_by_difficulty(env_id, num_seeds_per_level, level_difficulty=None):
64
+ """
65
+ Scan seeds and categorize them by difficulty based on mission length.
66
+
67
+ Args:
68
+ env_id: Environment ID
69
+ num_seeds_per_level: Number of seeds needed per difficulty level
70
+ level_difficulty: List of levels to collect seeds for.
71
+ Supported: 'easy', 'medium', 'hard'
72
+ If None, collects for ['easy', 'hard'].
73
+ max_seed_to_scan: Maximum seed value to scan
74
+
75
+ Returns:
76
+ dict with keys for each requested level, each containing a list of seeds
77
+ """
78
+
79
+ seeds = {level: [] for level in level_difficulty}
80
+
81
+ total_needed = sum(num_seeds_per_level for _ in level_difficulty)
82
+ print(f"Scanning seeds to categorize by difficulty (need {num_seeds_per_level} per level for {level_difficulty})...")
83
+
84
+ with tqdm(total=total_needed, desc="Categorizing seeds") as pbar:
85
+ seed = 1
86
+ all_done = False
87
+ while not all_done:
88
+ # Check if we have enough seeds for all requested levels
89
+ all_done = all(len(seeds[level]) >= num_seeds_per_level for level in level_difficulty)
90
+
91
+ try:
92
+ mission_length = get_mission_length(env_id, seed)
93
+
94
+ # easy: mission length <= 30
95
+ if 'easy' in level_difficulty and mission_length <= EASY_MAX_LENGTH and len(seeds['easy']) < num_seeds_per_level:
96
+ seeds['easy'].append(seed)
97
+ pbar.update(1)
98
+ # medium: mission length <= 75 (combines easy and medium)
99
+ elif 'medium' in level_difficulty and mission_length <= MEDIUM_MAX_LENGTH and len(seeds['medium']) < num_seeds_per_level:
100
+ seeds['medium'].append(seed)
101
+ pbar.update(1)
102
+ # hard: mission length > 75
103
+ elif 'hard' in level_difficulty and mission_length > MEDIUM_MAX_LENGTH and len(seeds['hard']) < num_seeds_per_level:
104
+ seeds['hard'].append(seed)
105
+ pbar.update(1)
106
+ except Exception as e:
107
+ logger.warning(f"Error getting mission length for seed {seed}: {e}")
108
+
109
+ seed += 1
110
+
111
+ return seeds
112
+
49
113
  # wrapper, necessarily modified to allow for both rgb obs (policy) and symbolic obs (bot)
50
114
 
51
115
  class RGBImgPartialObsWrapper(ObservationWrapper):
@@ -128,7 +192,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
128
192
  env.close()
129
193
  return None, None, False, 0
130
194
 
131
- episode_state[_step] = state_obs["rgb_image"] / 255. # normalizd to 0 to 1
195
+ episode_state[_step] = state_obs["rgb_image"]
132
196
  episode_action[_step] = action
133
197
 
134
198
  state_obs, reward, terminated, truncated, info = env.step(action)
@@ -151,6 +215,7 @@ def collect_demonstrations(
151
215
  num_steps = 500,
152
216
  random_action_prob = 0.05,
153
217
  num_workers = None,
218
+ difficulty = "easy",
154
219
  output_dir = "babyai-minibosslevel-trajectories"
155
220
  ):
156
221
  """
@@ -178,11 +243,9 @@ def collect_demonstrations(
178
243
 
179
244
  total_episodes = num_seeds * num_episodes_per_seed
180
245
 
181
- # Prepare seeds for all episodes
182
- seeds = []
183
- for count in range(num_seeds):
184
- for it in range(num_episodes_per_seed):
185
- seeds.append(count + 1)
246
+ # Collect seeds by difficulty
247
+ assert difficulty in ['easy', 'medium', 'hard']
248
+ seeds = categorize_seeds_by_difficulty(env_id, num_seeds_per_level=num_seeds, level_difficulty=[difficulty])
186
249
 
187
250
  successful = 0
188
251
  progressbar = tqdm(total=total_episodes)
@@ -203,14 +266,17 @@ def collect_demonstrations(
203
266
  )
204
267
 
205
268
  # Parallel execution with bounded pending futures to avoid OOM
206
- max_pending = num_workers * 4
269
+ max_pending = num_workers
270
+
271
+ # Flatten seeds: repeat each seed num_episodes_per_seed times
272
+ all_seeds = seeds[difficulty] * num_episodes_per_seed
207
273
 
208
274
  with ProcessPoolExecutor(max_workers=num_workers) as executor:
209
- seed_iter = iter(seeds)
275
+ seed_iter = iter(all_seeds)
210
276
  futures = {}
211
277
 
212
278
  # Initial batch of submissions
213
- for _ in range(min(max_pending, len(seeds))):
279
+ for _ in range(min(max_pending, len(all_seeds))):
214
280
  seed = next(seed_iter, None)
215
281
  if exists(seed):
216
282
  future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
@@ -244,7 +310,13 @@ def collect_demonstrations(
244
310
  buffer.flush()
245
311
  progressbar.close()
246
312
 
313
+ # Save the seeds used for reproducibility
314
+ seeds_array = np.array(seeds[difficulty])
315
+ seeds_path = output_folder / "seeds.npy"
316
+ np.save(seeds_path, seeds_array)
317
+
247
318
  logger.info(f"Saved {successful} trajectories to {output_dir}")
319
+ logger.info(f"Saved {len(seeds_array)} seeds to {seeds_path}")
248
320
 
249
321
  if __name__ == "__main__":
250
322
  fire.Fire(collect_demonstrations)
@@ -291,7 +291,7 @@ class MetaController(Module):
291
291
  else:
292
292
  # else during inference, use the previous sampled latent action
293
293
 
294
- assert seq_len == 1, f'inference RL phase must be done one token at a time'
294
+ assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
295
295
  z_prev = prev_sampled_latent_action
296
296
 
297
297
  # switch input is previous latent action and the embedding
@@ -407,10 +407,19 @@ class Transformer(Module):
407
407
 
408
408
  # meta controller
409
409
 
410
- self.meta_controller = meta_controller
410
+ self.meta_controller = meta_controller
411
411
 
412
412
  self.register_buffer('zero', tensor(0.), persistent = False)
413
413
 
414
+ # ensure devices match
415
+
416
+ if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
417
+
418
+ def _ensure_consistent_device(self, network):
419
+ self.model_device = next(self.parameters()).device
420
+ if next(network.parameters()).device != self.model_device:
421
+ network.to(self.model_device)
422
+
414
423
  def evolve(
415
424
  self,
416
425
  num_generations,
@@ -447,6 +456,8 @@ class Transformer(Module):
447
456
 
448
457
  # meta controller is either given or already given at init
449
458
 
459
+ if exists(meta_controller): self._ensure_consistent_device(meta_controller)
460
+
450
461
  meta_controller = default(meta_controller, self.meta_controller)
451
462
 
452
463
  if force_behavior_cloning:
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
241
241
  if discovery_phase:
242
242
  z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
243
243
  else:
244
- assert seq_len == 1, f'inference RL phase must be done one token at a time'
244
+ assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
245
245
  z_prev = prev_sampled_code
246
246
 
247
247
  switch_input = torch.cat((meta_embed, z_prev), dim=-1)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.42"
3
+ version = "0.0.44"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -114,13 +114,14 @@ def test_metacontroller(
114
114
  for one_state in subset_state.unbind(dim = 1):
115
115
  one_state = rearrange(one_state, 'b d -> b 1 d')
116
116
 
117
- logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
117
+ logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, cache = cache, return_cache = True)
118
118
 
119
119
  past_action_id = model.action_readout.sample(logits)
120
120
 
121
121
  # extract grpo data and store
122
122
 
123
- grpo_data_list.append(extract_grpo_data(meta_controller, cache))
123
+ grpo_data = extract_grpo_data(meta_controller, cache)
124
+ grpo_data_list.append(grpo_data)
124
125
 
125
126
  # accumulate across time for the episode data
126
127
 
@@ -145,6 +146,13 @@ def test_metacontroller(
145
146
  # simulate a policy loss update over the entire group
146
147
 
147
148
  group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
149
+
150
+ # parallel verification
151
+
152
+ parallel_action_dist = meta_controller.get_action_dist_for_internal_rl(group_states)
153
+ parallel_log_probs = meta_controller.log_prob(parallel_action_dist, group_latent_actions)
154
+
155
+ assert torch.allclose(parallel_log_probs, group_log_probs, atol = 1e-5), 'parallel log probs do not match stored log probs'
148
156
 
149
157
  for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
150
158
  replay_buffer.store_episode(
@@ -6,7 +6,8 @@
6
6
  # "memmap-replay-buffer>=0.0.12",
7
7
  # "metacontroller-pytorch",
8
8
  # "minigrid",
9
- # "tqdm"
9
+ # "tqdm",
10
+ # "wandb"
10
11
  # ]
11
12
  # ///
12
13
 
@@ -57,22 +58,23 @@ def default(v, d):
57
58
  # main
58
59
 
59
60
  def main(
60
- env_name: str = 'BabyAI-BossLevel-v0',
61
- num_episodes: int = int(10e6),
62
- max_timesteps: int = 500,
63
- buffer_size: int = 5_000,
64
- render_every_eps: int = 1_000,
65
- video_folder: str = './recordings',
61
+ env_name = 'BabyAI-BossLevel-v0',
62
+ num_episodes = int(10e6),
63
+ max_timesteps = 500,
64
+ buffer_size = 5_000,
65
+ render_every_eps = 1_000,
66
+ video_folder = './recordings',
66
67
  seed: int | None = None,
67
68
  transformer_weights_path: str | None = None,
68
69
  meta_controller_weights_path: str | None = None,
69
- output_meta_controller_path: str = 'metacontroller_rl_trained.pt',
70
- use_resnet: bool = False,
71
- lr: float = 1e-4,
72
- num_groups: int = 16,
73
- max_grad_norm: float = 1.0,
74
- use_wandb: bool = False,
75
- wandb_project: str = 'metacontroller-babyai-rl'
70
+ output_meta_controller_path = 'metacontroller_rl_trained.pt',
71
+ use_resnet = False,
72
+ lr = 1e-4,
73
+ batch_size = 16,
74
+ num_groups = 16,
75
+ max_grad_norm = 1.0,
76
+ use_wandb = False,
77
+ wandb_project = 'metacontroller-babyai-rl'
76
78
  ):
77
79
  # accelerator
78
80
 
@@ -263,7 +265,7 @@ def main(
263
265
  # learn
264
266
 
265
267
  if len(replay_buffer) >= buffer_size:
266
- dl = replay_buffer.dataloader(batch_size = num_groups)
268
+ dl = replay_buffer.dataloader(batch_size = batch_size)
267
269
  dl = accelerator.prepare(dl)
268
270
 
269
271
  meta_controller.train()
@@ -296,10 +298,11 @@ def main(
296
298
 
297
299
  accelerator.log({
298
300
  'loss': loss.item(),
299
- 'grad_norm': grad_norm.item()
301
+ 'grad_norm': grad_norm.item(),
302
+ 'reward': cumulative_rewards.mean().item()
300
303
  })
301
304
 
302
- accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
305
+ accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}, reward: {cumulative_rewards.mean().item():.4f}')
303
306
 
304
307
  env.close()
305
308
 
@@ -0,0 +1,213 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "fire",
4
+ # "gymnasium",
5
+ # "gymnasium[other]",
6
+ # "metacontroller-pytorch",
7
+ # "minigrid",
8
+ # "tqdm",
9
+ # "x-evolution",
10
+ # "einops"
11
+ # ]
12
+ # ///
13
+
14
+ from __future__ import annotations
15
+ import fire
16
+ from pathlib import Path
17
+ from shutil import rmtree
18
+ import numpy as np
19
+
20
+ import torch
21
+ from torch import nn, Tensor, tensor
22
+ from torch.nn import Module
23
+ from einops import rearrange
24
+
25
+ from babyai_env import create_env
26
+ from metacontroller.metacontroller import Transformer, MetaController
27
+
28
+ # functions
29
+
30
+ def exists(v):
31
+ return v is not None
32
+
33
+ def default(v, d):
34
+ return v if exists(v) else d
35
+
36
+ # default fitness function
37
+
38
+ def default_fitness_fn(
39
+ rewards: list[float],
40
+ states: list[any],
41
+ actions: list[any],
42
+ next_states: list[any],
43
+ infos: list[any]
44
+ ) -> float:
45
+ """
46
+ researchers can modify this function to engineer their own rewards and fitness scores
47
+ processing the entire episode at once for every noise vector of the population separately
48
+ """
49
+ return sum(rewards)
50
+
51
+ # babyai environment for ES
52
+
53
+ class BabyAIEnvironment(Module):
54
+ def __init__(
55
+ self,
56
+ env_id = 'BabyAI-BossLevel-v0',
57
+ video_folder = './recordings_babyai_es',
58
+ render_every_eps = 100,
59
+ max_steps = 500,
60
+ use_resnet = False,
61
+ fitness_fn = default_fitness_fn
62
+ ):
63
+ super().__init__()
64
+
65
+ self.env_id = env_id
66
+ self.video_folder = video_folder
67
+ self.render_every_eps = render_every_eps
68
+ self.max_steps = max_steps
69
+ self.use_resnet = use_resnet
70
+ self.fitness_fn = fitness_fn
71
+
72
+ # initial env creation for observation space etc. if needed
73
+ # but create_env is called inside pre_main_callback or reset
74
+ self.env = None
75
+
76
+ def pre_main_callback(self):
77
+ # clean up and initialize environment
78
+ rmtree(self.video_folder, ignore_errors = True)
79
+
80
+ self.env = create_env(
81
+ self.env_id,
82
+ render_mode = 'rgb_array',
83
+ video_folder = self.video_folder,
84
+ render_every_eps = self.render_every_eps
85
+ )
86
+
87
+ def forward(self, model):
88
+ device = next(model.parameters()).device
89
+
90
+ seed = torch.randint(0, int(1e6), ()).item()
91
+ state, _ = self.env.reset(seed = seed)
92
+
93
+ step = 0
94
+ cache = None
95
+ past_action_id = None
96
+
97
+ unwrapped_model = getattr(model, 'model', model)
98
+
99
+ episode_rewards = []
100
+ episode_states = []
101
+ episode_actions = []
102
+ episode_next_states = []
103
+ episode_infos = []
104
+
105
+ while step < self.max_steps:
106
+ image = state['image']
107
+ image_tensor = torch.from_numpy(image).float().to(device)
108
+
109
+ if self.use_resnet:
110
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
111
+ image_tensor = unwrapped_model.visual_encode(image_tensor)
112
+ else:
113
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
114
+
115
+ if torch.is_tensor(past_action_id):
116
+ past_action_id = past_action_id.long()
117
+
118
+ with torch.no_grad():
119
+ logits, cache = model(
120
+ image_tensor,
121
+ past_action_id,
122
+ return_cache = True,
123
+ return_raw_action_dist = True,
124
+ cache = cache
125
+ )
126
+
127
+ action = unwrapped_model.action_readout.sample(logits)
128
+ past_action_id = action
129
+ action_id = action.squeeze()
130
+
131
+ next_state, reward, terminated, truncated, info = self.env.step(action_id.cpu().numpy().item())
132
+
133
+ episode_rewards.append(reward)
134
+ episode_states.append(state)
135
+ episode_actions.append(action_id)
136
+ episode_next_states.append(next_state)
137
+ episode_infos.append(info)
138
+
139
+ done = terminated or truncated
140
+ if done:
141
+ break
142
+
143
+ state = next_state
144
+ step += 1
145
+
146
+ return self.fitness_fn(
147
+ episode_rewards,
148
+ episode_states,
149
+ episode_actions,
150
+ episode_next_states,
151
+ episode_infos
152
+ )
153
+
154
+ def main(
155
+ env_id = 'BabyAI-BossLevel-v0',
156
+ num_generations = 100,
157
+ max_steps = 500,
158
+ render_every_eps = 100,
159
+ video_folder = './recordings_babyai_es',
160
+ transformer_weights_path: str | None = None,
161
+ meta_controller_weights_path: str | None = None,
162
+ output_meta_controller_path = 'metacontroller_es_trained.pt',
163
+ use_resnet = False,
164
+ noise_population_size = 50,
165
+ noise_scale = 1e-2,
166
+ learning_rate = 1e-3,
167
+ fitness_fn = default_fitness_fn
168
+ ):
169
+ # load model
170
+
171
+ assert exists(transformer_weights_path), "Transformer weights must be provided"
172
+
173
+ # lazy import to avoid unnecessary dependencies if not used
174
+ from metacontroller.transformer_with_resnet import TransformerWithResnet as TransformerResnet
175
+ transformer_klass = TransformerResnet if use_resnet else Transformer
176
+
177
+ model = transformer_klass.init_and_load(transformer_weights_path, strict = False)
178
+ model.eval()
179
+
180
+ if exists(meta_controller_weights_path):
181
+ meta_controller = MetaController.init_and_load(meta_controller_weights_path, strict = False)
182
+ model.meta_controller = meta_controller
183
+
184
+ assert exists(model.meta_controller), "MetaController must be present for evolution"
185
+
186
+ # setup environment
187
+
188
+ babyai_env = BabyAIEnvironment(
189
+ env_id = env_id,
190
+ video_folder = video_folder,
191
+ render_every_eps = render_every_eps,
192
+ max_steps = max_steps,
193
+ use_resnet = use_resnet,
194
+ fitness_fn = fitness_fn
195
+ )
196
+
197
+ # evolve
198
+
199
+ model.evolve(
200
+ num_generations = num_generations,
201
+ environment = babyai_env,
202
+ noise_population_size = noise_population_size,
203
+ noise_scale = noise_scale,
204
+ learning_rate = learning_rate
205
+ )
206
+
207
+ # save
208
+
209
+ model.meta_controller.save(output_meta_controller_path)
210
+ print(f'MetaController weights saved to {output_meta_controller_path}')
211
+
212
+ if __name__ == '__main__':
213
+ fire.Fire(main)
@@ -18,7 +18,7 @@ from tqdm import tqdm
18
18
  from pathlib import Path
19
19
 
20
20
  import torch
21
- from torch.optim import Adam
21
+ from torch.optim import AdamW
22
22
  from torch.utils.data import DataLoader
23
23
 
24
24
  from accelerate import Accelerator
@@ -31,14 +31,20 @@ from metacontroller.transformer_with_resnet import TransformerWithResnet
31
31
  import minigrid
32
32
  import gymnasium as gym
33
33
 
34
+ # TODO: loss is still ~300 and it could be the resnet output?
35
+ # TODO: changelog (paper hparams, checkpointing, difficulty levels in trajectory collection)
36
+
34
37
  def train(
35
38
  input_dir = "babyai-minibosslevel-trajectories",
36
39
  env_id = "BabyAI-MiniBossLevel-v0",
37
40
  cloning_epochs = 10,
38
41
  discovery_epochs = 10,
39
- batch_size = 32,
42
+ batch_size = 128,
43
+ gradient_accumulation_steps = None,
40
44
  lr = 1e-4,
41
45
  discovery_lr = 1e-4,
46
+ weight_decay = 0.03,
47
+ discovery_weight_decay = 0.03,
42
48
  dim = 512,
43
49
  depth = 2,
44
50
  heads = 8,
@@ -47,6 +53,7 @@ def train(
47
53
  wandb_project = "metacontroller-babyai-bc",
48
54
  checkpoint_path = "transformer_bc.pt",
49
55
  meta_controller_checkpoint_path = "meta_controller_discovery.pt",
56
+ save_steps = 50,
50
57
  state_loss_weight = 1.,
51
58
  action_loss_weight = 1.,
52
59
  discovery_action_recon_loss_weight = 1.,
@@ -55,6 +62,22 @@ def train(
55
62
  max_grad_norm = 1.,
56
63
  use_resnet = False
57
64
  ):
65
+
66
+ def store_checkpoint(step:int):
67
+ if accelerator.is_main_process:
68
+
69
+ # Add step to checkpoint filenames
70
+ checkpoint_path_with_step = checkpoint_path.replace('.pt', f'_step_{step}.pt')
71
+ meta_controller_checkpoint_path_with_step = meta_controller_checkpoint_path.replace('.pt', f'_step_{step}.pt')
72
+
73
+ unwrapped_model = accelerator.unwrap_model(model)
74
+ unwrapped_model.save(checkpoint_path_with_step)
75
+
76
+ unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
77
+ unwrapped_meta_controller.save(meta_controller_checkpoint_path_with_step)
78
+
79
+ accelerator.print(f"Model saved to {checkpoint_path_with_step}, MetaController to {meta_controller_checkpoint_path_with_step}")
80
+
58
81
  # accelerator
59
82
 
60
83
  accelerator = Accelerator(log_with = "wandb" if use_wandb else None)
@@ -99,6 +122,10 @@ def train(
99
122
 
100
123
  accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
101
124
 
125
+ # meta controller
126
+
127
+ meta_controller = MetaController(dim)
128
+
102
129
  # transformer
103
130
 
104
131
  transformer_class = TransformerWithResnet if use_resnet else Transformer
@@ -108,18 +135,15 @@ def train(
108
135
  state_embed_readout = dict(num_continuous = state_dim),
109
136
  action_embed_readout = dict(num_discrete = num_actions),
110
137
  lower_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
111
- upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
138
+ upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
139
+ meta_controller = meta_controller
112
140
  )
113
141
 
114
- # meta controller
115
-
116
- meta_controller = MetaController(dim)
117
-
118
142
  # optimizer
119
143
 
120
- optim_model = Adam(model.parameters(), lr = lr)
144
+ optim_model = AdamW(model.parameters(), lr = lr, weight_decay = weight_decay)
121
145
 
122
- optim_meta_controller = Adam(meta_controller.discovery_parameters(), lr = discovery_lr)
146
+ optim_meta_controller = AdamW(meta_controller.discovery_parameters(), lr = discovery_lr, weight_decay = discovery_weight_decay)
123
147
 
124
148
  # prepare
125
149
 
@@ -127,6 +151,7 @@ def train(
127
151
 
128
152
  # training
129
153
 
154
+ gradient_step = 0
130
155
  for epoch in range(cloning_epochs + discovery_epochs):
131
156
 
132
157
  model.train()
@@ -154,13 +179,14 @@ def train(
154
179
  else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
155
180
  states = rearrange(states, 'b t ... -> b t (...)')
156
181
 
182
+
157
183
  with accelerator.accumulate(model):
158
184
  losses = model(
159
185
  states,
160
186
  actions,
161
187
  episode_lens = episode_lens,
162
188
  discovery_phase = is_discovering,
163
- meta_controller = meta_controller if is_discovering else None
189
+ force_behavior_cloning = not is_discovering
164
190
  )
165
191
 
166
192
  if is_discovering:
@@ -190,14 +216,19 @@ def train(
190
216
  action_loss = action_loss.item(),
191
217
  )
192
218
 
219
+ # gradient accumulation
220
+
221
+ if gradient_accumulation_steps is not None: loss /= gradient_accumulation_steps
222
+
193
223
  # backprop
194
224
 
195
225
  accelerator.backward(loss)
196
226
 
197
227
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
198
228
 
199
- optim.step()
200
- optim.zero_grad()
229
+ if gradient_accumulation_steps is None or gradient_step % gradient_accumulation_steps == 0:
230
+ optim.step()
231
+ optim.zero_grad()
201
232
 
202
233
  # log
203
234
 
@@ -211,23 +242,21 @@ def train(
211
242
  })
212
243
 
213
244
  progress_bar.set_postfix(**log)
245
+ gradient_step += 1
246
+
247
+ # checkpoint
248
+
249
+ if gradient_step % save_steps == 0:
250
+ accelerator.wait_for_everyone()
251
+ store_checkpoint(gradient_step)
214
252
 
215
253
  avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
216
254
  avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
217
255
  accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
218
256
 
219
257
  # save weights
220
-
221
258
  accelerator.wait_for_everyone()
222
- if accelerator.is_main_process:
223
-
224
- unwrapped_model = accelerator.unwrap_model(model)
225
- unwrapped_model.save(checkpoint_path)
226
-
227
- unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
228
- unwrapped_meta_controller.save(meta_controller_checkpoint_path)
229
-
230
- accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
259
+ store_checkpoint(gradient_step)
231
260
 
232
261
  accelerator.end_training()
233
262
 
@@ -1,35 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 1. Gather trajectories
5
- echo "Gathering trajectories..."
6
- uv run gather_babyai_trajs.py \
7
- --num_seeds 100 \
8
- --num_episodes_per_seed 10 \
9
- --num_steps 500 \
10
- --output_dir end_to_end_trajectories \
11
- --env_id BabyAI-MiniBossLevel-v0
12
-
13
- # 2. Behavioral cloning
14
- echo "Training behavioral cloning model..."
15
- ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
16
- --cloning_epochs 10 \
17
- --discovery_epochs 10 \
18
- --batch_size 256 \
19
- --input_dir end_to_end_trajectories \
20
- --env_id BabyAI-MiniBossLevel-v0 \
21
- --checkpoint_path end_to_end_model.pt \
22
- --use_resnet
23
-
24
- # 3. Inference rollouts
25
- echo "Running inference rollouts..."
26
- uv run train_babyai.py \
27
- --transformer_weights_path end_to_end_model.pt \
28
- --meta_controller_weights_path meta_controller_discovery.pt \
29
- --env_name BabyAI-MiniBossLevel-v0 \
30
- --num_episodes 1000 \
31
- --buffer_size 1000 \
32
- --max_timesteps 100 \
33
- --num_groups 16 \
34
- --lr 1e-4 \
35
- --use_resnet