metacontroller-pytorch 0.0.36__py3-none-any.whl → 0.0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- metacontroller/metacontroller.py +17 -2
- metacontroller/metacontroller_with_binary_mapper.py +10 -1
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.38.dist-info}/METADATA +13 -1
- metacontroller_pytorch-0.0.38.dist-info/RECORD +8 -0
- metacontroller_pytorch-0.0.36.dist-info/RECORD +0 -8
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.38.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.38.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -126,6 +126,7 @@ class MetaController(Module):
|
|
|
126
126
|
)
|
|
127
127
|
):
|
|
128
128
|
super().__init__()
|
|
129
|
+
self.dim_model = dim_model
|
|
129
130
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
130
131
|
|
|
131
132
|
# the linear that brings from model dimension
|
|
@@ -171,6 +172,15 @@ class MetaController(Module):
|
|
|
171
172
|
|
|
172
173
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
173
174
|
|
|
175
|
+
@property
|
|
176
|
+
def replay_buffer_field_dict(self):
|
|
177
|
+
return dict(
|
|
178
|
+
states = ('float', self.dim_model),
|
|
179
|
+
log_probs = ('float', self.dim_latent),
|
|
180
|
+
switch_betas = ('float', self.dim_latent if self.switch_per_latent_dim else 1),
|
|
181
|
+
latent_actions = ('float', self.dim_latent)
|
|
182
|
+
)
|
|
183
|
+
|
|
174
184
|
def discovery_parameters(self):
|
|
175
185
|
return [
|
|
176
186
|
*self.model_to_meta.parameters(),
|
|
@@ -408,6 +418,7 @@ class Transformer(Module):
|
|
|
408
418
|
meta_controller: Module | None = None,
|
|
409
419
|
cache: TransformerOutput | None = None,
|
|
410
420
|
discovery_phase = False,
|
|
421
|
+
force_behavior_cloning = False,
|
|
411
422
|
meta_controller_temperature = 1.,
|
|
412
423
|
return_raw_action_dist = False,
|
|
413
424
|
return_latents = False,
|
|
@@ -420,11 +431,15 @@ class Transformer(Module):
|
|
|
420
431
|
|
|
421
432
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
422
433
|
|
|
434
|
+
if force_behavior_cloning:
|
|
435
|
+
assert not discovery_phase, 'discovery phase cannot be set to True if force behavioral cloning is set to True'
|
|
436
|
+
meta_controller = None
|
|
437
|
+
|
|
423
438
|
has_meta_controller = exists(meta_controller)
|
|
424
439
|
|
|
425
440
|
assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
|
|
426
441
|
|
|
427
|
-
behavioral_cloning = not has_meta_controller and not return_raw_action_dist
|
|
442
|
+
behavioral_cloning = force_behavior_cloning or (not has_meta_controller and not return_raw_action_dist)
|
|
428
443
|
|
|
429
444
|
# by default, if meta controller is passed in, transformer is no grad
|
|
430
445
|
|
|
@@ -472,7 +487,7 @@ class Transformer(Module):
|
|
|
472
487
|
|
|
473
488
|
with meta_controller_context():
|
|
474
489
|
|
|
475
|
-
if exists(meta_controller):
|
|
490
|
+
if exists(meta_controller) and not behavioral_cloning:
|
|
476
491
|
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature, episode_lens = episode_lens)
|
|
477
492
|
else:
|
|
478
493
|
control_signal, next_meta_hiddens = self.zero, None
|
|
@@ -74,7 +74,7 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
74
74
|
kl_loss_threshold = 0.
|
|
75
75
|
):
|
|
76
76
|
super().__init__()
|
|
77
|
-
|
|
77
|
+
self.dim_model = dim_model
|
|
78
78
|
assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
|
|
79
79
|
|
|
80
80
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
@@ -126,6 +126,15 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
126
126
|
|
|
127
127
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
128
128
|
|
|
129
|
+
@property
|
|
130
|
+
def replay_buffer_field_dict(self):
|
|
131
|
+
return dict(
|
|
132
|
+
states = ('float', self.dim_model),
|
|
133
|
+
log_probs = ('float', self.dim_code_bits),
|
|
134
|
+
switch_betas = ('float', self.num_codes if self.switch_per_code else 1),
|
|
135
|
+
latent_actions = ('float', self.num_codes)
|
|
136
|
+
)
|
|
137
|
+
|
|
129
138
|
def discovery_parameters(self):
|
|
130
139
|
return [
|
|
131
140
|
*self.model_to_meta.parameters(),
|
{metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.38.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.38
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -94,6 +94,18 @@ $ pip install metacontroller-pytorch
|
|
|
94
94
|
}
|
|
95
95
|
```
|
|
96
96
|
|
|
97
|
+
```bibtex
|
|
98
|
+
@misc{hwang2025dynamicchunkingendtoendhierarchical,
|
|
99
|
+
title = {Dynamic Chunking for End-to-End Hierarchical Sequence Modeling},
|
|
100
|
+
author = {Sukjun Hwang and Brandon Wang and Albert Gu},
|
|
101
|
+
year = {2025},
|
|
102
|
+
eprint = {2507.07955},
|
|
103
|
+
archivePrefix = {arXiv},
|
|
104
|
+
primaryClass = {cs.LG},
|
|
105
|
+
url = {https://arxiv.org/abs/2507.07955},
|
|
106
|
+
}
|
|
107
|
+
```
|
|
108
|
+
|
|
97
109
|
```bibtex
|
|
98
110
|
@misc{fleuret2025freetransformer,
|
|
99
111
|
title = {The Free Transformer},
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=bhgCqqM-dfysGrMtZYe2w87lRVkf8fETjxUCdjrnI8Q,17386
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
|
|
4
|
+
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
+
metacontroller_pytorch-0.0.38.dist-info/METADATA,sha256=xFax5wjB3faXWgHNeHoRG1uOxl2S3-xeQuqJvrY3YTo,5114
|
|
6
|
+
metacontroller_pytorch-0.0.38.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.38.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=rVYzBJ8jQx9tfkZ3B9NdxKTI7dyBxXtTl4kwfizYuis,16728
|
|
3
|
-
metacontroller/metacontroller_with_binary_mapper.py,sha256=odZs49ZWY7_FfEweYkD0moX7Vn0jGd91FjFTxzjLyr8,9480
|
|
4
|
-
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
-
metacontroller_pytorch-0.0.36.dist-info/METADATA,sha256=eLKG8B0gSZyIMkaLvjYE8SvWVN387BuuNFoOC_6lmT4,4747
|
|
6
|
-
metacontroller_pytorch-0.0.36.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
metacontroller_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
metacontroller_pytorch-0.0.36.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.36.dist-info → metacontroller_pytorch-0.0.38.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|