metacontroller-pytorch 0.0.38__tar.gz → 0.0.41__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/PKG-INFO +86 -1
  2. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/README.md +85 -0
  3. metacontroller_pytorch-0.0.41/babyai_env.py +41 -0
  4. metacontroller_pytorch-0.0.41/metacontroller/__init__.py +1 -0
  5. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/pyproject.toml +1 -1
  6. metacontroller_pytorch-0.0.41/test_babyai_e2e.sh +35 -0
  7. metacontroller_pytorch-0.0.41/train_babyai.py +314 -0
  8. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/train_behavior_clone_babyai.py +2 -2
  9. metacontroller_pytorch-0.0.38/metacontroller/__init__.py +0 -1
  10. metacontroller_pytorch-0.0.38/test_babyai_e2e.sh +0 -14
  11. metacontroller_pytorch-0.0.38/train_babyai.py +0 -140
  12. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.github/workflows/python-publish.yml +0 -0
  13. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.github/workflows/test.yml +0 -0
  14. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.gitignore +0 -0
  15. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/LICENSE +0 -0
  16. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/fig1.png +0 -0
  17. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/gather_babyai_trajs.py +0 -0
  18. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller.py +0 -0
  19. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  20. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/transformer_with_resnet.py +0 -0
  21. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/tests/test_metacontroller.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.38
3
+ Version: 0.0.41
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -69,6 +69,91 @@ $ pip install metacontroller-pytorch
69
69
 
70
70
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
71
 
72
+ ## Usage
73
+
74
+ ```python
75
+ import torch
76
+ from metacontroller import Transformer, MetaController
77
+
78
+ # 1. initialize model
79
+
80
+ model = Transformer(
81
+ dim = 512,
82
+ action_embed_readout = dict(num_discrete = 4),
83
+ state_embed_readout = dict(num_continuous = 384),
84
+ lower_body = dict(depth = 2),
85
+ upper_body = dict(depth = 2)
86
+ )
87
+
88
+ state = torch.randn(2, 128, 384)
89
+ actions = torch.randint(0, 4, (2, 128))
90
+
91
+ # 2. behavioral cloning (BC)
92
+
93
+ state_loss, action_loss = model(state, actions)
94
+ (state_loss + action_loss).backward()
95
+
96
+ # 3. discovery phase
97
+
98
+ meta_controller = MetaController(
99
+ dim_model = 512,
100
+ dim_meta_controller = 256,
101
+ dim_latent = 128
102
+ )
103
+
104
+ action_recon_loss, kl_loss, switch_loss = model(
105
+ state,
106
+ actions,
107
+ meta_controller = meta_controller,
108
+ discovery_phase = True
109
+ )
110
+
111
+ (action_recon_loss + kl_loss + switch_loss).backward()
112
+
113
+ # 4. internal rl phase (GRPO)
114
+
115
+ # ... collect trajectories ...
116
+
117
+ logits, cache = model(
118
+ one_state,
119
+ past_action_id,
120
+ meta_controller = meta_controller,
121
+ return_cache = True
122
+ )
123
+
124
+ meta_output = cache.prev_hiddens.meta_controller
125
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
126
+
127
+ # ... calculate advantages ...
128
+
129
+ loss = meta_controller.policy_loss(
130
+ group_states,
131
+ group_old_log_probs,
132
+ group_latent_actions,
133
+ group_advantages,
134
+ group_switch_betas
135
+ )
136
+
137
+ loss.backward()
138
+ ```
139
+
140
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
141
+
142
+ ```python
143
+ # 5. evolve (ES over GRPO)
144
+
145
+ model.meta_controller = meta_controller
146
+
147
+ def environment_callable(model):
148
+ # return a fitness score
149
+ return 1.0
150
+
151
+ model.evolve(
152
+ num_generations = 10,
153
+ environment = environment_callable
154
+ )
155
+ ```
156
+
72
157
  ## Citations
73
158
 
74
159
  ```bibtex
@@ -16,6 +16,91 @@ $ pip install metacontroller-pytorch
16
16
 
17
17
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
18
18
 
19
+ ## Usage
20
+
21
+ ```python
22
+ import torch
23
+ from metacontroller import Transformer, MetaController
24
+
25
+ # 1. initialize model
26
+
27
+ model = Transformer(
28
+ dim = 512,
29
+ action_embed_readout = dict(num_discrete = 4),
30
+ state_embed_readout = dict(num_continuous = 384),
31
+ lower_body = dict(depth = 2),
32
+ upper_body = dict(depth = 2)
33
+ )
34
+
35
+ state = torch.randn(2, 128, 384)
36
+ actions = torch.randint(0, 4, (2, 128))
37
+
38
+ # 2. behavioral cloning (BC)
39
+
40
+ state_loss, action_loss = model(state, actions)
41
+ (state_loss + action_loss).backward()
42
+
43
+ # 3. discovery phase
44
+
45
+ meta_controller = MetaController(
46
+ dim_model = 512,
47
+ dim_meta_controller = 256,
48
+ dim_latent = 128
49
+ )
50
+
51
+ action_recon_loss, kl_loss, switch_loss = model(
52
+ state,
53
+ actions,
54
+ meta_controller = meta_controller,
55
+ discovery_phase = True
56
+ )
57
+
58
+ (action_recon_loss + kl_loss + switch_loss).backward()
59
+
60
+ # 4. internal rl phase (GRPO)
61
+
62
+ # ... collect trajectories ...
63
+
64
+ logits, cache = model(
65
+ one_state,
66
+ past_action_id,
67
+ meta_controller = meta_controller,
68
+ return_cache = True
69
+ )
70
+
71
+ meta_output = cache.prev_hiddens.meta_controller
72
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
73
+
74
+ # ... calculate advantages ...
75
+
76
+ loss = meta_controller.policy_loss(
77
+ group_states,
78
+ group_old_log_probs,
79
+ group_latent_actions,
80
+ group_advantages,
81
+ group_switch_betas
82
+ )
83
+
84
+ loss.backward()
85
+ ```
86
+
87
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
88
+
89
+ ```python
90
+ # 5. evolve (ES over GRPO)
91
+
92
+ model.meta_controller = meta_controller
93
+
94
+ def environment_callable(model):
95
+ # return a fitness score
96
+ return 1.0
97
+
98
+ model.evolve(
99
+ num_generations = 10,
100
+ environment = environment_callable
101
+ )
102
+ ```
103
+
19
104
  ## Citations
20
105
 
21
106
  ```bibtex
@@ -0,0 +1,41 @@
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+
4
+ import gymnasium as gym
5
+ import minigrid
6
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
7
+
8
+ # functions
9
+
10
+ def divisible_by(num, den):
11
+ return (num % den) == 0
12
+
13
+ # env creation
14
+
15
+ def create_env(
16
+ env_id,
17
+ render_mode = 'rgb_array',
18
+ video_folder = None,
19
+ render_every_eps = 1000
20
+ ):
21
+ # register minigrid environments if needed
22
+ minigrid.register_minigrid_envs()
23
+
24
+ # environment
25
+ env = gym.make(env_id, render_mode = render_mode)
26
+ env = FullyObsWrapper(env)
27
+ env = SymbolicObsWrapper(env)
28
+
29
+ if video_folder is not None:
30
+ video_folder = Path(video_folder)
31
+ rmtree(video_folder, ignore_errors = True)
32
+
33
+ env = gym.wrappers.RecordVideo(
34
+ env = env,
35
+ video_folder = str(video_folder),
36
+ name_prefix = 'babyai',
37
+ episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
38
+ disable_logger = True
39
+ )
40
+
41
+ return env
@@ -0,0 +1 @@
1
+ from metacontroller.metacontroller import MetaController, Transformer
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.38"
3
+ version = "0.0.41"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,35 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # 1. Gather trajectories
5
+ echo "Gathering trajectories..."
6
+ uv run gather_babyai_trajs.py \
7
+ --num_seeds 100 \
8
+ --num_episodes_per_seed 10 \
9
+ --num_steps 500 \
10
+ --output_dir end_to_end_trajectories \
11
+ --env_id BabyAI-MiniBossLevel-v0
12
+
13
+ # 2. Behavioral cloning
14
+ echo "Training behavioral cloning model..."
15
+ ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
16
+ --cloning_epochs 10 \
17
+ --discovery_epochs 10 \
18
+ --batch_size 256 \
19
+ --input_dir end_to_end_trajectories \
20
+ --env_id BabyAI-MiniBossLevel-v0 \
21
+ --checkpoint_path end_to_end_model.pt \
22
+ --use_resnet
23
+
24
+ # 3. Inference rollouts
25
+ echo "Running inference rollouts..."
26
+ uv run train_babyai.py \
27
+ --transformer_weights_path end_to_end_model.pt \
28
+ --meta_controller_weights_path meta_controller_discovery.pt \
29
+ --env_name BabyAI-MiniBossLevel-v0 \
30
+ --num_episodes 1000 \
31
+ --buffer_size 1000 \
32
+ --max_timesteps 100 \
33
+ --num_groups 16 \
34
+ --lr 1e-4 \
35
+ --use_resnet
@@ -0,0 +1,314 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "fire",
4
+ # "gymnasium",
5
+ # "gymnasium[other]",
6
+ # "memmap-replay-buffer>=0.0.12",
7
+ # "metacontroller-pytorch",
8
+ # "minigrid",
9
+ # "tqdm"
10
+ # ]
11
+ # ///
12
+
13
+ from fire import Fire
14
+ from pathlib import Path
15
+ from functools import partial
16
+ from shutil import rmtree
17
+ from tqdm import tqdm
18
+
19
+ import torch
20
+ from torch import cat, tensor, stack
21
+ from torch.optim import Adam
22
+
23
+ from einops import rearrange
24
+
25
+ from accelerate import Accelerator
26
+
27
+ from babyai_env import create_env
28
+ from memmap_replay_buffer import ReplayBuffer
29
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
30
+ from metacontroller.transformer_with_resnet import TransformerWithResnet
31
+
32
+ # research entry point
33
+
34
+ def reward_shaping_fn(
35
+ cumulative_rewards: torch.Tensor,
36
+ all_rewards: torch.Tensor,
37
+ episode_lens: torch.Tensor
38
+ ) -> torch.Tensor | None:
39
+ """
40
+ researchers can modify this function to engineer rewards
41
+ or return None to reject the entire batch
42
+
43
+ cumulative_rewards: (num_episodes,)
44
+ all_rewards: (num_episodes, max_timesteps)
45
+ episode_lens: (num_episodes,)
46
+ """
47
+ return cumulative_rewards
48
+
49
+ # helpers
50
+
51
+ def exists(v):
52
+ return v is not None
53
+
54
+ def default(v, d):
55
+ return v if exists(v) else d
56
+
57
+ # main
58
+
59
+ def main(
60
+ env_name: str = 'BabyAI-BossLevel-v0',
61
+ num_episodes: int = int(10e6),
62
+ max_timesteps: int = 500,
63
+ buffer_size: int = 5_000,
64
+ render_every_eps: int = 1_000,
65
+ video_folder: str = './recordings',
66
+ seed: int | None = None,
67
+ transformer_weights_path: str | None = None,
68
+ meta_controller_weights_path: str | None = None,
69
+ output_meta_controller_path: str = 'metacontroller_rl_trained.pt',
70
+ use_resnet: bool = False,
71
+ lr: float = 1e-4,
72
+ num_groups: int = 16,
73
+ max_grad_norm: float = 1.0,
74
+ use_wandb: bool = False,
75
+ wandb_project: str = 'metacontroller-babyai-rl'
76
+ ):
77
+ # accelerator
78
+
79
+ accelerator = Accelerator(log_with = 'wandb' if use_wandb else None)
80
+
81
+ if use_wandb:
82
+ accelerator.init_trackers(wandb_project)
83
+
84
+ # environment
85
+
86
+ env = create_env(
87
+ env_name,
88
+ render_mode = 'rgb_array',
89
+ video_folder = video_folder,
90
+ render_every_eps = render_every_eps
91
+ )
92
+
93
+ # load models
94
+
95
+ model = None
96
+ if exists(transformer_weights_path):
97
+ weights_path = Path(transformer_weights_path)
98
+ assert weights_path.exists(), f"transformer weights not found at {weights_path}"
99
+
100
+ transformer_klass = TransformerWithResnet if use_resnet else Transformer
101
+ model = transformer_klass.init_and_load(str(weights_path), strict = False)
102
+ model.eval()
103
+
104
+ meta_controller = None
105
+ if exists(meta_controller_weights_path):
106
+ weights_path = Path(meta_controller_weights_path)
107
+ assert weights_path.exists(), f"meta controller weights not found at {weights_path}"
108
+ meta_controller = MetaController.init_and_load(str(weights_path), strict = False)
109
+ meta_controller.eval()
110
+
111
+ meta_controller = default(meta_controller, getattr(model, 'meta_controller', None))
112
+ assert exists(meta_controller), "MetaController must be present for reinforcement learning"
113
+
114
+ # optimizer
115
+
116
+ optim = Adam(meta_controller.internal_rl_parameters(), lr = lr)
117
+
118
+ # prepare
119
+
120
+ model, meta_controller, optim = accelerator.prepare(model, meta_controller, optim)
121
+
122
+ unwrapped_model = accelerator.unwrap_model(model)
123
+ unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
124
+
125
+ # replay buffer
126
+
127
+ replay_buffer = ReplayBuffer(
128
+ './replay-data',
129
+ max_episodes = buffer_size,
130
+ max_timesteps = max_timesteps + 1,
131
+ fields = meta_controller.replay_buffer_field_dict,
132
+ meta_fields = dict(advantages = 'float'),
133
+ overwrite = True,
134
+ circular = True
135
+ )
136
+
137
+ # rollouts
138
+
139
+ num_batch_updates = num_episodes // num_groups
140
+
141
+ pbar = tqdm(range(num_batch_updates), desc = 'training')
142
+
143
+ for _ in pbar:
144
+
145
+ all_episodes = []
146
+ all_cumulative_rewards = []
147
+ all_step_rewards = []
148
+ all_episode_lens = []
149
+
150
+ group_seed = default(seed, torch.randint(0, 1000000, (1,)).item())
151
+
152
+ for _ in range(num_groups):
153
+
154
+ state, *_ = env.reset(seed = group_seed)
155
+
156
+ cache = None
157
+ past_action_id = None
158
+
159
+ states = []
160
+ log_probs = []
161
+ switch_betas = []
162
+ latent_actions = []
163
+
164
+ total_reward = 0.
165
+ step_rewards = []
166
+ episode_len = max_timesteps
167
+
168
+ for step in range(max_timesteps):
169
+
170
+ image = state['image']
171
+ image_tensor = torch.from_numpy(image).float().to(accelerator.device)
172
+
173
+ if use_resnet:
174
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
175
+ image_tensor = model.visual_encode(image_tensor)
176
+ else:
177
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
178
+
179
+ if torch.is_tensor(past_action_id):
180
+ past_action_id = past_action_id.long()
181
+
182
+ with torch.no_grad():
183
+ logits, cache = unwrapped_model(
184
+ image_tensor,
185
+ past_action_id,
186
+ meta_controller = unwrapped_meta_controller,
187
+ return_cache = True,
188
+ return_raw_action_dist = True,
189
+ cache = cache
190
+ )
191
+
192
+ action = unwrapped_model.action_readout.sample(logits)
193
+ past_action_id = action
194
+ action = action.squeeze()
195
+
196
+ # GRPO collection
197
+
198
+ meta_output = cache.prev_hiddens.meta_controller
199
+ old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
200
+
201
+ states.append(meta_output.input_residual_stream)
202
+ log_probs.append(old_log_probs)
203
+ switch_betas.append(meta_output.switch_beta)
204
+ latent_actions.append(meta_output.actions)
205
+
206
+ next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
207
+
208
+ total_reward += reward
209
+ step_rewards.append(reward)
210
+ done = terminated or truncated
211
+
212
+ if done:
213
+ episode_len = step + 1
214
+ break
215
+
216
+ state = next_state
217
+
218
+ # store episode
219
+
220
+ all_episodes.append((
221
+ cat(states, dim = 1).squeeze(0),
222
+ cat(log_probs, dim = 1).squeeze(0),
223
+ cat(switch_betas, dim = 1).squeeze(0),
224
+ cat(latent_actions, dim = 1).squeeze(0)
225
+ ))
226
+
227
+ all_cumulative_rewards.append(tensor(total_reward))
228
+ all_step_rewards.append(tensor(step_rewards))
229
+ all_episode_lens.append(episode_len)
230
+
231
+ # compute advantages
232
+
233
+ cumulative_rewards = stack(all_cumulative_rewards)
234
+ episode_lens = tensor(all_episode_lens)
235
+
236
+ # pad step rewards
237
+
238
+ max_len = max(all_episode_lens)
239
+ padded_step_rewards = torch.zeros(num_episodes, max_len)
240
+
241
+ for i, (rewards, length) in enumerate(zip(all_step_rewards, all_episode_lens)):
242
+ padded_step_rewards[i, :length] = rewards
243
+
244
+ # reward shaping hook
245
+
246
+ shaped_rewards = reward_shaping_fn(cumulative_rewards, padded_step_rewards, episode_lens)
247
+
248
+ if not exists(shaped_rewards):
249
+ continue
250
+
251
+ group_advantages = z_score(shaped_rewards)
252
+
253
+ group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
254
+
255
+ for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
256
+ replay_buffer.store_episode(
257
+ states = states,
258
+ log_probs = log_probs,
259
+ switch_betas = switch_betas,
260
+ latent_actions = latent_actions,
261
+ advantages = advantages
262
+ )
263
+
264
+ # learn
265
+
266
+ if len(replay_buffer) >= buffer_size:
267
+ dl = replay_buffer.dataloader(batch_size = num_groups)
268
+ dl = accelerator.prepare(dl)
269
+
270
+ meta_controller.train()
271
+
272
+ batch = next(iter(dl))
273
+
274
+ loss = meta_controller.policy_loss(
275
+ batch['states'],
276
+ batch['log_probs'],
277
+ batch['latent_actions'],
278
+ batch['advantages'],
279
+ batch['switch_betas'] == 1.,
280
+ episode_lens = batch['_lens']
281
+ )
282
+
283
+ accelerator.backward(loss)
284
+
285
+ grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
286
+
287
+ optim.step()
288
+ optim.zero_grad()
289
+
290
+ meta_controller.eval()
291
+
292
+ pbar.set_postfix(
293
+ loss = f'{loss.item():.4f}',
294
+ grad_norm = f'{grad_norm.item():.4f}',
295
+ reward = f'{cumulative_rewards.mean().item():.4f}'
296
+ )
297
+
298
+ accelerator.log({
299
+ 'loss': loss.item(),
300
+ 'grad_norm': grad_norm.item()
301
+ })
302
+
303
+ accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
304
+
305
+ env.close()
306
+
307
+ # save
308
+
309
+ if exists(output_meta_controller_path):
310
+ unwrapped_meta_controller.save(output_meta_controller_path)
311
+ accelerator.print(f'MetaController weights saved to {output_meta_controller_path}')
312
+
313
+ if __name__ == '__main__':
314
+ Fire(main)
@@ -92,8 +92,8 @@ def train(
92
92
  else: state_dim = int(torch.tensor(state_shape).prod().item())
93
93
 
94
94
  # deduce num_actions from the environment
95
- minigrid.register_minigrid_envs()
96
- temp_env = gym.make(env_id)
95
+ from babyai_env import create_env
96
+ temp_env = create_env(env_id)
97
97
  num_actions = int(temp_env.action_space.n)
98
98
  temp_env.close()
99
99
 
@@ -1 +0,0 @@
1
- from metacontroller.metacontroller import MetaController
@@ -1,14 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 1. Gather trajectories
5
- echo "Gathering trajectories..."
6
- uv run gather_babyai_trajs.py --num_seeds 1000 --num_episodes_per_seed 100 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
-
8
- # 2. Behavioral cloning
9
- echo "Training behavioral cloning model..."
10
- uv run train_behavior_clone_babyai.py --cloning_epochs 10 --discovery_epochs 10 --batch_size 256 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt --use_resnet
11
-
12
- # 3. Inference rollouts
13
- echo "Running inference rollouts..."
14
- uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
@@ -1,140 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "fire",
4
- # "gymnasium",
5
- # "gymnasium[other]",
6
- # "memmap-replay-buffer>=0.0.12",
7
- # "metacontroller-pytorch",
8
- # "minigrid",
9
- # "tqdm"
10
- # ]
11
- # ///
12
-
13
- from fire import Fire
14
- from tqdm import tqdm
15
- from shutil import rmtree
16
- from pathlib import Path
17
-
18
- import torch
19
- from einops import rearrange
20
-
21
- import gymnasium as gym
22
- import minigrid
23
- from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
24
-
25
- from memmap_replay_buffer import ReplayBuffer
26
- from metacontroller.metacontroller import Transformer
27
-
28
- # functions
29
-
30
- def exists(v):
31
- return v is not None
32
-
33
- def default(v, d):
34
- return v if exists(v) else d
35
-
36
- def divisible_by(num, den):
37
- return (num % den) == 0
38
-
39
- # main
40
-
41
- def main(
42
- env_name = 'BabyAI-BossLevel-v0',
43
- num_episodes = int(10e6),
44
- max_timesteps = 500,
45
- buffer_size = 5_000,
46
- render_every_eps = 1_000,
47
- video_folder = './recordings',
48
- seed = None,
49
- weights_path = None
50
- ):
51
-
52
- # environment
53
-
54
- env = gym.make(env_name, render_mode = 'rgb_array')
55
- env = FullyObsWrapper(env.unwrapped)
56
- env = SymbolicObsWrapper(env.unwrapped)
57
-
58
- rmtree(video_folder, ignore_errors = True)
59
-
60
- env = gym.wrappers.RecordVideo(
61
- env = env,
62
- video_folder = video_folder,
63
- name_prefix = 'babyai',
64
- episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
65
- disable_logger = True
66
- )
67
-
68
- # maybe load model
69
-
70
- model = None
71
- if exists(weights_path):
72
- weights_path = Path(weights_path)
73
- assert weights_path.exists(), f"weights not found at {weights_path}"
74
- model = Transformer.init_and_load(str(weights_path), strict = False)
75
- model.eval()
76
-
77
- # replay
78
-
79
- replay_buffer = ReplayBuffer(
80
- './replay-data',
81
- max_episodes = buffer_size,
82
- max_timesteps = max_timesteps + 1,
83
- fields = dict(
84
- action = 'int',
85
- state_image = ('float', (7, 7, 3)),
86
- state_direction = 'int'
87
- ),
88
- overwrite = True,
89
- circular = True
90
- )
91
-
92
- # rollouts
93
-
94
- for _ in tqdm(range(num_episodes)):
95
-
96
- state, *_ = env.reset(seed = seed)
97
-
98
- cache = None
99
- past_action_id = None
100
-
101
- for _ in range(max_timesteps):
102
-
103
- if exists(model):
104
- # preprocess state
105
- # assume state is a dict with 'image'
106
- image = state['image']
107
- image_tensor = torch.from_numpy(image).float()
108
- image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
109
-
110
- if exists(past_action_id) and torch.is_tensor(past_action_id):
111
- past_action_id = past_action_id.long()
112
-
113
- with torch.no_grad():
114
- logits, cache = model(
115
- image_tensor,
116
- past_action_id,
117
- return_cache = True,
118
- return_raw_action_dist = True,
119
- cache = cache
120
- )
121
-
122
- action = model.action_readout.sample(logits)
123
- past_action_id = action
124
- action = action.squeeze()
125
- else:
126
- action = torch.randint(0, 7, ())
127
-
128
- next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
129
-
130
- done = terminated or truncated
131
-
132
- if done:
133
- break
134
-
135
- state = next_state
136
-
137
- env.close()
138
-
139
- if __name__ == '__main__':
140
- Fire(main)