metacontroller-pytorch 0.0.36__py3-none-any.whl → 0.0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -126,6 +126,7 @@ class MetaController(Module):
126
126
  )
127
127
  ):
128
128
  super().__init__()
129
+ self.dim_model = dim_model
129
130
  dim_meta = default(dim_meta_controller, dim_model)
130
131
 
131
132
  # the linear that brings from model dimension
@@ -171,6 +172,15 @@ class MetaController(Module):
171
172
 
172
173
  self.register_buffer('zero', tensor(0.), persistent = False)
173
174
 
175
+ @property
176
+ def replay_buffer_field_dict(self):
177
+ return dict(
178
+ states = ('float', self.dim_model),
179
+ log_probs = ('float', self.dim_latent),
180
+ switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
181
+ latent_actions = ('float', self.dim_latent)
182
+ )
183
+
174
184
  def discovery_parameters(self):
175
185
  return [
176
186
  *self.model_to_meta.parameters(),
@@ -408,6 +418,7 @@ class Transformer(Module):
408
418
  meta_controller: Module | None = None,
409
419
  cache: TransformerOutput | None = None,
410
420
  discovery_phase = False,
421
+ force_behavior_cloning = False,
411
422
  meta_controller_temperature = 1.,
412
423
  return_raw_action_dist = False,
413
424
  return_latents = False,
@@ -420,11 +431,15 @@ class Transformer(Module):
420
431
 
421
432
  meta_controller = default(meta_controller, self.meta_controller)
422
433
 
434
+ if force_behavior_cloning:
435
+ assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
436
+ meta_controller = None
437
+
423
438
  has_meta_controller = exists(meta_controller)
424
439
 
425
440
  assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
426
441
 
427
- behavioral_cloning = not has_meta_controller and not return_raw_action_dist
442
+ behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
428
443
 
429
444
  # by default, if meta controller is passed in, transformer is no grad
430
445
 
@@ -472,7 +487,7 @@ class Transformer(Module):
472
487
 
473
488
  with meta_controller_context():
474
489
 
475
- if exists(meta_controller):
490
+ if exists(meta_controller) and not behavioral_cloning:
476
491
  control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
477
492
  else:
478
493
  control_signal, next_meta_hiddens = self.zero, None
@@ -74,7 +74,7 @@ class MetaControllerWithBinaryMapper(Module):
74
74
  kl_loss_threshold = 0.
75
75
  ):
76
76
  super().__init__()
77
-
77
+ self.dim_model = dim_model
78
78
  assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
79
 
80
80
  dim_meta = default(dim_meta_controller, dim_model)
@@ -126,6 +126,15 @@ class MetaControllerWithBinaryMapper(Module):
126
126
 
127
127
  self.register_buffer('zero', tensor(0.), persistent = False)
128
128
 
129
+ @property
130
+ def replay_buffer_field_dict(self):
131
+ return dict(
132
+ states = ('float', self.dim_model),
133
+ log_probs = ('float', self.dim_code_bits),
134
+ switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
135
+ latent_actions = ('float', self.num_codes)
136
+ )
137
+
129
138
  def discovery_parameters(self):
130
139
  return [
131
140
  *self.model_to_meta.parameters(),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.36
3
+ Version: 0.0.38
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -94,6 +94,18 @@ $ pip install metacontroller-pytorch
94
94
  }
95
95
  ```
96
96
 
97
+ ```bibtex
98
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
99
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
100
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
101
+ year = {2025},
102
+ eprint = {2507.07955},
103
+ archivePrefix = {arXiv},
104
+ primaryClass = {cs.LG},
105
+ url = {https://arxiv.org/abs/2507.07955},
106
+ }
107
+ ```
108
+
97
109
  ```bibtex
98
110
  @misc{fleuret2025freetransformer,
99
111
  title = {The Free Transformer},
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=bhgCqqM-dfysGrMtZYe2w87lRVkf8fETjxUCdjrnI8Q,17386
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.38.dist-info/METADATA,sha256=xFax5wjB3faXWgHNeHoRG1uOxl2S3-xeQuqJvrY3YTo,5114
6
+ metacontroller_pytorch-0.0.38.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.38.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=rVYzBJ8jQx9tfkZ3B9NdxKTI7dyBxXtTl4kwfizYuis,16728
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=odZs49ZWY7_FfEweYkD0moX7Vn0jGd91FjFTxzjLyr8,9480
4
- metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.36.dist-info/METADATA,sha256=eLKG8B0gSZyIMkaLvjYE8SvWVN387BuuNFoOC_6lmT4,4747
6
- metacontroller_pytorch-0.0.36.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.36.dist-info/RECORD,,