metacontroller-pytorch 0.0.36__tar.gz → 0.0.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (17) hide show
  1. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/PKG-INFO +13 -1
  2. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/README.md +12 -0
  3. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/metacontroller/metacontroller.py +17 -2
  4. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/metacontroller/metacontroller_with_binary_mapper.py +10 -1
  5. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/pyproject.toml +1 -1
  6. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/tests/test_metacontroller.py +1 -16
  7. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/train_behavior_clone_babyai.py +10 -14
  8. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
  9. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/.github/workflows/test.yml +0 -0
  10. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/.gitignore +0 -0
  11. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/LICENSE +0 -0
  12. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/fig1.png +0 -0
  13. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/gather_babyai_trajs.py +0 -0
  14. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/metacontroller/__init__.py +0 -0
  15. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/metacontroller/transformer_with_resnet.py +0 -0
  16. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/test_babyai_e2e.sh +0 -0
  17. {metacontroller_pytorch-0.0.36 → metacontroller_pytorch-0.0.38}/train_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.36
3
+ Version: 0.0.38
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -94,6 +94,18 @@ $ pip install metacontroller-pytorch
94
94
  }
95
95
  ```
96
96
 
97
+ ```bibtex
98
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
99
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
100
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
101
+ year = {2025},
102
+ eprint = {2507.07955},
103
+ archivePrefix = {arXiv},
104
+ primaryClass = {cs.LG},
105
+ url = {https://arxiv.org/abs/2507.07955},
106
+ }
107
+ ```
108
+
97
109
  ```bibtex
98
110
  @misc{fleuret2025freetransformer,
99
111
  title = {The Free Transformer},
@@ -41,6 +41,18 @@ $ pip install metacontroller-pytorch
41
41
  }
42
42
  ```
43
43
 
44
+ ```bibtex
45
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
46
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
47
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
48
+ year = {2025},
49
+ eprint = {2507.07955},
50
+ archivePrefix = {arXiv},
51
+ primaryClass = {cs.LG},
52
+ url = {https://arxiv.org/abs/2507.07955},
53
+ }
54
+ ```
55
+
44
56
  ```bibtex
45
57
  @misc{fleuret2025freetransformer,
46
58
  title = {The Free Transformer},
@@ -126,6 +126,7 @@ class MetaController(Module):
126
126
  )
127
127
  ):
128
128
  super().__init__()
129
+ self.dim_model = dim_model
129
130
  dim_meta = default(dim_meta_controller, dim_model)
130
131
 
131
132
  # the linear that brings from model dimension
@@ -171,6 +172,15 @@ class MetaController(Module):
171
172
 
172
173
  self.register_buffer('zero', tensor(0.), persistent = False)
173
174
 
175
+ @property
176
+ def replay_buffer_field_dict(self):
177
+ return dict(
178
+ states = ('float', self.dim_model),
179
+ log_probs = ('float', self.dim_latent),
180
+ switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
181
+ latent_actions = ('float', self.dim_latent)
182
+ )
183
+
174
184
  def discovery_parameters(self):
175
185
  return [
176
186
  *self.model_to_meta.parameters(),
@@ -408,6 +418,7 @@ class Transformer(Module):
408
418
  meta_controller: Module | None = None,
409
419
  cache: TransformerOutput | None = None,
410
420
  discovery_phase = False,
421
+ force_behavior_cloning = False,
411
422
  meta_controller_temperature = 1.,
412
423
  return_raw_action_dist = False,
413
424
  return_latents = False,
@@ -420,11 +431,15 @@ class Transformer(Module):
420
431
 
421
432
  meta_controller = default(meta_controller, self.meta_controller)
422
433
 
434
+ if force_behavior_cloning:
435
+ assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
436
+ meta_controller = None
437
+
423
438
  has_meta_controller = exists(meta_controller)
424
439
 
425
440
  assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
426
441
 
427
- behavioral_cloning = not has_meta_controller and not return_raw_action_dist
442
+ behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
428
443
 
429
444
  # by default, if meta controller is passed in, transformer is no grad
430
445
 
@@ -472,7 +487,7 @@ class Transformer(Module):
472
487
 
473
488
  with meta_controller_context():
474
489
 
475
- if exists(meta_controller):
490
+ if exists(meta_controller) and not behavioral_cloning:
476
491
  control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
477
492
  else:
478
493
  control_signal, next_meta_hiddens = self.zero, None
@@ -74,7 +74,7 @@ class MetaControllerWithBinaryMapper(Module):
74
74
  kl_loss_threshold = 0.
75
75
  ):
76
76
  super().__init__()
77
-
77
+ self.dim_model = dim_model
78
78
  assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
79
 
80
80
  dim_meta = default(dim_meta_controller, dim_model)
@@ -126,6 +126,15 @@ class MetaControllerWithBinaryMapper(Module):
126
126
 
127
127
  self.register_buffer('zero', tensor(0.), persistent = False)
128
128
 
129
+ @property
130
+ def replay_buffer_field_dict(self):
131
+ return dict(
132
+ states = ('float', self.dim_model),
133
+ log_probs = ('float', self.dim_code_bits),
134
+ switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
135
+ latent_actions = ('float', self.num_codes)
136
+ )
137
+
129
138
  def discovery_parameters(self):
130
139
  return [
131
140
  *self.model_to_meta.parameters(),
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.36"
3
+ version = "0.0.38"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -69,12 +69,6 @@ def test_metacontroller(
69
69
  dim_latent = 128,
70
70
  switch_per_latent_dim = switch_per_latent_dim
71
71
  )
72
-
73
- field_shapes = dict(
74
- log_probs = ('float', 128),
75
- switch_betas = ('float', 128 if switch_per_latent_dim else 1),
76
- latent_actions = ('float', 128)
77
- )
78
72
  else:
79
73
  meta_controller = MetaControllerWithBinaryMapper(
80
74
  dim_model = 512,
@@ -83,12 +77,6 @@ def test_metacontroller(
83
77
  dim_code_bits = 8, # 2 ** 8 = 256 codes
84
78
  )
85
79
 
86
- field_shapes = dict(
87
- log_probs = ('float', 8),
88
- switch_betas = ('float', 8 if switch_per_latent_dim else 1),
89
- latent_actions = ('float', 256)
90
- )
91
-
92
80
  # discovery phase
93
81
 
94
82
  (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
@@ -104,10 +92,7 @@ def test_metacontroller(
104
92
  test_folder,
105
93
  max_episodes = 3,
106
94
  max_timesteps = 256,
107
- fields = dict(
108
- states = ('float', 512),
109
- **field_shapes
110
- ),
95
+ fields = meta_controller.replay_buffer_field_dict,
111
96
  meta_fields = dict(
112
97
  advantages = 'float'
113
98
  )
@@ -130,8 +130,8 @@ def train(
130
130
  for epoch in range(cloning_epochs + discovery_epochs):
131
131
 
132
132
  model.train()
133
- total_state_loss = 0.
134
- total_action_loss = 0.
133
+ from collections import defaultdict
134
+ total_losses = defaultdict(float)
135
135
 
136
136
  progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
137
137
 
@@ -200,9 +200,9 @@ def train(
200
200
  optim.zero_grad()
201
201
 
202
202
  # log
203
-
204
- total_state_loss += state_loss.item()
205
- total_action_loss += action_loss.item()
203
+
204
+ for key, value in log.items():
205
+ total_losses[key] += value
206
206
 
207
207
  accelerator.log({
208
208
  **log,
@@ -210,15 +210,11 @@ def train(
210
210
  "grad_norm": grad_norm.item()
211
211
  })
212
212
 
213
- progress_bar.set_postfix(
214
- state_loss = state_loss.item(),
215
- action_loss = action_loss.item()
216
- )
217
-
218
- avg_state_loss = total_state_loss / len(dataloader)
219
- avg_action_loss = total_action_loss / len(dataloader)
213
+ progress_bar.set_postfix(**log)
220
214
 
221
- accelerator.print(f"Epoch {epoch}: state_loss={avg_state_loss:.4f}, action_loss={avg_action_loss:.4f}")
215
+ avg_losses = {k: v / len(dataloader) for k, v in total_losses.items()}
216
+ avg_losses_str = ", ".join([f"{k}={v:.4f}" for k, v in avg_losses.items()])
217
+ accelerator.print(f"Epoch {epoch}: {avg_losses_str}")
222
218
 
223
219
  # save weights
224
220
 
@@ -231,7 +227,7 @@ def train(
231
227
  unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
232
228
  unwrapped_meta_controller.save(meta_controller_checkpoint_path)
233
229
 
234
- accelerator.print(f"Model saved to {checkpoint_path}, MetaControler to {meta_controller_checkpoint_path}")
230
+ accelerator.print(f"Model saved to {checkpoint_path}, MetaController to {meta_controller_checkpoint_path}")
235
231
 
236
232
  accelerator.end_training()
237
233