metacontroller-pytorch 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,6 +46,17 @@ def default(*args):
46
46
  return arg
47
47
  return None
48
48
 
49
+ def is_empty(t):
50
+ return t.numel() == 0
51
+
52
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
53
+ if pad == (0, 0):
54
+ return t
55
+
56
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
57
+ zeros = ((0, 0) * dims_from_right)
58
+ return F.pad(t, (*zeros, *pad), value = value)
59
+
49
60
  # tensor helpers
50
61
 
51
62
  def straight_through(src, tgt):
@@ -101,7 +112,9 @@ class MetaController(Module):
101
112
 
102
113
  self.switch_per_latent_dim = switch_per_latent_dim
103
114
 
104
- self.switching_unit = GRU(dim_meta, dim_meta)
115
+
116
+ self.dim_latent = dim_latent
117
+ self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
105
118
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
106
119
 
107
120
  self.switch_gating = AssocScan(**assoc_scan_kwargs)
@@ -147,10 +160,11 @@ class MetaController(Module):
147
160
  hard_switch = False,
148
161
  temperature = 1.
149
162
  ):
163
+ device = residual_stream.device
150
164
 
151
165
  # destruct prev cache
152
166
 
153
- prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
167
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
154
168
 
155
169
  # getting proposed action for the two phases
156
170
 
@@ -175,13 +189,34 @@ class MetaController(Module):
175
189
 
176
190
  action_dist = readout(proposed_action_hidden)
177
191
 
178
- sampled_action = readout.sample(action_dist, temperature = temperature)
192
+ sampled_latent_action = readout.sample(action_dist, temperature = temperature)
179
193
 
180
194
  # switching unit timer
181
195
 
182
- batch, _, dim = sampled_action.shape
196
+ batch, seq_len, dim = sampled_latent_action.shape
197
+
198
+ # initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
199
+
200
+ if not exists(prev_sampled_latent_action):
201
+ prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
202
+
203
+ if discovery_phase:
204
+ z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
205
+
206
+ else:
207
+ # else during inference, use the previous sampled latent action
183
208
 
184
- switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
209
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
210
+ z_prev = prev_sampled_latent_action
211
+
212
+ # switch input is previous latent action and the embedding
213
+
214
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
215
+
216
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
217
+ switch_input,
218
+ prev_switching_unit_gru_hidden
219
+ )
185
220
 
186
221
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
187
222
 
@@ -213,7 +248,7 @@ class MetaController(Module):
213
248
  switch_beta = straight_through(switch_beta, hard_switch_beta)
214
249
 
215
250
  forget = 1. - switch_beta
216
- gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
251
+ gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
217
252
 
218
253
  next_switch_gated_action = gated_action[:, -1]
219
254
 
@@ -233,10 +268,11 @@ class MetaController(Module):
233
268
  next_hiddens = (
234
269
  next_action_proposer_hidden,
235
270
  next_switching_unit_gru_hidden,
236
- next_switch_gated_action
271
+ next_switch_gated_action,
272
+ sampled_latent_action[:, -1:]
237
273
  )
238
274
 
239
- return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
275
+ return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
240
276
 
241
277
  # main transformer, which is subsumed into the environment after behavioral cloning
242
278
 
@@ -297,7 +333,7 @@ class Transformer(Module):
297
333
  def forward(
298
334
  self,
299
335
  state,
300
- action_ids,
336
+ action_ids: Tensor | None = None,
301
337
  meta_controller: Module | None = None,
302
338
  cache: TransformerOutput | None = None,
303
339
  discovery_phase = False,
@@ -306,6 +342,8 @@ class Transformer(Module):
306
342
  return_latents = False,
307
343
  return_cache = False,
308
344
  ):
345
+ device = state.device
346
+
309
347
  meta_controller = default(meta_controller, self.meta_controller)
310
348
 
311
349
  meta_controlling = exists(meta_controller)
@@ -325,6 +363,7 @@ class Transformer(Module):
325
363
  # handle maybe behavioral cloning
326
364
 
327
365
  if behavioral_cloning or (meta_controlling and discovery_phase):
366
+ assert not is_empty(action_ids), f'`action_ids` cannot be empty when doing discovery or behavioral cloning'
328
367
 
329
368
  state, target_state = state[:, :-1], state[:, 1:]
330
369
  action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
@@ -334,7 +373,16 @@ class Transformer(Module):
334
373
  with lower_transformer_context():
335
374
 
336
375
  state_embed = self.state_embed(state)
337
- action_embed = self.action_embed(action_ids)
376
+
377
+ # handle no past action for first timestep
378
+
379
+ if exists(action_ids):
380
+ action_embed = self.action_embed(action_ids)
381
+ else:
382
+ action_embed = state_embed[:, 0:0] # empty action embed
383
+
384
+ if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
385
+ action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
338
386
 
339
387
  embed = state_embed + action_embed
340
388
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,6 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
+ Requires-Dist: memmap-replay-buffer>=0.0.1
42
43
  Requires-Dist: torch>=2.5
43
44
  Requires-Dist: x-evolution>=0.1.23
44
45
  Requires-Dist: x-mlps-pytorch
@@ -54,6 +55,16 @@ Description-Content-Type: text/markdown
54
55
 
55
56
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
56
57
 
58
+ ## Install
59
+
60
+ ```shell
61
+ $ pip install metacontroller-pytorch
62
+ ```
63
+
64
+ ## Appreciation
65
+
66
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
67
+
57
68
  ## Citations
58
69
 
59
70
  ```bibtex
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
3
+ metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
4
+ metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.21.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=3QZrId9z8I6MMQ3GhEQ6Xb5LFRTFJq4EAU4JCvRmm-4,12368
3
- metacontroller_pytorch-0.0.20.dist-info/METADATA,sha256=5t4rDJiJzbx7m9BNsTTgO5JOnavaX-3jv31HTGuLP6A,4034
4
- metacontroller_pytorch-0.0.20.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.20.dist-info/RECORD,,