metacontroller-pytorch 0.0.37__tar.gz → 0.0.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (20) hide show
  1. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/PKG-INFO +13 -1
  2. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/README.md +12 -0
  3. metacontroller_pytorch-0.0.40/babyai_env.py +41 -0
  4. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/metacontroller/metacontroller.py +10 -0
  5. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/metacontroller/metacontroller_with_binary_mapper.py +10 -1
  6. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/pyproject.toml +1 -1
  7. metacontroller_pytorch-0.0.40/test_babyai_e2e.sh +35 -0
  8. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/tests/test_metacontroller.py +1 -16
  9. metacontroller_pytorch-0.0.40/train_babyai.py +314 -0
  10. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/train_behavior_clone_babyai.py +2 -2
  11. metacontroller_pytorch-0.0.37/test_babyai_e2e.sh +0 -14
  12. metacontroller_pytorch-0.0.37/train_babyai.py +0 -140
  13. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/.github/workflows/python-publish.yml +0 -0
  14. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/.github/workflows/test.yml +0 -0
  15. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/.gitignore +0 -0
  16. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/LICENSE +0 -0
  17. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/fig1.png +0 -0
  18. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/gather_babyai_trajs.py +0 -0
  19. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/metacontroller/__init__.py +0 -0
  20. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.40}/metacontroller/transformer_with_resnet.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.37
3
+ Version: 0.0.40
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -94,6 +94,18 @@ $ pip install metacontroller-pytorch
94
94
  }
95
95
  ```
96
96
 
97
+ ```bibtex
98
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
99
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
100
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
101
+ year = {2025},
102
+ eprint = {2507.07955},
103
+ archivePrefix = {arXiv},
104
+ primaryClass = {cs.LG},
105
+ url = {https://arxiv.org/abs/2507.07955},
106
+ }
107
+ ```
108
+
97
109
  ```bibtex
98
110
  @misc{fleuret2025freetransformer,
99
111
  title = {The Free Transformer},
@@ -41,6 +41,18 @@ $ pip install metacontroller-pytorch
41
41
  }
42
42
  ```
43
43
 
44
+ ```bibtex
45
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
46
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
47
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
48
+ year = {2025},
49
+ eprint = {2507.07955},
50
+ archivePrefix = {arXiv},
51
+ primaryClass = {cs.LG},
52
+ url = {https://arxiv.org/abs/2507.07955},
53
+ }
54
+ ```
55
+
44
56
  ```bibtex
45
57
  @misc{fleuret2025freetransformer,
46
58
  title = {The Free Transformer},
@@ -0,0 +1,41 @@
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+
4
+ import gymnasium as gym
5
+ import minigrid
6
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
7
+
8
+ # functions
9
+
10
+ def divisible_by(num, den):
11
+ return (num % den) == 0
12
+
13
+ # env creation
14
+
15
+ def create_env(
16
+ env_id,
17
+ render_mode = 'rgb_array',
18
+ video_folder = None,
19
+ render_every_eps = 1000
20
+ ):
21
+ # register minigrid environments if needed
22
+ minigrid.register_minigrid_envs()
23
+
24
+ # environment
25
+ env = gym.make(env_id, render_mode = render_mode)
26
+ env = FullyObsWrapper(env)
27
+ env = SymbolicObsWrapper(env)
28
+
29
+ if video_folder is not None:
30
+ video_folder = Path(video_folder)
31
+ rmtree(video_folder, ignore_errors = True)
32
+
33
+ env = gym.wrappers.RecordVideo(
34
+ env = env,
35
+ video_folder = str(video_folder),
36
+ name_prefix = 'babyai',
37
+ episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
38
+ disable_logger = True
39
+ )
40
+
41
+ return env
@@ -126,6 +126,7 @@ class MetaController(Module):
126
126
  )
127
127
  ):
128
128
  super().__init__()
129
+ self.dim_model = dim_model
129
130
  dim_meta = default(dim_meta_controller, dim_model)
130
131
 
131
132
  # the linear that brings from model dimension
@@ -171,6 +172,15 @@ class MetaController(Module):
171
172
 
172
173
  self.register_buffer('zero', tensor(0.), persistent = False)
173
174
 
175
+ @property
176
+ def replay_buffer_field_dict(self):
177
+ return dict(
178
+ states = ('float', self.dim_model),
179
+ log_probs = ('float', self.dim_latent),
180
+ switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
181
+ latent_actions = ('float', self.dim_latent)
182
+ )
183
+
174
184
  def discovery_parameters(self):
175
185
  return [
176
186
  *self.model_to_meta.parameters(),
@@ -74,7 +74,7 @@ class MetaControllerWithBinaryMapper(Module):
74
74
  kl_loss_threshold = 0.
75
75
  ):
76
76
  super().__init__()
77
-
77
+ self.dim_model = dim_model
78
78
  assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
79
 
80
80
  dim_meta = default(dim_meta_controller, dim_model)
@@ -126,6 +126,15 @@ class MetaControllerWithBinaryMapper(Module):
126
126
 
127
127
  self.register_buffer('zero', tensor(0.), persistent = False)
128
128
 
129
+ @property
130
+ def replay_buffer_field_dict(self):
131
+ return dict(
132
+ states = ('float', self.dim_model),
133
+ log_probs = ('float', self.dim_code_bits),
134
+ switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
135
+ latent_actions = ('float', self.num_codes)
136
+ )
137
+
129
138
  def discovery_parameters(self):
130
139
  return [
131
140
  *self.model_to_meta.parameters(),
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.37"
3
+ version = "0.0.40"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,35 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # 1. Gather trajectories
5
+ echo "Gathering trajectories..."
6
+ uv run gather_babyai_trajs.py \
7
+ --num_seeds 100 \
8
+ --num_episodes_per_seed 10 \
9
+ --num_steps 500 \
10
+ --output_dir end_to_end_trajectories \
11
+ --env_id BabyAI-MiniBossLevel-v0
12
+
13
+ # 2. Behavioral cloning
14
+ echo "Training behavioral cloning model..."
15
+ ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
16
+ --cloning_epochs 10 \
17
+ --discovery_epochs 10 \
18
+ --batch_size 256 \
19
+ --input_dir end_to_end_trajectories \
20
+ --env_id BabyAI-MiniBossLevel-v0 \
21
+ --checkpoint_path end_to_end_model.pt \
22
+ --use_resnet
23
+
24
+ # 3. Inference rollouts
25
+ echo "Running inference rollouts..."
26
+ uv run train_babyai.py \
27
+ --transformer_weights_path end_to_end_model.pt \
28
+ --meta_controller_weights_path meta_controller_discovery.pt \
29
+ --env_name BabyAI-MiniBossLevel-v0 \
30
+ --num_episodes 1000 \
31
+ --buffer_size 1000 \
32
+ --max_timesteps 100 \
33
+ --num_groups 16 \
34
+ --lr 1e-4 \
35
+ --use_resnet
@@ -69,12 +69,6 @@ def test_metacontroller(
69
69
  dim_latent = 128,
70
70
  switch_per_latent_dim = switch_per_latent_dim
71
71
  )
72
-
73
- field_shapes = dict(
74
- log_probs = ('float', 128),
75
- switch_betas = ('float', 128 if switch_per_latent_dim else 1),
76
- latent_actions = ('float', 128)
77
- )
78
72
  else:
79
73
  meta_controller = MetaControllerWithBinaryMapper(
80
74
  dim_model = 512,
@@ -83,12 +77,6 @@ def test_metacontroller(
83
77
  dim_code_bits = 8, # 2 ** 8 = 256 codes
84
78
  )
85
79
 
86
- field_shapes = dict(
87
- log_probs = ('float', 8),
88
- switch_betas = ('float', 8 if switch_per_latent_dim else 1),
89
- latent_actions = ('float', 256)
90
- )
91
-
92
80
  # discovery phase
93
81
 
94
82
  (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
@@ -104,10 +92,7 @@ def test_metacontroller(
104
92
  test_folder,
105
93
  max_episodes = 3,
106
94
  max_timesteps = 256,
107
- fields = dict(
108
- states = ('float', 512),
109
- **field_shapes
110
- ),
95
+ fields = meta_controller.replay_buffer_field_dict,
111
96
  meta_fields = dict(
112
97
  advantages = 'float'
113
98
  )
@@ -0,0 +1,314 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "fire",
4
+ # "gymnasium",
5
+ # "gymnasium[other]",
6
+ # "memmap-replay-buffer>=0.0.12",
7
+ # "metacontroller-pytorch",
8
+ # "minigrid",
9
+ # "tqdm"
10
+ # ]
11
+ # ///
12
+
13
+ from fire import Fire
14
+ from pathlib import Path
15
+ from functools import partial
16
+ from shutil import rmtree
17
+ from tqdm import tqdm
18
+
19
+ import torch
20
+ from torch import cat, tensor, stack
21
+ from torch.optim import Adam
22
+
23
+ from einops import rearrange
24
+
25
+ from accelerate import Accelerator
26
+
27
+ from babyai_env import create_env
28
+ from memmap_replay_buffer import ReplayBuffer
29
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
30
+ from metacontroller.transformer_with_resnet import TransformerWithResnet
31
+
32
+ # research entry point
33
+
34
+ def reward_shaping_fn(
35
+ cumulative_rewards: torch.Tensor,
36
+ all_rewards: torch.Tensor,
37
+ episode_lens: torch.Tensor
38
+ ) -> torch.Tensor | None:
39
+ """
40
+ researchers can modify this function to engineer rewards
41
+ or return None to reject the entire batch
42
+
43
+ cumulative_rewards: (num_episodes,)
44
+ all_rewards: (num_episodes, max_timesteps)
45
+ episode_lens: (num_episodes,)
46
+ """
47
+ return cumulative_rewards
48
+
49
+ # helpers
50
+
51
+ def exists(v):
52
+ return v is not None
53
+
54
+ def default(v, d):
55
+ return v if exists(v) else d
56
+
57
+ # main
58
+
59
+ def main(
60
+ env_name: str = 'BabyAI-BossLevel-v0',
61
+ num_episodes: int = int(10e6),
62
+ max_timesteps: int = 500,
63
+ buffer_size: int = 5_000,
64
+ render_every_eps: int = 1_000,
65
+ video_folder: str = './recordings',
66
+ seed: int | None = None,
67
+ transformer_weights_path: str | None = None,
68
+ meta_controller_weights_path: str | None = None,
69
+ output_meta_controller_path: str = 'metacontroller_rl_trained.pt',
70
+ use_resnet: bool = False,
71
+ lr: float = 1e-4,
72
+ num_groups: int = 16,
73
+ max_grad_norm: float = 1.0,
74
+ use_wandb: bool = False,
75
+ wandb_project: str = 'metacontroller-babyai-rl'
76
+ ):
77
+ # accelerator
78
+
79
+ accelerator = Accelerator(log_with = 'wandb' if use_wandb else None)
80
+
81
+ if use_wandb:
82
+ accelerator.init_trackers(wandb_project)
83
+
84
+ # environment
85
+
86
+ env = create_env(
87
+ env_name,
88
+ render_mode = 'rgb_array',
89
+ video_folder = video_folder,
90
+ render_every_eps = render_every_eps
91
+ )
92
+
93
+ # load models
94
+
95
+ model = None
96
+ if exists(transformer_weights_path):
97
+ weights_path = Path(transformer_weights_path)
98
+ assert weights_path.exists(), f"transformer weights not found at {weights_path}"
99
+
100
+ transformer_klass = TransformerWithResnet if use_resnet else Transformer
101
+ model = transformer_klass.init_and_load(str(weights_path), strict = False)
102
+ model.eval()
103
+
104
+ meta_controller = None
105
+ if exists(meta_controller_weights_path):
106
+ weights_path = Path(meta_controller_weights_path)
107
+ assert weights_path.exists(), f"meta controller weights not found at {weights_path}"
108
+ meta_controller = MetaController.init_and_load(str(weights_path), strict = False)
109
+ meta_controller.eval()
110
+
111
+ meta_controller = default(meta_controller, getattr(model, 'meta_controller', None))
112
+ assert exists(meta_controller), "MetaController must be present for reinforcement learning"
113
+
114
+ # optimizer
115
+
116
+ optim = Adam(meta_controller.internal_rl_parameters(), lr = lr)
117
+
118
+ # prepare
119
+
120
+ model, meta_controller, optim = accelerator.prepare(model, meta_controller, optim)
121
+
122
+ unwrapped_model = accelerator.unwrap_model(model)
123
+ unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
124
+
125
+ # replay buffer
126
+
127
+ replay_buffer = ReplayBuffer(
128
+ './replay-data',
129
+ max_episodes = buffer_size,
130
+ max_timesteps = max_timesteps + 1,
131
+ fields = meta_controller.replay_buffer_field_dict,
132
+ meta_fields = dict(advantages = 'float'),
133
+ overwrite = True,
134
+ circular = True
135
+ )
136
+
137
+ # rollouts
138
+
139
+ num_batch_updates = num_episodes // num_groups
140
+
141
+ pbar = tqdm(range(num_batch_updates), desc = 'training')
142
+
143
+ for _ in pbar:
144
+
145
+ all_episodes = []
146
+ all_cumulative_rewards = []
147
+ all_step_rewards = []
148
+ all_episode_lens = []
149
+
150
+ group_seed = default(seed, torch.randint(0, 1000000, (1,)).item())
151
+
152
+ for _ in range(num_groups):
153
+
154
+ state, *_ = env.reset(seed = group_seed)
155
+
156
+ cache = None
157
+ past_action_id = None
158
+
159
+ states = []
160
+ log_probs = []
161
+ switch_betas = []
162
+ latent_actions = []
163
+
164
+ total_reward = 0.
165
+ step_rewards = []
166
+ episode_len = max_timesteps
167
+
168
+ for step in range(max_timesteps):
169
+
170
+ image = state['image']
171
+ image_tensor = torch.from_numpy(image).float().to(accelerator.device)
172
+
173
+ if use_resnet:
174
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
175
+ image_tensor = model.visual_encode(image_tensor)
176
+ else:
177
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
178
+
179
+ if torch.is_tensor(past_action_id):
180
+ past_action_id = past_action_id.long()
181
+
182
+ with torch.no_grad():
183
+ logits, cache = unwrapped_model(
184
+ image_tensor,
185
+ past_action_id,
186
+ meta_controller = unwrapped_meta_controller,
187
+ return_cache = True,
188
+ return_raw_action_dist = True,
189
+ cache = cache
190
+ )
191
+
192
+ action = unwrapped_model.action_readout.sample(logits)
193
+ past_action_id = action
194
+ action = action.squeeze()
195
+
196
+ # GRPO collection
197
+
198
+ meta_output = cache.prev_hiddens.meta_controller
199
+ old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
200
+
201
+ states.append(meta_output.input_residual_stream)
202
+ log_probs.append(old_log_probs)
203
+ switch_betas.append(meta_output.switch_beta)
204
+ latent_actions.append(meta_output.actions)
205
+
206
+ next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
207
+
208
+ total_reward += reward
209
+ step_rewards.append(reward)
210
+ done = terminated or truncated
211
+
212
+ if done:
213
+ episode_len = step + 1
214
+ break
215
+
216
+ state = next_state
217
+
218
+ # store episode
219
+
220
+ all_episodes.append((
221
+ cat(states, dim = 1).squeeze(0),
222
+ cat(log_probs, dim = 1).squeeze(0),
223
+ cat(switch_betas, dim = 1).squeeze(0),
224
+ cat(latent_actions, dim = 1).squeeze(0)
225
+ ))
226
+
227
+ all_cumulative_rewards.append(tensor(total_reward))
228
+ all_step_rewards.append(tensor(step_rewards))
229
+ all_episode_lens.append(episode_len)
230
+
231
+ # compute advantages
232
+
233
+ cumulative_rewards = stack(all_cumulative_rewards)
234
+ episode_lens = tensor(all_episode_lens)
235
+
236
+ # pad step rewards
237
+
238
+ max_len = max(all_episode_lens)
239
+ padded_step_rewards = torch.zeros(num_episodes, max_len)
240
+
241
+ for i, (rewards, length) in enumerate(zip(all_step_rewards, all_episode_lens)):
242
+ padded_step_rewards[i, :length] = rewards
243
+
244
+ # reward shaping hook
245
+
246
+ shaped_rewards = reward_shaping_fn(cumulative_rewards, padded_step_rewards, episode_lens)
247
+
248
+ if not exists(shaped_rewards):
249
+ continue
250
+
251
+ group_advantages = z_score(shaped_rewards)
252
+
253
+ group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
254
+
255
+ for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
256
+ replay_buffer.store_episode(
257
+ states = states,
258
+ log_probs = log_probs,
259
+ switch_betas = switch_betas,
260
+ latent_actions = latent_actions,
261
+ advantages = advantages
262
+ )
263
+
264
+ # learn
265
+
266
+ if len(replay_buffer) >= buffer_size:
267
+ dl = replay_buffer.dataloader(batch_size = num_groups)
268
+ dl = accelerator.prepare(dl)
269
+
270
+ meta_controller.train()
271
+
272
+ batch = next(iter(dl))
273
+
274
+ loss = meta_controller.policy_loss(
275
+ batch['states'],
276
+ batch['log_probs'],
277
+ batch['latent_actions'],
278
+ batch['advantages'],
279
+ batch['switch_betas'] == 1.,
280
+ episode_lens = batch['_lens']
281
+ )
282
+
283
+ accelerator.backward(loss)
284
+
285
+ grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
286
+
287
+ optim.step()
288
+ optim.zero_grad()
289
+
290
+ meta_controller.eval()
291
+
292
+ pbar.set_postfix(
293
+ loss = f'{loss.item():.4f}',
294
+ grad_norm = f'{grad_norm.item():.4f}',
295
+ reward = f'{cumulative_rewards.mean().item():.4f}'
296
+ )
297
+
298
+ accelerator.log({
299
+ 'loss': loss.item(),
300
+ 'grad_norm': grad_norm.item()
301
+ })
302
+
303
+ accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
304
+
305
+ env.close()
306
+
307
+ # save
308
+
309
+ if exists(output_meta_controller_path):
310
+ unwrapped_meta_controller.save(output_meta_controller_path)
311
+ accelerator.print(f'MetaController weights saved to {output_meta_controller_path}')
312
+
313
+ if __name__ == '__main__':
314
+ Fire(main)
@@ -92,8 +92,8 @@ def train(
92
92
  else: state_dim = int(torch.tensor(state_shape).prod().item())
93
93
 
94
94
  # deduce num_actions from the environment
95
- minigrid.register_minigrid_envs()
96
- temp_env = gym.make(env_id)
95
+ from babyai_env import create_env
96
+ temp_env = create_env(env_id)
97
97
  num_actions = int(temp_env.action_space.n)
98
98
  temp_env.close()
99
99
 
@@ -1,14 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 1. Gather trajectories
5
- echo "Gathering trajectories..."
6
- uv run gather_babyai_trajs.py --num_seeds 1000 --num_episodes_per_seed 100 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
-
8
- # 2. Behavioral cloning
9
- echo "Training behavioral cloning model..."
10
- uv run train_behavior_clone_babyai.py --cloning_epochs 10 --discovery_epochs 10 --batch_size 256 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt --use_resnet
11
-
12
- # 3. Inference rollouts
13
- echo "Running inference rollouts..."
14
- uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
@@ -1,140 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "fire",
4
- # "gymnasium",
5
- # "gymnasium[other]",
6
- # "memmap-replay-buffer>=0.0.12",
7
- # "metacontroller-pytorch",
8
- # "minigrid",
9
- # "tqdm"
10
- # ]
11
- # ///
12
-
13
- from fire import Fire
14
- from tqdm import tqdm
15
- from shutil import rmtree
16
- from pathlib import Path
17
-
18
- import torch
19
- from einops import rearrange
20
-
21
- import gymnasium as gym
22
- import minigrid
23
- from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
24
-
25
- from memmap_replay_buffer import ReplayBuffer
26
- from metacontroller.metacontroller import Transformer
27
-
28
- # functions
29
-
30
- def exists(v):
31
- return v is not None
32
-
33
- def default(v, d):
34
- return v if exists(v) else d
35
-
36
- def divisible_by(num, den):
37
- return (num % den) == 0
38
-
39
- # main
40
-
41
- def main(
42
- env_name = 'BabyAI-BossLevel-v0',
43
- num_episodes = int(10e6),
44
- max_timesteps = 500,
45
- buffer_size = 5_000,
46
- render_every_eps = 1_000,
47
- video_folder = './recordings',
48
- seed = None,
49
- weights_path = None
50
- ):
51
-
52
- # environment
53
-
54
- env = gym.make(env_name, render_mode = 'rgb_array')
55
- env = FullyObsWrapper(env.unwrapped)
56
- env = SymbolicObsWrapper(env.unwrapped)
57
-
58
- rmtree(video_folder, ignore_errors = True)
59
-
60
- env = gym.wrappers.RecordVideo(
61
- env = env,
62
- video_folder = video_folder,
63
- name_prefix = 'babyai',
64
- episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
65
- disable_logger = True
66
- )
67
-
68
- # maybe load model
69
-
70
- model = None
71
- if exists(weights_path):
72
- weights_path = Path(weights_path)
73
- assert weights_path.exists(), f"weights not found at {weights_path}"
74
- model = Transformer.init_and_load(str(weights_path), strict = False)
75
- model.eval()
76
-
77
- # replay
78
-
79
- replay_buffer = ReplayBuffer(
80
- './replay-data',
81
- max_episodes = buffer_size,
82
- max_timesteps = max_timesteps + 1,
83
- fields = dict(
84
- action = 'int',
85
- state_image = ('float', (7, 7, 3)),
86
- state_direction = 'int'
87
- ),
88
- overwrite = True,
89
- circular = True
90
- )
91
-
92
- # rollouts
93
-
94
- for _ in tqdm(range(num_episodes)):
95
-
96
- state, *_ = env.reset(seed = seed)
97
-
98
- cache = None
99
- past_action_id = None
100
-
101
- for _ in range(max_timesteps):
102
-
103
- if exists(model):
104
- # preprocess state
105
- # assume state is a dict with 'image'
106
- image = state['image']
107
- image_tensor = torch.from_numpy(image).float()
108
- image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
109
-
110
- if exists(past_action_id) and torch.is_tensor(past_action_id):
111
- past_action_id = past_action_id.long()
112
-
113
- with torch.no_grad():
114
- logits, cache = model(
115
- image_tensor,
116
- past_action_id,
117
- return_cache = True,
118
- return_raw_action_dist = True,
119
- cache = cache
120
- )
121
-
122
- action = model.action_readout.sample(logits)
123
- past_action_id = action
124
- action = action.squeeze()
125
- else:
126
- action = torch.randint(0, 7, ())
127
-
128
- next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
129
-
130
- done = terminated or truncated
131
-
132
- if done:
133
- break
134
-
135
- state = next_state
136
-
137
- env.close()
138
-
139
- if __name__ == '__main__':
140
- Fire(main)