metacontroller-pytorch 0.0.43__tar.gz → 0.0.46__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.gitignore +6 -0
  2. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/PKG-INFO +1 -1
  3. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/gather_babyai_trajs.py +81 -9
  4. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller.py +39 -4
  5. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/pyproject.toml +1 -1
  6. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/train_babyai.py +71 -43
  7. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/train_behavior_clone_babyai.py +57 -24
  8. metacontroller_pytorch-0.0.43/test_babyai_e2e.sh +0 -35
  9. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.github/workflows/python-publish.yml +0 -0
  10. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.github/workflows/test.yml +0 -0
  11. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/LICENSE +0 -0
  12. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/README.md +0 -0
  13. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/babyai_env.py +0 -0
  14. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/fig1.png +0 -0
  15. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/__init__.py +0 -0
  16. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  17. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/transformer_with_resnet.py +0 -0
  18. {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/tests/test_metacontroller.py +0 -0
  19. /metacontroller_pytorch-0.0.43/train_baby_evo_strat.py → /metacontroller_pytorch-0.0.46/train_babyai_evo_strat.py +0 -0
@@ -1,5 +1,11 @@
1
1
  replay-data/
2
2
  recordings/
3
+ trajectories/
4
+ wandb/
5
+ checkpoints/
6
+ *.sh
7
+ *.out
8
+ *.slurm
3
9
 
4
10
  # Byte-compiled / optimized / DLL files
5
11
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.43
3
+ Version: 0.0.46
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -37,6 +37,9 @@ from minigrid.core.constants import OBJECT_TO_IDX
37
37
 
38
38
  from memmap_replay_buffer import ReplayBuffer
39
39
 
40
+ # Difficulty thresholds based on mission length
41
+ EASY_MAX_LENGTH = 30 # easy: 0 to 30
42
+ MEDIUM_MAX_LENGTH = 75 # medium: 30 to 75, hard: > 75
40
43
 
41
44
  # helpers
42
45
 
@@ -46,6 +49,67 @@ def exists(val):
46
49
  def sample(prob):
47
50
  return random.random() < prob
48
51
 
52
+ def get_mission_length(env_id, seed):
53
+ """
54
+ Get the mission length for a given seed.
55
+ Returns the length of the mission string.
56
+ """
57
+ env = gym.make(env_id, render_mode="rgb_array")
58
+ env.reset(seed=seed)
59
+ length = len(env.unwrapped.mission)
60
+ env.close()
61
+ return length
62
+
63
+ def categorize_seeds_by_difficulty(env_id, num_seeds_per_level, level_difficulty=None):
64
+ """
65
+ Scan seeds and categorize them by difficulty based on mission length.
66
+
67
+ Args:
68
+ env_id: Environment ID
69
+ num_seeds_per_level: Number of seeds needed per difficulty level
70
+ level_difficulty: List of levels to collect seeds for.
71
+ Supported: 'easy', 'medium', 'hard'
72
+ If None, collects for ['easy', 'hard'].
73
+ max_seed_to_scan: Maximum seed value to scan
74
+
75
+ Returns:
76
+ dict with keys for each requested level, each containing a list of seeds
77
+ """
78
+
79
+ seeds = {level: [] for level in level_difficulty}
80
+
81
+ total_needed = sum(num_seeds_per_level for _ in level_difficulty)
82
+ print(f"Scanning seeds to categorize by difficulty (need {num_seeds_per_level} per level for {level_difficulty})...")
83
+
84
+ with tqdm(total=total_needed, desc="Categorizing seeds") as pbar:
85
+ seed = 1
86
+ all_done = False
87
+ while not all_done:
88
+ # Check if we have enough seeds for all requested levels
89
+ all_done = all(len(seeds[level]) >= num_seeds_per_level for level in level_difficulty)
90
+
91
+ try:
92
+ mission_length = get_mission_length(env_id, seed)
93
+
94
+ # easy: mission length <= 30
95
+ if 'easy' in level_difficulty and mission_length <= EASY_MAX_LENGTH and len(seeds['easy']) < num_seeds_per_level:
96
+ seeds['easy'].append(seed)
97
+ pbar.update(1)
98
+ # medium: mission length <= 75 (combines easy and medium)
99
+ elif 'medium' in level_difficulty and mission_length <= MEDIUM_MAX_LENGTH and len(seeds['medium']) < num_seeds_per_level:
100
+ seeds['medium'].append(seed)
101
+ pbar.update(1)
102
+ # hard: mission length > 75
103
+ elif 'hard' in level_difficulty and mission_length > MEDIUM_MAX_LENGTH and len(seeds['hard']) < num_seeds_per_level:
104
+ seeds['hard'].append(seed)
105
+ pbar.update(1)
106
+ except Exception as e:
107
+ logger.warning(f"Error getting mission length for seed {seed}: {e}")
108
+
109
+ seed += 1
110
+
111
+ return seeds
112
+
49
113
  # wrapper, necessarily modified to allow for both rgb obs (policy) and symbolic obs (bot)
50
114
 
51
115
  class RGBImgPartialObsWrapper(ObservationWrapper):
@@ -128,7 +192,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
128
192
  env.close()
129
193
  return None, None, False, 0
130
194
 
131
- episode_state[_step] = state_obs["rgb_image"] / 255. # normalizd to 0 to 1
195
+ episode_state[_step] = state_obs["rgb_image"]
132
196
  episode_action[_step] = action
133
197
 
134
198
  state_obs, reward, terminated, truncated, info = env.step(action)
@@ -151,6 +215,7 @@ def collect_demonstrations(
151
215
  num_steps = 500,
152
216
  random_action_prob = 0.05,
153
217
  num_workers = None,
218
+ difficulty = "easy",
154
219
  output_dir = "babyai-minibosslevel-trajectories"
155
220
  ):
156
221
  """
@@ -178,11 +243,9 @@ def collect_demonstrations(
178
243
 
179
244
  total_episodes = num_seeds * num_episodes_per_seed
180
245
 
181
- # Prepare seeds for all episodes
182
- seeds = []
183
- for count in range(num_seeds):
184
- for it in range(num_episodes_per_seed):
185
- seeds.append(count + 1)
246
+ # Collect seeds by difficulty
247
+ assert difficulty in ['easy', 'medium', 'hard']
248
+ seeds = categorize_seeds_by_difficulty(env_id, num_seeds_per_level=num_seeds, level_difficulty=[difficulty])
186
249
 
187
250
  successful = 0
188
251
  progressbar = tqdm(total=total_episodes)
@@ -203,14 +266,17 @@ def collect_demonstrations(
203
266
  )
204
267
 
205
268
  # Parallel execution with bounded pending futures to avoid OOM
206
- max_pending = num_workers * 4
269
+ max_pending = num_workers
270
+
271
+ # Flatten seeds: repeat each seed num_episodes_per_seed times
272
+ all_seeds = seeds[difficulty] * num_episodes_per_seed
207
273
 
208
274
  with ProcessPoolExecutor(max_workers=num_workers) as executor:
209
- seed_iter = iter(seeds)
275
+ seed_iter = iter(all_seeds)
210
276
  futures = {}
211
277
 
212
278
  # Initial batch of submissions
213
- for _ in range(min(max_pending, len(seeds))):
279
+ for _ in range(min(max_pending, len(all_seeds))):
214
280
  seed = next(seed_iter, None)
215
281
  if exists(seed):
216
282
  future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
@@ -244,7 +310,13 @@ def collect_demonstrations(
244
310
  buffer.flush()
245
311
  progressbar.close()
246
312
 
313
+ # Save the seeds used for reproducibility
314
+ seeds_array = np.array(seeds[difficulty])
315
+ seeds_path = output_folder / "seeds.npy"
316
+ np.save(seeds_path, seeds_array)
317
+
247
318
  logger.info(f"Saved {successful} trajectories to {output_dir}")
319
+ logger.info(f"Saved {len(seeds_array)} seeds to {seeds_path}")
248
320
 
249
321
  if __name__ == "__main__":
250
322
  fire.Fire(collect_demonstrations)
@@ -54,6 +54,19 @@ def default(*args):
54
54
  def straight_through(src, tgt):
55
55
  return tgt + src - src.detach()
56
56
 
57
+ # losses
58
+
59
+ BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
60
+ 'state',
61
+ 'action'
62
+ ))
63
+
64
+ DiscoveryLosses = namedtuple('DiscoveryLosses', (
65
+ 'action_recon',
66
+ 'kl',
67
+ 'switch'
68
+ ))
69
+
57
70
  # meta controller
58
71
 
59
72
  MetaControllerOutput = namedtuple('MetaControllerOutput', (
@@ -407,10 +420,19 @@ class Transformer(Module):
407
420
 
408
421
  # meta controller
409
422
 
410
- self.meta_controller = meta_controller
423
+ self.meta_controller = meta_controller
411
424
 
412
425
  self.register_buffer('zero', tensor(0.), persistent = False)
413
426
 
427
+ # ensure devices match
428
+
429
+ if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
430
+
431
+ def _ensure_consistent_device(self, network):
432
+ self.model_device = next(self.parameters()).device
433
+ if next(network.parameters()).device != self.model_device:
434
+ network.to(self.model_device)
435
+
414
436
  def evolve(
415
437
  self,
416
438
  num_generations,
@@ -441,12 +463,15 @@ class Transformer(Module):
441
463
  return_raw_action_dist = False,
442
464
  return_latents = False,
443
465
  return_cache = False,
444
- episode_lens: Tensor | None = None
466
+ episode_lens: Tensor | None = None,
467
+ return_meta_controller_output = False
445
468
  ):
446
469
  device = state.device
447
470
 
448
471
  # meta controller is either given or already given at init
449
472
 
473
+ if exists(meta_controller): self._ensure_consistent_device(meta_controller)
474
+
450
475
  meta_controller = default(meta_controller, self.meta_controller)
451
476
 
452
477
  if force_behavior_cloning:
@@ -533,13 +558,23 @@ class Transformer(Module):
533
558
 
534
559
  action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
535
560
 
536
- return state_clone_loss, action_clone_loss
561
+ losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
562
+
563
+ if not return_meta_controller_output:
564
+ return losses
565
+
566
+ return losses, next_meta_hiddens
537
567
 
538
568
  elif discovery_phase:
539
569
 
540
570
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
541
571
 
542
- return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
572
+ losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
573
+
574
+ if not return_meta_controller_output:
575
+ return losses
576
+
577
+ return losses, next_meta_hiddens
543
578
 
544
579
  # returning
545
580
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.43"
3
+ version = "0.0.46"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -6,7 +6,8 @@
6
6
  # "memmap-replay-buffer>=0.0.12",
7
7
  # "metacontroller-pytorch",
8
8
  # "minigrid",
9
- # "tqdm"
9
+ # "tqdm",
10
+ # "wandb"
10
11
  # ]
11
12
  # ///
12
13
 
@@ -17,35 +18,46 @@ from shutil import rmtree
17
18
  from tqdm import tqdm
18
19
 
19
20
  import torch
20
- from torch import cat, tensor, stack
21
+ from torch import cat, tensor, stack, Tensor
21
22
  from torch.optim import Adam
22
23
 
23
24
  from einops import rearrange
24
25
 
26
+ from torch_einops_utils import pad_sequence
27
+
25
28
  from accelerate import Accelerator
26
29
 
27
30
  from babyai_env import create_env
31
+
28
32
  from memmap_replay_buffer import ReplayBuffer
33
+
29
34
  from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
30
35
  from metacontroller.transformer_with_resnet import TransformerWithResnet
31
36
 
32
37
  # research entry point
33
38
 
34
39
  def reward_shaping_fn(
35
- cumulative_rewards: torch.Tensor,
36
- all_rewards: torch.Tensor,
37
- episode_lens: torch.Tensor
38
- ) -> torch.Tensor | None:
40
+ cumulative_rewards: Tensor, # float(num_episodes,)
41
+ all_rewards: Tensor, # float(num_episodes, max_timesteps)
42
+ episode_lens: Tensor, # int(num_episodes,)
43
+ reject_threshold_cumulative_reward_variance: float = 0.
44
+ ) -> Tensor | None:
39
45
  """
40
46
  researchers can modify this function to engineer rewards
41
47
  or return None to reject the entire batch
42
-
43
- cumulative_rewards: (num_episodes,)
44
- all_rewards: (num_episodes, max_timesteps)
45
- episode_lens: (num_episodes,)
46
48
  """
49
+
50
+ if cumulative_rewards.var() < reject_threshold_cumulative_reward_variance:
51
+ return None
52
+
47
53
  return cumulative_rewards
48
54
 
55
+ def should_reject_group_based_on_switch_betas(
56
+ switch_betas: Tensor,
57
+ episode_lens: Tensor
58
+ ):
59
+ return switch_betas.sum().item() == 0.
60
+
49
61
  # helpers
50
62
 
51
63
  def exists(v):
@@ -68,12 +80,14 @@ def main(
68
80
  meta_controller_weights_path: str | None = None,
69
81
  output_meta_controller_path = 'metacontroller_rl_trained.pt',
70
82
  use_resnet = False,
83
+ num_epochs = 3,
71
84
  lr = 1e-4,
72
85
  batch_size = 16,
73
86
  num_groups = 16,
74
87
  max_grad_norm = 1.0,
75
88
  use_wandb = False,
76
- wandb_project = 'metacontroller-babyai-rl'
89
+ wandb_project = 'metacontroller-babyai-rl',
90
+ reject_threshold_cumulative_reward_variance = 0.
77
91
  ):
78
92
  # accelerator
79
93
 
@@ -233,25 +247,37 @@ def main(
233
247
  cumulative_rewards = stack(all_cumulative_rewards)
234
248
  episode_lens = tensor(all_episode_lens)
235
249
 
236
- # pad step rewards
237
-
238
250
  max_len = max(all_episode_lens)
239
- padded_step_rewards = torch.zeros(num_episodes, max_len)
240
251
 
241
- for i, (rewards, length) in enumerate(zip(all_step_rewards, all_episode_lens)):
242
- padded_step_rewards[i, :length] = rewards
252
+ # pad step rewards
253
+
254
+ padded_step_rewards = pad_sequence(all_step_rewards, dim = 0)
243
255
 
244
256
  # reward shaping hook
245
257
 
246
- shaped_rewards = reward_shaping_fn(cumulative_rewards, padded_step_rewards, episode_lens)
258
+ shaped_rewards = reward_shaping_fn(
259
+ cumulative_rewards,
260
+ padded_step_rewards,
261
+ episode_lens,
262
+ reject_threshold_cumulative_reward_variance = reject_threshold_cumulative_reward_variance
263
+ )
247
264
 
248
265
  if not exists(shaped_rewards):
266
+ accelerator.print(f'group rejected - variance of {cumulative_rewards.var().item():.4f} is lower than threshold of {reject_threshold_cumulative_reward_variance}')
249
267
  continue
250
268
 
251
269
  group_advantages = z_score(shaped_rewards)
252
270
 
253
271
  group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
254
272
 
273
+ # whether to reject group based on switch betas (as it determines the mask for learning)
274
+
275
+ padded_group_switch_betas, episode_lens = pad_sequence(group_switch_betas, dim = 0, return_lens = True)
276
+
277
+ if should_reject_group_based_on_switch_betas(padded_group_switch_betas, episode_lens):
278
+ accelerator.print(f'group rejected - switch betas for the entire group does not meet criteria for learning')
279
+ continue
280
+
255
281
  for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
256
282
  replay_buffer.store_episode(
257
283
  states = states,
@@ -264,43 +290,45 @@ def main(
264
290
  # learn
265
291
 
266
292
  if len(replay_buffer) >= buffer_size:
267
- dl = replay_buffer.dataloader(batch_size = batch_size)
293
+ dl = replay_buffer.dataloader(batch_size = batch_size, shuffle = True)
268
294
  dl = accelerator.prepare(dl)
269
295
 
270
296
  meta_controller.train()
271
297
 
272
- batch = next(iter(dl))
273
-
274
- loss = meta_controller.policy_loss(
275
- batch['states'],
276
- batch['log_probs'],
277
- batch['latent_actions'],
278
- batch['advantages'],
279
- batch['switch_betas'] == 1.,
280
- episode_lens = batch['_lens']
281
- )
298
+ for epoch in range(num_epochs):
299
+ for batch in dl:
300
+ loss = meta_controller.policy_loss(
301
+ batch['states'],
302
+ batch['log_probs'],
303
+ batch['latent_actions'],
304
+ batch['advantages'],
305
+ batch['switch_betas'] == 1.,
306
+ episode_lens = batch['_lens']
307
+ )
282
308
 
283
- accelerator.backward(loss)
309
+ accelerator.backward(loss)
284
310
 
285
- grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
311
+ grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
286
312
 
287
- optim.step()
288
- optim.zero_grad()
313
+ optim.step()
314
+ optim.zero_grad()
289
315
 
290
- meta_controller.eval()
316
+ pbar.set_postfix(
317
+ epoch = epoch,
318
+ loss = f'{loss.item():.4f}',
319
+ grad_norm = f'{grad_norm.item():.4f}',
320
+ reward = f'{cumulative_rewards.mean().item():.4f}'
321
+ )
291
322
 
292
- pbar.set_postfix(
293
- loss = f'{loss.item():.4f}',
294
- grad_norm = f'{grad_norm.item():.4f}',
295
- reward = f'{cumulative_rewards.mean().item():.4f}'
296
- )
323
+ accelerator.log({
324
+ 'loss': loss.item(),
325
+ 'grad_norm': grad_norm.item(),
326
+ 'reward': cumulative_rewards.mean().item()
327
+ })
297
328
 
298
- accelerator.log({
299
- 'loss': loss.item(),
300
- 'grad_norm': grad_norm.item()
301
- })
329
+ meta_controller.eval()
302
330
 
303
- accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
331
+ accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}, reward: {cumulative_rewards.mean().item():.4f}')
304
332
 
305
333
  env.close()
306
334
 
@@ -18,7 +18,7 @@ from tqdm import tqdm
18
18
  from pathlib import Path
19
19
 
20
20
  import torch
21
- from torch.optim import Adam
21
+ from torch.optim import AdamW
22
22
  from torch.utils.data import DataLoader
23
23
 
24
24
  from accelerate import Accelerator
@@ -31,14 +31,20 @@ from metacontroller.transformer_with_resnet import TransformerWithResnet
31
31
  import minigrid
32
32
  import gymnasium as gym
33
33
 
34
+ # TODO: loss is still ~300 and it could be the resnet output?
35
+ # TODO: changelog (paper hparams, checkpointing, difficulty levels in trajectory collection)
36
+
34
37
  def train(
35
38
  input_dir = "babyai-minibosslevel-trajectories",
36
39
  env_id = "BabyAI-MiniBossLevel-v0",
37
40
  cloning_epochs = 10,
38
41
  discovery_epochs = 10,
39
- batch_size = 32,
42
+ batch_size = 128,
43
+ gradient_accumulation_steps = None,
40
44
  lr = 1e-4,
41
45
  discovery_lr = 1e-4,
46
+ weight_decay = 0.03,
47
+ discovery_weight_decay = 0.03,
42
48
  dim = 512,
43
49
  depth = 2,
44
50
  heads = 8,
@@ -47,6 +53,7 @@ def train(
47
53
  wandb_project = "metacontroller-babyai-bc",
48
54
  checkpoint_path = "transformer_bc.pt",
49
55
  meta_controller_checkpoint_path = "meta_controller_discovery.pt",
56
+ save_steps = 50,
50
57
  state_loss_weight = 1.,
51
58
  action_loss_weight = 1.,
52
59
  discovery_action_recon_loss_weight = 1.,
@@ -55,6 +62,22 @@ def train(
55
62
  max_grad_norm = 1.,
56
63
  use_resnet = False
57
64
  ):
65
+
66
+ def store_checkpoint(step:int):
67
+ if accelerator.is_main_process:
68
+
69
+ # Add step to checkpoint filenames
70
+ checkpoint_path_with_step = checkpoint_path.replace('.pt', f'_step_{step}.pt')
71
+ meta_controller_checkpoint_path_with_step = meta_controller_checkpoint_path.replace('.pt', f'_step_{step}.pt')
72
+
73
+ unwrapped_model = accelerator.unwrap_model(model)
74
+ unwrapped_model.save(checkpoint_path_with_step)
75
+
76
+ unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
77
+ unwrapped_meta_controller.save(meta_controller_checkpoint_path_with_step)
78
+
79
+ accelerator.print(f"Model saved to {checkpoint_path_with_step}, MetaController to {meta_controller_checkpoint_path_with_step}")
80
+
58
81
  # accelerator
59
82
 
60
83
  accelerator = Accelerator(log_with = "wandb" if use_wandb else None)
@@ -87,11 +110,13 @@ def train(
87
110
 
88
111
  # state shape and action dimension
89
112
  # state: (B, T, H, W, C) or (B, T, D)
113
+
90
114
  state_shape = replay_buffer.shapes['state']
91
115
  if use_resnet: state_dim = 256
92
116
  else: state_dim = int(torch.tensor(state_shape).prod().item())
93
117
 
94
118
  # deduce num_actions from the environment
119
+
95
120
  from babyai_env import create_env
96
121
  temp_env = create_env(env_id)
97
122
  num_actions = int(temp_env.action_space.n)
@@ -99,6 +124,10 @@ def train(
99
124
 
100
125
  accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
101
126
 
127
+ # meta controller
128
+
129
+ meta_controller = MetaController(dim)
130
+
102
131
  # transformer
103
132
 
104
133
  transformer_class = TransformerWithResnet if use_resnet else Transformer
@@ -108,18 +137,15 @@ def train(
108
137
  state_embed_readout = dict(num_continuous = state_dim),
109
138
  action_embed_readout = dict(num_discrete = num_actions),
110
139
  lower_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
111
- upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
140
+ upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
141
+ meta_controller = meta_controller
112
142
  )
113
143
 
114
- # meta controller
115
-
116
- meta_controller = MetaController(dim)
117
-
118
144
  # optimizer
119
145
 
120
- optim_model = Adam(model.parameters(), lr = lr)
146
+ optim_model = AdamW(model.parameters(), lr = lr, weight_decay = weight_decay)
121
147
 
122
- optim_meta_controller = Adam(meta_controller.discovery_parameters(), lr = discovery_lr)
148
+ optim_meta_controller = AdamW(meta_controller.discovery_parameters(), lr = discovery_lr, weight_decay = discovery_weight_decay)
123
149
 
124
150
  # prepare
125
151
 
@@ -127,6 +153,7 @@ def train(
127
153
 
128
154
  # training
129
155
 
156
+ gradient_step = 0
130
157
  for epoch in range(cloning_epochs + discovery_epochs):
131
158
 
132
159
  model.train()
@@ -154,13 +181,15 @@ def train(
154
181
  else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
155
182
  states = rearrange(states, 'b t ... -> b t (...)')
156
183
 
184
+
157
185
  with accelerator.accumulate(model):
158
- losses = model(
186
+ losses, meta_controller_output = model(
159
187
  states,
160
188
  actions,
161
189
  episode_lens = episode_lens,
162
190
  discovery_phase = is_discovering,
163
- meta_controller = meta_controller if is_discovering else None
191
+ force_behavior_cloning = not is_discovering,
192
+ return_meta_controller_output = True
164
193
  )
165
194
 
166
195
  if is_discovering:
@@ -175,7 +204,8 @@ def train(
175
204
  log = dict(
176
205
  action_recon_loss = action_recon_loss.item(),
177
206
  kl_loss = kl_loss.item(),
178
- switch_loss = switch_loss.item()
207
+ switch_loss = switch_loss.item(),
208
+ switch_density = meta_controller_output.switch_beta.mean().item()
179
209
  )
180
210
  else:
181
211
  state_loss, action_loss = losses
@@ -190,14 +220,19 @@ def train(
190
220
  action_loss = action_loss.item(),
191
221
  )
192
222
 
223
+ # gradient accumulation
224
+
225
+ if gradient_accumulation_steps is not None: loss /= gradient_accumulation_steps
226
+
193
227
  # backprop
194
228
 
195
229
  accelerator.backward(loss)
196
230
 
197
231
  grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
198
232
 
199
- optim.step()
200
- optim.zero_grad()
233
+ if gradient_accumulation_steps is None or gradient_step % gradient_accumulation_steps == 0:
234
+ optim.step()
235
+ optim.zero_grad()
201
236
 
202
237
  # log
203
238
 
@@ -211,23 +246,21 @@ def train(
211
246
  })
212
247
 
213
248
  progress_bar.set_postfix(**log)
249
+ gradient_step += 1
250
+
251
+ # checkpoint
252
+
253
+ if gradient_step % save_steps == 0:
254
+ accelerator.wait_for_everyone()
255
+ store_checkpoint(gradient_step)
214
256
 
215
257
  avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
216
258
  avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
217
259
  accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
218
260
 
219
261
  # save weights
220
-
221
262
  accelerator.wait_for_everyone()
222
- if accelerator.is_main_process:
223
-
224
- unwrapped_model = accelerator.unwrap_model(model)
225
- unwrapped_model.save(checkpoint_path)
226
-
227
- unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
228
- unwrapped_meta_controller.save(meta_controller_checkpoint_path)
229
-
230
- accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
263
+ store_checkpoint(gradient_step)
231
264
 
232
265
  accelerator.end_training()
233
266
 
@@ -1,35 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 1. Gather trajectories
5
- echo "Gathering trajectories..."
6
- uv run gather_babyai_trajs.py \
7
- --num_seeds 100 \
8
- --num_episodes_per_seed 10 \
9
- --num_steps 500 \
10
- --output_dir end_to_end_trajectories \
11
- --env_id BabyAI-MiniBossLevel-v0
12
-
13
- # 2. Behavioral cloning
14
- echo "Training behavioral cloning model..."
15
- ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
16
- --cloning_epochs 10 \
17
- --discovery_epochs 10 \
18
- --batch_size 256 \
19
- --input_dir end_to_end_trajectories \
20
- --env_id BabyAI-MiniBossLevel-v0 \
21
- --checkpoint_path end_to_end_model.pt \
22
- --use_resnet
23
-
24
- # 3. Inference rollouts
25
- echo "Running inference rollouts..."
26
- uv run train_babyai.py \
27
- --transformer_weights_path end_to_end_model.pt \
28
- --meta_controller_weights_path meta_controller_discovery.pt \
29
- --env_name BabyAI-MiniBossLevel-v0 \
30
- --num_episodes 1000 \
31
- --buffer_size 1000 \
32
- --max_timesteps 100 \
33
- --num_groups 16 \
34
- --lr 1e-4 \
35
- --use_resnet