metacontroller-pytorch 0.0.20__tar.gz → 0.0.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/.gitignore +3 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/PKG-INFO +12 -1
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/README.md +10 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/metacontroller/metacontroller.py +58 -13
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/pyproject.toml +2 -1
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/tests/test_metacontroller.py +10 -6
- metacontroller_pytorch-0.0.22/train.py +97 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/metacontroller/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.22
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -39,6 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
|
+
Requires-Dist: memmap-replay-buffer>=0.0.1
|
|
42
43
|
Requires-Dist: torch>=2.5
|
|
43
44
|
Requires-Dist: x-evolution>=0.1.23
|
|
44
45
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -54,6 +55,16 @@ Description-Content-Type: text/markdown
|
|
|
54
55
|
|
|
55
56
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
56
57
|
|
|
58
|
+
## Install
|
|
59
|
+
|
|
60
|
+
```shell
|
|
61
|
+
$ pip install metacontroller-pytorch
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Appreciation
|
|
65
|
+
|
|
66
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
67
|
+
|
|
57
68
|
## Citations
|
|
58
69
|
|
|
59
70
|
```bibtex
|
|
@@ -4,6 +4,16 @@
|
|
|
4
4
|
|
|
5
5
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
6
6
|
|
|
7
|
+
## Install
|
|
8
|
+
|
|
9
|
+
```shell
|
|
10
|
+
$ pip install metacontroller-pytorch
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Appreciation
|
|
14
|
+
|
|
15
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
16
|
+
|
|
7
17
|
## Citations
|
|
8
18
|
|
|
9
19
|
```bibtex
|
{metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/metacontroller/metacontroller.py
RENAMED
|
@@ -46,6 +46,14 @@ def default(*args):
|
|
|
46
46
|
return arg
|
|
47
47
|
return None
|
|
48
48
|
|
|
49
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
|
50
|
+
if pad == (0, 0):
|
|
51
|
+
return t
|
|
52
|
+
|
|
53
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
54
|
+
zeros = ((0, 0) * dims_from_right)
|
|
55
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
56
|
+
|
|
49
57
|
# tensor helpers
|
|
50
58
|
|
|
51
59
|
def straight_through(src, tgt):
|
|
@@ -101,7 +109,9 @@ class MetaController(Module):
|
|
|
101
109
|
|
|
102
110
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
103
111
|
|
|
104
|
-
|
|
112
|
+
|
|
113
|
+
self.dim_latent = dim_latent
|
|
114
|
+
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
105
115
|
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
106
116
|
|
|
107
117
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
@@ -147,10 +157,11 @@ class MetaController(Module):
|
|
|
147
157
|
hard_switch = False,
|
|
148
158
|
temperature = 1.
|
|
149
159
|
):
|
|
160
|
+
device = residual_stream.device
|
|
150
161
|
|
|
151
162
|
# destruct prev cache
|
|
152
163
|
|
|
153
|
-
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) *
|
|
164
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
154
165
|
|
|
155
166
|
# getting proposed action for the two phases
|
|
156
167
|
|
|
@@ -175,13 +186,34 @@ class MetaController(Module):
|
|
|
175
186
|
|
|
176
187
|
action_dist = readout(proposed_action_hidden)
|
|
177
188
|
|
|
178
|
-
|
|
189
|
+
sampled_latent_action = readout.sample(action_dist, temperature = temperature)
|
|
179
190
|
|
|
180
191
|
# switching unit timer
|
|
181
192
|
|
|
182
|
-
batch,
|
|
193
|
+
batch, seq_len, dim = sampled_latent_action.shape
|
|
194
|
+
|
|
195
|
+
# initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
|
|
196
|
+
|
|
197
|
+
if not exists(prev_sampled_latent_action):
|
|
198
|
+
prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
|
|
199
|
+
|
|
200
|
+
if discovery_phase:
|
|
201
|
+
z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
|
|
202
|
+
|
|
203
|
+
else:
|
|
204
|
+
# else during inference, use the previous sampled latent action
|
|
205
|
+
|
|
206
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
207
|
+
z_prev = prev_sampled_latent_action
|
|
183
208
|
|
|
184
|
-
|
|
209
|
+
# switch input is previous latent action and the embedding
|
|
210
|
+
|
|
211
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
212
|
+
|
|
213
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
214
|
+
switch_input,
|
|
215
|
+
prev_switching_unit_gru_hidden
|
|
216
|
+
)
|
|
185
217
|
|
|
186
218
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
187
219
|
|
|
@@ -213,7 +245,7 @@ class MetaController(Module):
|
|
|
213
245
|
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
214
246
|
|
|
215
247
|
forget = 1. - switch_beta
|
|
216
|
-
gated_action = self.switch_gating(switch_beta,
|
|
248
|
+
gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
|
|
217
249
|
|
|
218
250
|
next_switch_gated_action = gated_action[:, -1]
|
|
219
251
|
|
|
@@ -233,10 +265,11 @@ class MetaController(Module):
|
|
|
233
265
|
next_hiddens = (
|
|
234
266
|
next_action_proposer_hidden,
|
|
235
267
|
next_switching_unit_gru_hidden,
|
|
236
|
-
next_switch_gated_action
|
|
268
|
+
next_switch_gated_action,
|
|
269
|
+
sampled_latent_action[:, -1:]
|
|
237
270
|
)
|
|
238
271
|
|
|
239
|
-
return control_signal, MetaControllerOutput(next_hiddens, action_dist,
|
|
272
|
+
return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
|
|
240
273
|
|
|
241
274
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
242
275
|
|
|
@@ -297,7 +330,7 @@ class Transformer(Module):
|
|
|
297
330
|
def forward(
|
|
298
331
|
self,
|
|
299
332
|
state,
|
|
300
|
-
|
|
333
|
+
actions: Tensor | None = None,
|
|
301
334
|
meta_controller: Module | None = None,
|
|
302
335
|
cache: TransformerOutput | None = None,
|
|
303
336
|
discovery_phase = False,
|
|
@@ -306,6 +339,8 @@ class Transformer(Module):
|
|
|
306
339
|
return_latents = False,
|
|
307
340
|
return_cache = False,
|
|
308
341
|
):
|
|
342
|
+
device = state.device
|
|
343
|
+
|
|
309
344
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
310
345
|
|
|
311
346
|
meta_controlling = exists(meta_controller)
|
|
@@ -325,16 +360,26 @@ class Transformer(Module):
|
|
|
325
360
|
# handle maybe behavioral cloning
|
|
326
361
|
|
|
327
362
|
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
363
|
+
assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
|
|
328
364
|
|
|
329
365
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
330
|
-
|
|
366
|
+
actions, target_actions = actions[:, :-1], actions[:, 1:]
|
|
331
367
|
|
|
332
368
|
# transformer lower body
|
|
333
369
|
|
|
334
370
|
with lower_transformer_context():
|
|
335
371
|
|
|
336
372
|
state_embed = self.state_embed(state)
|
|
337
|
-
|
|
373
|
+
|
|
374
|
+
# handle no past action for first timestep
|
|
375
|
+
|
|
376
|
+
if exists(actions):
|
|
377
|
+
action_embed = self.action_embed(actions)
|
|
378
|
+
else:
|
|
379
|
+
action_embed = state_embed[:, 0:0] # empty action embed
|
|
380
|
+
|
|
381
|
+
if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
|
|
382
|
+
action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
|
|
338
383
|
|
|
339
384
|
embed = state_embed + action_embed
|
|
340
385
|
|
|
@@ -367,13 +412,13 @@ class Transformer(Module):
|
|
|
367
412
|
state_dist_params = self.state_readout(attended)
|
|
368
413
|
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
369
414
|
|
|
370
|
-
action_clone_loss = self.action_readout.calculate_loss(dist_params,
|
|
415
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
371
416
|
|
|
372
417
|
return state_clone_loss, action_clone_loss
|
|
373
418
|
|
|
374
419
|
elif meta_controlling and discovery_phase:
|
|
375
420
|
|
|
376
|
-
action_recon_loss = self.action_readout.calculate_loss(dist_params,
|
|
421
|
+
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
377
422
|
|
|
378
423
|
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
379
424
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.22"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -29,6 +29,7 @@ dependencies = [
|
|
|
29
29
|
"einx>=0.3.0",
|
|
30
30
|
"einops>=0.8.1",
|
|
31
31
|
"loguru",
|
|
32
|
+
"memmap-replay-buffer>=0.0.1",
|
|
32
33
|
"torch>=2.5",
|
|
33
34
|
"x-evolution>=0.1.23",
|
|
34
35
|
"x-mlps-pytorch",
|
{metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/tests/test_metacontroller.py
RENAMED
|
@@ -4,6 +4,8 @@ param = pytest.mark.parametrize
|
|
|
4
4
|
import torch
|
|
5
5
|
from metacontroller.metacontroller import Transformer, MetaController
|
|
6
6
|
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
|
|
7
9
|
@param('action_discrete', (False, True))
|
|
8
10
|
@param('switch_per_latent_dim', (False, True))
|
|
9
11
|
def test_metacontroller(
|
|
@@ -49,16 +51,18 @@ def test_metacontroller(
|
|
|
49
51
|
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
|
|
50
52
|
(action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
|
|
51
53
|
|
|
52
|
-
# internal rl
|
|
54
|
+
# internal rl - done iteratively
|
|
53
55
|
|
|
54
|
-
|
|
56
|
+
cache = None
|
|
57
|
+
past_action_id = None
|
|
55
58
|
|
|
56
|
-
|
|
59
|
+
for one_state in state.unbind(dim = 1):
|
|
60
|
+
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
57
61
|
|
|
58
|
-
|
|
59
|
-
logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
|
|
62
|
+
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
|
|
60
63
|
|
|
61
|
-
|
|
64
|
+
assert logits.shape == (1, 1, *assert_shape)
|
|
65
|
+
past_action_id = model.action_readout.sample(logits)
|
|
62
66
|
|
|
63
67
|
# evolutionary strategies over grpo
|
|
64
68
|
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "fire",
|
|
4
|
+
# "gymnasium",
|
|
5
|
+
# "gymnasium[other]",
|
|
6
|
+
# "memmap-replay-buffer>=0.0.10",
|
|
7
|
+
# "metacontroller-pytorch",
|
|
8
|
+
# "minigrid",
|
|
9
|
+
# "tqdm"
|
|
10
|
+
# ]
|
|
11
|
+
# ///
|
|
12
|
+
|
|
13
|
+
from fire import Fire
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from shutil import rmtree
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
import gymnasium as gym
|
|
20
|
+
import minigrid
|
|
21
|
+
|
|
22
|
+
from memmap_replay_buffer import ReplayBuffer
|
|
23
|
+
|
|
24
|
+
# functions
|
|
25
|
+
|
|
26
|
+
def exists(v):
|
|
27
|
+
return v is not None
|
|
28
|
+
|
|
29
|
+
def default(v, d):
|
|
30
|
+
return v if exists(v) else d
|
|
31
|
+
|
|
32
|
+
def divisible_by(num, den):
|
|
33
|
+
return (num % den) == 0
|
|
34
|
+
|
|
35
|
+
# main
|
|
36
|
+
|
|
37
|
+
def main(
|
|
38
|
+
env_name = 'BabyAI-BossLevel-v0',
|
|
39
|
+
num_episodes = int(10e6),
|
|
40
|
+
max_timesteps = 500,
|
|
41
|
+
buffer_size = 5_000,
|
|
42
|
+
render_every_eps = 1_000,
|
|
43
|
+
video_folder = './recordings',
|
|
44
|
+
seed = None
|
|
45
|
+
):
|
|
46
|
+
|
|
47
|
+
# environment
|
|
48
|
+
|
|
49
|
+
env = gym.make(env_name, render_mode = 'rgb_array')
|
|
50
|
+
|
|
51
|
+
rmtree(video_folder, ignore_errors = True)
|
|
52
|
+
|
|
53
|
+
env = gym.wrappers.RecordVideo(
|
|
54
|
+
env = env,
|
|
55
|
+
video_folder = video_folder,
|
|
56
|
+
name_prefix = 'babyai',
|
|
57
|
+
episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
|
|
58
|
+
disable_logger = True
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# replay
|
|
62
|
+
|
|
63
|
+
replay_buffer = ReplayBuffer(
|
|
64
|
+
'./replay-data',
|
|
65
|
+
max_episodes = buffer_size,
|
|
66
|
+
max_timesteps = max_timesteps + 1,
|
|
67
|
+
fields = dict(
|
|
68
|
+
action = 'int',
|
|
69
|
+
state_image = ('float', (7, 7, 3)),
|
|
70
|
+
state_direction = 'int'
|
|
71
|
+
),
|
|
72
|
+
overwrite = True,
|
|
73
|
+
circular = True
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# rollouts
|
|
77
|
+
|
|
78
|
+
for _ in tqdm(range(num_episodes)):
|
|
79
|
+
|
|
80
|
+
state, *_ = env.reset(seed = seed)
|
|
81
|
+
|
|
82
|
+
for _ in range(max_timesteps):
|
|
83
|
+
|
|
84
|
+
action = torch.randint(0, 7, ())
|
|
85
|
+
next_state, reward, terminated, truncated, *_ = env.step(action.numpy())
|
|
86
|
+
|
|
87
|
+
done = terminated or truncated
|
|
88
|
+
|
|
89
|
+
if done:
|
|
90
|
+
break
|
|
91
|
+
|
|
92
|
+
state = next_state
|
|
93
|
+
|
|
94
|
+
# running
|
|
95
|
+
|
|
96
|
+
if __name__ == '__main__':
|
|
97
|
+
Fire(main)
|
{metacontroller_pytorch-0.0.20 → metacontroller_pytorch-0.0.22}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|