metacontroller-pytorch 0.0.43__tar.gz → 0.0.46__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.gitignore +6 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/gather_babyai_trajs.py +81 -9
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller.py +39 -4
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/train_babyai.py +71 -43
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/train_behavior_clone_babyai.py +57 -24
- metacontroller_pytorch-0.0.43/test_babyai_e2e.sh +0 -35
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/README.md +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/babyai_env.py +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/tests/test_metacontroller.py +0 -0
- /metacontroller_pytorch-0.0.43/train_baby_evo_strat.py → /metacontroller_pytorch-0.0.46/train_babyai_evo_strat.py +0 -0
|
@@ -37,6 +37,9 @@ from minigrid.core.constants import OBJECT_TO_IDX
|
|
|
37
37
|
|
|
38
38
|
from memmap_replay_buffer import ReplayBuffer
|
|
39
39
|
|
|
40
|
+
# Difficulty thresholds based on mission length
|
|
41
|
+
EASY_MAX_LENGTH = 30 # easy: 0 to 30
|
|
42
|
+
MEDIUM_MAX_LENGTH = 75 # medium: 30 to 75, hard: > 75
|
|
40
43
|
|
|
41
44
|
# helpers
|
|
42
45
|
|
|
@@ -46,6 +49,67 @@ def exists(val):
|
|
|
46
49
|
def sample(prob):
|
|
47
50
|
return random.random() < prob
|
|
48
51
|
|
|
52
|
+
def get_mission_length(env_id, seed):
|
|
53
|
+
"""
|
|
54
|
+
Get the mission length for a given seed.
|
|
55
|
+
Returns the length of the mission string.
|
|
56
|
+
"""
|
|
57
|
+
env = gym.make(env_id, render_mode="rgb_array")
|
|
58
|
+
env.reset(seed=seed)
|
|
59
|
+
length = len(env.unwrapped.mission)
|
|
60
|
+
env.close()
|
|
61
|
+
return length
|
|
62
|
+
|
|
63
|
+
def categorize_seeds_by_difficulty(env_id, num_seeds_per_level, level_difficulty=None):
|
|
64
|
+
"""
|
|
65
|
+
Scan seeds and categorize them by difficulty based on mission length.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
env_id: Environment ID
|
|
69
|
+
num_seeds_per_level: Number of seeds needed per difficulty level
|
|
70
|
+
level_difficulty: List of levels to collect seeds for.
|
|
71
|
+
Supported: 'easy', 'medium', 'hard'
|
|
72
|
+
If None, collects for ['easy', 'hard'].
|
|
73
|
+
max_seed_to_scan: Maximum seed value to scan
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
dict with keys for each requested level, each containing a list of seeds
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
seeds = {level: [] for level in level_difficulty}
|
|
80
|
+
|
|
81
|
+
total_needed = sum(num_seeds_per_level for _ in level_difficulty)
|
|
82
|
+
print(f"Scanning seeds to categorize by difficulty (need {num_seeds_per_level} per level for {level_difficulty})...")
|
|
83
|
+
|
|
84
|
+
with tqdm(total=total_needed, desc="Categorizing seeds") as pbar:
|
|
85
|
+
seed = 1
|
|
86
|
+
all_done = False
|
|
87
|
+
while not all_done:
|
|
88
|
+
# Check if we have enough seeds for all requested levels
|
|
89
|
+
all_done = all(len(seeds[level]) >= num_seeds_per_level for level in level_difficulty)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
mission_length = get_mission_length(env_id, seed)
|
|
93
|
+
|
|
94
|
+
# easy: mission length <= 30
|
|
95
|
+
if 'easy' in level_difficulty and mission_length <= EASY_MAX_LENGTH and len(seeds['easy']) < num_seeds_per_level:
|
|
96
|
+
seeds['easy'].append(seed)
|
|
97
|
+
pbar.update(1)
|
|
98
|
+
# medium: mission length <= 75 (combines easy and medium)
|
|
99
|
+
elif 'medium' in level_difficulty and mission_length <= MEDIUM_MAX_LENGTH and len(seeds['medium']) < num_seeds_per_level:
|
|
100
|
+
seeds['medium'].append(seed)
|
|
101
|
+
pbar.update(1)
|
|
102
|
+
# hard: mission length > 75
|
|
103
|
+
elif 'hard' in level_difficulty and mission_length > MEDIUM_MAX_LENGTH and len(seeds['hard']) < num_seeds_per_level:
|
|
104
|
+
seeds['hard'].append(seed)
|
|
105
|
+
pbar.update(1)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.warning(f"Error getting mission length for seed {seed}: {e}")
|
|
108
|
+
|
|
109
|
+
seed += 1
|
|
110
|
+
|
|
111
|
+
return seeds
|
|
112
|
+
|
|
49
113
|
# wrapper, necessarily modified to allow for both rgb obs (policy) and symbolic obs (bot)
|
|
50
114
|
|
|
51
115
|
class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
@@ -128,7 +192,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
|
|
|
128
192
|
env.close()
|
|
129
193
|
return None, None, False, 0
|
|
130
194
|
|
|
131
|
-
episode_state[_step] = state_obs["rgb_image"]
|
|
195
|
+
episode_state[_step] = state_obs["rgb_image"]
|
|
132
196
|
episode_action[_step] = action
|
|
133
197
|
|
|
134
198
|
state_obs, reward, terminated, truncated, info = env.step(action)
|
|
@@ -151,6 +215,7 @@ def collect_demonstrations(
|
|
|
151
215
|
num_steps = 500,
|
|
152
216
|
random_action_prob = 0.05,
|
|
153
217
|
num_workers = None,
|
|
218
|
+
difficulty = "easy",
|
|
154
219
|
output_dir = "babyai-minibosslevel-trajectories"
|
|
155
220
|
):
|
|
156
221
|
"""
|
|
@@ -178,11 +243,9 @@ def collect_demonstrations(
|
|
|
178
243
|
|
|
179
244
|
total_episodes = num_seeds * num_episodes_per_seed
|
|
180
245
|
|
|
181
|
-
#
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
for it in range(num_episodes_per_seed):
|
|
185
|
-
seeds.append(count + 1)
|
|
246
|
+
# Collect seeds by difficulty
|
|
247
|
+
assert difficulty in ['easy', 'medium', 'hard']
|
|
248
|
+
seeds = categorize_seeds_by_difficulty(env_id, num_seeds_per_level=num_seeds, level_difficulty=[difficulty])
|
|
186
249
|
|
|
187
250
|
successful = 0
|
|
188
251
|
progressbar = tqdm(total=total_episodes)
|
|
@@ -203,14 +266,17 @@ def collect_demonstrations(
|
|
|
203
266
|
)
|
|
204
267
|
|
|
205
268
|
# Parallel execution with bounded pending futures to avoid OOM
|
|
206
|
-
max_pending = num_workers
|
|
269
|
+
max_pending = num_workers
|
|
270
|
+
|
|
271
|
+
# Flatten seeds: repeat each seed num_episodes_per_seed times
|
|
272
|
+
all_seeds = seeds[difficulty] * num_episodes_per_seed
|
|
207
273
|
|
|
208
274
|
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
|
209
|
-
seed_iter = iter(
|
|
275
|
+
seed_iter = iter(all_seeds)
|
|
210
276
|
futures = {}
|
|
211
277
|
|
|
212
278
|
# Initial batch of submissions
|
|
213
|
-
for _ in range(min(max_pending, len(
|
|
279
|
+
for _ in range(min(max_pending, len(all_seeds))):
|
|
214
280
|
seed = next(seed_iter, None)
|
|
215
281
|
if exists(seed):
|
|
216
282
|
future = executor.submit(collect_single_episode, env_id, seed, num_steps, random_action_prob, state_shape)
|
|
@@ -244,7 +310,13 @@ def collect_demonstrations(
|
|
|
244
310
|
buffer.flush()
|
|
245
311
|
progressbar.close()
|
|
246
312
|
|
|
313
|
+
# Save the seeds used for reproducibility
|
|
314
|
+
seeds_array = np.array(seeds[difficulty])
|
|
315
|
+
seeds_path = output_folder / "seeds.npy"
|
|
316
|
+
np.save(seeds_path, seeds_array)
|
|
317
|
+
|
|
247
318
|
logger.info(f"Saved {successful} trajectories to {output_dir}")
|
|
319
|
+
logger.info(f"Saved {len(seeds_array)} seeds to {seeds_path}")
|
|
248
320
|
|
|
249
321
|
if __name__ == "__main__":
|
|
250
322
|
fire.Fire(collect_demonstrations)
|
{metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/metacontroller/metacontroller.py
RENAMED
|
@@ -54,6 +54,19 @@ def default(*args):
|
|
|
54
54
|
def straight_through(src, tgt):
|
|
55
55
|
return tgt + src - src.detach()
|
|
56
56
|
|
|
57
|
+
# losses
|
|
58
|
+
|
|
59
|
+
BehavioralCloningLosses = namedtuple('BehavioralCloningLosses', (
|
|
60
|
+
'state',
|
|
61
|
+
'action'
|
|
62
|
+
))
|
|
63
|
+
|
|
64
|
+
DiscoveryLosses = namedtuple('DiscoveryLosses', (
|
|
65
|
+
'action_recon',
|
|
66
|
+
'kl',
|
|
67
|
+
'switch'
|
|
68
|
+
))
|
|
69
|
+
|
|
57
70
|
# meta controller
|
|
58
71
|
|
|
59
72
|
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
@@ -407,10 +420,19 @@ class Transformer(Module):
|
|
|
407
420
|
|
|
408
421
|
# meta controller
|
|
409
422
|
|
|
410
|
-
self.meta_controller = meta_controller
|
|
423
|
+
self.meta_controller = meta_controller
|
|
411
424
|
|
|
412
425
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
413
426
|
|
|
427
|
+
# ensure devices match
|
|
428
|
+
|
|
429
|
+
if exists(self.meta_controller): self._ensure_consistent_device(self.meta_controller)
|
|
430
|
+
|
|
431
|
+
def _ensure_consistent_device(self, network):
|
|
432
|
+
self.model_device = next(self.parameters()).device
|
|
433
|
+
if next(network.parameters()).device != self.model_device:
|
|
434
|
+
network.to(self.model_device)
|
|
435
|
+
|
|
414
436
|
def evolve(
|
|
415
437
|
self,
|
|
416
438
|
num_generations,
|
|
@@ -441,12 +463,15 @@ class Transformer(Module):
|
|
|
441
463
|
return_raw_action_dist = False,
|
|
442
464
|
return_latents = False,
|
|
443
465
|
return_cache = False,
|
|
444
|
-
episode_lens: Tensor | None = None
|
|
466
|
+
episode_lens: Tensor | None = None,
|
|
467
|
+
return_meta_controller_output = False
|
|
445
468
|
):
|
|
446
469
|
device = state.device
|
|
447
470
|
|
|
448
471
|
# meta controller is either given or already given at init
|
|
449
472
|
|
|
473
|
+
if exists(meta_controller): self._ensure_consistent_device(meta_controller)
|
|
474
|
+
|
|
450
475
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
451
476
|
|
|
452
477
|
if force_behavior_cloning:
|
|
@@ -533,13 +558,23 @@ class Transformer(Module):
|
|
|
533
558
|
|
|
534
559
|
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
|
|
535
560
|
|
|
536
|
-
|
|
561
|
+
losses = BehavioralCloningLosses(state_clone_loss, action_clone_loss)
|
|
562
|
+
|
|
563
|
+
if not return_meta_controller_output:
|
|
564
|
+
return losses
|
|
565
|
+
|
|
566
|
+
return losses, next_meta_hiddens
|
|
537
567
|
|
|
538
568
|
elif discovery_phase:
|
|
539
569
|
|
|
540
570
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
541
571
|
|
|
542
|
-
|
|
572
|
+
losses = DiscoveryLosses(action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss)
|
|
573
|
+
|
|
574
|
+
if not return_meta_controller_output:
|
|
575
|
+
return losses
|
|
576
|
+
|
|
577
|
+
return losses, next_meta_hiddens
|
|
543
578
|
|
|
544
579
|
# returning
|
|
545
580
|
|
|
@@ -6,7 +6,8 @@
|
|
|
6
6
|
# "memmap-replay-buffer>=0.0.12",
|
|
7
7
|
# "metacontroller-pytorch",
|
|
8
8
|
# "minigrid",
|
|
9
|
-
# "tqdm"
|
|
9
|
+
# "tqdm",
|
|
10
|
+
# "wandb"
|
|
10
11
|
# ]
|
|
11
12
|
# ///
|
|
12
13
|
|
|
@@ -17,35 +18,46 @@ from shutil import rmtree
|
|
|
17
18
|
from tqdm import tqdm
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
|
-
from torch import cat, tensor, stack
|
|
21
|
+
from torch import cat, tensor, stack, Tensor
|
|
21
22
|
from torch.optim import Adam
|
|
22
23
|
|
|
23
24
|
from einops import rearrange
|
|
24
25
|
|
|
26
|
+
from torch_einops_utils import pad_sequence
|
|
27
|
+
|
|
25
28
|
from accelerate import Accelerator
|
|
26
29
|
|
|
27
30
|
from babyai_env import create_env
|
|
31
|
+
|
|
28
32
|
from memmap_replay_buffer import ReplayBuffer
|
|
33
|
+
|
|
29
34
|
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
|
|
30
35
|
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
31
36
|
|
|
32
37
|
# research entry point
|
|
33
38
|
|
|
34
39
|
def reward_shaping_fn(
|
|
35
|
-
cumulative_rewards:
|
|
36
|
-
all_rewards:
|
|
37
|
-
episode_lens:
|
|
38
|
-
|
|
40
|
+
cumulative_rewards: Tensor, # float(num_episodes,)
|
|
41
|
+
all_rewards: Tensor, # float(num_episodes, max_timesteps)
|
|
42
|
+
episode_lens: Tensor, # int(num_episodes,)
|
|
43
|
+
reject_threshold_cumulative_reward_variance: float = 0.
|
|
44
|
+
) -> Tensor | None:
|
|
39
45
|
"""
|
|
40
46
|
researchers can modify this function to engineer rewards
|
|
41
47
|
or return None to reject the entire batch
|
|
42
|
-
|
|
43
|
-
cumulative_rewards: (num_episodes,)
|
|
44
|
-
all_rewards: (num_episodes, max_timesteps)
|
|
45
|
-
episode_lens: (num_episodes,)
|
|
46
48
|
"""
|
|
49
|
+
|
|
50
|
+
if cumulative_rewards.var() < reject_threshold_cumulative_reward_variance:
|
|
51
|
+
return None
|
|
52
|
+
|
|
47
53
|
return cumulative_rewards
|
|
48
54
|
|
|
55
|
+
def should_reject_group_based_on_switch_betas(
|
|
56
|
+
switch_betas: Tensor,
|
|
57
|
+
episode_lens: Tensor
|
|
58
|
+
):
|
|
59
|
+
return switch_betas.sum().item() == 0.
|
|
60
|
+
|
|
49
61
|
# helpers
|
|
50
62
|
|
|
51
63
|
def exists(v):
|
|
@@ -68,12 +80,14 @@ def main(
|
|
|
68
80
|
meta_controller_weights_path: str | None = None,
|
|
69
81
|
output_meta_controller_path = 'metacontroller_rl_trained.pt',
|
|
70
82
|
use_resnet = False,
|
|
83
|
+
num_epochs = 3,
|
|
71
84
|
lr = 1e-4,
|
|
72
85
|
batch_size = 16,
|
|
73
86
|
num_groups = 16,
|
|
74
87
|
max_grad_norm = 1.0,
|
|
75
88
|
use_wandb = False,
|
|
76
|
-
wandb_project = 'metacontroller-babyai-rl'
|
|
89
|
+
wandb_project = 'metacontroller-babyai-rl',
|
|
90
|
+
reject_threshold_cumulative_reward_variance = 0.
|
|
77
91
|
):
|
|
78
92
|
# accelerator
|
|
79
93
|
|
|
@@ -233,25 +247,37 @@ def main(
|
|
|
233
247
|
cumulative_rewards = stack(all_cumulative_rewards)
|
|
234
248
|
episode_lens = tensor(all_episode_lens)
|
|
235
249
|
|
|
236
|
-
# pad step rewards
|
|
237
|
-
|
|
238
250
|
max_len = max(all_episode_lens)
|
|
239
|
-
padded_step_rewards = torch.zeros(num_episodes, max_len)
|
|
240
251
|
|
|
241
|
-
|
|
242
|
-
|
|
252
|
+
# pad step rewards
|
|
253
|
+
|
|
254
|
+
padded_step_rewards = pad_sequence(all_step_rewards, dim = 0)
|
|
243
255
|
|
|
244
256
|
# reward shaping hook
|
|
245
257
|
|
|
246
|
-
shaped_rewards = reward_shaping_fn(
|
|
258
|
+
shaped_rewards = reward_shaping_fn(
|
|
259
|
+
cumulative_rewards,
|
|
260
|
+
padded_step_rewards,
|
|
261
|
+
episode_lens,
|
|
262
|
+
reject_threshold_cumulative_reward_variance = reject_threshold_cumulative_reward_variance
|
|
263
|
+
)
|
|
247
264
|
|
|
248
265
|
if not exists(shaped_rewards):
|
|
266
|
+
accelerator.print(f'group rejected - variance of {cumulative_rewards.var().item():.4f} is lower than threshold of {reject_threshold_cumulative_reward_variance}')
|
|
249
267
|
continue
|
|
250
268
|
|
|
251
269
|
group_advantages = z_score(shaped_rewards)
|
|
252
270
|
|
|
253
271
|
group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
|
|
254
272
|
|
|
273
|
+
# whether to reject group based on switch betas (as it determines the mask for learning)
|
|
274
|
+
|
|
275
|
+
padded_group_switch_betas, episode_lens = pad_sequence(group_switch_betas, dim = 0, return_lens = True)
|
|
276
|
+
|
|
277
|
+
if should_reject_group_based_on_switch_betas(padded_group_switch_betas, episode_lens):
|
|
278
|
+
accelerator.print(f'group rejected - switch betas for the entire group does not meet criteria for learning')
|
|
279
|
+
continue
|
|
280
|
+
|
|
255
281
|
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
256
282
|
replay_buffer.store_episode(
|
|
257
283
|
states = states,
|
|
@@ -264,43 +290,45 @@ def main(
|
|
|
264
290
|
# learn
|
|
265
291
|
|
|
266
292
|
if len(replay_buffer) >= buffer_size:
|
|
267
|
-
dl = replay_buffer.dataloader(batch_size = batch_size)
|
|
293
|
+
dl = replay_buffer.dataloader(batch_size = batch_size, shuffle = True)
|
|
268
294
|
dl = accelerator.prepare(dl)
|
|
269
295
|
|
|
270
296
|
meta_controller.train()
|
|
271
297
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
298
|
+
for epoch in range(num_epochs):
|
|
299
|
+
for batch in dl:
|
|
300
|
+
loss = meta_controller.policy_loss(
|
|
301
|
+
batch['states'],
|
|
302
|
+
batch['log_probs'],
|
|
303
|
+
batch['latent_actions'],
|
|
304
|
+
batch['advantages'],
|
|
305
|
+
batch['switch_betas'] == 1.,
|
|
306
|
+
episode_lens = batch['_lens']
|
|
307
|
+
)
|
|
282
308
|
|
|
283
|
-
|
|
309
|
+
accelerator.backward(loss)
|
|
284
310
|
|
|
285
|
-
|
|
311
|
+
grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
|
|
286
312
|
|
|
287
|
-
|
|
288
|
-
|
|
313
|
+
optim.step()
|
|
314
|
+
optim.zero_grad()
|
|
289
315
|
|
|
290
|
-
|
|
316
|
+
pbar.set_postfix(
|
|
317
|
+
epoch = epoch,
|
|
318
|
+
loss = f'{loss.item():.4f}',
|
|
319
|
+
grad_norm = f'{grad_norm.item():.4f}',
|
|
320
|
+
reward = f'{cumulative_rewards.mean().item():.4f}'
|
|
321
|
+
)
|
|
291
322
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
323
|
+
accelerator.log({
|
|
324
|
+
'loss': loss.item(),
|
|
325
|
+
'grad_norm': grad_norm.item(),
|
|
326
|
+
'reward': cumulative_rewards.mean().item()
|
|
327
|
+
})
|
|
297
328
|
|
|
298
|
-
|
|
299
|
-
'loss': loss.item(),
|
|
300
|
-
'grad_norm': grad_norm.item()
|
|
301
|
-
})
|
|
329
|
+
meta_controller.eval()
|
|
302
330
|
|
|
303
|
-
accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
|
|
331
|
+
accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}, reward: {cumulative_rewards.mean().item():.4f}')
|
|
304
332
|
|
|
305
333
|
env.close()
|
|
306
334
|
|
{metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/train_behavior_clone_babyai.py
RENAMED
|
@@ -18,7 +18,7 @@ from tqdm import tqdm
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
|
-
from torch.optim import
|
|
21
|
+
from torch.optim import AdamW
|
|
22
22
|
from torch.utils.data import DataLoader
|
|
23
23
|
|
|
24
24
|
from accelerate import Accelerator
|
|
@@ -31,14 +31,20 @@ from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
|
31
31
|
import minigrid
|
|
32
32
|
import gymnasium as gym
|
|
33
33
|
|
|
34
|
+
# TODO: loss is still ~300 and it could be the resnet output?
|
|
35
|
+
# TODO: changelog (paper hparams, checkpointing, difficulty levels in trajectory collection)
|
|
36
|
+
|
|
34
37
|
def train(
|
|
35
38
|
input_dir = "babyai-minibosslevel-trajectories",
|
|
36
39
|
env_id = "BabyAI-MiniBossLevel-v0",
|
|
37
40
|
cloning_epochs = 10,
|
|
38
41
|
discovery_epochs = 10,
|
|
39
|
-
batch_size =
|
|
42
|
+
batch_size = 128,
|
|
43
|
+
gradient_accumulation_steps = None,
|
|
40
44
|
lr = 1e-4,
|
|
41
45
|
discovery_lr = 1e-4,
|
|
46
|
+
weight_decay = 0.03,
|
|
47
|
+
discovery_weight_decay = 0.03,
|
|
42
48
|
dim = 512,
|
|
43
49
|
depth = 2,
|
|
44
50
|
heads = 8,
|
|
@@ -47,6 +53,7 @@ def train(
|
|
|
47
53
|
wandb_project = "metacontroller-babyai-bc",
|
|
48
54
|
checkpoint_path = "transformer_bc.pt",
|
|
49
55
|
meta_controller_checkpoint_path = "meta_controller_discovery.pt",
|
|
56
|
+
save_steps = 50,
|
|
50
57
|
state_loss_weight = 1.,
|
|
51
58
|
action_loss_weight = 1.,
|
|
52
59
|
discovery_action_recon_loss_weight = 1.,
|
|
@@ -55,6 +62,22 @@ def train(
|
|
|
55
62
|
max_grad_norm = 1.,
|
|
56
63
|
use_resnet = False
|
|
57
64
|
):
|
|
65
|
+
|
|
66
|
+
def store_checkpoint(step:int):
|
|
67
|
+
if accelerator.is_main_process:
|
|
68
|
+
|
|
69
|
+
# Add step to checkpoint filenames
|
|
70
|
+
checkpoint_path_with_step = checkpoint_path.replace('.pt', f'_step_{step}.pt')
|
|
71
|
+
meta_controller_checkpoint_path_with_step = meta_controller_checkpoint_path.replace('.pt', f'_step_{step}.pt')
|
|
72
|
+
|
|
73
|
+
unwrapped_model = accelerator.unwrap_model(model)
|
|
74
|
+
unwrapped_model.save(checkpoint_path_with_step)
|
|
75
|
+
|
|
76
|
+
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
77
|
+
unwrapped_meta_controller.save(meta_controller_checkpoint_path_with_step)
|
|
78
|
+
|
|
79
|
+
accelerator.print(f"Model saved to {checkpoint_path_with_step}, MetaController to {meta_controller_checkpoint_path_with_step}")
|
|
80
|
+
|
|
58
81
|
# accelerator
|
|
59
82
|
|
|
60
83
|
accelerator = Accelerator(log_with = "wandb" if use_wandb else None)
|
|
@@ -87,11 +110,13 @@ def train(
|
|
|
87
110
|
|
|
88
111
|
# state shape and action dimension
|
|
89
112
|
# state: (B, T, H, W, C) or (B, T, D)
|
|
113
|
+
|
|
90
114
|
state_shape = replay_buffer.shapes['state']
|
|
91
115
|
if use_resnet: state_dim = 256
|
|
92
116
|
else: state_dim = int(torch.tensor(state_shape).prod().item())
|
|
93
117
|
|
|
94
118
|
# deduce num_actions from the environment
|
|
119
|
+
|
|
95
120
|
from babyai_env import create_env
|
|
96
121
|
temp_env = create_env(env_id)
|
|
97
122
|
num_actions = int(temp_env.action_space.n)
|
|
@@ -99,6 +124,10 @@ def train(
|
|
|
99
124
|
|
|
100
125
|
accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
|
|
101
126
|
|
|
127
|
+
# meta controller
|
|
128
|
+
|
|
129
|
+
meta_controller = MetaController(dim)
|
|
130
|
+
|
|
102
131
|
# transformer
|
|
103
132
|
|
|
104
133
|
transformer_class = TransformerWithResnet if use_resnet else Transformer
|
|
@@ -108,18 +137,15 @@ def train(
|
|
|
108
137
|
state_embed_readout = dict(num_continuous = state_dim),
|
|
109
138
|
action_embed_readout = dict(num_discrete = num_actions),
|
|
110
139
|
lower_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
|
|
111
|
-
upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
|
|
140
|
+
upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head),
|
|
141
|
+
meta_controller = meta_controller
|
|
112
142
|
)
|
|
113
143
|
|
|
114
|
-
# meta controller
|
|
115
|
-
|
|
116
|
-
meta_controller = MetaController(dim)
|
|
117
|
-
|
|
118
144
|
# optimizer
|
|
119
145
|
|
|
120
|
-
optim_model =
|
|
146
|
+
optim_model = AdamW(model.parameters(), lr = lr, weight_decay = weight_decay)
|
|
121
147
|
|
|
122
|
-
optim_meta_controller =
|
|
148
|
+
optim_meta_controller = AdamW(meta_controller.discovery_parameters(), lr = discovery_lr, weight_decay = discovery_weight_decay)
|
|
123
149
|
|
|
124
150
|
# prepare
|
|
125
151
|
|
|
@@ -127,6 +153,7 @@ def train(
|
|
|
127
153
|
|
|
128
154
|
# training
|
|
129
155
|
|
|
156
|
+
gradient_step = 0
|
|
130
157
|
for epoch in range(cloning_epochs + discovery_epochs):
|
|
131
158
|
|
|
132
159
|
model.train()
|
|
@@ -154,13 +181,15 @@ def train(
|
|
|
154
181
|
else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
|
|
155
182
|
states = rearrange(states, 'b t ... -> b t (...)')
|
|
156
183
|
|
|
184
|
+
|
|
157
185
|
with accelerator.accumulate(model):
|
|
158
|
-
losses = model(
|
|
186
|
+
losses, meta_controller_output = model(
|
|
159
187
|
states,
|
|
160
188
|
actions,
|
|
161
189
|
episode_lens = episode_lens,
|
|
162
190
|
discovery_phase = is_discovering,
|
|
163
|
-
|
|
191
|
+
force_behavior_cloning = not is_discovering,
|
|
192
|
+
return_meta_controller_output = True
|
|
164
193
|
)
|
|
165
194
|
|
|
166
195
|
if is_discovering:
|
|
@@ -175,7 +204,8 @@ def train(
|
|
|
175
204
|
log = dict(
|
|
176
205
|
action_recon_loss = action_recon_loss.item(),
|
|
177
206
|
kl_loss = kl_loss.item(),
|
|
178
|
-
switch_loss = switch_loss.item()
|
|
207
|
+
switch_loss = switch_loss.item(),
|
|
208
|
+
switch_density = meta_controller_output.switch_beta.mean().item()
|
|
179
209
|
)
|
|
180
210
|
else:
|
|
181
211
|
state_loss, action_loss = losses
|
|
@@ -190,14 +220,19 @@ def train(
|
|
|
190
220
|
action_loss = action_loss.item(),
|
|
191
221
|
)
|
|
192
222
|
|
|
223
|
+
# gradient accumulation
|
|
224
|
+
|
|
225
|
+
if gradient_accumulation_steps is not None: loss /= gradient_accumulation_steps
|
|
226
|
+
|
|
193
227
|
# backprop
|
|
194
228
|
|
|
195
229
|
accelerator.backward(loss)
|
|
196
230
|
|
|
197
231
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
|
|
198
232
|
|
|
199
|
-
|
|
200
|
-
|
|
233
|
+
if gradient_accumulation_steps is None or gradient_step % gradient_accumulation_steps == 0:
|
|
234
|
+
optim.step()
|
|
235
|
+
optim.zero_grad()
|
|
201
236
|
|
|
202
237
|
# log
|
|
203
238
|
|
|
@@ -211,23 +246,21 @@ def train(
|
|
|
211
246
|
})
|
|
212
247
|
|
|
213
248
|
progress_bar.set_postfix(**log)
|
|
249
|
+
gradient_step += 1
|
|
250
|
+
|
|
251
|
+
# checkpoint
|
|
252
|
+
|
|
253
|
+
if gradient_step % save_steps == 0:
|
|
254
|
+
accelerator.wait_for_everyone()
|
|
255
|
+
store_checkpoint(gradient_step)
|
|
214
256
|
|
|
215
257
|
avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
|
|
216
258
|
avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
|
|
217
259
|
accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
|
|
218
260
|
|
|
219
261
|
# save weights
|
|
220
|
-
|
|
221
262
|
accelerator.wait_for_everyone()
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
unwrapped_model = accelerator.unwrap_model(model)
|
|
225
|
-
unwrapped_model.save(checkpoint_path)
|
|
226
|
-
|
|
227
|
-
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
228
|
-
unwrapped_meta_controller.save(meta_controller_checkpoint_path)
|
|
229
|
-
|
|
230
|
-
accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
|
|
263
|
+
store_checkpoint(gradient_step)
|
|
231
264
|
|
|
232
265
|
accelerator.end_training()
|
|
233
266
|
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
#!/bin/bash
|
|
2
|
-
set -e
|
|
3
|
-
|
|
4
|
-
# 1. Gather trajectories
|
|
5
|
-
echo "Gathering trajectories..."
|
|
6
|
-
uv run gather_babyai_trajs.py \
|
|
7
|
-
--num_seeds 100 \
|
|
8
|
-
--num_episodes_per_seed 10 \
|
|
9
|
-
--num_steps 500 \
|
|
10
|
-
--output_dir end_to_end_trajectories \
|
|
11
|
-
--env_id BabyAI-MiniBossLevel-v0
|
|
12
|
-
|
|
13
|
-
# 2. Behavioral cloning
|
|
14
|
-
echo "Training behavioral cloning model..."
|
|
15
|
-
ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
|
|
16
|
-
--cloning_epochs 10 \
|
|
17
|
-
--discovery_epochs 10 \
|
|
18
|
-
--batch_size 256 \
|
|
19
|
-
--input_dir end_to_end_trajectories \
|
|
20
|
-
--env_id BabyAI-MiniBossLevel-v0 \
|
|
21
|
-
--checkpoint_path end_to_end_model.pt \
|
|
22
|
-
--use_resnet
|
|
23
|
-
|
|
24
|
-
# 3. Inference rollouts
|
|
25
|
-
echo "Running inference rollouts..."
|
|
26
|
-
uv run train_babyai.py \
|
|
27
|
-
--transformer_weights_path end_to_end_model.pt \
|
|
28
|
-
--meta_controller_weights_path meta_controller_discovery.pt \
|
|
29
|
-
--env_name BabyAI-MiniBossLevel-v0 \
|
|
30
|
-
--num_episodes 1000 \
|
|
31
|
-
--buffer_size 1000 \
|
|
32
|
-
--max_timesteps 100 \
|
|
33
|
-
--num_groups 16 \
|
|
34
|
-
--lr 1e-4 \
|
|
35
|
-
--use_resnet
|
{metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.43 → metacontroller_pytorch-0.0.46}/tests/test_metacontroller.py
RENAMED
|
File without changes
|
|
File without changes
|