metacontroller-pytorch 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,7 @@ from einops.layers.torch import Rearrange
18
18
 
19
19
  # external modules
20
20
 
21
- from x_transformers import Decoder
21
+ from x_transformers import Encoder, Decoder
22
22
  from x_mlps_pytorch import Feedforwards
23
23
  from x_evolution import EvoStrategy
24
24
 
@@ -46,6 +46,17 @@ def default(*args):
46
46
  return arg
47
47
  return None
48
48
 
49
+ def is_empty(t):
50
+ return t.numel() == 0
51
+
52
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
53
+ if pad == (0, 0):
54
+ return t
55
+
56
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
57
+ zeros = ((0, 0) * dims_from_right)
58
+ return F.pad(t, (*zeros, *pad), value = value)
59
+
49
60
  # tensor helpers
50
61
 
51
62
  def straight_through(src, tgt):
@@ -72,7 +83,11 @@ class MetaController(Module):
72
83
  decoder_expansion_factor = 2.,
73
84
  decoder_depth = 1,
74
85
  hypernetwork_low_rank = 16,
75
- assoc_scan_kwargs: dict = dict()
86
+ assoc_scan_kwargs: dict = dict(),
87
+ bidirectional_temporal_encoder_kwargs: dict = dict(
88
+ attn_dim_head = 32,
89
+ heads = 8
90
+ )
76
91
  ):
77
92
  super().__init__()
78
93
  dim_meta = default(dim_meta_controller, dim_model)
@@ -81,9 +96,9 @@ class MetaController(Module):
81
96
 
82
97
  self.model_to_meta = Linear(dim_model, dim_meta)
83
98
 
84
- # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
99
+ # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use bidirectional attention as placeholder
85
100
 
86
- self.bidirectional_temporal_compressor = GRU(dim_meta, dim_meta, bidirectional = True) # revisit naming
101
+ self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
87
102
 
88
103
  self.emitter = GRU(dim_meta * 2, dim_meta * 2)
89
104
  self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
@@ -97,7 +112,9 @@ class MetaController(Module):
97
112
 
98
113
  self.switch_per_latent_dim = switch_per_latent_dim
99
114
 
100
- self.switching_unit = GRU(dim_meta, dim_meta)
115
+
116
+ self.dim_latent = dim_latent
117
+ self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
101
118
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
102
119
 
103
120
  self.switch_gating = AssocScan(**assoc_scan_kwargs)
@@ -122,7 +139,7 @@ class MetaController(Module):
122
139
  def discovery_parameters(self):
123
140
  return [
124
141
  *self.model_to_meta.parameters(),
125
- *self.bidirectional_temporal_compressor.parameters(),
142
+ *self.bidirectional_temporal_encoder.parameters(),
126
143
  *self.emitter.parameters(),
127
144
  *self.emitter_to_action_mean_log_var.parameters(),
128
145
  *self.decoder.parameters(),
@@ -143,10 +160,11 @@ class MetaController(Module):
143
160
  hard_switch = False,
144
161
  temperature = 1.
145
162
  ):
163
+ device = residual_stream.device
146
164
 
147
165
  # destruct prev cache
148
166
 
149
- prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
167
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
150
168
 
151
169
  # getting proposed action for the two phases
152
170
 
@@ -157,10 +175,9 @@ class MetaController(Module):
157
175
  if discovery_phase:
158
176
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
159
177
 
160
- temporal_compressed, _ = self.bidirectional_temporal_compressor(meta_embed)
161
- temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
178
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed)
162
179
 
163
- proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, meta_embed), dim = -1))
180
+ proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
164
181
  readout = self.emitter_to_action_mean_log_var
165
182
 
166
183
  else: # else internal rl phase
@@ -172,13 +189,34 @@ class MetaController(Module):
172
189
 
173
190
  action_dist = readout(proposed_action_hidden)
174
191
 
175
- sampled_action = readout.sample(action_dist, temperature = temperature)
192
+ sampled_latent_action = readout.sample(action_dist, temperature = temperature)
176
193
 
177
194
  # switching unit timer
178
195
 
179
- batch, _, dim = sampled_action.shape
196
+ batch, seq_len, dim = sampled_latent_action.shape
197
+
198
+ # initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
199
+
200
+ if not exists(prev_sampled_latent_action):
201
+ prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
202
+
203
+ if discovery_phase:
204
+ z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
205
+
206
+ else:
207
+ # else during inference, use the previous sampled latent action
180
208
 
181
- switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
209
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
210
+ z_prev = prev_sampled_latent_action
211
+
212
+ # switch input is previous latent action and the embedding
213
+
214
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
215
+
216
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
217
+ switch_input,
218
+ prev_switching_unit_gru_hidden
219
+ )
182
220
 
183
221
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
184
222
 
@@ -210,7 +248,7 @@ class MetaController(Module):
210
248
  switch_beta = straight_through(switch_beta, hard_switch_beta)
211
249
 
212
250
  forget = 1. - switch_beta
213
- gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
251
+ gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
214
252
 
215
253
  next_switch_gated_action = gated_action[:, -1]
216
254
 
@@ -230,10 +268,11 @@ class MetaController(Module):
230
268
  next_hiddens = (
231
269
  next_action_proposer_hidden,
232
270
  next_switching_unit_gru_hidden,
233
- next_switch_gated_action
271
+ next_switch_gated_action,
272
+ sampled_latent_action[:, -1:]
234
273
  )
235
274
 
236
- return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
275
+ return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
237
276
 
238
277
  # main transformer, which is subsumed into the environment after behavioral cloning
239
278
 
@@ -294,7 +333,7 @@ class Transformer(Module):
294
333
  def forward(
295
334
  self,
296
335
  state,
297
- action_ids,
336
+ action_ids: Tensor | None = None,
298
337
  meta_controller: Module | None = None,
299
338
  cache: TransformerOutput | None = None,
300
339
  discovery_phase = False,
@@ -303,6 +342,8 @@ class Transformer(Module):
303
342
  return_latents = False,
304
343
  return_cache = False,
305
344
  ):
345
+ device = state.device
346
+
306
347
  meta_controller = default(meta_controller, self.meta_controller)
307
348
 
308
349
  meta_controlling = exists(meta_controller)
@@ -322,6 +363,7 @@ class Transformer(Module):
322
363
  # handle maybe behavioral cloning
323
364
 
324
365
  if behavioral_cloning or (meta_controlling and discovery_phase):
366
+ assert not is_empty(action_ids), f'`action_ids` cannot be empty when doing discovery or behavioral cloning'
325
367
 
326
368
  state, target_state = state[:, :-1], state[:, 1:]
327
369
  action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
@@ -331,7 +373,16 @@ class Transformer(Module):
331
373
  with lower_transformer_context():
332
374
 
333
375
  state_embed = self.state_embed(state)
334
- action_embed = self.action_embed(action_ids)
376
+
377
+ # handle no past action for first timestep
378
+
379
+ if exists(action_ids):
380
+ action_embed = self.action_embed(action_ids)
381
+ else:
382
+ action_embed = state_embed[:, 0:0] # empty action embed
383
+
384
+ if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
385
+ action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
335
386
 
336
387
  embed = state_embed + action_embed
337
388
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.19
3
+ Version: 0.0.21
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,6 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
+ Requires-Dist: memmap-replay-buffer>=0.0.1
42
43
  Requires-Dist: torch>=2.5
43
44
  Requires-Dist: x-evolution>=0.1.23
44
45
  Requires-Dist: x-mlps-pytorch
@@ -54,6 +55,16 @@ Description-Content-Type: text/markdown
54
55
 
55
56
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
56
57
 
58
+ ## Install
59
+
60
+ ```shell
61
+ $ pip install metacontroller-pytorch
62
+ ```
63
+
64
+ ## Appreciation
65
+
66
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
67
+
57
68
  ## Citations
58
69
 
59
70
  ```bibtex
@@ -78,3 +89,15 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
78
89
  url = {https://api.semanticscholar.org/CorpusID:279464702}
79
90
  }
80
91
  ```
92
+
93
+ ```bibtex
94
+ @misc{fleuret2025freetransformer,
95
+ title = {The Free Transformer},
96
+ author = {François Fleuret},
97
+ year = {2025},
98
+ eprint = {2510.17558},
99
+ archivePrefix = {arXiv},
100
+ primaryClass = {cs.LG},
101
+ url = {https://arxiv.org/abs/2510.17558},
102
+ }
103
+ ```
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
3
+ metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
4
+ metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.21.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=GTErzikqVd8XDY8pmDnY8t4uIjbGCUd1GZBJX13peo8,12339
3
- metacontroller_pytorch-0.0.19.dist-info/METADATA,sha256=lX3L7J3CKoSyxvJniLdSJsCu0UMEbJTxQLEw6zzT7dY,3741
4
- metacontroller_pytorch-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.19.dist-info/RECORD,,