metacontroller-pytorch 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,7 @@ from einops.layers.torch import Rearrange
18
18
 
19
19
  # external modules
20
20
 
21
- from x_transformers import Decoder
21
+ from x_transformers import Encoder, Decoder
22
22
  from x_mlps_pytorch import Feedforwards
23
23
  from x_evolution import EvoStrategy
24
24
 
@@ -72,7 +72,11 @@ class MetaController(Module):
72
72
  decoder_expansion_factor = 2.,
73
73
  decoder_depth = 1,
74
74
  hypernetwork_low_rank = 16,
75
- assoc_scan_kwargs: dict = dict()
75
+ assoc_scan_kwargs: dict = dict(),
76
+ bidirectional_temporal_encoder_kwargs: dict = dict(
77
+ attn_dim_head = 32,
78
+ heads = 8
79
+ )
76
80
  ):
77
81
  super().__init__()
78
82
  dim_meta = default(dim_meta_controller, dim_model)
@@ -81,9 +85,9 @@ class MetaController(Module):
81
85
 
82
86
  self.model_to_meta = Linear(dim_model, dim_meta)
83
87
 
84
- # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
88
+ # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use bidirectional attention as placeholder
85
89
 
86
- self.bidirectional_temporal_compressor = GRU(dim_meta, dim_meta, bidirectional = True) # revisit naming
90
+ self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
87
91
 
88
92
  self.emitter = GRU(dim_meta * 2, dim_meta * 2)
89
93
  self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
@@ -122,7 +126,7 @@ class MetaController(Module):
122
126
  def discovery_parameters(self):
123
127
  return [
124
128
  *self.model_to_meta.parameters(),
125
- *self.bidirectional_temporal_compressor.parameters(),
129
+ *self.bidirectional_temporal_encoder.parameters(),
126
130
  *self.emitter.parameters(),
127
131
  *self.emitter_to_action_mean_log_var.parameters(),
128
132
  *self.decoder.parameters(),
@@ -157,10 +161,9 @@ class MetaController(Module):
157
161
  if discovery_phase:
158
162
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
159
163
 
160
- temporal_compressed, _ = self.bidirectional_temporal_compressor(meta_embed)
161
- temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
164
+ encoded_temporal = self.bidirectional_temporal_encoder(meta_embed)
162
165
 
163
- proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, meta_embed), dim = -1))
166
+ proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
164
167
  readout = self.emitter_to_action_mean_log_var
165
168
 
166
169
  else: # else internal rl phase
@@ -256,7 +259,7 @@ class Transformer(Module):
256
259
  super().__init__()
257
260
 
258
261
  if isinstance(lower_body, dict):
259
- lower_body = Decoder(dim = dim, **lower_body)
262
+ lower_body = Decoder(dim = dim, pre_norm_has_final_norm = False, **lower_body)
260
263
 
261
264
  if isinstance(upper_body, dict):
262
265
  upper_body = Decoder(dim = dim, **upper_body)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -78,3 +78,15 @@ Implementation of the MetaController proposed in [Emergent temporal abstractions
78
78
  url = {https://api.semanticscholar.org/CorpusID:279464702}
79
79
  }
80
80
  ```
81
+
82
+ ```bibtex
83
+ @misc{fleuret2025freetransformer,
84
+ title = {The Free Transformer},
85
+ author = {François Fleuret},
86
+ year = {2025},
87
+ eprint = {2510.17558},
88
+ archivePrefix = {arXiv},
89
+ primaryClass = {cs.LG},
90
+ url = {https://arxiv.org/abs/2510.17558},
91
+ }
92
+ ```
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=3QZrId9z8I6MMQ3GhEQ6Xb5LFRTFJq4EAU4JCvRmm-4,12368
3
+ metacontroller_pytorch-0.0.20.dist-info/METADATA,sha256=5t4rDJiJzbx7m9BNsTTgO5JOnavaX-3jv31HTGuLP6A,4034
4
+ metacontroller_pytorch-0.0.20.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.20.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=xrsP8YyYFZ_Z4rZx0BiYJZT2Q3zzXZppZnPKZfg-mtg,12306
3
- metacontroller_pytorch-0.0.18.dist-info/METADATA,sha256=fwk9OgNmoPBbZYsezBTerlcAd6iPMpK2zWXFPDlhrs4,3741
4
- metacontroller_pytorch-0.0.18.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.18.dist-info/RECORD,,