metacontroller-pytorch 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +69 -18
- {metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.21.dist-info}/METADATA +24 -1
- metacontroller_pytorch-0.0.21.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.19.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.21.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.21.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -18,7 +18,7 @@ from einops.layers.torch import Rearrange
|
|
|
18
18
|
|
|
19
19
|
# external modules
|
|
20
20
|
|
|
21
|
-
from x_transformers import Decoder
|
|
21
|
+
from x_transformers import Encoder, Decoder
|
|
22
22
|
from x_mlps_pytorch import Feedforwards
|
|
23
23
|
from x_evolution import EvoStrategy
|
|
24
24
|
|
|
@@ -46,6 +46,17 @@ def default(*args):
|
|
|
46
46
|
return arg
|
|
47
47
|
return None
|
|
48
48
|
|
|
49
|
+
def is_empty(t):
|
|
50
|
+
return t.numel() == 0
|
|
51
|
+
|
|
52
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
|
53
|
+
if pad == (0, 0):
|
|
54
|
+
return t
|
|
55
|
+
|
|
56
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
57
|
+
zeros = ((0, 0) * dims_from_right)
|
|
58
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
59
|
+
|
|
49
60
|
# tensor helpers
|
|
50
61
|
|
|
51
62
|
def straight_through(src, tgt):
|
|
@@ -72,7 +83,11 @@ class MetaController(Module):
|
|
|
72
83
|
decoder_expansion_factor = 2.,
|
|
73
84
|
decoder_depth = 1,
|
|
74
85
|
hypernetwork_low_rank = 16,
|
|
75
|
-
assoc_scan_kwargs: dict = dict()
|
|
86
|
+
assoc_scan_kwargs: dict = dict(),
|
|
87
|
+
bidirectional_temporal_encoder_kwargs: dict = dict(
|
|
88
|
+
attn_dim_head = 32,
|
|
89
|
+
heads = 8
|
|
90
|
+
)
|
|
76
91
|
):
|
|
77
92
|
super().__init__()
|
|
78
93
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
@@ -81,9 +96,9 @@ class MetaController(Module):
|
|
|
81
96
|
|
|
82
97
|
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
83
98
|
|
|
84
|
-
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use
|
|
99
|
+
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use bidirectional attention as placeholder
|
|
85
100
|
|
|
86
|
-
self.
|
|
101
|
+
self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
|
|
87
102
|
|
|
88
103
|
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
89
104
|
self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
|
|
@@ -97,7 +112,9 @@ class MetaController(Module):
|
|
|
97
112
|
|
|
98
113
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
99
114
|
|
|
100
|
-
|
|
115
|
+
|
|
116
|
+
self.dim_latent = dim_latent
|
|
117
|
+
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
101
118
|
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
102
119
|
|
|
103
120
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
@@ -122,7 +139,7 @@ class MetaController(Module):
|
|
|
122
139
|
def discovery_parameters(self):
|
|
123
140
|
return [
|
|
124
141
|
*self.model_to_meta.parameters(),
|
|
125
|
-
*self.
|
|
142
|
+
*self.bidirectional_temporal_encoder.parameters(),
|
|
126
143
|
*self.emitter.parameters(),
|
|
127
144
|
*self.emitter_to_action_mean_log_var.parameters(),
|
|
128
145
|
*self.decoder.parameters(),
|
|
@@ -143,10 +160,11 @@ class MetaController(Module):
|
|
|
143
160
|
hard_switch = False,
|
|
144
161
|
temperature = 1.
|
|
145
162
|
):
|
|
163
|
+
device = residual_stream.device
|
|
146
164
|
|
|
147
165
|
# destruct prev cache
|
|
148
166
|
|
|
149
|
-
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) *
|
|
167
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
150
168
|
|
|
151
169
|
# getting proposed action for the two phases
|
|
152
170
|
|
|
@@ -157,10 +175,9 @@ class MetaController(Module):
|
|
|
157
175
|
if discovery_phase:
|
|
158
176
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
159
177
|
|
|
160
|
-
|
|
161
|
-
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
178
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed)
|
|
162
179
|
|
|
163
|
-
proposed_action_hidden, _ = self.emitter(cat((
|
|
180
|
+
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
164
181
|
readout = self.emitter_to_action_mean_log_var
|
|
165
182
|
|
|
166
183
|
else: # else internal rl phase
|
|
@@ -172,13 +189,34 @@ class MetaController(Module):
|
|
|
172
189
|
|
|
173
190
|
action_dist = readout(proposed_action_hidden)
|
|
174
191
|
|
|
175
|
-
|
|
192
|
+
sampled_latent_action = readout.sample(action_dist, temperature = temperature)
|
|
176
193
|
|
|
177
194
|
# switching unit timer
|
|
178
195
|
|
|
179
|
-
batch,
|
|
196
|
+
batch, seq_len, dim = sampled_latent_action.shape
|
|
197
|
+
|
|
198
|
+
# initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
|
|
199
|
+
|
|
200
|
+
if not exists(prev_sampled_latent_action):
|
|
201
|
+
prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
|
|
202
|
+
|
|
203
|
+
if discovery_phase:
|
|
204
|
+
z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
|
|
205
|
+
|
|
206
|
+
else:
|
|
207
|
+
# else during inference, use the previous sampled latent action
|
|
180
208
|
|
|
181
|
-
|
|
209
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
210
|
+
z_prev = prev_sampled_latent_action
|
|
211
|
+
|
|
212
|
+
# switch input is previous latent action and the embedding
|
|
213
|
+
|
|
214
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
215
|
+
|
|
216
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
217
|
+
switch_input,
|
|
218
|
+
prev_switching_unit_gru_hidden
|
|
219
|
+
)
|
|
182
220
|
|
|
183
221
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
184
222
|
|
|
@@ -210,7 +248,7 @@ class MetaController(Module):
|
|
|
210
248
|
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
211
249
|
|
|
212
250
|
forget = 1. - switch_beta
|
|
213
|
-
gated_action = self.switch_gating(switch_beta,
|
|
251
|
+
gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
|
|
214
252
|
|
|
215
253
|
next_switch_gated_action = gated_action[:, -1]
|
|
216
254
|
|
|
@@ -230,10 +268,11 @@ class MetaController(Module):
|
|
|
230
268
|
next_hiddens = (
|
|
231
269
|
next_action_proposer_hidden,
|
|
232
270
|
next_switching_unit_gru_hidden,
|
|
233
|
-
next_switch_gated_action
|
|
271
|
+
next_switch_gated_action,
|
|
272
|
+
sampled_latent_action[:, -1:]
|
|
234
273
|
)
|
|
235
274
|
|
|
236
|
-
return control_signal, MetaControllerOutput(next_hiddens, action_dist,
|
|
275
|
+
return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
|
|
237
276
|
|
|
238
277
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
239
278
|
|
|
@@ -294,7 +333,7 @@ class Transformer(Module):
|
|
|
294
333
|
def forward(
|
|
295
334
|
self,
|
|
296
335
|
state,
|
|
297
|
-
action_ids,
|
|
336
|
+
action_ids: Tensor | None = None,
|
|
298
337
|
meta_controller: Module | None = None,
|
|
299
338
|
cache: TransformerOutput | None = None,
|
|
300
339
|
discovery_phase = False,
|
|
@@ -303,6 +342,8 @@ class Transformer(Module):
|
|
|
303
342
|
return_latents = False,
|
|
304
343
|
return_cache = False,
|
|
305
344
|
):
|
|
345
|
+
device = state.device
|
|
346
|
+
|
|
306
347
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
307
348
|
|
|
308
349
|
meta_controlling = exists(meta_controller)
|
|
@@ -322,6 +363,7 @@ class Transformer(Module):
|
|
|
322
363
|
# handle maybe behavioral cloning
|
|
323
364
|
|
|
324
365
|
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
366
|
+
assert not is_empty(action_ids), f'`action_ids` cannot be empty when doing discovery or behavioral cloning'
|
|
325
367
|
|
|
326
368
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
327
369
|
action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
|
|
@@ -331,7 +373,16 @@ class Transformer(Module):
|
|
|
331
373
|
with lower_transformer_context():
|
|
332
374
|
|
|
333
375
|
state_embed = self.state_embed(state)
|
|
334
|
-
|
|
376
|
+
|
|
377
|
+
# handle no past action for first timestep
|
|
378
|
+
|
|
379
|
+
if exists(action_ids):
|
|
380
|
+
action_embed = self.action_embed(action_ids)
|
|
381
|
+
else:
|
|
382
|
+
action_embed = state_embed[:, 0:0] # empty action embed
|
|
383
|
+
|
|
384
|
+
if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
|
|
385
|
+
action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
|
|
335
386
|
|
|
336
387
|
embed = state_embed + action_embed
|
|
337
388
|
|
{metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.21.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -39,6 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
|
+
Requires-Dist: memmap-replay-buffer>=0.0.1
|
|
42
43
|
Requires-Dist: torch>=2.5
|
|
43
44
|
Requires-Dist: x-evolution>=0.1.23
|
|
44
45
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -54,6 +55,16 @@ Description-Content-Type: text/markdown
|
|
|
54
55
|
|
|
55
56
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
56
57
|
|
|
58
|
+
## Install
|
|
59
|
+
|
|
60
|
+
```shell
|
|
61
|
+
$ pip install metacontroller-pytorch
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Appreciation
|
|
65
|
+
|
|
66
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
67
|
+
|
|
57
68
|
## Citations
|
|
58
69
|
|
|
59
70
|
```bibtex
|
|
@@ -78,3 +89,15 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
|
|
|
78
89
|
url = {https://api.semanticscholar.org/CorpusID:279464702}
|
|
79
90
|
}
|
|
80
91
|
```
|
|
92
|
+
|
|
93
|
+
```bibtex
|
|
94
|
+
@misc{fleuret2025freetransformer,
|
|
95
|
+
title = {The Free Transformer},
|
|
96
|
+
author = {François Fleuret},
|
|
97
|
+
year = {2025},
|
|
98
|
+
eprint = {2510.17558},
|
|
99
|
+
archivePrefix = {arXiv},
|
|
100
|
+
primaryClass = {cs.LG},
|
|
101
|
+
url = {https://arxiv.org/abs/2510.17558},
|
|
102
|
+
}
|
|
103
|
+
```
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
|
|
3
|
+
metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
|
|
4
|
+
metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.21.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=GTErzikqVd8XDY8pmDnY8t4uIjbGCUd1GZBJX13peo8,12339
|
|
3
|
-
metacontroller_pytorch-0.0.19.dist-info/METADATA,sha256=lX3L7J3CKoSyxvJniLdSJsCu0UMEbJTxQLEw6zzT7dY,3741
|
|
4
|
-
metacontroller_pytorch-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.19.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.19.dist-info → metacontroller_pytorch-0.0.21.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|