metacontroller-pytorch 0.0.38__tar.gz → 0.0.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/PKG-INFO +1 -1
  2. metacontroller_pytorch-0.0.40/babyai_env.py +41 -0
  3. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/pyproject.toml +1 -1
  4. metacontroller_pytorch-0.0.40/test_babyai_e2e.sh +35 -0
  5. metacontroller_pytorch-0.0.40/train_babyai.py +314 -0
  6. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/train_behavior_clone_babyai.py +2 -2
  7. metacontroller_pytorch-0.0.38/test_babyai_e2e.sh +0 -14
  8. metacontroller_pytorch-0.0.38/train_babyai.py +0 -140
  9. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/.github/workflows/python-publish.yml +0 -0
  10. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/.github/workflows/test.yml +0 -0
  11. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/.gitignore +0 -0
  12. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/LICENSE +0 -0
  13. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/README.md +0 -0
  14. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/fig1.png +0 -0
  15. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/gather_babyai_trajs.py +0 -0
  16. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/metacontroller/__init__.py +0 -0
  17. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/metacontroller/metacontroller.py +0 -0
  18. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  19. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/metacontroller/transformer_with_resnet.py +0 -0
  20. {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.40}/tests/test_metacontroller.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.38
3
+ Version: 0.0.40
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,41 @@
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+
4
+ import gymnasium as gym
5
+ import minigrid
6
+ from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
7
+
8
+ # functions
9
+
10
+ def divisible_by(num, den):
11
+ return (num % den) == 0
12
+
13
+ # env creation
14
+
15
+ def create_env(
16
+ env_id,
17
+ render_mode = 'rgb_array',
18
+ video_folder = None,
19
+ render_every_eps = 1000
20
+ ):
21
+ # register minigrid environments if needed
22
+ minigrid.register_minigrid_envs()
23
+
24
+ # environment
25
+ env = gym.make(env_id, render_mode = render_mode)
26
+ env = FullyObsWrapper(env)
27
+ env = SymbolicObsWrapper(env)
28
+
29
+ if video_folder is not None:
30
+ video_folder = Path(video_folder)
31
+ rmtree(video_folder, ignore_errors = True)
32
+
33
+ env = gym.wrappers.RecordVideo(
34
+ env = env,
35
+ video_folder = str(video_folder),
36
+ name_prefix = 'babyai',
37
+ episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
38
+ disable_logger = True
39
+ )
40
+
41
+ return env
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.38"
3
+ version = "0.0.40"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,35 @@
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # 1. Gather trajectories
5
+ echo "Gathering trajectories..."
6
+ uv run gather_babyai_trajs.py \
7
+ --num_seeds 100 \
8
+ --num_episodes_per_seed 10 \
9
+ --num_steps 500 \
10
+ --output_dir end_to_end_trajectories \
11
+ --env_id BabyAI-MiniBossLevel-v0
12
+
13
+ # 2. Behavioral cloning
14
+ echo "Training behavioral cloning model..."
15
+ ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
16
+ --cloning_epochs 10 \
17
+ --discovery_epochs 10 \
18
+ --batch_size 256 \
19
+ --input_dir end_to_end_trajectories \
20
+ --env_id BabyAI-MiniBossLevel-v0 \
21
+ --checkpoint_path end_to_end_model.pt \
22
+ --use_resnet
23
+
24
+ # 3. Inference rollouts
25
+ echo "Running inference rollouts..."
26
+ uv run train_babyai.py \
27
+ --transformer_weights_path end_to_end_model.pt \
28
+ --meta_controller_weights_path meta_controller_discovery.pt \
29
+ --env_name BabyAI-MiniBossLevel-v0 \
30
+ --num_episodes 1000 \
31
+ --buffer_size 1000 \
32
+ --max_timesteps 100 \
33
+ --num_groups 16 \
34
+ --lr 1e-4 \
35
+ --use_resnet
@@ -0,0 +1,314 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "fire",
4
+ # "gymnasium",
5
+ # "gymnasium[other]",
6
+ # "memmap-replay-buffer>=0.0.12",
7
+ # "metacontroller-pytorch",
8
+ # "minigrid",
9
+ # "tqdm"
10
+ # ]
11
+ # ///
12
+
13
+ from fire import Fire
14
+ from pathlib import Path
15
+ from functools import partial
16
+ from shutil import rmtree
17
+ from tqdm import tqdm
18
+
19
+ import torch
20
+ from torch import cat, tensor, stack
21
+ from torch.optim import Adam
22
+
23
+ from einops import rearrange
24
+
25
+ from accelerate import Accelerator
26
+
27
+ from babyai_env import create_env
28
+ from memmap_replay_buffer import ReplayBuffer
29
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
30
+ from metacontroller.transformer_with_resnet import TransformerWithResnet
31
+
32
+ # research entry point
33
+
34
+ def reward_shaping_fn(
35
+ cumulative_rewards: torch.Tensor,
36
+ all_rewards: torch.Tensor,
37
+ episode_lens: torch.Tensor
38
+ ) -> torch.Tensor | None:
39
+ """
40
+ researchers can modify this function to engineer rewards
41
+ or return None to reject the entire batch
42
+
43
+ cumulative_rewards: (num_episodes,)
44
+ all_rewards: (num_episodes, max_timesteps)
45
+ episode_lens: (num_episodes,)
46
+ """
47
+ return cumulative_rewards
48
+
49
+ # helpers
50
+
51
+ def exists(v):
52
+ return v is not None
53
+
54
+ def default(v, d):
55
+ return v if exists(v) else d
56
+
57
+ # main
58
+
59
+ def main(
60
+ env_name: str = 'BabyAI-BossLevel-v0',
61
+ num_episodes: int = int(10e6),
62
+ max_timesteps: int = 500,
63
+ buffer_size: int = 5_000,
64
+ render_every_eps: int = 1_000,
65
+ video_folder: str = './recordings',
66
+ seed: int | None = None,
67
+ transformer_weights_path: str | None = None,
68
+ meta_controller_weights_path: str | None = None,
69
+ output_meta_controller_path: str = 'metacontroller_rl_trained.pt',
70
+ use_resnet: bool = False,
71
+ lr: float = 1e-4,
72
+ num_groups: int = 16,
73
+ max_grad_norm: float = 1.0,
74
+ use_wandb: bool = False,
75
+ wandb_project: str = 'metacontroller-babyai-rl'
76
+ ):
77
+ # accelerator
78
+
79
+ accelerator = Accelerator(log_with = 'wandb' if use_wandb else None)
80
+
81
+ if use_wandb:
82
+ accelerator.init_trackers(wandb_project)
83
+
84
+ # environment
85
+
86
+ env = create_env(
87
+ env_name,
88
+ render_mode = 'rgb_array',
89
+ video_folder = video_folder,
90
+ render_every_eps = render_every_eps
91
+ )
92
+
93
+ # load models
94
+
95
+ model = None
96
+ if exists(transformer_weights_path):
97
+ weights_path = Path(transformer_weights_path)
98
+ assert weights_path.exists(), f"transformer weights not found at {weights_path}"
99
+
100
+ transformer_klass = TransformerWithResnet if use_resnet else Transformer
101
+ model = transformer_klass.init_and_load(str(weights_path), strict = False)
102
+ model.eval()
103
+
104
+ meta_controller = None
105
+ if exists(meta_controller_weights_path):
106
+ weights_path = Path(meta_controller_weights_path)
107
+ assert weights_path.exists(), f"meta controller weights not found at {weights_path}"
108
+ meta_controller = MetaController.init_and_load(str(weights_path), strict = False)
109
+ meta_controller.eval()
110
+
111
+ meta_controller = default(meta_controller, getattr(model, 'meta_controller', None))
112
+ assert exists(meta_controller), "MetaController must be present for reinforcement learning"
113
+
114
+ # optimizer
115
+
116
+ optim = Adam(meta_controller.internal_rl_parameters(), lr = lr)
117
+
118
+ # prepare
119
+
120
+ model, meta_controller, optim = accelerator.prepare(model, meta_controller, optim)
121
+
122
+ unwrapped_model = accelerator.unwrap_model(model)
123
+ unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
124
+
125
+ # replay buffer
126
+
127
+ replay_buffer = ReplayBuffer(
128
+ './replay-data',
129
+ max_episodes = buffer_size,
130
+ max_timesteps = max_timesteps + 1,
131
+ fields = meta_controller.replay_buffer_field_dict,
132
+ meta_fields = dict(advantages = 'float'),
133
+ overwrite = True,
134
+ circular = True
135
+ )
136
+
137
+ # rollouts
138
+
139
+ num_batch_updates = num_episodes // num_groups
140
+
141
+ pbar = tqdm(range(num_batch_updates), desc = 'training')
142
+
143
+ for _ in pbar:
144
+
145
+ all_episodes = []
146
+ all_cumulative_rewards = []
147
+ all_step_rewards = []
148
+ all_episode_lens = []
149
+
150
+ group_seed = default(seed, torch.randint(0, 1000000, (1,)).item())
151
+
152
+ for _ in range(num_groups):
153
+
154
+ state, *_ = env.reset(seed = group_seed)
155
+
156
+ cache = None
157
+ past_action_id = None
158
+
159
+ states = []
160
+ log_probs = []
161
+ switch_betas = []
162
+ latent_actions = []
163
+
164
+ total_reward = 0.
165
+ step_rewards = []
166
+ episode_len = max_timesteps
167
+
168
+ for step in range(max_timesteps):
169
+
170
+ image = state['image']
171
+ image_tensor = torch.from_numpy(image).float().to(accelerator.device)
172
+
173
+ if use_resnet:
174
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
175
+ image_tensor = model.visual_encode(image_tensor)
176
+ else:
177
+ image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
178
+
179
+ if torch.is_tensor(past_action_id):
180
+ past_action_id = past_action_id.long()
181
+
182
+ with torch.no_grad():
183
+ logits, cache = unwrapped_model(
184
+ image_tensor,
185
+ past_action_id,
186
+ meta_controller = unwrapped_meta_controller,
187
+ return_cache = True,
188
+ return_raw_action_dist = True,
189
+ cache = cache
190
+ )
191
+
192
+ action = unwrapped_model.action_readout.sample(logits)
193
+ past_action_id = action
194
+ action = action.squeeze()
195
+
196
+ # GRPO collection
197
+
198
+ meta_output = cache.prev_hiddens.meta_controller
199
+ old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
200
+
201
+ states.append(meta_output.input_residual_stream)
202
+ log_probs.append(old_log_probs)
203
+ switch_betas.append(meta_output.switch_beta)
204
+ latent_actions.append(meta_output.actions)
205
+
206
+ next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
207
+
208
+ total_reward += reward
209
+ step_rewards.append(reward)
210
+ done = terminated or truncated
211
+
212
+ if done:
213
+ episode_len = step + 1
214
+ break
215
+
216
+ state = next_state
217
+
218
+ # store episode
219
+
220
+ all_episodes.append((
221
+ cat(states, dim = 1).squeeze(0),
222
+ cat(log_probs, dim = 1).squeeze(0),
223
+ cat(switch_betas, dim = 1).squeeze(0),
224
+ cat(latent_actions, dim = 1).squeeze(0)
225
+ ))
226
+
227
+ all_cumulative_rewards.append(tensor(total_reward))
228
+ all_step_rewards.append(tensor(step_rewards))
229
+ all_episode_lens.append(episode_len)
230
+
231
+ # compute advantages
232
+
233
+ cumulative_rewards = stack(all_cumulative_rewards)
234
+ episode_lens = tensor(all_episode_lens)
235
+
236
+ # pad step rewards
237
+
238
+ max_len = max(all_episode_lens)
239
+ padded_step_rewards = torch.zeros(num_episodes, max_len)
240
+
241
+ for i, (rewards, length) in enumerate(zip(all_step_rewards, all_episode_lens)):
242
+ padded_step_rewards[i, :length] = rewards
243
+
244
+ # reward shaping hook
245
+
246
+ shaped_rewards = reward_shaping_fn(cumulative_rewards, padded_step_rewards, episode_lens)
247
+
248
+ if not exists(shaped_rewards):
249
+ continue
250
+
251
+ group_advantages = z_score(shaped_rewards)
252
+
253
+ group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
254
+
255
+ for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
256
+ replay_buffer.store_episode(
257
+ states = states,
258
+ log_probs = log_probs,
259
+ switch_betas = switch_betas,
260
+ latent_actions = latent_actions,
261
+ advantages = advantages
262
+ )
263
+
264
+ # learn
265
+
266
+ if len(replay_buffer) >= buffer_size:
267
+ dl = replay_buffer.dataloader(batch_size = num_groups)
268
+ dl = accelerator.prepare(dl)
269
+
270
+ meta_controller.train()
271
+
272
+ batch = next(iter(dl))
273
+
274
+ loss = meta_controller.policy_loss(
275
+ batch['states'],
276
+ batch['log_probs'],
277
+ batch['latent_actions'],
278
+ batch['advantages'],
279
+ batch['switch_betas'] == 1.,
280
+ episode_lens = batch['_lens']
281
+ )
282
+
283
+ accelerator.backward(loss)
284
+
285
+ grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
286
+
287
+ optim.step()
288
+ optim.zero_grad()
289
+
290
+ meta_controller.eval()
291
+
292
+ pbar.set_postfix(
293
+ loss = f'{loss.item():.4f}',
294
+ grad_norm = f'{grad_norm.item():.4f}',
295
+ reward = f'{cumulative_rewards.mean().item():.4f}'
296
+ )
297
+
298
+ accelerator.log({
299
+ 'loss': loss.item(),
300
+ 'grad_norm': grad_norm.item()
301
+ })
302
+
303
+ accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
304
+
305
+ env.close()
306
+
307
+ # save
308
+
309
+ if exists(output_meta_controller_path):
310
+ unwrapped_meta_controller.save(output_meta_controller_path)
311
+ accelerator.print(f'MetaController weights saved to {output_meta_controller_path}')
312
+
313
+ if __name__ == '__main__':
314
+ Fire(main)
@@ -92,8 +92,8 @@ def train(
92
92
  else: state_dim = int(torch.tensor(state_shape).prod().item())
93
93
 
94
94
  # deduce num_actions from the environment
95
- minigrid.register_minigrid_envs()
96
- temp_env = gym.make(env_id)
95
+ from babyai_env import create_env
96
+ temp_env = create_env(env_id)
97
97
  num_actions = int(temp_env.action_space.n)
98
98
  temp_env.close()
99
99
 
@@ -1,14 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 1. Gather trajectories
5
- echo "Gathering trajectories..."
6
- uv run gather_babyai_trajs.py --num_seeds 1000 --num_episodes_per_seed 100 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
7
-
8
- # 2. Behavioral cloning
9
- echo "Training behavioral cloning model..."
10
- uv run train_behavior_clone_babyai.py --cloning_epochs 10 --discovery_epochs 10 --batch_size 256 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt --use_resnet
11
-
12
- # 3. Inference rollouts
13
- echo "Running inference rollouts..."
14
- uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
@@ -1,140 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "fire",
4
- # "gymnasium",
5
- # "gymnasium[other]",
6
- # "memmap-replay-buffer>=0.0.12",
7
- # "metacontroller-pytorch",
8
- # "minigrid",
9
- # "tqdm"
10
- # ]
11
- # ///
12
-
13
- from fire import Fire
14
- from tqdm import tqdm
15
- from shutil import rmtree
16
- from pathlib import Path
17
-
18
- import torch
19
- from einops import rearrange
20
-
21
- import gymnasium as gym
22
- import minigrid
23
- from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
24
-
25
- from memmap_replay_buffer import ReplayBuffer
26
- from metacontroller.metacontroller import Transformer
27
-
28
- # functions
29
-
30
- def exists(v):
31
- return v is not None
32
-
33
- def default(v, d):
34
- return v if exists(v) else d
35
-
36
- def divisible_by(num, den):
37
- return (num % den) == 0
38
-
39
- # main
40
-
41
- def main(
42
- env_name = 'BabyAI-BossLevel-v0',
43
- num_episodes = int(10e6),
44
- max_timesteps = 500,
45
- buffer_size = 5_000,
46
- render_every_eps = 1_000,
47
- video_folder = './recordings',
48
- seed = None,
49
- weights_path = None
50
- ):
51
-
52
- # environment
53
-
54
- env = gym.make(env_name, render_mode = 'rgb_array')
55
- env = FullyObsWrapper(env.unwrapped)
56
- env = SymbolicObsWrapper(env.unwrapped)
57
-
58
- rmtree(video_folder, ignore_errors = True)
59
-
60
- env = gym.wrappers.RecordVideo(
61
- env = env,
62
- video_folder = video_folder,
63
- name_prefix = 'babyai',
64
- episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
65
- disable_logger = True
66
- )
67
-
68
- # maybe load model
69
-
70
- model = None
71
- if exists(weights_path):
72
- weights_path = Path(weights_path)
73
- assert weights_path.exists(), f"weights not found at {weights_path}"
74
- model = Transformer.init_and_load(str(weights_path), strict = False)
75
- model.eval()
76
-
77
- # replay
78
-
79
- replay_buffer = ReplayBuffer(
80
- './replay-data',
81
- max_episodes = buffer_size,
82
- max_timesteps = max_timesteps + 1,
83
- fields = dict(
84
- action = 'int',
85
- state_image = ('float', (7, 7, 3)),
86
- state_direction = 'int'
87
- ),
88
- overwrite = True,
89
- circular = True
90
- )
91
-
92
- # rollouts
93
-
94
- for _ in tqdm(range(num_episodes)):
95
-
96
- state, *_ = env.reset(seed = seed)
97
-
98
- cache = None
99
- past_action_id = None
100
-
101
- for _ in range(max_timesteps):
102
-
103
- if exists(model):
104
- # preprocess state
105
- # assume state is a dict with 'image'
106
- image = state['image']
107
- image_tensor = torch.from_numpy(image).float()
108
- image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
109
-
110
- if exists(past_action_id) and torch.is_tensor(past_action_id):
111
- past_action_id = past_action_id.long()
112
-
113
- with torch.no_grad():
114
- logits, cache = model(
115
- image_tensor,
116
- past_action_id,
117
- return_cache = True,
118
- return_raw_action_dist = True,
119
- cache = cache
120
- )
121
-
122
- action = model.action_readout.sample(logits)
123
- past_action_id = action
124
- action = action.squeeze()
125
- else:
126
- action = torch.randint(0, 7, ())
127
-
128
- next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
129
-
130
- done = terminated or truncated
131
-
132
- if done:
133
- break
134
-
135
- state = next_state
136
-
137
- env.close()
138
-
139
- if __name__ == '__main__':
140
- Fire(main)