metacontroller-pytorch 0.0.42__tar.gz → 0.0.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller.py +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller_with_binary_mapper.py +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/tests/test_metacontroller.py +10 -2
- metacontroller_pytorch-0.0.43/train_baby_evo_strat.py +213 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/train_babyai.py +15 -14
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/README.md +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/babyai_env.py +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/train_behavior_clone_babyai.py +0 -0
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/metacontroller/metacontroller.py
RENAMED
|
@@ -291,7 +291,7 @@ class MetaController(Module):
|
|
|
291
291
|
else:
|
|
292
292
|
# else during inference, use the previous sampled latent action
|
|
293
293
|
|
|
294
|
-
assert seq_len == 1,
|
|
294
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
295
295
|
z_prev = prev_sampled_latent_action
|
|
296
296
|
|
|
297
297
|
# switch input is previous latent action and the embedding
|
|
@@ -241,7 +241,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
241
241
|
if discovery_phase:
|
|
242
242
|
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
243
243
|
else:
|
|
244
|
-
assert seq_len == 1,
|
|
244
|
+
assert seq_len == 1, 'inference RL phase must be done one token at a time - if replaying for policy optimization, please use `get_action_dist_for_internal_rl`'
|
|
245
245
|
z_prev = prev_sampled_code
|
|
246
246
|
|
|
247
247
|
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/tests/test_metacontroller.py
RENAMED
|
@@ -114,13 +114,14 @@ def test_metacontroller(
|
|
|
114
114
|
for one_state in subset_state.unbind(dim = 1):
|
|
115
115
|
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
116
116
|
|
|
117
|
-
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
|
|
117
|
+
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, cache = cache, return_cache = True)
|
|
118
118
|
|
|
119
119
|
past_action_id = model.action_readout.sample(logits)
|
|
120
120
|
|
|
121
121
|
# extract grpo data and store
|
|
122
122
|
|
|
123
|
-
|
|
123
|
+
grpo_data = extract_grpo_data(meta_controller, cache)
|
|
124
|
+
grpo_data_list.append(grpo_data)
|
|
124
125
|
|
|
125
126
|
# accumulate across time for the episode data
|
|
126
127
|
|
|
@@ -145,6 +146,13 @@ def test_metacontroller(
|
|
|
145
146
|
# simulate a policy loss update over the entire group
|
|
146
147
|
|
|
147
148
|
group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
|
|
149
|
+
|
|
150
|
+
# parallel verification
|
|
151
|
+
|
|
152
|
+
parallel_action_dist = meta_controller.get_action_dist_for_internal_rl(group_states)
|
|
153
|
+
parallel_log_probs = meta_controller.log_prob(parallel_action_dist, group_latent_actions)
|
|
154
|
+
|
|
155
|
+
assert torch.allclose(parallel_log_probs, group_log_probs, atol = 1e-5), 'parallel log probs do not match stored log probs'
|
|
148
156
|
|
|
149
157
|
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
150
158
|
replay_buffer.store_episode(
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "fire",
|
|
4
|
+
# "gymnasium",
|
|
5
|
+
# "gymnasium[other]",
|
|
6
|
+
# "metacontroller-pytorch",
|
|
7
|
+
# "minigrid",
|
|
8
|
+
# "tqdm",
|
|
9
|
+
# "x-evolution",
|
|
10
|
+
# "einops"
|
|
11
|
+
# ]
|
|
12
|
+
# ///
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
import fire
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from shutil import rmtree
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
from torch import nn, Tensor, tensor
|
|
22
|
+
from torch.nn import Module
|
|
23
|
+
from einops import rearrange
|
|
24
|
+
|
|
25
|
+
from babyai_env import create_env
|
|
26
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
27
|
+
|
|
28
|
+
# functions
|
|
29
|
+
|
|
30
|
+
def exists(v):
|
|
31
|
+
return v is not None
|
|
32
|
+
|
|
33
|
+
def default(v, d):
|
|
34
|
+
return v if exists(v) else d
|
|
35
|
+
|
|
36
|
+
# default fitness function
|
|
37
|
+
|
|
38
|
+
def default_fitness_fn(
|
|
39
|
+
rewards: list[float],
|
|
40
|
+
states: list[any],
|
|
41
|
+
actions: list[any],
|
|
42
|
+
next_states: list[any],
|
|
43
|
+
infos: list[any]
|
|
44
|
+
) -> float:
|
|
45
|
+
"""
|
|
46
|
+
researchers can modify this function to engineer their own rewards and fitness scores
|
|
47
|
+
processing the entire episode at once for every noise vector of the population separately
|
|
48
|
+
"""
|
|
49
|
+
return sum(rewards)
|
|
50
|
+
|
|
51
|
+
# babyai environment for ES
|
|
52
|
+
|
|
53
|
+
class BabyAIEnvironment(Module):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
env_id = 'BabyAI-BossLevel-v0',
|
|
57
|
+
video_folder = './recordings_babyai_es',
|
|
58
|
+
render_every_eps = 100,
|
|
59
|
+
max_steps = 500,
|
|
60
|
+
use_resnet = False,
|
|
61
|
+
fitness_fn = default_fitness_fn
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
|
|
65
|
+
self.env_id = env_id
|
|
66
|
+
self.video_folder = video_folder
|
|
67
|
+
self.render_every_eps = render_every_eps
|
|
68
|
+
self.max_steps = max_steps
|
|
69
|
+
self.use_resnet = use_resnet
|
|
70
|
+
self.fitness_fn = fitness_fn
|
|
71
|
+
|
|
72
|
+
# initial env creation for observation space etc. if needed
|
|
73
|
+
# but create_env is called inside pre_main_callback or reset
|
|
74
|
+
self.env = None
|
|
75
|
+
|
|
76
|
+
def pre_main_callback(self):
|
|
77
|
+
# clean up and initialize environment
|
|
78
|
+
rmtree(self.video_folder, ignore_errors = True)
|
|
79
|
+
|
|
80
|
+
self.env = create_env(
|
|
81
|
+
self.env_id,
|
|
82
|
+
render_mode = 'rgb_array',
|
|
83
|
+
video_folder = self.video_folder,
|
|
84
|
+
render_every_eps = self.render_every_eps
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def forward(self, model):
|
|
88
|
+
device = next(model.parameters()).device
|
|
89
|
+
|
|
90
|
+
seed = torch.randint(0, int(1e6), ()).item()
|
|
91
|
+
state, _ = self.env.reset(seed = seed)
|
|
92
|
+
|
|
93
|
+
step = 0
|
|
94
|
+
cache = None
|
|
95
|
+
past_action_id = None
|
|
96
|
+
|
|
97
|
+
unwrapped_model = getattr(model, 'model', model)
|
|
98
|
+
|
|
99
|
+
episode_rewards = []
|
|
100
|
+
episode_states = []
|
|
101
|
+
episode_actions = []
|
|
102
|
+
episode_next_states = []
|
|
103
|
+
episode_infos = []
|
|
104
|
+
|
|
105
|
+
while step < self.max_steps:
|
|
106
|
+
image = state['image']
|
|
107
|
+
image_tensor = torch.from_numpy(image).float().to(device)
|
|
108
|
+
|
|
109
|
+
if self.use_resnet:
|
|
110
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
|
|
111
|
+
image_tensor = unwrapped_model.visual_encode(image_tensor)
|
|
112
|
+
else:
|
|
113
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
|
|
114
|
+
|
|
115
|
+
if torch.is_tensor(past_action_id):
|
|
116
|
+
past_action_id = past_action_id.long()
|
|
117
|
+
|
|
118
|
+
with torch.no_grad():
|
|
119
|
+
logits, cache = model(
|
|
120
|
+
image_tensor,
|
|
121
|
+
past_action_id,
|
|
122
|
+
return_cache = True,
|
|
123
|
+
return_raw_action_dist = True,
|
|
124
|
+
cache = cache
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
action = unwrapped_model.action_readout.sample(logits)
|
|
128
|
+
past_action_id = action
|
|
129
|
+
action_id = action.squeeze()
|
|
130
|
+
|
|
131
|
+
next_state, reward, terminated, truncated, info = self.env.step(action_id.cpu().numpy().item())
|
|
132
|
+
|
|
133
|
+
episode_rewards.append(reward)
|
|
134
|
+
episode_states.append(state)
|
|
135
|
+
episode_actions.append(action_id)
|
|
136
|
+
episode_next_states.append(next_state)
|
|
137
|
+
episode_infos.append(info)
|
|
138
|
+
|
|
139
|
+
done = terminated or truncated
|
|
140
|
+
if done:
|
|
141
|
+
break
|
|
142
|
+
|
|
143
|
+
state = next_state
|
|
144
|
+
step += 1
|
|
145
|
+
|
|
146
|
+
return self.fitness_fn(
|
|
147
|
+
episode_rewards,
|
|
148
|
+
episode_states,
|
|
149
|
+
episode_actions,
|
|
150
|
+
episode_next_states,
|
|
151
|
+
episode_infos
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def main(
|
|
155
|
+
env_id = 'BabyAI-BossLevel-v0',
|
|
156
|
+
num_generations = 100,
|
|
157
|
+
max_steps = 500,
|
|
158
|
+
render_every_eps = 100,
|
|
159
|
+
video_folder = './recordings_babyai_es',
|
|
160
|
+
transformer_weights_path: str | None = None,
|
|
161
|
+
meta_controller_weights_path: str | None = None,
|
|
162
|
+
output_meta_controller_path = 'metacontroller_es_trained.pt',
|
|
163
|
+
use_resnet = False,
|
|
164
|
+
noise_population_size = 50,
|
|
165
|
+
noise_scale = 1e-2,
|
|
166
|
+
learning_rate = 1e-3,
|
|
167
|
+
fitness_fn = default_fitness_fn
|
|
168
|
+
):
|
|
169
|
+
# load model
|
|
170
|
+
|
|
171
|
+
assert exists(transformer_weights_path), "Transformer weights must be provided"
|
|
172
|
+
|
|
173
|
+
# lazy import to avoid unnecessary dependencies if not used
|
|
174
|
+
from metacontroller.transformer_with_resnet import TransformerWithResnet as TransformerResnet
|
|
175
|
+
transformer_klass = TransformerResnet if use_resnet else Transformer
|
|
176
|
+
|
|
177
|
+
model = transformer_klass.init_and_load(transformer_weights_path, strict = False)
|
|
178
|
+
model.eval()
|
|
179
|
+
|
|
180
|
+
if exists(meta_controller_weights_path):
|
|
181
|
+
meta_controller = MetaController.init_and_load(meta_controller_weights_path, strict = False)
|
|
182
|
+
model.meta_controller = meta_controller
|
|
183
|
+
|
|
184
|
+
assert exists(model.meta_controller), "MetaController must be present for evolution"
|
|
185
|
+
|
|
186
|
+
# setup environment
|
|
187
|
+
|
|
188
|
+
babyai_env = BabyAIEnvironment(
|
|
189
|
+
env_id = env_id,
|
|
190
|
+
video_folder = video_folder,
|
|
191
|
+
render_every_eps = render_every_eps,
|
|
192
|
+
max_steps = max_steps,
|
|
193
|
+
use_resnet = use_resnet,
|
|
194
|
+
fitness_fn = fitness_fn
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# evolve
|
|
198
|
+
|
|
199
|
+
model.evolve(
|
|
200
|
+
num_generations = num_generations,
|
|
201
|
+
environment = babyai_env,
|
|
202
|
+
noise_population_size = noise_population_size,
|
|
203
|
+
noise_scale = noise_scale,
|
|
204
|
+
learning_rate = learning_rate
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# save
|
|
208
|
+
|
|
209
|
+
model.meta_controller.save(output_meta_controller_path)
|
|
210
|
+
print(f'MetaController weights saved to {output_meta_controller_path}')
|
|
211
|
+
|
|
212
|
+
if __name__ == '__main__':
|
|
213
|
+
fire.Fire(main)
|
|
@@ -57,22 +57,23 @@ def default(v, d):
|
|
|
57
57
|
# main
|
|
58
58
|
|
|
59
59
|
def main(
|
|
60
|
-
env_name
|
|
61
|
-
num_episodes
|
|
62
|
-
max_timesteps
|
|
63
|
-
buffer_size
|
|
64
|
-
render_every_eps
|
|
65
|
-
video_folder
|
|
60
|
+
env_name = 'BabyAI-BossLevel-v0',
|
|
61
|
+
num_episodes = int(10e6),
|
|
62
|
+
max_timesteps = 500,
|
|
63
|
+
buffer_size = 5_000,
|
|
64
|
+
render_every_eps = 1_000,
|
|
65
|
+
video_folder = './recordings',
|
|
66
66
|
seed: int | None = None,
|
|
67
67
|
transformer_weights_path: str | None = None,
|
|
68
68
|
meta_controller_weights_path: str | None = None,
|
|
69
|
-
output_meta_controller_path
|
|
70
|
-
use_resnet
|
|
71
|
-
lr
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
69
|
+
output_meta_controller_path = 'metacontroller_rl_trained.pt',
|
|
70
|
+
use_resnet = False,
|
|
71
|
+
lr = 1e-4,
|
|
72
|
+
batch_size = 16,
|
|
73
|
+
num_groups = 16,
|
|
74
|
+
max_grad_norm = 1.0,
|
|
75
|
+
use_wandb = False,
|
|
76
|
+
wandb_project = 'metacontroller-babyai-rl'
|
|
76
77
|
):
|
|
77
78
|
# accelerator
|
|
78
79
|
|
|
@@ -263,7 +264,7 @@ def main(
|
|
|
263
264
|
# learn
|
|
264
265
|
|
|
265
266
|
if len(replay_buffer) >= buffer_size:
|
|
266
|
-
dl = replay_buffer.dataloader(batch_size =
|
|
267
|
+
dl = replay_buffer.dataloader(batch_size = batch_size)
|
|
267
268
|
dl = accelerator.prepare(dl)
|
|
268
269
|
|
|
269
270
|
meta_controller.train()
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.42 → metacontroller_pytorch-0.0.43}/train_behavior_clone_babyai.py
RENAMED
|
File without changes
|