metacontroller-pytorch 0.0.32__tar.gz → 0.0.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of metacontroller-pytorch might be problematic. Click here for more details.
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/PKG-INFO +2 -2
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller.py +54 -4
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller_with_binary_mapper.py +4 -1
- metacontroller_pytorch-0.0.33/metacontroller/transformer_with_resnet.py +194 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/pyproject.toml +2 -2
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/tests/test_metacontroller.py +80 -14
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/train_behavior_clone_babyai.py +2 -2
- metacontroller_pytorch-0.0.32/metacontroller/metacontroller_with_resnet.py +0 -250
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/README.md +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/gather_babyai_trajs.py +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/test_babyai_e2e.sh +0 -0
- {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/train_babyai.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -40,7 +40,7 @@ Requires-Dist: einops>=0.8.1
|
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
42
|
Requires-Dist: memmap-replay-buffer>=0.0.23
|
|
43
|
-
Requires-Dist: torch-einops-utils>=0.0.
|
|
43
|
+
Requires-Dist: torch-einops-utils>=0.0.19
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
45
|
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
|
46
46
|
Requires-Dist: x-evolution>=0.1.23
|
{metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller.py
RENAMED
|
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
-
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
29
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
|
|
30
30
|
from torch_einops_utils.save_load import save_load
|
|
31
31
|
|
|
32
32
|
# constants
|
|
@@ -66,6 +66,47 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
66
66
|
'switch_loss'
|
|
67
67
|
))
|
|
68
68
|
|
|
69
|
+
def z_score(t, eps = 1e-8):
|
|
70
|
+
return (t - t.mean()) / (t.std() + eps)
|
|
71
|
+
|
|
72
|
+
def policy_loss(
|
|
73
|
+
meta_controller,
|
|
74
|
+
state,
|
|
75
|
+
old_log_probs,
|
|
76
|
+
actions,
|
|
77
|
+
advantages,
|
|
78
|
+
mask,
|
|
79
|
+
episode_lens = None,
|
|
80
|
+
eps_clip = 0.2
|
|
81
|
+
):
|
|
82
|
+
# get new log probs
|
|
83
|
+
|
|
84
|
+
action_dist = meta_controller.get_action_dist_for_internal_rl(state)
|
|
85
|
+
new_log_probs = meta_controller.log_prob(action_dist, actions)
|
|
86
|
+
|
|
87
|
+
# calculate ratio
|
|
88
|
+
|
|
89
|
+
ratio = (new_log_probs - old_log_probs).exp()
|
|
90
|
+
|
|
91
|
+
# align ratio and advantages
|
|
92
|
+
|
|
93
|
+
ratio, advantages = align_dims_left((ratio, advantages))
|
|
94
|
+
|
|
95
|
+
# ppo surrogate loss
|
|
96
|
+
|
|
97
|
+
surr1 = ratio * advantages
|
|
98
|
+
surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
|
|
99
|
+
|
|
100
|
+
losses = -torch.min(surr1, surr2)
|
|
101
|
+
|
|
102
|
+
# masking
|
|
103
|
+
|
|
104
|
+
if exists(episode_lens):
|
|
105
|
+
mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
|
|
106
|
+
mask = mask & episode_mask
|
|
107
|
+
|
|
108
|
+
return masked_mean(losses, mask)
|
|
109
|
+
|
|
69
110
|
@save_load()
|
|
70
111
|
class MetaController(Module):
|
|
71
112
|
def __init__(
|
|
@@ -146,6 +187,16 @@ class MetaController(Module):
|
|
|
146
187
|
*self.action_proposer_mean_log_var.parameters()
|
|
147
188
|
]
|
|
148
189
|
|
|
190
|
+
def get_action_dist_for_internal_rl(
|
|
191
|
+
self,
|
|
192
|
+
residual_stream
|
|
193
|
+
):
|
|
194
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
195
|
+
|
|
196
|
+
proposed_action_hidden, _ = self.action_proposer(meta_embed)
|
|
197
|
+
|
|
198
|
+
return self.action_proposer_mean_log_var(proposed_action_hidden)
|
|
199
|
+
|
|
149
200
|
def log_prob(
|
|
150
201
|
self,
|
|
151
202
|
action_dist,
|
|
@@ -429,9 +480,8 @@ class Transformer(Module):
|
|
|
429
480
|
# maybe return behavior cloning loss
|
|
430
481
|
|
|
431
482
|
if behavioral_cloning:
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
loss_mask = lens_to_mask(episode_lens, state.shape[1])
|
|
483
|
+
|
|
484
|
+
loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
|
|
435
485
|
|
|
436
486
|
state_dist_params = self.state_readout(attended)
|
|
437
487
|
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
|
|
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
|
|
|
23
23
|
|
|
24
24
|
from assoc_scan import AssocScan
|
|
25
25
|
|
|
26
|
-
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
|
|
26
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
|
|
27
27
|
from torch_einops_utils.save_load import save_load
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
@@ -74,6 +74,9 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
74
74
|
kl_loss_threshold = 0.
|
|
75
75
|
):
|
|
76
76
|
super().__init__()
|
|
77
|
+
|
|
78
|
+
assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
|
|
79
|
+
|
|
77
80
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
78
81
|
|
|
79
82
|
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
from torch.nn import Module, ModuleList
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
|
|
9
|
+
from metacontroller.metacontroller import Transformer
|
|
10
|
+
|
|
11
|
+
from torch_einops_utils import pack_with_inverse
|
|
12
|
+
|
|
13
|
+
# resnet components
|
|
14
|
+
|
|
15
|
+
def exists(v):
|
|
16
|
+
return v is not None
|
|
17
|
+
|
|
18
|
+
class BasicBlock(Module):
|
|
19
|
+
expansion = 1
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dim,
|
|
24
|
+
dim_out,
|
|
25
|
+
stride = 1,
|
|
26
|
+
downsample: Module | None = None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
|
|
30
|
+
self.bn1 = nn.BatchNorm2d(dim_out)
|
|
31
|
+
self.relu = nn.ReLU(inplace = True)
|
|
32
|
+
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
|
|
33
|
+
self.bn2 = nn.BatchNorm2d(dim_out)
|
|
34
|
+
self.downsample = downsample
|
|
35
|
+
|
|
36
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
37
|
+
identity = x
|
|
38
|
+
|
|
39
|
+
out = self.conv1(x)
|
|
40
|
+
out = self.bn1(out)
|
|
41
|
+
out = self.relu(out)
|
|
42
|
+
|
|
43
|
+
out = self.conv2(out)
|
|
44
|
+
out = self.bn2(out)
|
|
45
|
+
|
|
46
|
+
if exists(self.downsample):
|
|
47
|
+
identity = self.downsample(x)
|
|
48
|
+
|
|
49
|
+
out += identity
|
|
50
|
+
return self.relu(out)
|
|
51
|
+
|
|
52
|
+
class Bottleneck(Module):
|
|
53
|
+
expansion = 4
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
dim,
|
|
58
|
+
dim_out,
|
|
59
|
+
stride = 1,
|
|
60
|
+
downsample: Module | None = None
|
|
61
|
+
):
|
|
62
|
+
super().__init__()
|
|
63
|
+
width = dim_out # simple resnet shortcut
|
|
64
|
+
self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
|
|
65
|
+
self.bn1 = nn.BatchNorm2d(width)
|
|
66
|
+
self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
|
|
67
|
+
self.bn2 = nn.BatchNorm2d(width)
|
|
68
|
+
self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
|
|
69
|
+
self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
|
|
70
|
+
self.relu = nn.ReLU(inplace = True)
|
|
71
|
+
self.downsample = downsample
|
|
72
|
+
|
|
73
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
74
|
+
identity = x
|
|
75
|
+
|
|
76
|
+
out = self.conv1(x)
|
|
77
|
+
out = self.bn1(out)
|
|
78
|
+
out = self.relu(out)
|
|
79
|
+
|
|
80
|
+
out = self.conv2(out)
|
|
81
|
+
out = self.bn2(out)
|
|
82
|
+
out = self.relu(out)
|
|
83
|
+
|
|
84
|
+
out = self.conv3(out)
|
|
85
|
+
out = self.bn3(out)
|
|
86
|
+
|
|
87
|
+
if exists(self.downsample):
|
|
88
|
+
identity = self.downsample(x)
|
|
89
|
+
|
|
90
|
+
out += identity
|
|
91
|
+
return self.relu(out)
|
|
92
|
+
|
|
93
|
+
class ResNet(Module):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
block: type[BasicBlock | Bottleneck],
|
|
97
|
+
layers: list[int],
|
|
98
|
+
num_classes = 1000,
|
|
99
|
+
channels = 3
|
|
100
|
+
):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.inplanes = 64
|
|
103
|
+
|
|
104
|
+
self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
|
|
105
|
+
self.bn1 = nn.BatchNorm2d(64)
|
|
106
|
+
self.relu = nn.ReLU(inplace = True)
|
|
107
|
+
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
|
|
108
|
+
|
|
109
|
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
110
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
|
|
111
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
|
|
112
|
+
self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
|
|
113
|
+
|
|
114
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
115
|
+
self.flatten = Rearrange('b c 1 1 -> b c')
|
|
116
|
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
117
|
+
|
|
118
|
+
def _make_layer(
|
|
119
|
+
self,
|
|
120
|
+
block: type[BasicBlock | Bottleneck],
|
|
121
|
+
planes: int,
|
|
122
|
+
blocks: int,
|
|
123
|
+
stride: int = 1
|
|
124
|
+
) -> nn.Sequential:
|
|
125
|
+
downsample = None
|
|
126
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
127
|
+
downsample = nn.Sequential(
|
|
128
|
+
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
|
|
129
|
+
nn.BatchNorm2d(planes * block.expansion),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
layers = []
|
|
133
|
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
134
|
+
self.inplanes = planes * block.expansion
|
|
135
|
+
for _ in range(1, blocks):
|
|
136
|
+
layers.append(block(self.inplanes, planes))
|
|
137
|
+
|
|
138
|
+
return nn.Sequential(*layers)
|
|
139
|
+
|
|
140
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
141
|
+
x = self.conv1(x)
|
|
142
|
+
x = self.bn1(x)
|
|
143
|
+
x = self.relu(x)
|
|
144
|
+
x = self.maxpool(x)
|
|
145
|
+
|
|
146
|
+
x = self.layer1(x)
|
|
147
|
+
x = self.layer2(x)
|
|
148
|
+
x = self.layer3(x)
|
|
149
|
+
x = self.layer4(x)
|
|
150
|
+
|
|
151
|
+
x = self.avgpool(x)
|
|
152
|
+
x = self.flatten(x)
|
|
153
|
+
x = self.fc(x)
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
# resnet factory
|
|
157
|
+
|
|
158
|
+
def resnet18(num_classes: any = 1000):
|
|
159
|
+
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
|
|
160
|
+
|
|
161
|
+
def resnet34(num_classes: any = 1000):
|
|
162
|
+
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
|
|
163
|
+
|
|
164
|
+
def resnet50(num_classes: any = 1000):
|
|
165
|
+
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
|
|
166
|
+
|
|
167
|
+
# transformer with resnet
|
|
168
|
+
|
|
169
|
+
class TransformerWithResnet(Transformer):
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
*,
|
|
173
|
+
resnet_type = 'resnet18',
|
|
174
|
+
**kwargs
|
|
175
|
+
):
|
|
176
|
+
super().__init__(**kwargs)
|
|
177
|
+
resnet_klass = resnet18
|
|
178
|
+
if resnet_type == 'resnet34':
|
|
179
|
+
resnet_klass = resnet34
|
|
180
|
+
elif resnet_type == 'resnet50':
|
|
181
|
+
resnet_klass = resnet50
|
|
182
|
+
|
|
183
|
+
self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
|
|
184
|
+
self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
|
|
185
|
+
|
|
186
|
+
def visual_encode(self, x: Tensor) -> Tensor:
|
|
187
|
+
if x.shape[-1] == 3:
|
|
188
|
+
x = rearrange(x, '... h w c -> ... c h w')
|
|
189
|
+
|
|
190
|
+
x, inverse = pack_with_inverse(x, '* c h w')
|
|
191
|
+
|
|
192
|
+
h = self.visual_encoder(x)
|
|
193
|
+
|
|
194
|
+
return inverse(h, '* d')
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.33"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -31,7 +31,7 @@ dependencies = [
|
|
|
31
31
|
"loguru",
|
|
32
32
|
"memmap-replay-buffer>=0.0.23",
|
|
33
33
|
"torch>=2.5",
|
|
34
|
-
"torch-einops-utils>=0.0.
|
|
34
|
+
"torch-einops-utils>=0.0.19",
|
|
35
35
|
"vector-quantize-pytorch>=1.27.20",
|
|
36
36
|
"x-evolution>=0.1.23",
|
|
37
37
|
"x-mlps-pytorch",
|
{metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/tests/test_metacontroller.py
RENAMED
|
@@ -4,19 +4,30 @@ param = pytest.mark.parametrize
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from
|
|
7
|
+
from torch import cat
|
|
8
|
+
from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
|
|
8
9
|
from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
|
|
9
10
|
|
|
10
11
|
from einops import rearrange
|
|
11
12
|
|
|
12
|
-
|
|
13
|
+
# functions
|
|
14
|
+
|
|
15
|
+
def exists(v):
|
|
16
|
+
return v is not None
|
|
17
|
+
|
|
18
|
+
# test
|
|
19
|
+
|
|
20
|
+
@param('use_binary_mapper_variant, switch_per_latent_dim', [
|
|
21
|
+
(False, False),
|
|
22
|
+
(False, True),
|
|
23
|
+
(True, False)
|
|
24
|
+
])
|
|
13
25
|
@param('action_discrete', (False, True))
|
|
14
|
-
@param('switch_per_latent_dim', (False, True))
|
|
15
26
|
@param('variable_length', (False, True))
|
|
16
27
|
def test_metacontroller(
|
|
17
28
|
use_binary_mapper_variant,
|
|
18
|
-
action_discrete,
|
|
19
29
|
switch_per_latent_dim,
|
|
30
|
+
action_discrete,
|
|
20
31
|
variable_length
|
|
21
32
|
):
|
|
22
33
|
|
|
@@ -69,22 +80,77 @@ def test_metacontroller(
|
|
|
69
80
|
|
|
70
81
|
# internal rl - done iteratively
|
|
71
82
|
|
|
72
|
-
|
|
73
|
-
|
|
83
|
+
# simulate grpo
|
|
84
|
+
|
|
85
|
+
all_episodes = []
|
|
86
|
+
all_rewards = []
|
|
87
|
+
|
|
88
|
+
for _ in range(3): # group of 3
|
|
89
|
+
subset_state = state[:1]
|
|
90
|
+
|
|
91
|
+
cache = None
|
|
92
|
+
past_action_id = None
|
|
93
|
+
|
|
94
|
+
states = []
|
|
95
|
+
log_probs = []
|
|
96
|
+
switch_betas = []
|
|
97
|
+
latent_actions = []
|
|
98
|
+
|
|
99
|
+
for one_state in subset_state.unbind(dim = 1):
|
|
100
|
+
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
101
|
+
|
|
102
|
+
logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
|
|
74
103
|
|
|
75
|
-
|
|
76
|
-
one_state = rearrange(one_state, 'b d -> b 1 d')
|
|
104
|
+
past_action_id = model.action_readout.sample(logits)
|
|
77
105
|
|
|
78
|
-
|
|
106
|
+
# get log prob from meta controller latent actions
|
|
79
107
|
|
|
80
|
-
|
|
81
|
-
past_action_id = model.action_readout.sample(logits)
|
|
108
|
+
meta_output = cache.prev_hiddens.meta_controller
|
|
82
109
|
|
|
83
|
-
|
|
110
|
+
old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
|
|
84
111
|
|
|
85
|
-
|
|
112
|
+
states.append(meta_output.input_residual_stream)
|
|
113
|
+
log_probs.append(old_log_probs)
|
|
114
|
+
switch_betas.append(meta_output.switch_beta)
|
|
115
|
+
latent_actions.append(meta_output.actions)
|
|
116
|
+
|
|
117
|
+
# accumulate across time for the episode data
|
|
118
|
+
|
|
119
|
+
all_episodes.append(dict(
|
|
120
|
+
states = cat(states, dim = 1),
|
|
121
|
+
log_probs = cat(log_probs, dim = 1),
|
|
122
|
+
switch_betas = cat(switch_betas, dim = 1),
|
|
123
|
+
latent_actions = cat(latent_actions, dim = 1)
|
|
124
|
+
))
|
|
125
|
+
|
|
126
|
+
all_rewards.append(torch.randn(1))
|
|
127
|
+
|
|
128
|
+
# calculate advantages using z-score
|
|
129
|
+
|
|
130
|
+
rewards = cat(all_rewards)
|
|
131
|
+
advantages = z_score(rewards)
|
|
132
|
+
|
|
133
|
+
assert advantages.shape == (3,)
|
|
134
|
+
|
|
135
|
+
# simulate a policy loss update over the entire group
|
|
136
|
+
|
|
137
|
+
group_states = cat([e['states'] for e in all_episodes], dim = 0)
|
|
138
|
+
group_log_probs = cat([e['log_probs'] for e in all_episodes], dim = 0)
|
|
139
|
+
group_latent_actions = cat([e['latent_actions'] for e in all_episodes], dim = 0)
|
|
140
|
+
group_switch_betas = cat([e['switch_betas'] for e in all_episodes], dim = 0)
|
|
141
|
+
|
|
142
|
+
if not use_binary_mapper_variant:
|
|
143
|
+
loss = policy_loss(
|
|
144
|
+
meta_controller,
|
|
145
|
+
group_states,
|
|
146
|
+
group_log_probs,
|
|
147
|
+
group_latent_actions,
|
|
148
|
+
advantages,
|
|
149
|
+
group_switch_betas == 1.,
|
|
150
|
+
episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
|
|
151
|
+
)
|
|
86
152
|
|
|
87
|
-
|
|
153
|
+
loss.backward()
|
|
88
154
|
|
|
89
155
|
# evolutionary strategies over grpo
|
|
90
156
|
|
{metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/train_behavior_clone_babyai.py
RENAMED
|
@@ -26,7 +26,7 @@ from memmap_replay_buffer import ReplayBuffer
|
|
|
26
26
|
from einops import rearrange
|
|
27
27
|
|
|
28
28
|
from metacontroller.metacontroller import Transformer
|
|
29
|
-
from metacontroller.
|
|
29
|
+
from metacontroller.transformer_with_resnet import TransformerWithResnet
|
|
30
30
|
|
|
31
31
|
import minigrid
|
|
32
32
|
import gymnasium as gym
|
|
@@ -95,7 +95,7 @@ def train(
|
|
|
95
95
|
|
|
96
96
|
# transformer
|
|
97
97
|
|
|
98
|
-
transformer_class =
|
|
98
|
+
transformer_class = TransformerWithResnet if use_resnet else Transformer
|
|
99
99
|
model = transformer_class(
|
|
100
100
|
dim = dim,
|
|
101
101
|
state_embed_readout = dict(num_continuous = state_dim),
|
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
from typing import Any, List, Type, Union, Optional
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
from torch import nn
|
|
6
|
-
from einops import rearrange
|
|
7
|
-
from metacontroller.metacontroller import Transformer
|
|
8
|
-
|
|
9
|
-
class TransformerWithResnetEncoder(Transformer):
|
|
10
|
-
def __init__(self, **kwargs):
|
|
11
|
-
super().__init__(**kwargs)
|
|
12
|
-
self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
|
|
13
|
-
self.visual_encoder = resnet18(out_dim=self.resnet_dim)
|
|
14
|
-
|
|
15
|
-
def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
16
|
-
b, t = x.shape[:2]
|
|
17
|
-
x = rearrange(x, 'b t h w c -> (b t) c h w')
|
|
18
|
-
h = self.visual_encoder(x)
|
|
19
|
-
h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
|
|
20
|
-
return h
|
|
21
|
-
|
|
22
|
-
# resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
|
|
23
|
-
|
|
24
|
-
class _BasicBlock(nn.Module):
|
|
25
|
-
expansion: int = 1
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
in_channels: int,
|
|
30
|
-
out_channels: int,
|
|
31
|
-
stride: int,
|
|
32
|
-
downsample: Optional[nn.Module] = None,
|
|
33
|
-
groups: int = 1,
|
|
34
|
-
base_channels: int = 64,
|
|
35
|
-
) -> None:
|
|
36
|
-
super(_BasicBlock, self).__init__()
|
|
37
|
-
self.stride = stride
|
|
38
|
-
self.downsample = downsample
|
|
39
|
-
self.groups = groups
|
|
40
|
-
self.base_channels = base_channels
|
|
41
|
-
|
|
42
|
-
self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
|
|
43
|
-
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
44
|
-
self.relu = nn.ReLU(True)
|
|
45
|
-
self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
|
|
46
|
-
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
47
|
-
|
|
48
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
49
|
-
identity = x
|
|
50
|
-
|
|
51
|
-
out = self.conv1(x)
|
|
52
|
-
out = self.bn1(out)
|
|
53
|
-
out = self.relu(out)
|
|
54
|
-
|
|
55
|
-
out = self.conv2(out)
|
|
56
|
-
out = self.bn2(out)
|
|
57
|
-
|
|
58
|
-
if self.downsample is not None:
|
|
59
|
-
identity = self.downsample(x)
|
|
60
|
-
|
|
61
|
-
out = torch.add(out, identity)
|
|
62
|
-
out = self.relu(out)
|
|
63
|
-
|
|
64
|
-
return out
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class _Bottleneck(nn.Module):
|
|
68
|
-
expansion: int = 4
|
|
69
|
-
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
in_channels: int,
|
|
73
|
-
out_channels: int,
|
|
74
|
-
stride: int,
|
|
75
|
-
downsample: Optional[nn.Module] = None,
|
|
76
|
-
groups: int = 1,
|
|
77
|
-
base_channels: int = 64,
|
|
78
|
-
) -> None:
|
|
79
|
-
super(_Bottleneck, self).__init__()
|
|
80
|
-
self.stride = stride
|
|
81
|
-
self.downsample = downsample
|
|
82
|
-
self.groups = groups
|
|
83
|
-
self.base_channels = base_channels
|
|
84
|
-
|
|
85
|
-
channels = int(out_channels * (base_channels / 64.0)) * groups
|
|
86
|
-
|
|
87
|
-
self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
|
|
88
|
-
self.bn1 = nn.BatchNorm2d(channels)
|
|
89
|
-
self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
|
|
90
|
-
self.bn2 = nn.BatchNorm2d(channels)
|
|
91
|
-
self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
|
|
92
|
-
self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
|
|
93
|
-
self.relu = nn.ReLU(True)
|
|
94
|
-
|
|
95
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
96
|
-
identity = x
|
|
97
|
-
|
|
98
|
-
out = self.conv1(x)
|
|
99
|
-
out = self.bn1(out)
|
|
100
|
-
out = self.relu(out)
|
|
101
|
-
|
|
102
|
-
out = self.conv2(out)
|
|
103
|
-
out = self.bn2(out)
|
|
104
|
-
out = self.relu(out)
|
|
105
|
-
|
|
106
|
-
out = self.conv3(out)
|
|
107
|
-
out = self.bn3(out)
|
|
108
|
-
|
|
109
|
-
if self.downsample is not None:
|
|
110
|
-
identity = self.downsample(x)
|
|
111
|
-
|
|
112
|
-
out = torch.add(out, identity)
|
|
113
|
-
out = self.relu(out)
|
|
114
|
-
|
|
115
|
-
return out
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class ResNet(nn.Module):
|
|
119
|
-
|
|
120
|
-
def __init__(
|
|
121
|
-
self,
|
|
122
|
-
arch_cfg: List[int],
|
|
123
|
-
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
124
|
-
groups: int = 1,
|
|
125
|
-
channels_per_group: int = 64,
|
|
126
|
-
out_dim: int = 1000,
|
|
127
|
-
) -> None:
|
|
128
|
-
super(ResNet, self).__init__()
|
|
129
|
-
self.in_channels = 64
|
|
130
|
-
self.dilation = 1
|
|
131
|
-
self.groups = groups
|
|
132
|
-
self.base_channels = channels_per_group
|
|
133
|
-
|
|
134
|
-
self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
|
|
135
|
-
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
|
136
|
-
self.relu = nn.ReLU(True)
|
|
137
|
-
self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
|
|
138
|
-
|
|
139
|
-
self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
|
|
140
|
-
self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
|
|
141
|
-
self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
|
|
142
|
-
self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
|
|
143
|
-
|
|
144
|
-
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
145
|
-
|
|
146
|
-
self.fc = nn.Linear(512 * block.expansion, out_dim)
|
|
147
|
-
|
|
148
|
-
# Initialize neural network weights
|
|
149
|
-
self._initialize_weights()
|
|
150
|
-
|
|
151
|
-
def _make_layer(
|
|
152
|
-
self,
|
|
153
|
-
repeat_times: int,
|
|
154
|
-
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
155
|
-
channels: int,
|
|
156
|
-
stride: int = 1,
|
|
157
|
-
) -> nn.Sequential:
|
|
158
|
-
downsample = None
|
|
159
|
-
|
|
160
|
-
if stride != 1 or self.in_channels != channels * block.expansion:
|
|
161
|
-
downsample = nn.Sequential(
|
|
162
|
-
nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
|
|
163
|
-
nn.BatchNorm2d(channels * block.expansion),
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
layers = [
|
|
167
|
-
block(
|
|
168
|
-
self.in_channels,
|
|
169
|
-
channels,
|
|
170
|
-
stride,
|
|
171
|
-
downsample,
|
|
172
|
-
self.groups,
|
|
173
|
-
self.base_channels
|
|
174
|
-
)
|
|
175
|
-
]
|
|
176
|
-
self.in_channels = channels * block.expansion
|
|
177
|
-
for _ in range(1, repeat_times):
|
|
178
|
-
layers.append(
|
|
179
|
-
block(
|
|
180
|
-
self.in_channels,
|
|
181
|
-
channels,
|
|
182
|
-
1,
|
|
183
|
-
None,
|
|
184
|
-
self.groups,
|
|
185
|
-
self.base_channels,
|
|
186
|
-
)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
return nn.Sequential(*layers)
|
|
190
|
-
|
|
191
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
192
|
-
out = self._forward_impl(x)
|
|
193
|
-
|
|
194
|
-
return out
|
|
195
|
-
|
|
196
|
-
# Support torch.script function
|
|
197
|
-
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
198
|
-
out = self.conv1(x)
|
|
199
|
-
out = self.bn1(out)
|
|
200
|
-
out = self.relu(out)
|
|
201
|
-
out = self.maxpool(out)
|
|
202
|
-
|
|
203
|
-
out = self.layer1(out)
|
|
204
|
-
out = self.layer2(out)
|
|
205
|
-
out = self.layer3(out)
|
|
206
|
-
out = self.layer4(out)
|
|
207
|
-
|
|
208
|
-
out = self.avgpool(out)
|
|
209
|
-
out = torch.flatten(out, 1)
|
|
210
|
-
out = self.fc(out)
|
|
211
|
-
|
|
212
|
-
return out
|
|
213
|
-
|
|
214
|
-
def _initialize_weights(self) -> None:
|
|
215
|
-
for module in self.modules():
|
|
216
|
-
if isinstance(module, nn.Conv2d):
|
|
217
|
-
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
218
|
-
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
219
|
-
nn.init.constant_(module.weight, 1)
|
|
220
|
-
nn.init.constant_(module.bias, 0)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def resnet18(**kwargs: Any) -> ResNet:
|
|
224
|
-
model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
|
|
225
|
-
|
|
226
|
-
return model
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def resnet34(**kwargs: Any) -> ResNet:
|
|
230
|
-
model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
|
|
231
|
-
|
|
232
|
-
return model
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def resnet50(**kwargs: Any) -> ResNet:
|
|
236
|
-
model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
|
|
237
|
-
|
|
238
|
-
return model
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def resnet101(**kwargs: Any) -> ResNet:
|
|
242
|
-
model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
|
|
243
|
-
|
|
244
|
-
return model
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def resnet152(**kwargs: Any) -> ResNet:
|
|
248
|
-
model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
|
|
249
|
-
|
|
250
|
-
return model
|
{metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|