metacontroller-pytorch 0.0.28__tar.gz → 0.0.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/PKG-INFO +5 -1
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/README.md +4 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/gather_babyai_trajs.py +43 -2
- metacontroller_pytorch-0.0.29/metacontroller/metacontroller_with_resnet.py +250 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/pyproject.toml +1 -1
- metacontroller_pytorch-0.0.29/test_babyai_e2e.sh +14 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/train_behavior_clone_babyai.py +25 -18
- metacontroller_pytorch-0.0.28/test_babyai_e2e.sh +0 -14
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/metacontroller/metacontroller.py +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/tests/test_metacontroller.py +0 -0
- {metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/train_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.29
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -67,6 +67,8 @@ $ pip install metacontroller-pytorch
|
|
|
67
67
|
|
|
68
68
|
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
69
69
|
|
|
70
|
+
- [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
|
|
71
|
+
|
|
70
72
|
## Citations
|
|
71
73
|
|
|
72
74
|
```bibtex
|
|
@@ -103,3 +105,5 @@ $ pip install metacontroller-pytorch
|
|
|
103
105
|
url = {https://arxiv.org/abs/2510.17558},
|
|
104
106
|
}
|
|
105
107
|
```
|
|
108
|
+
|
|
109
|
+
*Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
|
|
@@ -14,6 +14,8 @@ $ pip install metacontroller-pytorch
|
|
|
14
14
|
|
|
15
15
|
- [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
|
|
16
16
|
|
|
17
|
+
- [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
|
|
18
|
+
|
|
17
19
|
## Citations
|
|
18
20
|
|
|
19
21
|
```bibtex
|
|
@@ -50,3 +52,5 @@ $ pip install metacontroller-pytorch
|
|
|
50
52
|
url = {https://arxiv.org/abs/2510.17558},
|
|
51
53
|
}
|
|
52
54
|
```
|
|
55
|
+
|
|
56
|
+
*Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
|
|
@@ -31,8 +31,13 @@ import gymnasium as gym
|
|
|
31
31
|
from minigrid.utils.baby_ai_bot import BabyAIBot
|
|
32
32
|
from minigrid.wrappers import FullyObsWrapper, SymbolicObsWrapper
|
|
33
33
|
|
|
34
|
+
from gymnasium import spaces
|
|
35
|
+
from gymnasium.core import ObservationWrapper
|
|
36
|
+
from minigrid.core.constants import OBJECT_TO_IDX
|
|
37
|
+
|
|
34
38
|
from memmap_replay_buffer import ReplayBuffer
|
|
35
39
|
|
|
40
|
+
|
|
36
41
|
# helpers
|
|
37
42
|
|
|
38
43
|
def exists(val):
|
|
@@ -41,6 +46,40 @@ def exists(val):
|
|
|
41
46
|
def sample(prob):
|
|
42
47
|
return random.random() < prob
|
|
43
48
|
|
|
49
|
+
# wrapper, necessarily modified to allow for both rgb obs (policy) and symbolic obs (bot)
|
|
50
|
+
|
|
51
|
+
class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
52
|
+
"""
|
|
53
|
+
Wrapper to use partially observable RGB image as observation.
|
|
54
|
+
This can be used to have the agent to solve the gridworld in pixel space.
|
|
55
|
+
"""
|
|
56
|
+
def __init__(self, env, tile_size=1):
|
|
57
|
+
super().__init__(env)
|
|
58
|
+
|
|
59
|
+
# Rendering attributes for observations
|
|
60
|
+
self.tile_size = tile_size
|
|
61
|
+
|
|
62
|
+
symbolic_image_space = self.observation_space["image"]
|
|
63
|
+
|
|
64
|
+
obs_shape = env.observation_space.spaces["image"].shape
|
|
65
|
+
new_image_space = spaces.Box(
|
|
66
|
+
low=0,
|
|
67
|
+
high=255,
|
|
68
|
+
shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
|
|
69
|
+
dtype="uint8",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self.observation_space = spaces.Dict(
|
|
73
|
+
{**self.observation_space.spaces, "image": symbolic_image_space, "rgb_image": new_image_space}
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def observation(self, obs):
|
|
77
|
+
rgb_img_partial = self.unwrapped.get_frame(
|
|
78
|
+
tile_size=self.tile_size, agent_pov=True
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return {**obs, "rgb_image": rgb_img_partial}
|
|
82
|
+
|
|
44
83
|
# agent
|
|
45
84
|
|
|
46
85
|
class BabyAIBotEpsilonGreedy:
|
|
@@ -72,6 +111,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
|
|
|
72
111
|
env = gym.make(env_id, render_mode="rgb_array", highlight=False)
|
|
73
112
|
env = FullyObsWrapper(env.unwrapped)
|
|
74
113
|
env = SymbolicObsWrapper(env.unwrapped)
|
|
114
|
+
env = RGBImgPartialObsWrapper(env.unwrapped)
|
|
75
115
|
|
|
76
116
|
try:
|
|
77
117
|
state_obs, _ = env.reset(seed=seed)
|
|
@@ -88,7 +128,7 @@ def collect_single_episode(env_id, seed, num_steps, random_action_prob, state_sh
|
|
|
88
128
|
env.close()
|
|
89
129
|
return None, None, False, 0
|
|
90
130
|
|
|
91
|
-
episode_state[_step] = state_obs["
|
|
131
|
+
episode_state[_step] = state_obs["rgb_image"] / 255. # normalizd to 0 to 1
|
|
92
132
|
episode_action[_step] = action
|
|
93
133
|
|
|
94
134
|
state_obs, reward, terminated, truncated, info = env.step(action)
|
|
@@ -127,7 +167,8 @@ def collect_demonstrations(
|
|
|
127
167
|
temp_env = gym.make(env_id)
|
|
128
168
|
temp_env = FullyObsWrapper(temp_env.unwrapped)
|
|
129
169
|
temp_env = SymbolicObsWrapper(temp_env.unwrapped)
|
|
130
|
-
|
|
170
|
+
temp_env = RGBImgPartialObsWrapper(temp_env.unwrapped)
|
|
171
|
+
state_shape = temp_env.observation_space['rgb_image'].shape
|
|
131
172
|
temp_env.close()
|
|
132
173
|
|
|
133
174
|
logger.info(f"Detected state shape: {state_shape} for env {env_id}")
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
from typing import Any, List, Type, Union, Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
from torch import nn
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from metacontroller.metacontroller import Transformer
|
|
8
|
+
|
|
9
|
+
class TransformerWithResnetEncoder(Transformer):
|
|
10
|
+
def __init__(self, **kwargs):
|
|
11
|
+
super().__init__(**kwargs)
|
|
12
|
+
self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
|
|
13
|
+
self.visual_encoder = resnet18(out_dim=self.resnet_dim)
|
|
14
|
+
|
|
15
|
+
def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
b, t = x.shape[:2]
|
|
17
|
+
x = rearrange(x, 'b t h w c -> (b t) c h w')
|
|
18
|
+
h = self.visual_encoder(x)
|
|
19
|
+
h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
|
|
20
|
+
return h
|
|
21
|
+
|
|
22
|
+
# resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
|
|
23
|
+
|
|
24
|
+
class _BasicBlock(nn.Module):
|
|
25
|
+
expansion: int = 1
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
in_channels: int,
|
|
30
|
+
out_channels: int,
|
|
31
|
+
stride: int,
|
|
32
|
+
downsample: Optional[nn.Module] = None,
|
|
33
|
+
groups: int = 1,
|
|
34
|
+
base_channels: int = 64,
|
|
35
|
+
) -> None:
|
|
36
|
+
super(_BasicBlock, self).__init__()
|
|
37
|
+
self.stride = stride
|
|
38
|
+
self.downsample = downsample
|
|
39
|
+
self.groups = groups
|
|
40
|
+
self.base_channels = base_channels
|
|
41
|
+
|
|
42
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
|
|
43
|
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
44
|
+
self.relu = nn.ReLU(True)
|
|
45
|
+
self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
|
|
46
|
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
47
|
+
|
|
48
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
49
|
+
identity = x
|
|
50
|
+
|
|
51
|
+
out = self.conv1(x)
|
|
52
|
+
out = self.bn1(out)
|
|
53
|
+
out = self.relu(out)
|
|
54
|
+
|
|
55
|
+
out = self.conv2(out)
|
|
56
|
+
out = self.bn2(out)
|
|
57
|
+
|
|
58
|
+
if self.downsample is not None:
|
|
59
|
+
identity = self.downsample(x)
|
|
60
|
+
|
|
61
|
+
out = torch.add(out, identity)
|
|
62
|
+
out = self.relu(out)
|
|
63
|
+
|
|
64
|
+
return out
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class _Bottleneck(nn.Module):
|
|
68
|
+
expansion: int = 4
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
in_channels: int,
|
|
73
|
+
out_channels: int,
|
|
74
|
+
stride: int,
|
|
75
|
+
downsample: Optional[nn.Module] = None,
|
|
76
|
+
groups: int = 1,
|
|
77
|
+
base_channels: int = 64,
|
|
78
|
+
) -> None:
|
|
79
|
+
super(_Bottleneck, self).__init__()
|
|
80
|
+
self.stride = stride
|
|
81
|
+
self.downsample = downsample
|
|
82
|
+
self.groups = groups
|
|
83
|
+
self.base_channels = base_channels
|
|
84
|
+
|
|
85
|
+
channels = int(out_channels * (base_channels / 64.0)) * groups
|
|
86
|
+
|
|
87
|
+
self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
|
|
88
|
+
self.bn1 = nn.BatchNorm2d(channels)
|
|
89
|
+
self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
|
|
90
|
+
self.bn2 = nn.BatchNorm2d(channels)
|
|
91
|
+
self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
|
|
92
|
+
self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
|
|
93
|
+
self.relu = nn.ReLU(True)
|
|
94
|
+
|
|
95
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
96
|
+
identity = x
|
|
97
|
+
|
|
98
|
+
out = self.conv1(x)
|
|
99
|
+
out = self.bn1(out)
|
|
100
|
+
out = self.relu(out)
|
|
101
|
+
|
|
102
|
+
out = self.conv2(out)
|
|
103
|
+
out = self.bn2(out)
|
|
104
|
+
out = self.relu(out)
|
|
105
|
+
|
|
106
|
+
out = self.conv3(out)
|
|
107
|
+
out = self.bn3(out)
|
|
108
|
+
|
|
109
|
+
if self.downsample is not None:
|
|
110
|
+
identity = self.downsample(x)
|
|
111
|
+
|
|
112
|
+
out = torch.add(out, identity)
|
|
113
|
+
out = self.relu(out)
|
|
114
|
+
|
|
115
|
+
return out
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class ResNet(nn.Module):
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
arch_cfg: List[int],
|
|
123
|
+
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
124
|
+
groups: int = 1,
|
|
125
|
+
channels_per_group: int = 64,
|
|
126
|
+
out_dim: int = 1000,
|
|
127
|
+
) -> None:
|
|
128
|
+
super(ResNet, self).__init__()
|
|
129
|
+
self.in_channels = 64
|
|
130
|
+
self.dilation = 1
|
|
131
|
+
self.groups = groups
|
|
132
|
+
self.base_channels = channels_per_group
|
|
133
|
+
|
|
134
|
+
self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
|
|
135
|
+
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
|
136
|
+
self.relu = nn.ReLU(True)
|
|
137
|
+
self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
|
|
138
|
+
|
|
139
|
+
self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
|
|
140
|
+
self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
|
|
141
|
+
self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
|
|
142
|
+
self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
|
|
143
|
+
|
|
144
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
145
|
+
|
|
146
|
+
self.fc = nn.Linear(512 * block.expansion, out_dim)
|
|
147
|
+
|
|
148
|
+
# Initialize neural network weights
|
|
149
|
+
self._initialize_weights()
|
|
150
|
+
|
|
151
|
+
def _make_layer(
|
|
152
|
+
self,
|
|
153
|
+
repeat_times: int,
|
|
154
|
+
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
155
|
+
channels: int,
|
|
156
|
+
stride: int = 1,
|
|
157
|
+
) -> nn.Sequential:
|
|
158
|
+
downsample = None
|
|
159
|
+
|
|
160
|
+
if stride != 1 or self.in_channels != channels * block.expansion:
|
|
161
|
+
downsample = nn.Sequential(
|
|
162
|
+
nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
|
|
163
|
+
nn.BatchNorm2d(channels * block.expansion),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
layers = [
|
|
167
|
+
block(
|
|
168
|
+
self.in_channels,
|
|
169
|
+
channels,
|
|
170
|
+
stride,
|
|
171
|
+
downsample,
|
|
172
|
+
self.groups,
|
|
173
|
+
self.base_channels
|
|
174
|
+
)
|
|
175
|
+
]
|
|
176
|
+
self.in_channels = channels * block.expansion
|
|
177
|
+
for _ in range(1, repeat_times):
|
|
178
|
+
layers.append(
|
|
179
|
+
block(
|
|
180
|
+
self.in_channels,
|
|
181
|
+
channels,
|
|
182
|
+
1,
|
|
183
|
+
None,
|
|
184
|
+
self.groups,
|
|
185
|
+
self.base_channels,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return nn.Sequential(*layers)
|
|
190
|
+
|
|
191
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
192
|
+
out = self._forward_impl(x)
|
|
193
|
+
|
|
194
|
+
return out
|
|
195
|
+
|
|
196
|
+
# Support torch.script function
|
|
197
|
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
198
|
+
out = self.conv1(x)
|
|
199
|
+
out = self.bn1(out)
|
|
200
|
+
out = self.relu(out)
|
|
201
|
+
out = self.maxpool(out)
|
|
202
|
+
|
|
203
|
+
out = self.layer1(out)
|
|
204
|
+
out = self.layer2(out)
|
|
205
|
+
out = self.layer3(out)
|
|
206
|
+
out = self.layer4(out)
|
|
207
|
+
|
|
208
|
+
out = self.avgpool(out)
|
|
209
|
+
out = torch.flatten(out, 1)
|
|
210
|
+
out = self.fc(out)
|
|
211
|
+
|
|
212
|
+
return out
|
|
213
|
+
|
|
214
|
+
def _initialize_weights(self) -> None:
|
|
215
|
+
for module in self.modules():
|
|
216
|
+
if isinstance(module, nn.Conv2d):
|
|
217
|
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
218
|
+
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
219
|
+
nn.init.constant_(module.weight, 1)
|
|
220
|
+
nn.init.constant_(module.bias, 0)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def resnet18(**kwargs: Any) -> ResNet:
|
|
224
|
+
model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
|
|
225
|
+
|
|
226
|
+
return model
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def resnet34(**kwargs: Any) -> ResNet:
|
|
230
|
+
model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
|
|
231
|
+
|
|
232
|
+
return model
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def resnet50(**kwargs: Any) -> ResNet:
|
|
236
|
+
model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
|
|
237
|
+
|
|
238
|
+
return model
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def resnet101(**kwargs: Any) -> ResNet:
|
|
242
|
+
model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
|
|
243
|
+
|
|
244
|
+
return model
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def resnet152(**kwargs: Any) -> ResNet:
|
|
248
|
+
model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
|
|
249
|
+
|
|
250
|
+
return model
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
set -e
|
|
3
|
+
|
|
4
|
+
# 1. Gather trajectories
|
|
5
|
+
echo "Gathering trajectories..."
|
|
6
|
+
uv run gather_babyai_trajs.py --num_seeds 1000 --num_episodes_per_seed 100 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
|
|
7
|
+
|
|
8
|
+
# 2. Behavioral cloning
|
|
9
|
+
echo "Training behavioral cloning model..."
|
|
10
|
+
uv run train_behavior_clone_babyai.py --cloning_epochs 10 --discovery_epochs 10 --batch_size 256 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt --use_resnet
|
|
11
|
+
|
|
12
|
+
# 3. Inference rollouts
|
|
13
|
+
echo "Running inference rollouts..."
|
|
14
|
+
uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
|
{metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/train_behavior_clone_babyai.py
RENAMED
|
@@ -26,6 +26,7 @@ from memmap_replay_buffer import ReplayBuffer
|
|
|
26
26
|
from einops import rearrange
|
|
27
27
|
|
|
28
28
|
from metacontroller.metacontroller import Transformer
|
|
29
|
+
from metacontroller.metacontroller_with_resnet import TransformerWithResnetEncoder
|
|
29
30
|
|
|
30
31
|
import minigrid
|
|
31
32
|
import gymnasium as gym
|
|
@@ -33,7 +34,8 @@ import gymnasium as gym
|
|
|
33
34
|
def train(
|
|
34
35
|
input_dir: str = "babyai-minibosslevel-trajectories",
|
|
35
36
|
env_id: str = "BabyAI-MiniBossLevel-v0",
|
|
36
|
-
|
|
37
|
+
cloning_epochs: int = 10,
|
|
38
|
+
discovery_epochs: int = 10,
|
|
37
39
|
batch_size: int = 32,
|
|
38
40
|
lr: float = 1e-4,
|
|
39
41
|
dim: int = 512,
|
|
@@ -44,7 +46,8 @@ def train(
|
|
|
44
46
|
wandb_project: str = "metacontroller-babyai-bc",
|
|
45
47
|
checkpoint_path: str = "transformer_bc.pt",
|
|
46
48
|
state_loss_weight: float = 1.,
|
|
47
|
-
action_loss_weight: float = 1
|
|
49
|
+
action_loss_weight: float = 1.,
|
|
50
|
+
use_resnet: bool = False
|
|
48
51
|
):
|
|
49
52
|
# accelerator
|
|
50
53
|
|
|
@@ -54,7 +57,8 @@ def train(
|
|
|
54
57
|
accelerator.init_trackers(
|
|
55
58
|
wandb_project,
|
|
56
59
|
config = {
|
|
57
|
-
"
|
|
60
|
+
"cloning_epochs": cloning_epochs,
|
|
61
|
+
"discovery_epochs": discovery_epochs,
|
|
58
62
|
"batch_size": batch_size,
|
|
59
63
|
"lr": lr,
|
|
60
64
|
"dim": dim,
|
|
@@ -78,12 +82,8 @@ def train(
|
|
|
78
82
|
# state shape and action dimension
|
|
79
83
|
# state: (B, T, H, W, C) or (B, T, D)
|
|
80
84
|
state_shape = replay_buffer.shapes['state']
|
|
81
|
-
state_dim =
|
|
82
|
-
|
|
83
|
-
# state shape and action dimension
|
|
84
|
-
# state: (B, T, H, W, C) or (B, T, D)
|
|
85
|
-
state_shape = replay_buffer.shapes['state']
|
|
86
|
-
state_dim = int(torch.tensor(state_shape).prod().item())
|
|
85
|
+
if use_resnet: state_dim = 256
|
|
86
|
+
else: state_dim = int(torch.tensor(state_shape).prod().item())
|
|
87
87
|
|
|
88
88
|
# deduce num_actions from the environment
|
|
89
89
|
minigrid.register_minigrid_envs()
|
|
@@ -94,8 +94,9 @@ def train(
|
|
|
94
94
|
accelerator.print(f"Detected state_dim: {state_dim}, num_actions: {num_actions} from env: {env_id}")
|
|
95
95
|
|
|
96
96
|
# transformer
|
|
97
|
-
|
|
98
|
-
|
|
97
|
+
|
|
98
|
+
transformer_class = TransformerWithResnetEncoder if use_resnet else Transformer
|
|
99
|
+
model = transformer_class(
|
|
99
100
|
dim = dim,
|
|
100
101
|
state_embed_readout = dict(num_continuous = state_dim),
|
|
101
102
|
action_embed_readout = dict(num_discrete = num_actions),
|
|
@@ -112,13 +113,13 @@ def train(
|
|
|
112
113
|
model, optim, dataloader = accelerator.prepare(model, optim, dataloader)
|
|
113
114
|
|
|
114
115
|
# training
|
|
115
|
-
|
|
116
|
-
for epoch in range(epochs):
|
|
116
|
+
for epoch in range(cloning_epochs + discovery_epochs):
|
|
117
117
|
model.train()
|
|
118
118
|
total_state_loss = 0.
|
|
119
119
|
total_action_loss = 0.
|
|
120
120
|
|
|
121
121
|
progress_bar = tqdm(dataloader, desc = f"Epoch {epoch}", disable = not accelerator.is_local_main_process)
|
|
122
|
+
is_discovering = (epoch >= cloning_epochs) # discovery phase is BC with metacontroller tuning
|
|
122
123
|
|
|
123
124
|
for batch in progress_bar:
|
|
124
125
|
# batch is a NamedTuple (e.g. MemoryMappedBatch)
|
|
@@ -128,15 +129,20 @@ def train(
|
|
|
128
129
|
actions = batch['action'].long()
|
|
129
130
|
episode_lens = batch.get('_lens')
|
|
130
131
|
|
|
131
|
-
#
|
|
132
|
-
|
|
133
|
-
|
|
132
|
+
# use resnet18 to embed visual observations
|
|
133
|
+
if use_resnet:
|
|
134
|
+
states = model.visual_encode(states)
|
|
135
|
+
else: # flatten state: (B, T, 7, 7, 3) -> (B, T, 147)
|
|
136
|
+
states = rearrange(states, 'b t ... -> b t (...)')
|
|
134
137
|
|
|
135
138
|
with accelerator.accumulate(model):
|
|
136
|
-
state_loss, action_loss = model(states, actions, episode_lens = episode_lens)
|
|
139
|
+
state_loss, action_loss = model(states, actions, episode_lens = episode_lens, discovery_phase=is_discovering)
|
|
137
140
|
loss = state_loss * state_loss_weight + action_loss * action_loss_weight
|
|
138
141
|
|
|
139
142
|
accelerator.backward(loss)
|
|
143
|
+
|
|
144
|
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
145
|
+
|
|
140
146
|
optim.step()
|
|
141
147
|
optim.zero_grad()
|
|
142
148
|
|
|
@@ -148,7 +154,8 @@ def train(
|
|
|
148
154
|
accelerator.log({
|
|
149
155
|
"state_loss": state_loss.item(),
|
|
150
156
|
"action_loss": action_loss.item(),
|
|
151
|
-
"total_loss": loss.item()
|
|
157
|
+
"total_loss": loss.item(),
|
|
158
|
+
"grad_norm": grad_norm.item()
|
|
152
159
|
})
|
|
153
160
|
|
|
154
161
|
progress_bar.set_postfix(
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
#!/bin/bash
|
|
2
|
-
set -e
|
|
3
|
-
|
|
4
|
-
# 1. Gather trajectories
|
|
5
|
-
echo "Gathering trajectories..."
|
|
6
|
-
uv run gather_babyai_trajs.py --num_seeds 10 --num_episodes_per_seed 10 --output_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0
|
|
7
|
-
|
|
8
|
-
# 2. Behavioral cloning
|
|
9
|
-
echo "Training behavioral cloning model..."
|
|
10
|
-
uv run train_behavior_clone_babyai.py --epochs 1 --batch_size 16 --input_dir end_to_end_trajectories --env_id BabyAI-MiniBossLevel-v0 --checkpoint_path end_to_end_model.pt
|
|
11
|
-
|
|
12
|
-
# 3. Inference rollouts
|
|
13
|
-
echo "Running inference rollouts..."
|
|
14
|
-
uv run train_babyai.py --weights_path end_to_end_model.pt --env_name BabyAI-MiniBossLevel-v0 --num_episodes 5 --buffer_size 100 --max_timesteps 100
|
{metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/metacontroller/metacontroller.py
RENAMED
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.28 → metacontroller_pytorch-0.0.29}/tests/test_metacontroller.py
RENAMED
|
File without changes
|
|
File without changes
|