metacontroller-pytorch 0.0.37__tar.gz → 0.0.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (17) hide show
  1. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/PKG-INFO +13 -1
  2. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/README.md +12 -0
  3. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/metacontroller/metacontroller.py +10 -0
  4. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/metacontroller/metacontroller_with_binary_mapper.py +10 -1
  5. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/pyproject.toml +1 -1
  6. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/tests/test_metacontroller.py +1 -16
  7. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
  8. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/.github/workflows/test.yml +0 -0
  9. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/.gitignore +0 -0
  10. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/LICENSE +0 -0
  11. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/fig1.png +0 -0
  12. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/gather_babyai_trajs.py +0 -0
  13. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/metacontroller/__init__.py +0 -0
  14. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/metacontroller/transformer_with_resnet.py +0 -0
  15. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/test_babyai_e2e.sh +0 -0
  16. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/train_babyai.py +0 -0
  17. {metacontroller_pytorch-0.0.37 → metacontroller_pytorch-0.0.38}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.37
3
+ Version: 0.0.38
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -94,6 +94,18 @@ $ pip install metacontroller-pytorch
94
94
  }
95
95
  ```
96
96
 
97
+ ```bibtex
98
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
99
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
100
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
101
+ year = {2025},
102
+ eprint = {2507.07955},
103
+ archivePrefix = {arXiv},
104
+ primaryClass = {cs.LG},
105
+ url = {https://arxiv.org/abs/2507.07955},
106
+ }
107
+ ```
108
+
97
109
  ```bibtex
98
110
  @misc{fleuret2025freetransformer,
99
111
  title = {The Free Transformer},
@@ -41,6 +41,18 @@ $ pip install metacontroller-pytorch
41
41
  }
42
42
  ```
43
43
 
44
+ ```bibtex
45
+ @misc{hwang2025dynamicchunkingendtoendhierarchical,
46
+ title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
47
+ author = {Sukjun Hwang and Brandon Wang and Albert Gu},
48
+ year = {2025},
49
+ eprint = {2507.07955},
50
+ archivePrefix = {arXiv},
51
+ primaryClass = {cs.LG},
52
+ url = {https://arxiv.org/abs/2507.07955},
53
+ }
54
+ ```
55
+
44
56
  ```bibtex
45
57
  @misc{fleuret2025freetransformer,
46
58
  title = {The Free Transformer},
@@ -126,6 +126,7 @@ class MetaController(Module):
126
126
  )
127
127
  ):
128
128
  super().__init__()
129
+ self.dim_model = dim_model
129
130
  dim_meta = default(dim_meta_controller, dim_model)
130
131
 
131
132
  # the linear that brings from model dimension
@@ -171,6 +172,15 @@ class MetaController(Module):
171
172
 
172
173
  self.register_buffer('zero', tensor(0.), persistent = False)
173
174
 
175
+ @property
176
+ def replay_buffer_field_dict(self):
177
+ return dict(
178
+ states = ('float', self.dim_model),
179
+ log_probs = ('float', self.dim_latent),
180
+ switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
181
+ latent_actions = ('float', self.dim_latent)
182
+ )
183
+
174
184
  def discovery_parameters(self):
175
185
  return [
176
186
  *self.model_to_meta.parameters(),
@@ -74,7 +74,7 @@ class MetaControllerWithBinaryMapper(Module):
74
74
  kl_loss_threshold = 0.
75
75
  ):
76
76
  super().__init__()
77
-
77
+ self.dim_model = dim_model
78
78
  assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
79
 
80
80
  dim_meta = default(dim_meta_controller, dim_model)
@@ -126,6 +126,15 @@ class MetaControllerWithBinaryMapper(Module):
126
126
 
127
127
  self.register_buffer('zero', tensor(0.), persistent = False)
128
128
 
129
+ @property
130
+ def replay_buffer_field_dict(self):
131
+ return dict(
132
+ states = ('float', self.dim_model),
133
+ log_probs = ('float', self.dim_code_bits),
134
+ switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
135
+ latent_actions = ('float', self.num_codes)
136
+ )
137
+
129
138
  def discovery_parameters(self):
130
139
  return [
131
140
  *self.model_to_meta.parameters(),
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.37"
3
+ version = "0.0.38"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -69,12 +69,6 @@ def test_metacontroller(
69
69
  dim_latent = 128,
70
70
  switch_per_latent_dim = switch_per_latent_dim
71
71
  )
72
-
73
- field_shapes = dict(
74
- log_probs = ('float', 128),
75
- switch_betas = ('float', 128 if switch_per_latent_dim else 1),
76
- latent_actions = ('float', 128)
77
- )
78
72
  else:
79
73
  meta_controller = MetaControllerWithBinaryMapper(
80
74
  dim_model = 512,
@@ -83,12 +77,6 @@ def test_metacontroller(
83
77
  dim_code_bits = 8, # 2 ** 8 = 256 codes
84
78
  )
85
79
 
86
- field_shapes = dict(
87
- log_probs = ('float', 8),
88
- switch_betas = ('float', 8 if switch_per_latent_dim else 1),
89
- latent_actions = ('float', 256)
90
- )
91
-
92
80
  # discovery phase
93
81
 
94
82
  (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True, episode_lens = episode_lens)
@@ -104,10 +92,7 @@ def test_metacontroller(
104
92
  test_folder,
105
93
  max_episodes = 3,
106
94
  max_timesteps = 256,
107
- fields = dict(
108
- states = ('float', 512),
109
- **field_shapes
110
- ),
95
+ fields = meta_controller.replay_buffer_field_dict,
111
96
  meta_fields = dict(
112
97
  advantages = 'float'
113
98
  )