metacontroller-pytorch 0.0.38__tar.gz → 0.0.41__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/PKG-INFO +86 -1
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/README.md +85 -0
- metacontroller_pytorch-0.0.41/babyai_env.py +41 -0
- metacontroller_pytorch-0.0.41/metacontroller/__init__.py +1 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/pyproject.toml +1 -1
- metacontroller_pytorch-0.0.41/test_babyai_e2e.sh +35 -0
- metacontroller_pytorch-0.0.41/train_babyai.py +314 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/train_behavior_clone_babyai.py +2 -2
- metacontroller_pytorch-0.0.38/metacontroller/__init__.py +0 -1
- metacontroller_pytorch-0.0.38/test_babyai_e2e.sh +0 -14
- metacontroller_pytorch-0.0.38/train_babyai.py +0 -140
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller.py +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/transformer_with_resnet.py +0 -0
- {metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/tests/test_metacontroller.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.41
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -69,6 +69,91 @@ $ pip install metacontroller-pytorch
|
|
|
69
69
|
|
|
70
70
|
- [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
|
|
71
71
|
|
|
72
|
+
## Usage
|
|
73
|
+
|
|
74
|
+
```python
|
|
75
|
+
import torch
|
|
76
|
+
from metacontroller import Transformer, MetaController
|
|
77
|
+
|
|
78
|
+
# 1. initialize model
|
|
79
|
+
|
|
80
|
+
model = Transformer(
|
|
81
|
+
dim = 512,
|
|
82
|
+
action_embed_readout = dict(num_discrete = 4),
|
|
83
|
+
state_embed_readout = dict(num_continuous = 384),
|
|
84
|
+
lower_body = dict(depth = 2),
|
|
85
|
+
upper_body = dict(depth = 2)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
state = torch.randn(2, 128, 384)
|
|
89
|
+
actions = torch.randint(0, 4, (2, 128))
|
|
90
|
+
|
|
91
|
+
# 2. behavioral cloning (BC)
|
|
92
|
+
|
|
93
|
+
state_loss, action_loss = model(state, actions)
|
|
94
|
+
(state_loss + action_loss).backward()
|
|
95
|
+
|
|
96
|
+
# 3. discovery phase
|
|
97
|
+
|
|
98
|
+
meta_controller = MetaController(
|
|
99
|
+
dim_model = 512,
|
|
100
|
+
dim_meta_controller = 256,
|
|
101
|
+
dim_latent = 128
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
action_recon_loss, kl_loss, switch_loss = model(
|
|
105
|
+
state,
|
|
106
|
+
actions,
|
|
107
|
+
meta_controller = meta_controller,
|
|
108
|
+
discovery_phase = True
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
(action_recon_loss + kl_loss + switch_loss).backward()
|
|
112
|
+
|
|
113
|
+
# 4. internal rl phase (GRPO)
|
|
114
|
+
|
|
115
|
+
# ... collect trajectories ...
|
|
116
|
+
|
|
117
|
+
logits, cache = model(
|
|
118
|
+
one_state,
|
|
119
|
+
past_action_id,
|
|
120
|
+
meta_controller = meta_controller,
|
|
121
|
+
return_cache = True
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
meta_output = cache.prev_hiddens.meta_controller
|
|
125
|
+
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
126
|
+
|
|
127
|
+
# ... calculate advantages ...
|
|
128
|
+
|
|
129
|
+
loss = meta_controller.policy_loss(
|
|
130
|
+
group_states,
|
|
131
|
+
group_old_log_probs,
|
|
132
|
+
group_latent_actions,
|
|
133
|
+
group_advantages,
|
|
134
|
+
group_switch_betas
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
loss.backward()
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
# 5. evolve (ES over GRPO)
|
|
144
|
+
|
|
145
|
+
model.meta_controller = meta_controller
|
|
146
|
+
|
|
147
|
+
def environment_callable(model):
|
|
148
|
+
# return a fitness score
|
|
149
|
+
return 1.0
|
|
150
|
+
|
|
151
|
+
model.evolve(
|
|
152
|
+
num_generations = 10,
|
|
153
|
+
environment = environment_callable
|
|
154
|
+
)
|
|
155
|
+
```
|
|
156
|
+
|
|
72
157
|
## Citations
|
|
73
158
|
|
|
74
159
|
```bibtex
|
|
@@ -16,6 +16,91 @@ $ pip install metacontroller-pytorch
|
|
|
16
16
|
|
|
17
17
|
- [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
|
|
18
18
|
|
|
19
|
+
## Usage
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
import torch
|
|
23
|
+
from metacontroller import Transformer, MetaController
|
|
24
|
+
|
|
25
|
+
# 1. initialize model
|
|
26
|
+
|
|
27
|
+
model = Transformer(
|
|
28
|
+
dim = 512,
|
|
29
|
+
action_embed_readout = dict(num_discrete = 4),
|
|
30
|
+
state_embed_readout = dict(num_continuous = 384),
|
|
31
|
+
lower_body = dict(depth = 2),
|
|
32
|
+
upper_body = dict(depth = 2)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
state = torch.randn(2, 128, 384)
|
|
36
|
+
actions = torch.randint(0, 4, (2, 128))
|
|
37
|
+
|
|
38
|
+
# 2. behavioral cloning (BC)
|
|
39
|
+
|
|
40
|
+
state_loss, action_loss = model(state, actions)
|
|
41
|
+
(state_loss + action_loss).backward()
|
|
42
|
+
|
|
43
|
+
# 3. discovery phase
|
|
44
|
+
|
|
45
|
+
meta_controller = MetaController(
|
|
46
|
+
dim_model = 512,
|
|
47
|
+
dim_meta_controller = 256,
|
|
48
|
+
dim_latent = 128
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
action_recon_loss, kl_loss, switch_loss = model(
|
|
52
|
+
state,
|
|
53
|
+
actions,
|
|
54
|
+
meta_controller = meta_controller,
|
|
55
|
+
discovery_phase = True
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
(action_recon_loss + kl_loss + switch_loss).backward()
|
|
59
|
+
|
|
60
|
+
# 4. internal rl phase (GRPO)
|
|
61
|
+
|
|
62
|
+
# ... collect trajectories ...
|
|
63
|
+
|
|
64
|
+
logits, cache = model(
|
|
65
|
+
one_state,
|
|
66
|
+
past_action_id,
|
|
67
|
+
meta_controller = meta_controller,
|
|
68
|
+
return_cache = True
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
meta_output = cache.prev_hiddens.meta_controller
|
|
72
|
+
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
73
|
+
|
|
74
|
+
# ... calculate advantages ...
|
|
75
|
+
|
|
76
|
+
loss = meta_controller.policy_loss(
|
|
77
|
+
group_states,
|
|
78
|
+
group_old_log_probs,
|
|
79
|
+
group_latent_actions,
|
|
80
|
+
group_advantages,
|
|
81
|
+
group_switch_betas
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
loss.backward()
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
|
|
88
|
+
|
|
89
|
+
```python
|
|
90
|
+
# 5. evolve (ES over GRPO)
|
|
91
|
+
|
|
92
|
+
model.meta_controller = meta_controller
|
|
93
|
+
|
|
94
|
+
def environment_callable(model):
|
|
95
|
+
# return a fitness score
|
|
96
|
+
return 1.0
|
|
97
|
+
|
|
98
|
+
model.evolve(
|
|
99
|
+
num_generations = 10,
|
|
100
|
+
environment = environment_callable
|
|
101
|
+
)
|
|
102
|
+
```
|
|
103
|
+
|
|
19
104
|
## Citations
|
|
20
105
|
|
|
21
106
|
```bibtex
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from shutil import rmtree
|
|
3
|
+
|
|
4
|
+
import gymnasium as gym
|
|
5
|
+
import minigrid
|
|
6
|
+
from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
|
|
7
|
+
|
|
8
|
+
# functions
|
|
9
|
+
|
|
10
|
+
def divisible_by(num, den):
|
|
11
|
+
return (num % den) == 0
|
|
12
|
+
|
|
13
|
+
# env creation
|
|
14
|
+
|
|
15
|
+
def create_env(
|
|
16
|
+
env_id,
|
|
17
|
+
render_mode = 'rgb_array',
|
|
18
|
+
video_folder = None,
|
|
19
|
+
render_every_eps = 1000
|
|
20
|
+
):
|
|
21
|
+
# register minigrid environments if needed
|
|
22
|
+
minigrid.register_minigrid_envs()
|
|
23
|
+
|
|
24
|
+
# environment
|
|
25
|
+
env = gym.make(env_id, render_mode = render_mode)
|
|
26
|
+
env = FullyObsWrapper(env)
|
|
27
|
+
env = SymbolicObsWrapper(env)
|
|
28
|
+
|
|
29
|
+
if video_folder is not None:
|
|
30
|
+
video_folder = Path(video_folder)
|
|
31
|
+
rmtree(video_folder, ignore_errors = True)
|
|
32
|
+
|
|
33
|
+
env = gym.wrappers.RecordVideo(
|
|
34
|
+
env = env,
|
|
35
|
+
video_folder = str(video_folder),
|
|
36
|
+
name_prefix = 'babyai',
|
|
37
|
+
episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
|
|
38
|
+
disable_logger = True
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return env
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from metacontroller.metacontroller import MetaController, Transformer
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
set -e
|
|
3
|
+
|
|
4
|
+
# 1. Gather trajectories
|
|
5
|
+
echo "Gathering trajectories..."
|
|
6
|
+
uv run gather_babyai_trajs.py \
|
|
7
|
+
--num_seeds 100 \
|
|
8
|
+
--num_episodes_per_seed 10 \
|
|
9
|
+
--num_steps 500 \
|
|
10
|
+
--output_dir end_to_end_trajectories \
|
|
11
|
+
--env_id BabyAI-MiniBossLevel-v0
|
|
12
|
+
|
|
13
|
+
# 2. Behavioral cloning
|
|
14
|
+
echo "Training behavioral cloning model..."
|
|
15
|
+
ACCELERATE_USE_CPU=true ACCELERATE_MIXED_PRECISION=no uv run train_behavior_clone_babyai.py \
|
|
16
|
+
--cloning_epochs 10 \
|
|
17
|
+
--discovery_epochs 10 \
|
|
18
|
+
--batch_size 256 \
|
|
19
|
+
--input_dir end_to_end_trajectories \
|
|
20
|
+
--env_id BabyAI-MiniBossLevel-v0 \
|
|
21
|
+
--checkpoint_path end_to_end_model.pt \
|
|
22
|
+
--use_resnet
|
|
23
|
+
|
|
24
|
+
# 3. Inference rollouts
|
|
25
|
+
echo "Running inference rollouts..."
|
|
26
|
+
uv run train_babyai.py \
|
|
27
|
+
--transformer_weights_path end_to_end_model.pt \
|
|
28
|
+
--meta_controller_weights_path meta_controller_discovery.pt \
|
|
29
|
+
--env_name BabyAI-MiniBossLevel-v0 \
|
|
30
|
+
--num_episodes 1000 \
|
|
31
|
+
--buffer_size 1000 \
|
|
32
|
+
--max_timesteps 100 \
|
|
33
|
+
--num_groups 16 \
|
|
34
|
+
--lr 1e-4 \
|
|
35
|
+
--use_resnet
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "fire",
|
|
4
|
+
# "gymnasium",
|
|
5
|
+
# "gymnasium[other]",
|
|
6
|
+
# "memmap-replay-buffer>=0.0.12",
|
|
7
|
+
# "metacontroller-pytorch",
|
|
8
|
+
# "minigrid",
|
|
9
|
+
# "tqdm"
|
|
10
|
+
# ]
|
|
11
|
+
# ///
|
|
12
|
+
|
|
13
|
+
from fire import Fire
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from functools import partial
|
|
16
|
+
from shutil import rmtree
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import cat, tensor, stack
|
|
21
|
+
from torch.optim import Adam
|
|
22
|
+
|
|
23
|
+
from einops import rearrange
|
|
24
|
+
|
|
25
|
+
from accelerate import Accelerator
|
|
26
|
+
|
|
27
|
+
from babyai_env import create_env
|
|
28
|
+
from memmap_replay_buffer import ReplayBuffer
|
|
29
|
+
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
30
|
+
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
31
|
+
|
|
32
|
+
# research entry point
|
|
33
|
+
|
|
34
|
+
def reward_shaping_fn(
|
|
35
|
+
cumulative_rewards: torch.Tensor,
|
|
36
|
+
all_rewards: torch.Tensor,
|
|
37
|
+
episode_lens: torch.Tensor
|
|
38
|
+
) -> torch.Tensor | None:
|
|
39
|
+
"""
|
|
40
|
+
researchers can modify this function to engineer rewards
|
|
41
|
+
or return None to reject the entire batch
|
|
42
|
+
|
|
43
|
+
cumulative_rewards: (num_episodes,)
|
|
44
|
+
all_rewards: (num_episodes, max_timesteps)
|
|
45
|
+
episode_lens: (num_episodes,)
|
|
46
|
+
"""
|
|
47
|
+
return cumulative_rewards
|
|
48
|
+
|
|
49
|
+
# helpers
|
|
50
|
+
|
|
51
|
+
def exists(v):
|
|
52
|
+
return v is not None
|
|
53
|
+
|
|
54
|
+
def default(v, d):
|
|
55
|
+
return v if exists(v) else d
|
|
56
|
+
|
|
57
|
+
# main
|
|
58
|
+
|
|
59
|
+
def main(
|
|
60
|
+
env_name: str = 'BabyAI-BossLevel-v0',
|
|
61
|
+
num_episodes: int = int(10e6),
|
|
62
|
+
max_timesteps: int = 500,
|
|
63
|
+
buffer_size: int = 5_000,
|
|
64
|
+
render_every_eps: int = 1_000,
|
|
65
|
+
video_folder: str = './recordings',
|
|
66
|
+
seed: int | None = None,
|
|
67
|
+
transformer_weights_path: str | None = None,
|
|
68
|
+
meta_controller_weights_path: str | None = None,
|
|
69
|
+
output_meta_controller_path: str = 'metacontroller_rl_trained.pt',
|
|
70
|
+
use_resnet: bool = False,
|
|
71
|
+
lr: float = 1e-4,
|
|
72
|
+
num_groups: int = 16,
|
|
73
|
+
max_grad_norm: float = 1.0,
|
|
74
|
+
use_wandb: bool = False,
|
|
75
|
+
wandb_project: str = 'metacontroller-babyai-rl'
|
|
76
|
+
):
|
|
77
|
+
# accelerator
|
|
78
|
+
|
|
79
|
+
accelerator = Accelerator(log_with = 'wandb' if use_wandb else None)
|
|
80
|
+
|
|
81
|
+
if use_wandb:
|
|
82
|
+
accelerator.init_trackers(wandb_project)
|
|
83
|
+
|
|
84
|
+
# environment
|
|
85
|
+
|
|
86
|
+
env = create_env(
|
|
87
|
+
env_name,
|
|
88
|
+
render_mode = 'rgb_array',
|
|
89
|
+
video_folder = video_folder,
|
|
90
|
+
render_every_eps = render_every_eps
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# load models
|
|
94
|
+
|
|
95
|
+
model = None
|
|
96
|
+
if exists(transformer_weights_path):
|
|
97
|
+
weights_path = Path(transformer_weights_path)
|
|
98
|
+
assert weights_path.exists(), f"transformer weights not found at {weights_path}"
|
|
99
|
+
|
|
100
|
+
transformer_klass = TransformerWithResnet if use_resnet else Transformer
|
|
101
|
+
model = transformer_klass.init_and_load(str(weights_path), strict = False)
|
|
102
|
+
model.eval()
|
|
103
|
+
|
|
104
|
+
meta_controller = None
|
|
105
|
+
if exists(meta_controller_weights_path):
|
|
106
|
+
weights_path = Path(meta_controller_weights_path)
|
|
107
|
+
assert weights_path.exists(), f"meta controller weights not found at {weights_path}"
|
|
108
|
+
meta_controller = MetaController.init_and_load(str(weights_path), strict = False)
|
|
109
|
+
meta_controller.eval()
|
|
110
|
+
|
|
111
|
+
meta_controller = default(meta_controller, getattr(model, 'meta_controller', None))
|
|
112
|
+
assert exists(meta_controller), "MetaController must be present for reinforcement learning"
|
|
113
|
+
|
|
114
|
+
# optimizer
|
|
115
|
+
|
|
116
|
+
optim = Adam(meta_controller.internal_rl_parameters(), lr = lr)
|
|
117
|
+
|
|
118
|
+
# prepare
|
|
119
|
+
|
|
120
|
+
model, meta_controller, optim = accelerator.prepare(model, meta_controller, optim)
|
|
121
|
+
|
|
122
|
+
unwrapped_model = accelerator.unwrap_model(model)
|
|
123
|
+
unwrapped_meta_controller = accelerator.unwrap_model(meta_controller)
|
|
124
|
+
|
|
125
|
+
# replay buffer
|
|
126
|
+
|
|
127
|
+
replay_buffer = ReplayBuffer(
|
|
128
|
+
'./replay-data',
|
|
129
|
+
max_episodes = buffer_size,
|
|
130
|
+
max_timesteps = max_timesteps + 1,
|
|
131
|
+
fields = meta_controller.replay_buffer_field_dict,
|
|
132
|
+
meta_fields = dict(advantages = 'float'),
|
|
133
|
+
overwrite = True,
|
|
134
|
+
circular = True
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# rollouts
|
|
138
|
+
|
|
139
|
+
num_batch_updates = num_episodes // num_groups
|
|
140
|
+
|
|
141
|
+
pbar = tqdm(range(num_batch_updates), desc = 'training')
|
|
142
|
+
|
|
143
|
+
for _ in pbar:
|
|
144
|
+
|
|
145
|
+
all_episodes = []
|
|
146
|
+
all_cumulative_rewards = []
|
|
147
|
+
all_step_rewards = []
|
|
148
|
+
all_episode_lens = []
|
|
149
|
+
|
|
150
|
+
group_seed = default(seed, torch.randint(0, 1000000, (1,)).item())
|
|
151
|
+
|
|
152
|
+
for _ in range(num_groups):
|
|
153
|
+
|
|
154
|
+
state, *_ = env.reset(seed = group_seed)
|
|
155
|
+
|
|
156
|
+
cache = None
|
|
157
|
+
past_action_id = None
|
|
158
|
+
|
|
159
|
+
states = []
|
|
160
|
+
log_probs = []
|
|
161
|
+
switch_betas = []
|
|
162
|
+
latent_actions = []
|
|
163
|
+
|
|
164
|
+
total_reward = 0.
|
|
165
|
+
step_rewards = []
|
|
166
|
+
episode_len = max_timesteps
|
|
167
|
+
|
|
168
|
+
for step in range(max_timesteps):
|
|
169
|
+
|
|
170
|
+
image = state['image']
|
|
171
|
+
image_tensor = torch.from_numpy(image).float().to(accelerator.device)
|
|
172
|
+
|
|
173
|
+
if use_resnet:
|
|
174
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 h w c')
|
|
175
|
+
image_tensor = model.visual_encode(image_tensor)
|
|
176
|
+
else:
|
|
177
|
+
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
|
|
178
|
+
|
|
179
|
+
if torch.is_tensor(past_action_id):
|
|
180
|
+
past_action_id = past_action_id.long()
|
|
181
|
+
|
|
182
|
+
with torch.no_grad():
|
|
183
|
+
logits, cache = unwrapped_model(
|
|
184
|
+
image_tensor,
|
|
185
|
+
past_action_id,
|
|
186
|
+
meta_controller = unwrapped_meta_controller,
|
|
187
|
+
return_cache = True,
|
|
188
|
+
return_raw_action_dist = True,
|
|
189
|
+
cache = cache
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
action = unwrapped_model.action_readout.sample(logits)
|
|
193
|
+
past_action_id = action
|
|
194
|
+
action = action.squeeze()
|
|
195
|
+
|
|
196
|
+
# GRPO collection
|
|
197
|
+
|
|
198
|
+
meta_output = cache.prev_hiddens.meta_controller
|
|
199
|
+
old_log_probs = unwrapped_meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
200
|
+
|
|
201
|
+
states.append(meta_output.input_residual_stream)
|
|
202
|
+
log_probs.append(old_log_probs)
|
|
203
|
+
switch_betas.append(meta_output.switch_beta)
|
|
204
|
+
latent_actions.append(meta_output.actions)
|
|
205
|
+
|
|
206
|
+
next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
|
|
207
|
+
|
|
208
|
+
total_reward += reward
|
|
209
|
+
step_rewards.append(reward)
|
|
210
|
+
done = terminated or truncated
|
|
211
|
+
|
|
212
|
+
if done:
|
|
213
|
+
episode_len = step + 1
|
|
214
|
+
break
|
|
215
|
+
|
|
216
|
+
state = next_state
|
|
217
|
+
|
|
218
|
+
# store episode
|
|
219
|
+
|
|
220
|
+
all_episodes.append((
|
|
221
|
+
cat(states, dim = 1).squeeze(0),
|
|
222
|
+
cat(log_probs, dim = 1).squeeze(0),
|
|
223
|
+
cat(switch_betas, dim = 1).squeeze(0),
|
|
224
|
+
cat(latent_actions, dim = 1).squeeze(0)
|
|
225
|
+
))
|
|
226
|
+
|
|
227
|
+
all_cumulative_rewards.append(tensor(total_reward))
|
|
228
|
+
all_step_rewards.append(tensor(step_rewards))
|
|
229
|
+
all_episode_lens.append(episode_len)
|
|
230
|
+
|
|
231
|
+
# compute advantages
|
|
232
|
+
|
|
233
|
+
cumulative_rewards = stack(all_cumulative_rewards)
|
|
234
|
+
episode_lens = tensor(all_episode_lens)
|
|
235
|
+
|
|
236
|
+
# pad step rewards
|
|
237
|
+
|
|
238
|
+
max_len = max(all_episode_lens)
|
|
239
|
+
padded_step_rewards = torch.zeros(num_episodes, max_len)
|
|
240
|
+
|
|
241
|
+
for i, (rewards, length) in enumerate(zip(all_step_rewards, all_episode_lens)):
|
|
242
|
+
padded_step_rewards[i, :length] = rewards
|
|
243
|
+
|
|
244
|
+
# reward shaping hook
|
|
245
|
+
|
|
246
|
+
shaped_rewards = reward_shaping_fn(cumulative_rewards, padded_step_rewards, episode_lens)
|
|
247
|
+
|
|
248
|
+
if not exists(shaped_rewards):
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
group_advantages = z_score(shaped_rewards)
|
|
252
|
+
|
|
253
|
+
group_states, group_log_probs, group_switch_betas, group_latent_actions = zip(*all_episodes)
|
|
254
|
+
|
|
255
|
+
for states, log_probs, switch_betas, latent_actions, advantages in zip(group_states, group_log_probs, group_switch_betas, group_latent_actions, group_advantages):
|
|
256
|
+
replay_buffer.store_episode(
|
|
257
|
+
states = states,
|
|
258
|
+
log_probs = log_probs,
|
|
259
|
+
switch_betas = switch_betas,
|
|
260
|
+
latent_actions = latent_actions,
|
|
261
|
+
advantages = advantages
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# learn
|
|
265
|
+
|
|
266
|
+
if len(replay_buffer) >= buffer_size:
|
|
267
|
+
dl = replay_buffer.dataloader(batch_size = num_groups)
|
|
268
|
+
dl = accelerator.prepare(dl)
|
|
269
|
+
|
|
270
|
+
meta_controller.train()
|
|
271
|
+
|
|
272
|
+
batch = next(iter(dl))
|
|
273
|
+
|
|
274
|
+
loss = meta_controller.policy_loss(
|
|
275
|
+
batch['states'],
|
|
276
|
+
batch['log_probs'],
|
|
277
|
+
batch['latent_actions'],
|
|
278
|
+
batch['advantages'],
|
|
279
|
+
batch['switch_betas'] == 1.,
|
|
280
|
+
episode_lens = batch['_lens']
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
accelerator.backward(loss)
|
|
284
|
+
|
|
285
|
+
grad_norm = accelerator.clip_grad_norm_(meta_controller.parameters(), max_grad_norm)
|
|
286
|
+
|
|
287
|
+
optim.step()
|
|
288
|
+
optim.zero_grad()
|
|
289
|
+
|
|
290
|
+
meta_controller.eval()
|
|
291
|
+
|
|
292
|
+
pbar.set_postfix(
|
|
293
|
+
loss = f'{loss.item():.4f}',
|
|
294
|
+
grad_norm = f'{grad_norm.item():.4f}',
|
|
295
|
+
reward = f'{cumulative_rewards.mean().item():.4f}'
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
accelerator.log({
|
|
299
|
+
'loss': loss.item(),
|
|
300
|
+
'grad_norm': grad_norm.item()
|
|
301
|
+
})
|
|
302
|
+
|
|
303
|
+
accelerator.print(f'loss: {loss.item():.4f}, grad_norm: {grad_norm.item():.4f}')
|
|
304
|
+
|
|
305
|
+
env.close()
|
|
306
|
+
|
|
307
|
+
# save
|
|
308
|
+
|
|
309
|
+
if exists(output_meta_controller_path):
|
|
310
|
+
unwrapped_meta_controller.save(output_meta_controller_path)
|
|
311
|
+
accelerator.print(f'MetaController weights saved to {output_meta_controller_path}')
|
|
312
|
+
|
|
313
|
+
if __name__ == '__main__':
|
|
314
|
+
Fire(main)
|
{metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/train_behavior_clone_babyai.py
RENAMED
|
@@ -92,8 +92,8 @@ def train(
|
|
|
92
92
|
else: state_dim = int(torch.tensor(state_shape).prod().item())
|
|
93
93
|
|
|
94
94
|
# deduce num_actions from the environment
|
|
95
|
-
|
|
96
|
-
temp_env =
|
|
95
|
+
from babyai_env import create_env
|
|
96
|
+
temp_env = create_env(env_id)
|
|
97
97
|
num_actions = int(temp_env.action_space.n)
|
|
98
98
|
temp_env.close()
|
|
99
99
|
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from metacontroller.metacontroller import MetaController
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
#!/bin/bash
|
|
2
|
-
set -e
|
|
3
|
-
|
|
4
|
-
# 1. Gather trajectories
|
|
5
|
-
echo "Gathering trajectories..."
|
|
6
|
-
uv run gather_babyai_trajs.py --num_seeds 1000 --num_episodes_per_seed 100 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
|
|
7
|
-
|
|
8
|
-
# 2. Behavioral cloning
|
|
9
|
-
echo "Training behavioral cloning model..."
|
|
10
|
-
uv run train_behavior_clone_babyai.py --cloning_epochs 10 --discovery_epochs 10 --batch_size 256 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt --use_resnet
|
|
11
|
-
|
|
12
|
-
# 3. Inference rollouts
|
|
13
|
-
echo "Running inference rollouts..."
|
|
14
|
-
uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
|
|
@@ -1,140 +0,0 @@
|
|
|
1
|
-
# /// script
|
|
2
|
-
# dependencies = [
|
|
3
|
-
# "fire",
|
|
4
|
-
# "gymnasium",
|
|
5
|
-
# "gymnasium[other]",
|
|
6
|
-
# "memmap-replay-buffer>=0.0.12",
|
|
7
|
-
# "metacontroller-pytorch",
|
|
8
|
-
# "minigrid",
|
|
9
|
-
# "tqdm"
|
|
10
|
-
# ]
|
|
11
|
-
# ///
|
|
12
|
-
|
|
13
|
-
from fire import Fire
|
|
14
|
-
from tqdm import tqdm
|
|
15
|
-
from shutil import rmtree
|
|
16
|
-
from pathlib import Path
|
|
17
|
-
|
|
18
|
-
import torch
|
|
19
|
-
from einops import rearrange
|
|
20
|
-
|
|
21
|
-
import gymnasium as gym
|
|
22
|
-
import minigrid
|
|
23
|
-
from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
|
|
24
|
-
|
|
25
|
-
from memmap_replay_buffer import ReplayBuffer
|
|
26
|
-
from metacontroller.metacontroller import Transformer
|
|
27
|
-
|
|
28
|
-
# functions
|
|
29
|
-
|
|
30
|
-
def exists(v):
|
|
31
|
-
return v is not None
|
|
32
|
-
|
|
33
|
-
def default(v, d):
|
|
34
|
-
return v if exists(v) else d
|
|
35
|
-
|
|
36
|
-
def divisible_by(num, den):
|
|
37
|
-
return (num % den) == 0
|
|
38
|
-
|
|
39
|
-
# main
|
|
40
|
-
|
|
41
|
-
def main(
|
|
42
|
-
env_name = 'BabyAI-BossLevel-v0',
|
|
43
|
-
num_episodes = int(10e6),
|
|
44
|
-
max_timesteps = 500,
|
|
45
|
-
buffer_size = 5_000,
|
|
46
|
-
render_every_eps = 1_000,
|
|
47
|
-
video_folder = './recordings',
|
|
48
|
-
seed = None,
|
|
49
|
-
weights_path = None
|
|
50
|
-
):
|
|
51
|
-
|
|
52
|
-
# environment
|
|
53
|
-
|
|
54
|
-
env = gym.make(env_name, render_mode = 'rgb_array')
|
|
55
|
-
env = FullyObsWrapper(env.unwrapped)
|
|
56
|
-
env = SymbolicObsWrapper(env.unwrapped)
|
|
57
|
-
|
|
58
|
-
rmtree(video_folder, ignore_errors = True)
|
|
59
|
-
|
|
60
|
-
env = gym.wrappers.RecordVideo(
|
|
61
|
-
env = env,
|
|
62
|
-
video_folder = video_folder,
|
|
63
|
-
name_prefix = 'babyai',
|
|
64
|
-
episode_trigger = lambda eps_num: divisible_by(eps_num, render_every_eps),
|
|
65
|
-
disable_logger = True
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
# maybe load model
|
|
69
|
-
|
|
70
|
-
model = None
|
|
71
|
-
if exists(weights_path):
|
|
72
|
-
weights_path = Path(weights_path)
|
|
73
|
-
assert weights_path.exists(), f"weights not found at {weights_path}"
|
|
74
|
-
model = Transformer.init_and_load(str(weights_path), strict = False)
|
|
75
|
-
model.eval()
|
|
76
|
-
|
|
77
|
-
# replay
|
|
78
|
-
|
|
79
|
-
replay_buffer = ReplayBuffer(
|
|
80
|
-
'./replay-data',
|
|
81
|
-
max_episodes = buffer_size,
|
|
82
|
-
max_timesteps = max_timesteps + 1,
|
|
83
|
-
fields = dict(
|
|
84
|
-
action = 'int',
|
|
85
|
-
state_image = ('float', (7, 7, 3)),
|
|
86
|
-
state_direction = 'int'
|
|
87
|
-
),
|
|
88
|
-
overwrite = True,
|
|
89
|
-
circular = True
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
# rollouts
|
|
93
|
-
|
|
94
|
-
for _ in tqdm(range(num_episodes)):
|
|
95
|
-
|
|
96
|
-
state, *_ = env.reset(seed = seed)
|
|
97
|
-
|
|
98
|
-
cache = None
|
|
99
|
-
past_action_id = None
|
|
100
|
-
|
|
101
|
-
for _ in range(max_timesteps):
|
|
102
|
-
|
|
103
|
-
if exists(model):
|
|
104
|
-
# preprocess state
|
|
105
|
-
# assume state is a dict with 'image'
|
|
106
|
-
image = state['image']
|
|
107
|
-
image_tensor = torch.from_numpy(image).float()
|
|
108
|
-
image_tensor = rearrange(image_tensor, 'h w c -> 1 1 (h w c)')
|
|
109
|
-
|
|
110
|
-
if exists(past_action_id) and torch.is_tensor(past_action_id):
|
|
111
|
-
past_action_id = past_action_id.long()
|
|
112
|
-
|
|
113
|
-
with torch.no_grad():
|
|
114
|
-
logits, cache = model(
|
|
115
|
-
image_tensor,
|
|
116
|
-
past_action_id,
|
|
117
|
-
return_cache = True,
|
|
118
|
-
return_raw_action_dist = True,
|
|
119
|
-
cache = cache
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
action = model.action_readout.sample(logits)
|
|
123
|
-
past_action_id = action
|
|
124
|
-
action = action.squeeze()
|
|
125
|
-
else:
|
|
126
|
-
action = torch.randint(0, 7, ())
|
|
127
|
-
|
|
128
|
-
next_state, reward, terminated, truncated, *_ = env.step(action.cpu().numpy())
|
|
129
|
-
|
|
130
|
-
done = terminated or truncated
|
|
131
|
-
|
|
132
|
-
if done:
|
|
133
|
-
break
|
|
134
|
-
|
|
135
|
-
state = next_state
|
|
136
|
-
|
|
137
|
-
env.close()
|
|
138
|
-
|
|
139
|
-
if __name__ == '__main__':
|
|
140
|
-
Fire(main)
|
{metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.38 → metacontroller_pytorch-0.0.41}/tests/test_metacontroller.py
RENAMED
|
File without changes
|