metacontroller-pytorch 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +58 -10
- {metacontroller_pytorch-0.0.20.dist-info → metacontroller_pytorch-0.0.21.dist-info}/METADATA +12 -1
- metacontroller_pytorch-0.0.21.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.20.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.20.dist-info → metacontroller_pytorch-0.0.21.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.20.dist-info → metacontroller_pytorch-0.0.21.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -46,6 +46,17 @@ def default(*args):
|
|
|
46
46
|
return arg
|
|
47
47
|
return None
|
|
48
48
|
|
|
49
|
+
def is_empty(t):
|
|
50
|
+
return t.numel() == 0
|
|
51
|
+
|
|
52
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
|
53
|
+
if pad == (0, 0):
|
|
54
|
+
return t
|
|
55
|
+
|
|
56
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
57
|
+
zeros = ((0, 0) * dims_from_right)
|
|
58
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
59
|
+
|
|
49
60
|
# tensor helpers
|
|
50
61
|
|
|
51
62
|
def straight_through(src, tgt):
|
|
@@ -101,7 +112,9 @@ class MetaController(Module):
|
|
|
101
112
|
|
|
102
113
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
103
114
|
|
|
104
|
-
|
|
115
|
+
|
|
116
|
+
self.dim_latent = dim_latent
|
|
117
|
+
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
105
118
|
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
106
119
|
|
|
107
120
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
@@ -147,10 +160,11 @@ class MetaController(Module):
|
|
|
147
160
|
hard_switch = False,
|
|
148
161
|
temperature = 1.
|
|
149
162
|
):
|
|
163
|
+
device = residual_stream.device
|
|
150
164
|
|
|
151
165
|
# destruct prev cache
|
|
152
166
|
|
|
153
|
-
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) *
|
|
167
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
154
168
|
|
|
155
169
|
# getting proposed action for the two phases
|
|
156
170
|
|
|
@@ -175,13 +189,34 @@ class MetaController(Module):
|
|
|
175
189
|
|
|
176
190
|
action_dist = readout(proposed_action_hidden)
|
|
177
191
|
|
|
178
|
-
|
|
192
|
+
sampled_latent_action = readout.sample(action_dist, temperature = temperature)
|
|
179
193
|
|
|
180
194
|
# switching unit timer
|
|
181
195
|
|
|
182
|
-
batch,
|
|
196
|
+
batch, seq_len, dim = sampled_latent_action.shape
|
|
197
|
+
|
|
198
|
+
# initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
|
|
199
|
+
|
|
200
|
+
if not exists(prev_sampled_latent_action):
|
|
201
|
+
prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
|
|
202
|
+
|
|
203
|
+
if discovery_phase:
|
|
204
|
+
z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
|
|
205
|
+
|
|
206
|
+
else:
|
|
207
|
+
# else during inference, use the previous sampled latent action
|
|
183
208
|
|
|
184
|
-
|
|
209
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
210
|
+
z_prev = prev_sampled_latent_action
|
|
211
|
+
|
|
212
|
+
# switch input is previous latent action and the embedding
|
|
213
|
+
|
|
214
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
215
|
+
|
|
216
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
217
|
+
switch_input,
|
|
218
|
+
prev_switching_unit_gru_hidden
|
|
219
|
+
)
|
|
185
220
|
|
|
186
221
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
187
222
|
|
|
@@ -213,7 +248,7 @@ class MetaController(Module):
|
|
|
213
248
|
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
214
249
|
|
|
215
250
|
forget = 1. - switch_beta
|
|
216
|
-
gated_action = self.switch_gating(switch_beta,
|
|
251
|
+
gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
|
|
217
252
|
|
|
218
253
|
next_switch_gated_action = gated_action[:, -1]
|
|
219
254
|
|
|
@@ -233,10 +268,11 @@ class MetaController(Module):
|
|
|
233
268
|
next_hiddens = (
|
|
234
269
|
next_action_proposer_hidden,
|
|
235
270
|
next_switching_unit_gru_hidden,
|
|
236
|
-
next_switch_gated_action
|
|
271
|
+
next_switch_gated_action,
|
|
272
|
+
sampled_latent_action[:, -1:]
|
|
237
273
|
)
|
|
238
274
|
|
|
239
|
-
return control_signal, MetaControllerOutput(next_hiddens, action_dist,
|
|
275
|
+
return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
|
|
240
276
|
|
|
241
277
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
242
278
|
|
|
@@ -297,7 +333,7 @@ class Transformer(Module):
|
|
|
297
333
|
def forward(
|
|
298
334
|
self,
|
|
299
335
|
state,
|
|
300
|
-
action_ids,
|
|
336
|
+
action_ids: Tensor | None = None,
|
|
301
337
|
meta_controller: Module | None = None,
|
|
302
338
|
cache: TransformerOutput | None = None,
|
|
303
339
|
discovery_phase = False,
|
|
@@ -306,6 +342,8 @@ class Transformer(Module):
|
|
|
306
342
|
return_latents = False,
|
|
307
343
|
return_cache = False,
|
|
308
344
|
):
|
|
345
|
+
device = state.device
|
|
346
|
+
|
|
309
347
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
310
348
|
|
|
311
349
|
meta_controlling = exists(meta_controller)
|
|
@@ -325,6 +363,7 @@ class Transformer(Module):
|
|
|
325
363
|
# handle maybe behavioral cloning
|
|
326
364
|
|
|
327
365
|
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
366
|
+
assert not is_empty(action_ids), f'`action_ids` cannot be empty when doing discovery or behavioral cloning'
|
|
328
367
|
|
|
329
368
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
330
369
|
action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
|
|
@@ -334,7 +373,16 @@ class Transformer(Module):
|
|
|
334
373
|
with lower_transformer_context():
|
|
335
374
|
|
|
336
375
|
state_embed = self.state_embed(state)
|
|
337
|
-
|
|
376
|
+
|
|
377
|
+
# handle no past action for first timestep
|
|
378
|
+
|
|
379
|
+
if exists(action_ids):
|
|
380
|
+
action_embed = self.action_embed(action_ids)
|
|
381
|
+
else:
|
|
382
|
+
action_embed = state_embed[:, 0:0] # empty action embed
|
|
383
|
+
|
|
384
|
+
if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
|
|
385
|
+
action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
|
|
338
386
|
|
|
339
387
|
embed = state_embed + action_embed
|
|
340
388
|
|
{metacontroller_pytorch-0.0.20.dist-info → metacontroller_pytorch-0.0.21.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -39,6 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
|
+
Requires-Dist: memmap-replay-buffer>=0.0.1
|
|
42
43
|
Requires-Dist: torch>=2.5
|
|
43
44
|
Requires-Dist: x-evolution>=0.1.23
|
|
44
45
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -54,6 +55,16 @@ Description-Content-Type: text/markdown
|
|
|
54
55
|
|
|
55
56
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
56
57
|
|
|
58
|
+
## Install
|
|
59
|
+
|
|
60
|
+
```shell
|
|
61
|
+
$ pip install metacontroller-pytorch
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Appreciation
|
|
65
|
+
|
|
66
|
+
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
67
|
+
|
|
57
68
|
## Citations
|
|
58
69
|
|
|
59
70
|
```bibtex
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
|
|
3
|
+
metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
|
|
4
|
+
metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.21.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=3QZrId9z8I6MMQ3GhEQ6Xb5LFRTFJq4EAU4JCvRmm-4,12368
|
|
3
|
-
metacontroller_pytorch-0.0.20.dist-info/METADATA,sha256=5t4rDJiJzbx7m9BNsTTgO5JOnavaX-3jv31HTGuLP6A,4034
|
|
4
|
-
metacontroller_pytorch-0.0.20.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.20.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.20.dist-info → metacontroller_pytorch-0.0.21.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|