metacontroller-pytorch 0.0.42__tar.gz → 0.0.44__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.gitignore +6 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/gather_babyai_trajs.py +81 -9
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/metacontroller.py +13 -2
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/metacontroller_with_binary_mapper.py +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/tests/test_metacontroller.py +10 -2
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/train_babyai.py +20 -17
- metacontroller_pytorch-0.0.44/train_babyai_evo_strat.py +213 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/train_behavior_clone_babyai.py +51 -22
- metacontroller_pytorch-0.0.42/test_babyai_e2e.sh +0 -35
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/README.md +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/babyai_env.py +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/transformer_with_resnet.py +0 -0
|
@@ -37,6 +37,9 @@ from minigrid.core.constants import OBJECT_TO_IDX
|
|
|
37
37
|
|
|
38
38
|
from memmap_replay_buffer import ReplayBuffer
|
|
39
39
|
|
|
40
|
+
# Difficulty thresholds based on mission length
|
|
41
|
+
EASY_MAX_LENGTH = 30 # easy: 0 to 30
|
|
42
|
+
MEDIUM_MAX_LENGTH = 75 # medium: 30 to 75, hard: > 75
|
|
40
43
|
|
|
41
44
|
# helpers
|
|
42
45
|
|
|
@@ -46,6 +49,67 @@ def exists(val):
|
|
|
46
49
|
def sample(prob):
|
|
47
50
|
return random.random() < prob
|
|
48
51
|
|
|
52
|
+
def get_mission_length(env_id, seed):
|
|
53
|
+
"""
|
|
54
|
+
Get the mission length for a given seed.
|
|
55
|
+
Returns the length of the mission string.
|
|
56
|
+
"""
|
|
57
|
+
env = gym.make(env_id, render_mode="rgb_array")
|
|
58
|
+
env.reset(seed=seed)
|
|
59
|
+
length = len(env.unwrapped.mission)
|
|
60
|
+
env.close()
|
|
61
|
+
return length
|
|
62
|
+
|
|
63
|
+
def categorize_seeds_by_difficulty(env_id, num_seeds_per_level, level_difficulty=None):
|
|
64
|
+
"""
|
|
65
|
+
Scan seeds and categorize them by difficulty based on mission length.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
env_id: Environment ID
|
|
69
|
+
num_seeds_per_level: Number of seeds needed per difficulty level
|
|
70
|
+
level_difficulty: List of levels to collect seeds for.
|
|
71
|
+
Supported: 'easy', 'medium', 'hard'
|
|
72
|
+
If None, collects for ['easy', 'hard'].
|
|
73
|
+
max_seed_to_scan: Maximum seed value to scan
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
dict with keys for each requested level, each containing a list of seeds
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
seeds = {level: [] for level in level_difficulty}
|
|
80
|
+
|
|
81
|
+
total_needed = sum(num_seeds_per_level for _ in level_difficulty)
|
|
82
|
+
print(f"Scanning seeds to categorize by difficulty (need {num_seeds_per_level} per level for {level_difficulty})...")
|
|
83
|
+
|
|
84
|
+
with tqdm(total=total_needed, desc="Categorizing seeds") as pbar:
|
|
85
|
+
seed = 1
|
|
86
|
+
all_done = False
|
|
87
|
+
while not all_done:
|
|
88
|
+
# Check if we have enough seeds for all requested levels
|
|
89
|
+
all_done = all(len(seeds[level]) >= num_seeds_per_level for level in level_difficulty)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
mission_length = get_mission_length(env_id, seed)
|
|
93
|
+
|
|
94
|
+
# easy: mission length <= 30
|
|
95
|
+
if 'easy' in level_difficulty and mission_length <= EASY_MAX_LENGTH and len(seeds['easy']) < num_seeds_per_level:
|
|
96
|
+
seeds['easy'].append(seed)
|
|
97
|
+
pbar.update(1)
|
|
98
|
+
# medium: mission length <= 75 (combines easy and medium)
|
|
99
|
+
elif 'medium' in level_difficulty and mission_length <= MEDIUM_MAX_LENGTH and len(seeds['medium']) < num_seeds_per_level:
|
|
100
|
+
seeds['medium'].append(seed)
|
|
101
|
+
pbar.update(1)
|
|
102
|
+
# hard: mission length > 75
|
|
103
|
+
elif 'hard' in level_difficulty and mission_length > MEDIUM_MAX_LENGTH and len(seeds['hard']) < num_seeds_per_level:
|
|
104
|
+
seeds['hard'].append(seed)
|
|
105
|
+
pbar.update(1)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.warning(f"Error getting mission length for seed {seed}: {e}")
|
|
108
|
+
|
|
109
|
+
seed += 1
|
|
110
|
+
|
|
111
|
+
return seeds
|
|
112
|
+
|
|
49
113
|
# wrapper, necessarily modified to allow for both rgb obs (policy) and symbolic obs (bot)
|
|
50
114
|
|
|
51
115
|
class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
@@ -128,7 +192,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
|
|
|
128
192
|
env.close()
|
|
129
193
|
return None, None, False, 0
|
|
130
194
|
|
|
131
|
-
episode_state[_step] = state_obs["rgb_image"]
|
|
195
|
+
episode_state[_step] = state_obs["rgb_image"]
|
|
132
196
|
episode_action[_step] = action
|
|
133
197
|
|
|
134
198
|
state_obs, reward, terminated, truncated, info = env.step(action)
|
|
@@ -151,6 +215,7 @@ def collect_demonstrations(
|
|
|
151
215
|
num_steps = 500,
|
|
152
216
|
random_action_prob = 0.05,
|
|
153
217
|
num_workers = None,
|
|
218
|
+
difficulty = "easy",
|
|
154
219
|
output_dir = "babyai-minibosslevel-trajectories"
|
|
155
220
|
):
|
|
156
221
|
"""
|
|
@@ -178,11 +243,9 @@ def collect_demonstrations(
|
|
|
178
243
|
|
|
179
244
|
total_episodes = num_seeds * num_episodes_per_seed
|
|
180
245
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
for it in range(num_episodes_per_seed):
|
|
185
|
-
seeds.append(count + 1)
|
|
246
|
+
# Collect seeds by difficulty
|
|
247
|
+
assert difficulty in ['easy', 'medium', 'hard']
|
|
248
|
+
seeds = categorize_seeds_by_difficulty(env_id, num_seeds_per_level=num_seeds, level_difficulty=[difficulty])
|
|
186
249
|
|
|
187
250
|
successful = 0
|
|
188
251
|
progressbar = tqdm(total=total_episodes)
|
|
@@ -203,14 +266,17 @@ def collect_demonstrations(
|
|
|
203
266
|
)
|
|
204
267
|
|
|
205
268
|
# Parallel execution with bounded pending futures to avoid OOM
|
|
206
|
-
max_pending = num_workers
|
|
269
|
+
max_pending = num_workers
|
|
270
|
+
|
|
271
|
+
# Flatten seeds: repeat each seed num_episodes_per_seed times
|
|
272
|
+
all_seeds = seeds[difficulty] * num_episodes_per_seed
|
|
207
273
|
|
|
208
274
|
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
|
209
|
-
seed_iter = iter(
|
|
275
|
+
seed_iter = iter(all_seeds)
|
|
210
276
|
futures = {}
|
|
211
277
|
|
|
212
278
|
# Initial batch of submissions
|
|
213
|
-
for _ in range(min(max_pending, len(
|
|
279
|
+
for _ in range(min(max_pending, len(all_seeds))):
|
|
214
280
|
seed = next(seed_iter, None)
|
|
215
281
|
if exists(seed):
|
|
216
282
|
future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
|
|
@@ -244,7 +310,13 @@ def collect_demonstrations(
|
|
|
244
310
|
buffer.flush()
|
|
245
311
|
progressbar.close()
|
|
246
312
|
|
|
313
|
+
# Save the seeds used for reproducibility
|
|
314
|
+
seeds_array = np.array(seeds[difficulty])
|
|
315
|
+
seeds_path = output_folder / "seeds.npy"
|
|
316
|
+
np.save(seeds_path, seeds_array)
|
|
317
|
+
|
|
247
318
|
logger.info(f"Saved {successful} trajectories to {output_dir}")
|
|
319
|
+
logger.info(f"Saved {len(seeds_array)} seeds to {seeds_path}")
|
|
248
320
|
|
|
249
321
|
if __name__ == "__main__":
|
|
250
322
|
fire.Fire(collect_demonstrations)
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/metacontroller/metacontroller.py
RENAMED
|
@@ -291,7 +291,7 @@ class MetaController(Module):
|
|
|
291
291
|
else:
|
|
292
292
|
# else during inference, use the previous sampled latent action
|
|
293
293
|
|
|
294
|
-
assert seq_len == 1,
|
|
294
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
295
295
|
z_prev = prev_sampled_latent_action
|
|
296
296
|
|
|
297
297
|
# switch input is previous latent action and the embedding
|
|
@@ -407,10 +407,19 @@ class Transformer(Module):
|
|
|
407
407
|
|
|
408
408
|
# meta controller
|
|
409
409
|
|
|
410
|
-
self.meta_controller = meta_controller
|
|
410
|
+
self.meta_controller = meta_controller
|
|
411
411
|
|
|
412
412
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
413
413
|
|
|
414
|
+
# ensure devices match
|
|
415
|
+
|
|
416
|
+
if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
|
|
417
|
+
|
|
418
|
+
def _ensure_consistent_device(self, network):
|
|
419
|
+
self.model_device = next(self.parameters()).device
|
|
420
|
+
if next(network.parameters()).device != self.model_device:
|
|
421
|
+
network.to(self.model_device)
|
|
422
|
+
|
|
414
423
|
def evolve(
|
|
415
424
|
self,
|
|
416
425
|
num_generations,
|
|
@@ -447,6 +456,8 @@ class Transformer(Module):
|
|
|
447
456
|
|
|
448
457
|
# meta controller is either given or already given at init
|
|
449
458
|
|
|
459
|
+
if exists(meta_controller): self._ensure_consistent_device(meta_controller)
|
|
460
|
+
|
|
450
461
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
451
462
|
|
|
452
463
|
if force_behavior_cloning:
|
|
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
241
241
|
if discovery_phase:
|
|
242
242
|
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
243
243
|
else:
|
|
244
|
-
assert seq_len == 1,
|
|
244
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
245
245
|
z_prev = prev_sampled_code
|
|
246
246
|
|
|
247
247
|
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/tests/test_metacontroller.py
RENAMED
|
@@ -114,13 +114,14 @@ def test_metacontroller(
|
|
|
114
114
|
for one_state in subset_state.unbind(dim = 1):
|
|
115
115
|
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
116
116
|
|
|
117
|
-
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
|
|
117
|
+
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, cache = cache, return_cache = True)
|
|
118
118
|
|
|
119
119
|
past_action_id = model.action_readout.sample(logits)
|
|
120
120
|
|
|
121
121
|
# extract grpo data and store
|
|
122
122
|
|
|
123
|
-
|
|
123
|
+
grpo_data = extract_grpo_data(meta_controller, cache)
|
|
124
|
+
grpo_data_list.append(grpo_data)
|
|
124
125
|
|
|
125
126
|
# accumulate across time for the episode data
|
|
126
127
|
|
|
@@ -145,6 +146,13 @@ def test_metacontroller(
|
|
|
145
146
|
# simulate a policy loss update over the entire group
|
|
146
147
|
|
|
147
148
|
group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
|
|
149
|
+
|
|
150
|
+
# parallel verification
|
|
151
|
+
|
|
152
|
+
parallel_action_dist = meta_controller.get_action_dist_for_internal_rl(group_states)
|
|
153
|
+
parallel_log_probs = meta_controller.log_prob(parallel_action_dist, group_latent_actions)
|
|
154
|
+
|
|
155
|
+
assert torch.allclose(parallel_log_probs, group_log_probs, atol = 1e-5), 'parallel log probs do not match stored log probs'
|
|
148
156
|
|
|
149
157
|
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
150
158
|
replay_buffer.store_episode(
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
# "memmap-replay-buffer>=0.0.12",
|
|
7
7
|
# "metacontroller-pytorch",
|
|
8
8
|
# "minigrid",
|
|
9
|
-
# "tqdm"
|
|
9
|
+
# "tqdm",
|
|
10
|
+
# "wandb"
|
|
10
11
|
# ]
|
|
11
12
|
# ///
|
|
12
13
|
|
|
@@ -57,22 +58,23 @@ def default(v, d):
|
|
|
57
58
|
# main
|
|
58
59
|
|
|
59
60
|
def main(
|
|
60
|
-
env_name
|
|
61
|
-
num_episodes
|
|
62
|
-
max_timesteps
|
|
63
|
-
buffer_size
|
|
64
|
-
render_every_eps
|
|
65
|
-
video_folder
|
|
61
|
+
env_name = 'BabyAI-BossLevel-v0',
|
|
62
|
+
num_episodes = int(10e6),
|
|
63
|
+
max_timesteps = 500,
|
|
64
|
+
buffer_size = 5_000,
|
|
65
|
+
render_every_eps = 1_000,
|
|
66
|
+
video_folder = './recordings',
|
|
66
67
|
seed: int | None = None,
|
|
67
68
|
transformer_weights_path: str | None = None,
|
|
68
69
|
meta_controller_weights_path: str | None = None,
|
|
69
|
-
output_meta_controller_path
|
|
70
|
-
use_resnet
|
|
71
|
-
lr
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
70
|
+
output_meta_controller_path = 'metacontroller_rl_trained.pt',
|
|
71
|
+
use_resnet = False,
|
|
72
|
+
lr = 1e-4,
|
|
73
|
+
batch_size = 16,
|
|
74
|
+
num_groups = 16,
|
|
75
|
+
max_grad_norm = 1.0,
|
|
76
|
+
use_wandb = False,
|
|
77
|
+
wandb_project = 'metacontroller-babyai-rl'
|
|
76
78
|
):
|
|
77
79
|
# accelerator
|
|
78
80
|
|
|
@@ -263,7 +265,7 @@ def main(
|
|
|
263
265
|
# learn
|
|
264
266
|
|
|
265
267
|
if len(replay_buffer) >= buffer_size:
|
|
266
|
-
dl = replay_buffer.dataloader(batch_size =
|
|
268
|
+
dl = replay_buffer.dataloader(batch_size = batch_size)
|
|
267
269
|
dl = accelerator.prepare(dl)
|
|
268
270
|
|
|
269
271
|
meta_controller.train()
|
|
@@ -296,10 +298,11 @@ def main(
|
|
|
296
298
|
|
|
297
299
|
accelerator.log({
|
|
298
300
|
'loss': loss.item(),
|
|
299
|
-
'grad_norm': grad_norm.item()
|
|
301
|
+
'grad_norm': grad_norm.item(),
|
|
302
|
+
'reward': cumulative_rewards.mean().item()
|
|
300
303
|
})
|
|
301
304
|
|
|
302
|
-
accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
|
|
305
|
+
accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}, reward: {cumulative_rewards.mean().item():.4f}')
|
|
303
306
|
|
|
304
307
|
env.close()
|
|
305
308
|
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "fire",
|
|
4
|
+
# "gymnasium",
|
|
5
|
+
# "gymnasium[other]",
|
|
6
|
+
# "metacontroller-pytorch",
|
|
7
|
+
# "minigrid",
|
|
8
|
+
# "tqdm",
|
|
9
|
+
# "x-evolution",
|
|
10
|
+
# "einops"
|
|
11
|
+
# ]
|
|
12
|
+
# ///
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
import fire
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from shutil import rmtree
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn, Tensor, tensor
|
|
22
|
+
from torch.nn import Module
|
|
23
|
+
from einops import rearrange
|
|
24
|
+
|
|
25
|
+
from babyai_env import create_env
|
|
26
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
27
|
+
|
|
28
|
+
# functions
|
|
29
|
+
|
|
30
|
+
def exists(v):
|
|
31
|
+
return v is not None
|
|
32
|
+
|
|
33
|
+
def default(v, d):
|
|
34
|
+
return v if exists(v) else d
|
|
35
|
+
|
|
36
|
+
# default fitness function
|
|
37
|
+
|
|
38
|
+
def default_fitness_fn(
|
|
39
|
+
rewards: list[float],
|
|
40
|
+
states: list[any],
|
|
41
|
+
actions: list[any],
|
|
42
|
+
next_states: list[any],
|
|
43
|
+
infos: list[any]
|
|
44
|
+
) -> float:
|
|
45
|
+
"""
|
|
46
|
+
researchers can modify this function to engineer their own rewards and fitness scores
|
|
47
|
+
processing the entire episode at once for every noise vector of the population separately
|
|
48
|
+
"""
|
|
49
|
+
return sum(rewards)
|
|
50
|
+
|
|
51
|
+
# babyai environment for ES
|
|
52
|
+
|
|
53
|
+
class BabyAIEnvironment(Module):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
env_id = 'BabyAI-BossLevel-v0',
|
|
57
|
+
video_folder = './recordings_babyai_es',
|
|
58
|
+
render_every_eps = 100,
|
|
59
|
+
max_steps = 500,
|
|
60
|
+
use_resnet = False,
|
|
61
|
+
fitness_fn = default_fitness_fn
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
|
|
65
|
+
self.env_id = env_id
|
|
66
|
+
self.video_folder = video_folder
|
|
67
|
+
self.render_every_eps = render_every_eps
|
|
68
|
+
self.max_steps = max_steps
|
|
69
|
+
self.use_resnet = use_resnet
|
|
70
|
+
self.fitness_fn = fitness_fn
|
|
71
|
+
|
|
72
|
+
# initial env creation for observation space etc. if needed
|
|
73
|
+
# but create_env is called inside pre_main_callback or reset
|
|
74
|
+
self.env = None
|
|
75
|
+
|
|
76
|
+
def pre_main_callback(self):
|
|
77
|
+
# clean up and initialize environment
|
|
78
|
+
rmtree(self.video_folder, ignore_errors = True)
|
|
79
|
+
|
|
80
|
+
self.env = create_env(
|
|
81
|
+
self.env_id,
|
|
82
|
+
render_mode = 'rgb_array',
|
|
83
|
+
video_folder = self.video_folder,
|
|
84
|
+
render_every_eps = self.render_every_eps
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def forward(self, model):
|
|
88
|
+
device = next(model.parameters()).device
|
|
89
|
+
|
|
90
|
+
seed = torch.randint(0, int(1e6), ()).item()
|
|
91
|
+
state, _ = self.env.reset(seed = seed)
|
|
92
|
+
|
|
93
|
+
step = 0
|
|
94
|
+
cache = None
|
|
95
|
+
past_action_id = None
|
|
96
|
+
|
|
97
|
+
unwrapped_model = getattr(model, 'model', model)
|
|
98
|
+
|
|
99
|
+
episode_rewards = []
|
|
100
|
+
episode_states = []
|
|
101
|
+
episode_actions = []
|
|
102
|
+
episode_next_states = []
|
|
103
|
+
episode_infos = []
|
|
104
|
+
|
|
105
|
+
while step < self.max_steps:
|
|
106
|
+
image = state['image']
|
|
107
|
+
image_tensor = torch.from_numpy(image).float().to(device)
|
|
108
|
+
|
|
109
|
+
if self.use_resnet:
|
|
110
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
|
|
111
|
+
image_tensor = unwrapped_model.visual_encode(image_tensor)
|
|
112
|
+
else:
|
|
113
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
|
|
114
|
+
|
|
115
|
+
if torch.is_tensor(past_action_id):
|
|
116
|
+
past_action_id = past_action_id.long()
|
|
117
|
+
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
logits, cache = model(
|
|
120
|
+
image_tensor,
|
|
121
|
+
past_action_id,
|
|
122
|
+
return_cache = True,
|
|
123
|
+
return_raw_action_dist = True,
|
|
124
|
+
cache = cache
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
action = unwrapped_model.action_readout.sample(logits)
|
|
128
|
+
past_action_id = action
|
|
129
|
+
action_id = action.squeeze()
|
|
130
|
+
|
|
131
|
+
next_state, reward, terminated, truncated, info = self.env.step(action_id.cpu().numpy().item())
|
|
132
|
+
|
|
133
|
+
episode_rewards.append(reward)
|
|
134
|
+
episode_states.append(state)
|
|
135
|
+
episode_actions.append(action_id)
|
|
136
|
+
episode_next_states.append(next_state)
|
|
137
|
+
episode_infos.append(info)
|
|
138
|
+
|
|
139
|
+
done = terminated or truncated
|
|
140
|
+
if done:
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
state = next_state
|
|
144
|
+
step += 1
|
|
145
|
+
|
|
146
|
+
return self.fitness_fn(
|
|
147
|
+
episode_rewards,
|
|
148
|
+
episode_states,
|
|
149
|
+
episode_actions,
|
|
150
|
+
episode_next_states,
|
|
151
|
+
episode_infos
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def main(
|
|
155
|
+
env_id = 'BabyAI-BossLevel-v0',
|
|
156
|
+
num_generations = 100,
|
|
157
|
+
max_steps = 500,
|
|
158
|
+
render_every_eps = 100,
|
|
159
|
+
video_folder = './recordings_babyai_es',
|
|
160
|
+
transformer_weights_path: str | None = None,
|
|
161
|
+
meta_controller_weights_path: str | None = None,
|
|
162
|
+
output_meta_controller_path = 'metacontroller_es_trained.pt',
|
|
163
|
+
use_resnet = False,
|
|
164
|
+
noise_population_size = 50,
|
|
165
|
+
noise_scale = 1e-2,
|
|
166
|
+
learning_rate = 1e-3,
|
|
167
|
+
fitness_fn = default_fitness_fn
|
|
168
|
+
):
|
|
169
|
+
# load model
|
|
170
|
+
|
|
171
|
+
assert exists(transformer_weights_path), "Transformer weights must be provided"
|
|
172
|
+
|
|
173
|
+
# lazy import to avoid unnecessary dependencies if not used
|
|
174
|
+
from metacontroller.transformer_with_resnet import TransformerWithResnet as TransformerResnet
|
|
175
|
+
transformer_klass = TransformerResnet if use_resnet else Transformer
|
|
176
|
+
|
|
177
|
+
model = transformer_klass.init_and_load(transformer_weights_path, strict = False)
|
|
178
|
+
model.eval()
|
|
179
|
+
|
|
180
|
+
if exists(meta_controller_weights_path):
|
|
181
|
+
meta_controller = MetaController.init_and_load(meta_controller_weights_path, strict = False)
|
|
182
|
+
model.meta_controller = meta_controller
|
|
183
|
+
|
|
184
|
+
assert exists(model.meta_controller), "MetaController must be present for evolution"
|
|
185
|
+
|
|
186
|
+
# setup environment
|
|
187
|
+
|
|
188
|
+
babyai_env = BabyAIEnvironment(
|
|
189
|
+
env_id = env_id,
|
|
190
|
+
video_folder = video_folder,
|
|
191
|
+
render_every_eps = render_every_eps,
|
|
192
|
+
max_steps = max_steps,
|
|
193
|
+
use_resnet = use_resnet,
|
|
194
|
+
fitness_fn = fitness_fn
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# evolve
|
|
198
|
+
|
|
199
|
+
model.evolve(
|
|
200
|
+
num_generations = num_generations,
|
|
201
|
+
environment = babyai_env,
|
|
202
|
+
noise_population_size = noise_population_size,
|
|
203
|
+
noise_scale = noise_scale,
|
|
204
|
+
learning_rate = learning_rate
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# save
|
|
208
|
+
|
|
209
|
+
model.meta_controller.save(output_meta_controller_path)
|
|
210
|
+
print(f'MetaController weights saved to {output_meta_controller_path}')
|
|
211
|
+
|
|
212
|
+
if __name__ == '__main__':
|
|
213
|
+
fire.Fire(main)
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/train_behavior_clone_babyai.py
RENAMED
|
@@ -18,7 +18,7 @@ from tqdm import tqdm
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
|
-
from torch.optim import
|
|
21
|
+
from torch.optim import AdamW
|
|
22
22
|
from torch.utils.data import DataLoader
|
|
23
23
|
|
|
24
24
|
from accelerate import Accelerator
|
|
@@ -31,14 +31,20 @@ from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
|
31
31
|
import minigrid
|
|
32
32
|
import gymnasium as gym
|
|
33
33
|
|
|
34
|
+
# TODO: loss is still ~300 and it could be the resnet output?
|
|
35
|
+
# TODO: changelog (paper hparams, checkpointing, difficulty levels in trajectory collection)
|
|
36
|
+
|
|
34
37
|
def train(
|
|
35
38
|
input_dir = "babyai-minibosslevel-trajectories",
|
|
36
39
|
env_id = "BabyAI-MiniBossLevel-v0",
|
|
37
40
|
cloning_epochs = 10,
|
|
38
41
|
discovery_epochs = 10,
|
|
39
|
-
batch_size =
|
|
42
|
+
batch_size = 128,
|
|
43
|
+
gradient_accumulation_steps = None,
|
|
40
44
|
lr = 1e-4,
|
|
41
45
|
discovery_lr = 1e-4,
|
|
46
|
+
weight_decay = 0.03,
|
|
47
|
+
discovery_weight_decay = 0.03,
|
|
42
48
|
dim = 512,
|
|
43
49
|
depth = 2,
|
|
44
50
|
heads = 8,
|
|
@@ -47,6 +53,7 @@ def train(
|
|
|
47
53
|
wandb_project = "metacontroller-babyai-bc",
|
|
48
54
|
checkpoint_path = "transformer_bc.pt",
|
|
49
55
|
meta_controller_checkpoint_path = "meta_controller_discovery.pt",
|
|
56
|
+
save_steps = 50,
|
|
50
57
|
state_loss_weight = 1.,
|
|
51
58
|
action_loss_weight = 1.,
|
|
52
59
|
discovery_action_recon_loss_weight = 1.,
|
|
@@ -55,6 +62,22 @@ def train(
|
|
|
55
62
|
max_grad_norm = 1.,
|
|
56
63
|
use_resnet = False
|
|
57
64
|
):
|
|
65
|
+
|
|
66
|
+
def store_checkpoint(step:int):
|
|
67
|
+
if accelerator.is_main_process:
|
|
68
|
+
|
|
69
|
+
# Add step to checkpoint filenames
|
|
70
|
+
checkpoint_path_with_step = checkpoint_path.replace('.pt', f'_step_{step}.pt')
|
|
71
|
+
meta_controller_checkpoint_path_with_step = meta_controller_checkpoint_path.replace('.pt', f'_step_{step}.pt')
|
|
72
|
+
|
|
73
|
+
unwrapped_model = accelerator.unwrap_model(model)
|
|
74
|
+
unwrapped_model.save(checkpoint_path_with_step)
|
|
75
|
+
|
|
76
|
+
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
77
|
+
unwrapped_meta_controller.save(meta_controller_checkpoint_path_with_step)
|
|
78
|
+
|
|
79
|
+
accelerator.print(f"Model saved to {checkpoint_path_with_step}, MetaController to {meta_controller_checkpoint_path_with_step}")
|
|
80
|
+
|
|
58
81
|
# accelerator
|
|
59
82
|
|
|
60
83
|
accelerator = Accelerator(log_with = "wandb" if use_wandb else None)
|
|
@@ -99,6 +122,10 @@ def train(
|
|
|
99
122
|
|
|
100
123
|
accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
|
|
101
124
|
|
|
125
|
+
# meta controller
|
|
126
|
+
|
|
127
|
+
meta_controller = MetaController(dim)
|
|
128
|
+
|
|
102
129
|
# transformer
|
|
103
130
|
|
|
104
131
|
transformer_class = TransformerWithResnet if use_resnet else Transformer
|
|
@@ -108,18 +135,15 @@ def train(
|
|
|
108
135
|
state_embed_readout = dict(num_continuous = state_dim),
|
|
109
136
|
action_embed_readout = dict(num_discrete = num_actions),
|
|
110
137
|
lower_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
|
|
111
|
-
upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
|
|
138
|
+
upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
|
|
139
|
+
meta_controller = meta_controller
|
|
112
140
|
)
|
|
113
141
|
|
|
114
|
-
# meta controller
|
|
115
|
-
|
|
116
|
-
meta_controller = MetaController(dim)
|
|
117
|
-
|
|
118
142
|
# optimizer
|
|
119
143
|
|
|
120
|
-
optim_model =
|
|
144
|
+
optim_model = AdamW(model.parameters(), lr = lr, weight_decay = weight_decay)
|
|
121
145
|
|
|
122
|
-
optim_meta_controller =
|
|
146
|
+
optim_meta_controller = AdamW(meta_controller.discovery_parameters(), lr = discovery_lr, weight_decay = discovery_weight_decay)
|
|
123
147
|
|
|
124
148
|
# prepare
|
|
125
149
|
|
|
@@ -127,6 +151,7 @@ def train(
|
|
|
127
151
|
|
|
128
152
|
# training
|
|
129
153
|
|
|
154
|
+
gradient_step = 0
|
|
130
155
|
for epoch in range(cloning_epochs + discovery_epochs):
|
|
131
156
|
|
|
132
157
|
model.train()
|
|
@@ -154,13 +179,14 @@ def train(
|
|
|
154
179
|
else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
|
|
155
180
|
states = rearrange(states, 'b t ... -> b t (...)')
|
|
156
181
|
|
|
182
|
+
|
|
157
183
|
with accelerator.accumulate(model):
|
|
158
184
|
losses = model(
|
|
159
185
|
states,
|
|
160
186
|
actions,
|
|
161
187
|
episode_lens = episode_lens,
|
|
162
188
|
discovery_phase = is_discovering,
|
|
163
|
-
|
|
189
|
+
force_behavior_cloning = not is_discovering
|
|
164
190
|
)
|
|
165
191
|
|
|
166
192
|
if is_discovering:
|
|
@@ -190,14 +216,19 @@ def train(
|
|
|
190
216
|
action_loss = action_loss.item(),
|
|
191
217
|
)
|
|
192
218
|
|
|
219
|
+
# gradient accumulation
|
|
220
|
+
|
|
221
|
+
if gradient_accumulation_steps is not None: loss /= gradient_accumulation_steps
|
|
222
|
+
|
|
193
223
|
# backprop
|
|
194
224
|
|
|
195
225
|
accelerator.backward(loss)
|
|
196
226
|
|
|
197
227
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
|
|
198
228
|
|
|
199
|
-
|
|
200
|
-
|
|
229
|
+
if gradient_accumulation_steps is None or gradient_step % gradient_accumulation_steps == 0:
|
|
230
|
+
optim.step()
|
|
231
|
+
optim.zero_grad()
|
|
201
232
|
|
|
202
233
|
# log
|
|
203
234
|
|
|
@@ -211,23 +242,21 @@ def train(
|
|
|
211
242
|
})
|
|
212
243
|
|
|
213
244
|
progress_bar.set_postfix(**log)
|
|
245
|
+
gradient_step += 1
|
|
246
|
+
|
|
247
|
+
# checkpoint
|
|
248
|
+
|
|
249
|
+
if gradient_step % save_steps == 0:
|
|
250
|
+
accelerator.wait_for_everyone()
|
|
251
|
+
store_checkpoint(gradient_step)
|
|
214
252
|
|
|
215
253
|
avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
|
|
216
254
|
avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
|
|
217
255
|
accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
|
|
218
256
|
|
|
219
257
|
# save weights
|
|
220
|
-
|
|
221
258
|
accelerator.wait_for_everyone()
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
unwrapped_model = accelerator.unwrap_model(model)
|
|
225
|
-
unwrapped_model.save(checkpoint_path)
|
|
226
|
-
|
|
227
|
-
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
228
|
-
unwrapped_meta_controller.save(meta_controller_checkpoint_path)
|
|
229
|
-
|
|
230
|
-
accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
|
|
259
|
+
store_checkpoint(gradient_step)
|
|
231
260
|
|
|
232
261
|
accelerator.end_training()
|
|
233
262
|
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
#!/bin/bash
|
|
2
|
-
set -e
|
|
3
|
-
|
|
4
|
-
# 1. Gather trajectories
|
|
5
|
-
echo "Gathering trajectories..."
|
|
6
|
-
uv run gather_babyai_trajs.py \
|
|
7
|
-
--num_seeds 100 \
|
|
8
|
-
--num_episodes_per_seed 10 \
|
|
9
|
-
--num_steps 500 \
|
|
10
|
-
--output_dir end_to_end_trajectories \
|
|
11
|
-
--env_id BabyAI-MiniBossLevel-v0
|
|
12
|
-
|
|
13
|
-
# 2. Behavioral cloning
|
|
14
|
-
echo "Training behavioral cloning model..."
|
|
15
|
-
ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
|
|
16
|
-
--cloning_epochs 10 \
|
|
17
|
-
--discovery_epochs 10 \
|
|
18
|
-
--batch_size 256 \
|
|
19
|
-
--input_dir end_to_end_trajectories \
|
|
20
|
-
--env_id BabyAI-MiniBossLevel-v0 \
|
|
21
|
-
--checkpoint_path end_to_end_model.pt \
|
|
22
|
-
--use_resnet
|
|
23
|
-
|
|
24
|
-
# 3. Inference rollouts
|
|
25
|
-
echo "Running inference rollouts..."
|
|
26
|
-
uv run train_babyai.py \
|
|
27
|
-
--transformer_weights_path end_to_end_model.pt \
|
|
28
|
-
--meta_controller_weights_path meta_controller_discovery.pt \
|
|
29
|
-
--env_name BabyAI-MiniBossLevel-v0 \
|
|
30
|
-
--num_episodes 1000 \
|
|
31
|
-
--buffer_size 1000 \
|
|
32
|
-
--max_timesteps 100 \
|
|
33
|
-
--num_groups 16 \
|
|
34
|
-
--lr 1e-4 \
|
|
35
|
-
--use_resnet
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.44}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|