metacontroller-pytorch 0.0.35__tar.gz → 0.0.36__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/PKG-INFO +2 -2
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/metacontroller/metacontroller.py +14 -7
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/metacontroller/metacontroller_with_binary_mapper.py +4 -2
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/pyproject.toml +2 -2
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/tests/test_metacontroller.py +56 -10
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/train_behavior_clone_babyai.py +82 -25
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/README.md +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/train_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.36
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -39,7 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
|
-
Requires-Dist: memmap-replay-buffer>=0.0.
|
|
42
|
+
Requires-Dist: memmap-replay-buffer>=0.0.25
|
|
43
43
|
Requires-Dist: torch-einops-utils>=0.0.19
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
45
|
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
{metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/metacontroller/metacontroller.py
RENAMED
|
@@ -336,6 +336,8 @@ class MetaController(Module):
|
|
|
336
336
|
|
|
337
337
|
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
|
|
338
338
|
|
|
339
|
+
MetaController.policy_loss = policy_loss
|
|
340
|
+
|
|
339
341
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
340
342
|
|
|
341
343
|
Hiddens = namedtuple('Hiddens', (
|
|
@@ -414,17 +416,21 @@ class Transformer(Module):
|
|
|
414
416
|
):
|
|
415
417
|
device = state.device
|
|
416
418
|
|
|
419
|
+
# meta controller is either given or already given at init
|
|
420
|
+
|
|
417
421
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
418
422
|
|
|
419
|
-
|
|
423
|
+
has_meta_controller = exists(meta_controller)
|
|
420
424
|
|
|
421
|
-
|
|
425
|
+
assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
|
|
426
|
+
|
|
427
|
+
behavioral_cloning = not has_meta_controller and not return_raw_action_dist
|
|
422
428
|
|
|
423
429
|
# by default, if meta controller is passed in, transformer is no grad
|
|
424
430
|
|
|
425
|
-
lower_transformer_context = nullcontext if not
|
|
426
|
-
meta_controller_context = nullcontext if
|
|
427
|
-
upper_transformer_context = nullcontext if (not
|
|
431
|
+
lower_transformer_context = nullcontext if not has_meta_controller else torch.no_grad
|
|
432
|
+
meta_controller_context = nullcontext if has_meta_controller else torch.no_grad
|
|
433
|
+
upper_transformer_context = nullcontext if (not has_meta_controller or discovery_phase) else torch.no_grad
|
|
428
434
|
|
|
429
435
|
# handle cache
|
|
430
436
|
|
|
@@ -432,7 +438,8 @@ class Transformer(Module):
|
|
|
432
438
|
|
|
433
439
|
# handle maybe behavioral cloning
|
|
434
440
|
|
|
435
|
-
if behavioral_cloning or
|
|
441
|
+
if behavioral_cloning or discovery_phase: # during behavior cloning and discovery phase, the network is predicting / reconstructing the next token
|
|
442
|
+
|
|
436
443
|
assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
|
|
437
444
|
|
|
438
445
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
@@ -495,7 +502,7 @@ class Transformer(Module):
|
|
|
495
502
|
|
|
496
503
|
return state_clone_loss, action_clone_loss
|
|
497
504
|
|
|
498
|
-
elif
|
|
505
|
+
elif discovery_phase:
|
|
499
506
|
|
|
500
507
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
501
508
|
|
|
@@ -28,7 +28,7 @@ from torch_einops_utils.save_load import save_load
|
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
30
30
|
|
|
31
|
-
from metacontroller.metacontroller import MetaControllerOutput
|
|
31
|
+
from metacontroller.metacontroller import MetaControllerOutput, policy_loss
|
|
32
32
|
|
|
33
33
|
# constants
|
|
34
34
|
|
|
@@ -170,7 +170,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
170
170
|
action_log_probs = log_probs.gather(-1, codes)
|
|
171
171
|
action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
|
|
172
172
|
|
|
173
|
-
return action_log_probs
|
|
173
|
+
return action_log_probs
|
|
174
174
|
|
|
175
175
|
def forward(
|
|
176
176
|
self,
|
|
@@ -302,3 +302,5 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
302
302
|
switch_beta = rearrange(switch_beta, '... 1 -> ...')
|
|
303
303
|
|
|
304
304
|
return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
|
|
305
|
+
|
|
306
|
+
MetaControllerWithBinaryMapper.policy_loss = policy_loss
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.36"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -29,7 +29,7 @@ dependencies = [
|
|
|
29
29
|
"einx>=0.3.0",
|
|
30
30
|
"einops>=0.8.1",
|
|
31
31
|
"loguru",
|
|
32
|
-
"memmap-replay-buffer>=0.0.
|
|
32
|
+
"memmap-replay-buffer>=0.0.25",
|
|
33
33
|
"torch>=2.5",
|
|
34
34
|
"torch-einops-utils>=0.0.19",
|
|
35
35
|
"vector-quantize-pytorch>=1.27.20",
|
{metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/tests/test_metacontroller.py
RENAMED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import pytest
|
|
2
2
|
param = pytest.mark.parametrize
|
|
3
3
|
|
|
4
|
+
from shutil import rmtree
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from functools import partial
|
|
6
7
|
|
|
@@ -9,6 +10,8 @@ from torch import cat
|
|
|
9
10
|
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
10
11
|
from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
|
|
11
12
|
|
|
13
|
+
from memmap_replay_buffer import ReplayBuffer
|
|
14
|
+
|
|
12
15
|
from einops import rearrange
|
|
13
16
|
|
|
14
17
|
# functions
|
|
@@ -66,6 +69,12 @@ def test_metacontroller(
|
|
|
66
69
|
dim_latent = 128,
|
|
67
70
|
switch_per_latent_dim = switch_per_latent_dim
|
|
68
71
|
)
|
|
72
|
+
|
|
73
|
+
field_shapes = dict(
|
|
74
|
+
log_probs = ('float', 128),
|
|
75
|
+
switch_betas = ('float', 128 if switch_per_latent_dim else 1),
|
|
76
|
+
latent_actions = ('float', 128)
|
|
77
|
+
)
|
|
69
78
|
else:
|
|
70
79
|
meta_controller = MetaControllerWithBinaryMapper(
|
|
71
80
|
dim_model = 512,
|
|
@@ -74,6 +83,12 @@ def test_metacontroller(
|
|
|
74
83
|
dim_code_bits = 8, # 2 ** 8 = 256 codes
|
|
75
84
|
)
|
|
76
85
|
|
|
86
|
+
field_shapes = dict(
|
|
87
|
+
log_probs = ('float', 8),
|
|
88
|
+
switch_betas = ('float', 8 if switch_per_latent_dim else 1),
|
|
89
|
+
latent_actions = ('float', 256)
|
|
90
|
+
)
|
|
91
|
+
|
|
77
92
|
# discovery phase
|
|
78
93
|
|
|
79
94
|
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
|
|
@@ -81,6 +96,23 @@ def test_metacontroller(
|
|
|
81
96
|
|
|
82
97
|
# internal rl - done iteratively
|
|
83
98
|
|
|
99
|
+
# replay buffer
|
|
100
|
+
|
|
101
|
+
test_folder = './test-buffer-for-grpo'
|
|
102
|
+
|
|
103
|
+
replay_buffer = ReplayBuffer(
|
|
104
|
+
test_folder,
|
|
105
|
+
max_episodes = 3,
|
|
106
|
+
max_timesteps = 256,
|
|
107
|
+
fields = dict(
|
|
108
|
+
states = ('float', 512),
|
|
109
|
+
**field_shapes
|
|
110
|
+
),
|
|
111
|
+
meta_fields = dict(
|
|
112
|
+
advantages = 'float'
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
84
116
|
# simulate grpo
|
|
85
117
|
|
|
86
118
|
all_episodes = []
|
|
@@ -129,22 +161,34 @@ def test_metacontroller(
|
|
|
129
161
|
# calculate advantages using z-score
|
|
130
162
|
|
|
131
163
|
rewards = cat(all_rewards)
|
|
132
|
-
|
|
164
|
+
group_advantages = z_score(rewards)
|
|
133
165
|
|
|
134
|
-
assert
|
|
166
|
+
assert group_advantages.shape == (3,)
|
|
135
167
|
|
|
136
168
|
# simulate a policy loss update over the entire group
|
|
137
169
|
|
|
138
170
|
group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
|
|
139
171
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
172
|
+
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
173
|
+
replay_buffer.store_episode(
|
|
174
|
+
states = states,
|
|
175
|
+
log_probs = log_probs,
|
|
176
|
+
switch_betas = switch_betas,
|
|
177
|
+
latent_actions = latent_actions,
|
|
178
|
+
advantages = advantages
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
dl = replay_buffer.dataloader(batch_size = 3)
|
|
182
|
+
|
|
183
|
+
batch = next(iter(dl))
|
|
184
|
+
|
|
185
|
+
loss = meta_controller.policy_loss(
|
|
186
|
+
batch['states'],
|
|
187
|
+
batch['log_probs'],
|
|
188
|
+
batch['latent_actions'],
|
|
189
|
+
batch['advantages'],
|
|
190
|
+
batch['switch_betas'] == 1.,
|
|
191
|
+
episode_lens = batch['_lens']
|
|
148
192
|
)
|
|
149
193
|
|
|
150
194
|
loss.backward()
|
|
@@ -167,3 +211,5 @@ def test_metacontroller(
|
|
|
167
211
|
|
|
168
212
|
Path('./meta_controller.pt').unlink()
|
|
169
213
|
Path('./trained.pt').unlink()
|
|
214
|
+
|
|
215
|
+
rmtree(test_folder, ignore_errors = True)
|
{metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/train_behavior_clone_babyai.py
RENAMED
|
@@ -25,29 +25,35 @@ from accelerate import Accelerator
|
|
|
25
25
|
from memmap_replay_buffer import ReplayBuffer
|
|
26
26
|
from einops import rearrange
|
|
27
27
|
|
|
28
|
-
from metacontroller.metacontroller import Transformer
|
|
28
|
+
from metacontroller.metacontroller import Transformer, MetaController
|
|
29
29
|
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
30
30
|
|
|
31
31
|
import minigrid
|
|
32
32
|
import gymnasium as gym
|
|
33
33
|
|
|
34
34
|
def train(
|
|
35
|
-
input_dir
|
|
36
|
-
env_id
|
|
37
|
-
cloning_epochs
|
|
38
|
-
discovery_epochs
|
|
39
|
-
batch_size
|
|
40
|
-
lr
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
35
|
+
input_dir = "babyai-minibosslevel-trajectories",
|
|
36
|
+
env_id = "BabyAI-MiniBossLevel-v0",
|
|
37
|
+
cloning_epochs = 10,
|
|
38
|
+
discovery_epochs = 10,
|
|
39
|
+
batch_size = 32,
|
|
40
|
+
lr = 1e-4,
|
|
41
|
+
discovery_lr = 1e-4,
|
|
42
|
+
dim = 512,
|
|
43
|
+
depth = 2,
|
|
44
|
+
heads = 8,
|
|
45
|
+
dim_head = 64,
|
|
46
|
+
use_wandb = False,
|
|
47
|
+
wandb_project = "metacontroller-babyai-bc",
|
|
48
|
+
checkpoint_path = "transformer_bc.pt",
|
|
49
|
+
meta_controller_checkpoint_path = "meta_controller_discovery.pt",
|
|
50
|
+
state_loss_weight = 1.,
|
|
51
|
+
action_loss_weight = 1.,
|
|
52
|
+
discovery_action_recon_loss_weight = 1.,
|
|
53
|
+
discovery_kl_loss_weight = 1.,
|
|
54
|
+
discovery_switch_loss_weight = 1.,
|
|
55
|
+
max_grad_norm = 1.,
|
|
56
|
+
use_resnet = False
|
|
51
57
|
):
|
|
52
58
|
# accelerator
|
|
53
59
|
|
|
@@ -96,6 +102,7 @@ def train(
|
|
|
96
102
|
# transformer
|
|
97
103
|
|
|
98
104
|
transformer_class = TransformerWithResnet if use_resnet else Transformer
|
|
105
|
+
|
|
99
106
|
model = transformer_class(
|
|
100
107
|
dim = dim,
|
|
101
108
|
state_embed_readout = dict(num_continuous = state_dim),
|
|
@@ -104,23 +111,34 @@ def train(
|
|
|
104
111
|
upper_body = dict(depth = depth, heads = heads, attn_dim_head = dim_head)
|
|
105
112
|
)
|
|
106
113
|
|
|
114
|
+
# meta controller
|
|
115
|
+
|
|
116
|
+
meta_controller = MetaController(dim)
|
|
117
|
+
|
|
107
118
|
# optimizer
|
|
108
119
|
|
|
109
|
-
|
|
120
|
+
optim_model = Adam(model.parameters(), lr = lr)
|
|
121
|
+
|
|
122
|
+
optim_meta_controller = Adam(meta_controller.discovery_parameters(), lr = discovery_lr)
|
|
110
123
|
|
|
111
124
|
# prepare
|
|
112
125
|
|
|
113
|
-
model,
|
|
126
|
+
model, optim_model, optim_meta_controller, dataloader = accelerator.prepare(model, optim_model, optim_meta_controller, dataloader)
|
|
114
127
|
|
|
115
128
|
# training
|
|
129
|
+
|
|
116
130
|
for epoch in range(cloning_epochs + discovery_epochs):
|
|
131
|
+
|
|
117
132
|
model.train()
|
|
118
133
|
total_state_loss = 0.
|
|
119
134
|
total_action_loss = 0.
|
|
120
135
|
|
|
121
136
|
progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
|
|
137
|
+
|
|
122
138
|
is_discovering = (epoch >= cloning_epochs) # discovery phase is BC with metacontroller tuning
|
|
123
139
|
|
|
140
|
+
optim = optim_model if not is_discovering else optim_meta_controller
|
|
141
|
+
|
|
124
142
|
for batch in progress_bar:
|
|
125
143
|
# batch is a NamedTuple (e.g. MemoryMappedBatch)
|
|
126
144
|
# state: (B, T, 7, 7, 3), action: (B, T)
|
|
@@ -130,18 +148,53 @@ def train(
|
|
|
130
148
|
episode_lens = batch.get('_lens')
|
|
131
149
|
|
|
132
150
|
# use resnet18 to embed visual observations
|
|
151
|
+
|
|
133
152
|
if use_resnet:
|
|
134
153
|
states = model.visual_encode(states)
|
|
135
154
|
else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
|
|
136
155
|
states = rearrange(states, 'b t ... -> b t (...)')
|
|
137
156
|
|
|
138
157
|
with accelerator.accumulate(model):
|
|
139
|
-
|
|
140
|
-
|
|
158
|
+
losses = model(
|
|
159
|
+
states,
|
|
160
|
+
actions,
|
|
161
|
+
episode_lens = episode_lens,
|
|
162
|
+
discovery_phase = is_discovering,
|
|
163
|
+
meta_controller = meta_controller if is_discovering else None
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if is_discovering:
|
|
167
|
+
action_recon_loss, kl_loss, switch_loss = losses
|
|
168
|
+
|
|
169
|
+
loss = (
|
|
170
|
+
action_recon_loss * discovery_action_recon_loss_weight +
|
|
171
|
+
kl_loss * discovery_kl_loss_weight +
|
|
172
|
+
switch_loss * discovery_switch_loss_weight
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
log = dict(
|
|
176
|
+
action_recon_loss = action_recon_loss.item(),
|
|
177
|
+
kl_loss = kl_loss.item(),
|
|
178
|
+
switch_loss = switch_loss.item()
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
state_loss, action_loss = losses
|
|
182
|
+
|
|
183
|
+
loss = (
|
|
184
|
+
state_loss * state_loss_weight +
|
|
185
|
+
action_loss * action_loss_weight
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
log = dict(
|
|
189
|
+
state_loss = state_loss.item(),
|
|
190
|
+
action_loss = action_loss.item(),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# backprop
|
|
141
194
|
|
|
142
195
|
accelerator.backward(loss)
|
|
143
196
|
|
|
144
|
-
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=
|
|
197
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = max_grad_norm)
|
|
145
198
|
|
|
146
199
|
optim.step()
|
|
147
200
|
optim.zero_grad()
|
|
@@ -152,8 +205,7 @@ def train(
|
|
|
152
205
|
total_action_loss += action_loss.item()
|
|
153
206
|
|
|
154
207
|
accelerator.log({
|
|
155
|
-
|
|
156
|
-
"action_loss": action_loss.item(),
|
|
208
|
+
**log,
|
|
157
209
|
"total_loss": loss.item(),
|
|
158
210
|
"grad_norm": grad_norm.item()
|
|
159
211
|
})
|
|
@@ -172,9 +224,14 @@ def train(
|
|
|
172
224
|
|
|
173
225
|
accelerator.wait_for_everyone()
|
|
174
226
|
if accelerator.is_main_process:
|
|
227
|
+
|
|
175
228
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
176
229
|
unwrapped_model.save(checkpoint_path)
|
|
177
|
-
|
|
230
|
+
|
|
231
|
+
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
232
|
+
unwrapped_meta_controller.save(meta_controller_checkpoint_path)
|
|
233
|
+
|
|
234
|
+
accelerator.print(f"Model saved to {checkpoint_path}, MetaControler to {meta_controller_checkpoint_path}")
|
|
178
235
|
|
|
179
236
|
accelerator.end_training()
|
|
180
237
|
|
{metacontroller_pytorch-0.0.35 → metacontroller_pytorch-0.0.36}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|