metacontroller-pytorch 0.0.41__tar.gz → 0.0.42__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/PKG-INFO +2 -2
  2. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/README.md +1 -1
  3. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller.py +18 -0
  4. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/pyproject.toml +1 -1
  5. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/tests/test_metacontroller.py +7 -15
  6. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/train_babyai.py +6 -7
  7. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.github/workflows/python-publish.yml +0 -0
  8. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.github/workflows/test.yml +0 -0
  9. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.gitignore +0 -0
  10. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/LICENSE +0 -0
  11. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/babyai_env.py +0 -0
  12. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/fig1.png +0 -0
  13. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/gather_babyai_trajs.py +0 -0
  14. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/__init__.py +0 -0
  15. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  16. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/transformer_with_resnet.py +0 -0
  17. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/test_babyai_e2e.sh +0 -0
  18. {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.41
3
+ Version: 0.0.42
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img src="./fig1.png" width="400px"></img>
55
55
 
56
- ## metacontroller (wip)
56
+ ## metacontroller
57
57
 
58
58
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
59
59
 
@@ -1,6 +1,6 @@
1
1
  <img src="./fig1.png" width="400px"></img>
2
2
 
3
- ## metacontroller (wip)
3
+ ## metacontroller
4
4
 
5
5
  Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
6
6
 
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ GRPOOutput = namedtuple('GRPOOutput', (
70
+ 'state',
71
+ 'action',
72
+ 'log_prob',
73
+ 'switch_beta'
74
+ ))
75
+
69
76
  def z_score(t, eps = 1e-8):
70
77
  return (t - t.mean()) / (t.std() + eps)
71
78
 
@@ -107,6 +114,17 @@ def policy_loss(
107
114
 
108
115
  return masked_mean(losses, mask)
109
116
 
117
+ def extract_grpo_data(meta_controller, transformer_output):
118
+ meta_output = transformer_output.prev_hiddens.meta_controller
119
+
120
+ state = meta_output.input_residual_stream
121
+ action = meta_output.actions
122
+ switch_beta = meta_output.switch_beta
123
+
124
+ log_prob = meta_controller.log_prob(meta_output.action_dist, action)
125
+
126
+ return GRPOOutput(state, action, log_prob, switch_beta)
127
+
110
128
  @save_load()
111
129
  class MetaController(Module):
112
130
  def __init__(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.41"
3
+ version = "0.0.42"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -7,7 +7,7 @@ from functools import partial
7
7
 
8
8
  import torch
9
9
  from torch import cat
10
- from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
10
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
11
11
  from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
12
12
 
13
13
  from memmap_replay_buffer import ReplayBuffer
@@ -109,10 +109,7 @@ def test_metacontroller(
109
109
  cache = None
110
110
  past_action_id = None
111
111
 
112
- states = []
113
- log_probs = []
114
- switch_betas = []
115
- latent_actions = []
112
+ grpo_data_list = []
116
113
 
117
114
  for one_state in subset_state.unbind(dim = 1):
118
115
  one_state = rearrange(one_state, 'b d -> b 1 d')
@@ -121,24 +118,19 @@ def test_metacontroller(
121
118
 
122
119
  past_action_id = model.action_readout.sample(logits)
123
120
 
124
- # get log prob from meta controller latent actions
121
+ # extract grpo data and store
125
122
 
126
- meta_output = cache.prev_hiddens.meta_controller
127
-
128
- old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
129
-
130
- states.append(meta_output.input_residual_stream)
131
- log_probs.append(old_log_probs)
132
- switch_betas.append(meta_output.switch_beta)
133
- latent_actions.append(meta_output.actions)
123
+ grpo_data_list.append(extract_grpo_data(meta_controller, cache))
134
124
 
135
125
  # accumulate across time for the episode data
136
126
 
127
+ states, actions, log_probs, switch_betas = zip(*grpo_data_list)
128
+
137
129
  all_episodes.append((
138
130
  cat(states, dim = 1),
139
131
  cat(log_probs, dim = 1),
140
132
  cat(switch_betas, dim = 1),
141
- cat(latent_actions, dim = 1)
133
+ cat(actions, dim = 1)
142
134
  ))
143
135
 
144
136
  all_rewards.append(torch.randn(1))
@@ -26,7 +26,7 @@ from accelerate import Accelerator
26
26
 
27
27
  from babyai_env import create_env
28
28
  from memmap_replay_buffer import ReplayBuffer
29
- from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
29
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
30
30
  from metacontroller.transformer_with_resnet import TransformerWithResnet
31
31
 
32
32
  # research entry point
@@ -195,13 +195,12 @@ def main(
195
195
 
196
196
  # GRPO collection
197
197
 
198
- meta_output = cache.prev_hiddens.meta_controller
199
- old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
198
+ grpo_data = extract_grpo_data(unwrapped_meta_controller, cache)
200
199
 
201
- states.append(meta_output.input_residual_stream)
202
- log_probs.append(old_log_probs)
203
- switch_betas.append(meta_output.switch_beta)
204
- latent_actions.append(meta_output.actions)
200
+ states.append(grpo_data.state)
201
+ log_probs.append(grpo_data.log_prob)
202
+ switch_betas.append(grpo_data.switch_beta)
203
+ latent_actions.append(grpo_data.action)
205
204
 
206
205
  next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
207
206