metacontroller-pytorch 0.0.20__tar.gz → 0.0.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -1,3 +1,6 @@
1
+ replay-data/
2
+ recordings/
3
+
1
4
  # Byte-compiled / optimized / DLL files
2
5
  __pycache__/
3
6
  *.py[codz]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.20
3
+ Version: 0.0.22
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,6 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
+ Requires-Dist: memmap-replay-buffer>=0.0.1
42
43
  Requires-Dist: torch>=2.5
43
44
  Requires-Dist: x-evolution>=0.1.23
44
45
  Requires-Dist: x-mlps-pytorch
@@ -54,6 +55,16 @@ Description-Content-Type: text/markdown
54
55
 
55
56
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
56
57
 
58
+ ## Install
59
+
60
+ ```shell
61
+ $ pip install metacontroller-pytorch
62
+ ```
63
+
64
+ ## Appreciation
65
+
66
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
67
+
57
68
  ## Citations
58
69
 
59
70
  ```bibtex
@@ -4,6 +4,16 @@
4
4
 
5
5
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
6
6
 
7
+ ## Install
8
+
9
+ ```shell
10
+ $ pip install metacontroller-pytorch
11
+ ```
12
+
13
+ ## Appreciation
14
+
15
+ - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
16
+
7
17
  ## Citations
8
18
 
9
19
  ```bibtex
@@ -46,6 +46,14 @@ def default(*args):
46
46
  return arg
47
47
  return None
48
48
 
49
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
50
+ if pad == (0, 0):
51
+ return t
52
+
53
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
54
+ zeros = ((0, 0) * dims_from_right)
55
+ return F.pad(t, (*zeros, *pad), value = value)
56
+
49
57
  # tensor helpers
50
58
 
51
59
  def straight_through(src, tgt):
@@ -101,7 +109,9 @@ class MetaController(Module):
101
109
 
102
110
  self.switch_per_latent_dim = switch_per_latent_dim
103
111
 
104
- self.switching_unit = GRU(dim_meta, dim_meta)
112
+
113
+ self.dim_latent = dim_latent
114
+ self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
105
115
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
106
116
 
107
117
  self.switch_gating = AssocScan(**assoc_scan_kwargs)
@@ -147,10 +157,11 @@ class MetaController(Module):
147
157
  hard_switch = False,
148
158
  temperature = 1.
149
159
  ):
160
+ device = residual_stream.device
150
161
 
151
162
  # destruct prev cache
152
163
 
153
- prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
164
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_latent_action = cache.prev_hiddens if exists(cache) else ((None,) * 4)
154
165
 
155
166
  # getting proposed action for the two phases
156
167
 
@@ -175,13 +186,34 @@ class MetaController(Module):
175
186
 
176
187
  action_dist = readout(proposed_action_hidden)
177
188
 
178
- sampled_action = readout.sample(action_dist, temperature = temperature)
189
+ sampled_latent_action = readout.sample(action_dist, temperature = temperature)
179
190
 
180
191
  # switching unit timer
181
192
 
182
- batch, _, dim = sampled_action.shape
193
+ batch, seq_len, dim = sampled_latent_action.shape
194
+
195
+ # initialize prev sampled latent action to be zeros if not available (for first timestep and for discovery phase)
196
+
197
+ if not exists(prev_sampled_latent_action):
198
+ prev_sampled_latent_action = torch.zeros(batch, 1, self.dim_latent, device = device)
199
+
200
+ if discovery_phase:
201
+ z_prev = cat((prev_sampled_latent_action, sampled_latent_action[:, :-1]), dim = 1)
202
+
203
+ else:
204
+ # else during inference, use the previous sampled latent action
205
+
206
+ assert seq_len == 1, f'inference RL phase must be done one token at a time'
207
+ z_prev = prev_sampled_latent_action
183
208
 
184
- switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
209
+ # switch input is previous latent action and the embedding
210
+
211
+ switch_input = torch.cat((meta_embed, z_prev), dim=-1)
212
+
213
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
214
+ switch_input,
215
+ prev_switching_unit_gru_hidden
216
+ )
185
217
 
186
218
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
187
219
 
@@ -213,7 +245,7 @@ class MetaController(Module):
213
245
  switch_beta = straight_through(switch_beta, hard_switch_beta)
214
246
 
215
247
  forget = 1. - switch_beta
216
- gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
248
+ gated_action = self.switch_gating(switch_beta, sampled_latent_action * forget, prev = prev_switch_gated_hiddens)
217
249
 
218
250
  next_switch_gated_action = gated_action[:, -1]
219
251
 
@@ -233,10 +265,11 @@ class MetaController(Module):
233
265
  next_hiddens = (
234
266
  next_action_proposer_hidden,
235
267
  next_switching_unit_gru_hidden,
236
- next_switch_gated_action
268
+ next_switch_gated_action,
269
+ sampled_latent_action[:, -1:]
237
270
  )
238
271
 
239
- return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
272
+ return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_latent_action, kl_loss, switch_loss)
240
273
 
241
274
  # main transformer, which is subsumed into the environment after behavioral cloning
242
275
 
@@ -297,7 +330,7 @@ class Transformer(Module):
297
330
  def forward(
298
331
  self,
299
332
  state,
300
- action_ids,
333
+ actions: Tensor | None = None,
301
334
  meta_controller: Module | None = None,
302
335
  cache: TransformerOutput | None = None,
303
336
  discovery_phase = False,
@@ -306,6 +339,8 @@ class Transformer(Module):
306
339
  return_latents = False,
307
340
  return_cache = False,
308
341
  ):
342
+ device = state.device
343
+
309
344
  meta_controller = default(meta_controller, self.meta_controller)
310
345
 
311
346
  meta_controlling = exists(meta_controller)
@@ -325,16 +360,26 @@ class Transformer(Module):
325
360
  # handle maybe behavioral cloning
326
361
 
327
362
  if behavioral_cloning or (meta_controlling and discovery_phase):
363
+ assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
328
364
 
329
365
  state, target_state = state[:, :-1], state[:, 1:]
330
- action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
366
+ actions, target_actions = actions[:, :-1], actions[:, 1:]
331
367
 
332
368
  # transformer lower body
333
369
 
334
370
  with lower_transformer_context():
335
371
 
336
372
  state_embed = self.state_embed(state)
337
- action_embed = self.action_embed(action_ids)
373
+
374
+ # handle no past action for first timestep
375
+
376
+ if exists(actions):
377
+ action_embed = self.action_embed(actions)
378
+ else:
379
+ action_embed = state_embed[:, 0:0] # empty action embed
380
+
381
+ if action_embed.shape[-2] == (state_embed.shape[-2] - 1):
382
+ action_embed = pad_at_dim(action_embed, (1, 0), dim = 1)
338
383
 
339
384
  embed = state_embed + action_embed
340
385
 
@@ -367,13 +412,13 @@ class Transformer(Module):
367
412
  state_dist_params = self.state_readout(attended)
368
413
  state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
369
414
 
370
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
415
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
371
416
 
372
417
  return state_clone_loss, action_clone_loss
373
418
 
374
419
  elif meta_controlling and discovery_phase:
375
420
 
376
- action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
421
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
377
422
 
378
423
  return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
379
424
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.20"
3
+ version = "0.0.22"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,6 +29,7 @@ dependencies = [
29
29
  "einx>=0.3.0",
30
30
  "einops>=0.8.1",
31
31
  "loguru",
32
+ "memmap-replay-buffer>=0.0.1",
32
33
  "torch>=2.5",
33
34
  "x-evolution>=0.1.23",
34
35
  "x-mlps-pytorch",
@@ -4,6 +4,8 @@ param = pytest.mark.parametrize
4
4
  import torch
5
5
  from metacontroller.metacontroller import Transformer, MetaController
6
6
 
7
+ from einops import rearrange
8
+
7
9
  @param('action_discrete', (False, True))
8
10
  @param('switch_per_latent_dim', (False, True))
9
11
  def test_metacontroller(
@@ -49,16 +51,18 @@ def test_metacontroller(
49
51
  (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
50
52
  (action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
51
53
 
52
- # internal rl
54
+ # internal rl - done iteratively
53
55
 
54
- logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True)
56
+ cache = None
57
+ past_action_id = None
55
58
 
56
- assert logits.shape == (1, 1024, *assert_shape)
59
+ for one_state in state.unbind(dim = 1):
60
+ one_state = rearrange(one_state, 'b d -> b 1 d')
57
61
 
58
- logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
59
- logits, cache = model(state, actions, meta_controller = meta_controller, return_cache = True, cache = cache)
62
+ logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
60
63
 
61
- assert logits.shape == (1, 1, *assert_shape)
64
+ assert logits.shape == (1, 1, *assert_shape)
65
+ past_action_id = model.action_readout.sample(logits)
62
66
 
63
67
  # evolutionary strategies over grpo
64
68
 
@@ -0,0 +1,97 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "fire",
4
+ # "gymnasium",
5
+ # "gymnasium[other]",
6
+ # "memmap-replay-buffer>=0.0.10",
7
+ # "metacontroller-pytorch",
8
+ # "minigrid",
9
+ # "tqdm"
10
+ # ]
11
+ # ///
12
+
13
+ from fire import Fire
14
+ from tqdm import tqdm
15
+ from shutil import rmtree
16
+
17
+ import torch
18
+
19
+ import gymnasium as gym
20
+ import minigrid
21
+
22
+ from memmap_replay_buffer import ReplayBuffer
23
+
24
+ # functions
25
+
26
+ def exists(v):
27
+ return v is not None
28
+
29
+ def default(v, d):
30
+ return v if exists(v) else d
31
+
32
+ def divisible_by(num, den):
33
+ return (num % den) == 0
34
+
35
+ # main
36
+
37
+ def main(
38
+ env_name = 'BabyAI-BossLevel-v0',
39
+ num_episodes = int(10e6),
40
+ max_timesteps = 500,
41
+ buffer_size = 5_000,
42
+ render_every_eps = 1_000,
43
+ video_folder = './recordings',
44
+ seed = None
45
+ ):
46
+
47
+ # environment
48
+
49
+ env = gym.make(env_name, render_mode = 'rgb_array')
50
+
51
+ rmtree(video_folder, ignore_errors = True)
52
+
53
+ env = gym.wrappers.RecordVideo(
54
+ env = env,
55
+ video_folder = video_folder,
56
+ name_prefix = 'babyai',
57
+ episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
58
+ disable_logger = True
59
+ )
60
+
61
+ # replay
62
+
63
+ replay_buffer = ReplayBuffer(
64
+ './replay-data',
65
+ max_episodes = buffer_size,
66
+ max_timesteps = max_timesteps + 1,
67
+ fields = dict(
68
+ action = 'int',
69
+ state_image = ('float', (7, 7, 3)),
70
+ state_direction = 'int'
71
+ ),
72
+ overwrite = True,
73
+ circular = True
74
+ )
75
+
76
+ # rollouts
77
+
78
+ for _ in tqdm(range(num_episodes)):
79
+
80
+ state, *_ = env.reset(seed = seed)
81
+
82
+ for _ in range(max_timesteps):
83
+
84
+ action = torch.randint(0, 7, ())
85
+ next_state, reward, terminated, truncated, *_ = env.step(action.numpy())
86
+
87
+ done = terminated or truncated
88
+
89
+ if done:
90
+ break
91
+
92
+ state = next_state
93
+
94
+ # running
95
+
96
+ if __name__ == '__main__':
97
+ Fire(main)