metacontroller-pytorch 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- from metacontroller.metacontroller import MetaController
1
+ from metacontroller.metacontroller import MetaController, Transformer
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ GRPOOutput = namedtuple('GRPOOutput', (
70
+ 'state',
71
+ 'action',
72
+ 'log_prob',
73
+ 'switch_beta'
74
+ ))
75
+
69
76
  def z_score(t, eps = 1e-8):
70
77
  return (t - t.mean()) / (t.std() + eps)
71
78
 
@@ -107,6 +114,17 @@ def policy_loss(
107
114
 
108
115
  return masked_mean(losses, mask)
109
116
 
117
+ def extract_grpo_data(meta_controller, transformer_output):
118
+ meta_output = transformer_output.prev_hiddens.meta_controller
119
+
120
+ state = meta_output.input_residual_stream
121
+ action = meta_output.actions
122
+ switch_beta = meta_output.switch_beta
123
+
124
+ log_prob = meta_controller.log_prob(meta_output.action_dist, action)
125
+
126
+ return GRPOOutput(state, action, log_prob, switch_beta)
127
+
110
128
  @save_load()
111
129
  class MetaController(Module):
112
130
  def __init__(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.40
3
+ Version: 0.0.42
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./fig1.png" width="400px"></img>
55
55
 
56
- ## metacontroller (wip)
56
+ ## metacontroller
57
57
 
58
58
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
59
59
 
@@ -69,6 +69,91 @@ $ pip install metacontroller-pytorch
69
69
 
70
70
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
71
 
72
+ ## Usage
73
+
74
+ ```python
75
+ import torch
76
+ from metacontroller import Transformer, MetaController
77
+
78
+ # 1. initialize model
79
+
80
+ model = Transformer(
81
+ dim = 512,
82
+ action_embed_readout = dict(num_discrete = 4),
83
+ state_embed_readout = dict(num_continuous = 384),
84
+ lower_body = dict(depth = 2),
85
+ upper_body = dict(depth = 2)
86
+ )
87
+
88
+ state = torch.randn(2, 128, 384)
89
+ actions = torch.randint(0, 4, (2, 128))
90
+
91
+ # 2. behavioral cloning (BC)
92
+
93
+ state_loss, action_loss = model(state, actions)
94
+ (state_loss + action_loss).backward()
95
+
96
+ # 3. discovery phase
97
+
98
+ meta_controller = MetaController(
99
+ dim_model = 512,
100
+ dim_meta_controller = 256,
101
+ dim_latent = 128
102
+ )
103
+
104
+ action_recon_loss, kl_loss, switch_loss = model(
105
+ state,
106
+ actions,
107
+ meta_controller = meta_controller,
108
+ discovery_phase = True
109
+ )
110
+
111
+ (action_recon_loss + kl_loss + switch_loss).backward()
112
+
113
+ # 4. internal rl phase (GRPO)
114
+
115
+ # ... collect trajectories ...
116
+
117
+ logits, cache = model(
118
+ one_state,
119
+ past_action_id,
120
+ meta_controller = meta_controller,
121
+ return_cache = True
122
+ )
123
+
124
+ meta_output = cache.prev_hiddens.meta_controller
125
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
126
+
127
+ # ... calculate advantages ...
128
+
129
+ loss = meta_controller.policy_loss(
130
+ group_states,
131
+ group_old_log_probs,
132
+ group_latent_actions,
133
+ group_advantages,
134
+ group_switch_betas
135
+ )
136
+
137
+ loss.backward()
138
+ ```
139
+
140
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
141
+
142
+ ```python
143
+ # 5. evolve (ES over GRPO)
144
+
145
+ model.meta_controller = meta_controller
146
+
147
+ def environment_callable(model):
148
+ # return a fitness score
149
+ return 1.0
150
+
151
+ model.evolve(
152
+ num_generations = 10,
153
+ environment = environment_callable
154
+ )
155
+ ```
156
+
72
157
  ## Citations
73
158
 
74
159
  ```bibtex
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=iSKbCDp3UrWhZg7SIJFYNjdVQU56u-vqZarE6qCSX74,70
2
+ metacontroller/metacontroller.py,sha256=hOzMIeBwNZhrzpt6tnLahxuHJ4pPQ7JlEGBOxYHI_88,17875
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.42.dist-info/METADATA,sha256=f9KRrtFWHgZrx5HZYBGNrtfXrcfSOeZlRFfx7VYMOd0,6816
6
+ metacontroller_pytorch-0.0.42.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.42.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.42.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=bhgCqqM-dfysGrMtZYe2w87lRVkf8fETjxUCdjrnI8Q,17386
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=Ce5-O95_pLuWNA3aZTlKrTGbc5cemb61tBtJBdSiLx4,9843
4
- metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.40.dist-info/METADATA,sha256=0mvv1dI8pe3mNRLb8SJ3ZRWmBDvLgmz2wgI8j04xmQI,5114
6
- metacontroller_pytorch-0.0.40.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.40.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.40.dist-info/RECORD,,