metacontroller-pytorch 0.0.41__tar.gz → 0.0.42__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/PKG-INFO +2 -2
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/README.md +1 -1
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller.py +18 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/tests/test_metacontroller.py +7 -15
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/train_babyai.py +6 -7
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/babyai_env.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/train_behavior_clone_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.42
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
|
|
|
53
53
|
|
|
54
54
|
<img src="./fig1.png" width="400px"></img>
|
|
55
55
|
|
|
56
|
-
## metacontroller
|
|
56
|
+
## metacontroller
|
|
57
57
|
|
|
58
58
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
59
59
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
<img src="./fig1.png" width="400px"></img>
|
|
2
2
|
|
|
3
|
-
## metacontroller
|
|
3
|
+
## metacontroller
|
|
4
4
|
|
|
5
5
|
Implementation of the MetaController proposed in [Emergent temporal abstractions in autoregressive models enable hierarchical reinforcement learning](https://arxiv.org/abs/2512.20605)
|
|
6
6
|
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/metacontroller/metacontroller.py
RENAMED
|
@@ -66,6 +66,13 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
66
66
|
'switch_loss'
|
|
67
67
|
))
|
|
68
68
|
|
|
69
|
+
GRPOOutput = namedtuple('GRPOOutput', (
|
|
70
|
+
'state',
|
|
71
|
+
'action',
|
|
72
|
+
'log_prob',
|
|
73
|
+
'switch_beta'
|
|
74
|
+
))
|
|
75
|
+
|
|
69
76
|
def z_score(t, eps = 1e-8):
|
|
70
77
|
return (t - t.mean()) / (t.std() + eps)
|
|
71
78
|
|
|
@@ -107,6 +114,17 @@ def policy_loss(
|
|
|
107
114
|
|
|
108
115
|
return masked_mean(losses, mask)
|
|
109
116
|
|
|
117
|
+
def extract_grpo_data(meta_controller, transformer_output):
|
|
118
|
+
meta_output = transformer_output.prev_hiddens.meta_controller
|
|
119
|
+
|
|
120
|
+
state = meta_output.input_residual_stream
|
|
121
|
+
action = meta_output.actions
|
|
122
|
+
switch_beta = meta_output.switch_beta
|
|
123
|
+
|
|
124
|
+
log_prob = meta_controller.log_prob(meta_output.action_dist, action)
|
|
125
|
+
|
|
126
|
+
return GRPOOutput(state, action, log_prob, switch_beta)
|
|
127
|
+
|
|
110
128
|
@save_load()
|
|
111
129
|
class MetaController(Module):
|
|
112
130
|
def __init__(
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/tests/test_metacontroller.py
RENAMED
|
@@ -7,7 +7,7 @@ from functools import partial
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
from torch import cat
|
|
10
|
-
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
10
|
+
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
|
|
11
11
|
from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
|
|
12
12
|
|
|
13
13
|
from memmap_replay_buffer import ReplayBuffer
|
|
@@ -109,10 +109,7 @@ def test_metacontroller(
|
|
|
109
109
|
cache = None
|
|
110
110
|
past_action_id = None
|
|
111
111
|
|
|
112
|
-
|
|
113
|
-
log_probs = []
|
|
114
|
-
switch_betas = []
|
|
115
|
-
latent_actions = []
|
|
112
|
+
grpo_data_list = []
|
|
116
113
|
|
|
117
114
|
for one_state in subset_state.unbind(dim = 1):
|
|
118
115
|
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
@@ -121,24 +118,19 @@ def test_metacontroller(
|
|
|
121
118
|
|
|
122
119
|
past_action_id = model.action_readout.sample(logits)
|
|
123
120
|
|
|
124
|
-
#
|
|
121
|
+
# extract grpo data and store
|
|
125
122
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
129
|
-
|
|
130
|
-
states.append(meta_output.input_residual_stream)
|
|
131
|
-
log_probs.append(old_log_probs)
|
|
132
|
-
switch_betas.append(meta_output.switch_beta)
|
|
133
|
-
latent_actions.append(meta_output.actions)
|
|
123
|
+
grpo_data_list.append(extract_grpo_data(meta_controller, cache))
|
|
134
124
|
|
|
135
125
|
# accumulate across time for the episode data
|
|
136
126
|
|
|
127
|
+
states, actions, log_probs, switch_betas = zip(*grpo_data_list)
|
|
128
|
+
|
|
137
129
|
all_episodes.append((
|
|
138
130
|
cat(states, dim = 1),
|
|
139
131
|
cat(log_probs, dim = 1),
|
|
140
132
|
cat(switch_betas, dim = 1),
|
|
141
|
-
cat(
|
|
133
|
+
cat(actions, dim = 1)
|
|
142
134
|
))
|
|
143
135
|
|
|
144
136
|
all_rewards.append(torch.randn(1))
|
|
@@ -26,7 +26,7 @@ from accelerate import Accelerator
|
|
|
26
26
|
|
|
27
27
|
from babyai_env import create_env
|
|
28
28
|
from memmap_replay_buffer import ReplayBuffer
|
|
29
|
-
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
29
|
+
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score, extract_grpo_data
|
|
30
30
|
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
31
31
|
|
|
32
32
|
# research entry point
|
|
@@ -195,13 +195,12 @@ def main(
|
|
|
195
195
|
|
|
196
196
|
# GRPO collection
|
|
197
197
|
|
|
198
|
-
|
|
199
|
-
old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
198
|
+
grpo_data = extract_grpo_data(unwrapped_meta_controller, cache)
|
|
200
199
|
|
|
201
|
-
states.append(
|
|
202
|
-
log_probs.append(
|
|
203
|
-
switch_betas.append(
|
|
204
|
-
latent_actions.append(
|
|
200
|
+
states.append(grpo_data.state)
|
|
201
|
+
log_probs.append(grpo_data.log_prob)
|
|
202
|
+
switch_betas.append(grpo_data.switch_beta)
|
|
203
|
+
latent_actions.append(grpo_data.action)
|
|
205
204
|
|
|
206
205
|
next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
|
|
207
206
|
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.41 → metacontroller_pytorch-0.0.42}/train_behavior_clone_babyai.py
RENAMED
|
File without changes
|