metacontroller-pytorch 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +68 -6
- metacontroller/metacontroller_with_binary_mapper.py +24 -1
- metacontroller/transformer_with_resnet.py +194 -0
- {metacontroller_pytorch-0.0.31.dist-info → metacontroller_pytorch-0.0.33.dist-info}/METADATA +2 -2
- metacontroller_pytorch-0.0.33.dist-info/RECORD +8 -0
- metacontroller/metacontroller_with_resnet.py +0 -250
- metacontroller_pytorch-0.0.31.dist-info/RECORD +0 -8
- {metacontroller_pytorch-0.0.31.dist-info → metacontroller_pytorch-0.0.33.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.31.dist-info → metacontroller_pytorch-0.0.33.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
-
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
29
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
|
|
30
30
|
from torch_einops_utils.save_load import save_load
|
|
31
31
|
|
|
32
32
|
# constants
|
|
@@ -66,6 +66,47 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
66
66
|
'switch_loss'
|
|
67
67
|
))
|
|
68
68
|
|
|
69
|
+
def z_score(t, eps = 1e-8):
|
|
70
|
+
return (t - t.mean()) / (t.std() + eps)
|
|
71
|
+
|
|
72
|
+
def policy_loss(
|
|
73
|
+
meta_controller,
|
|
74
|
+
state,
|
|
75
|
+
old_log_probs,
|
|
76
|
+
actions,
|
|
77
|
+
advantages,
|
|
78
|
+
mask,
|
|
79
|
+
episode_lens = None,
|
|
80
|
+
eps_clip = 0.2
|
|
81
|
+
):
|
|
82
|
+
# get new log probs
|
|
83
|
+
|
|
84
|
+
action_dist = meta_controller.get_action_dist_for_internal_rl(state)
|
|
85
|
+
new_log_probs = meta_controller.log_prob(action_dist, actions)
|
|
86
|
+
|
|
87
|
+
# calculate ratio
|
|
88
|
+
|
|
89
|
+
ratio = (new_log_probs - old_log_probs).exp()
|
|
90
|
+
|
|
91
|
+
# align ratio and advantages
|
|
92
|
+
|
|
93
|
+
ratio, advantages = align_dims_left((ratio, advantages))
|
|
94
|
+
|
|
95
|
+
# ppo surrogate loss
|
|
96
|
+
|
|
97
|
+
surr1 = ratio * advantages
|
|
98
|
+
surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
|
|
99
|
+
|
|
100
|
+
losses = -torch.min(surr1, surr2)
|
|
101
|
+
|
|
102
|
+
# masking
|
|
103
|
+
|
|
104
|
+
if exists(episode_lens):
|
|
105
|
+
mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
|
|
106
|
+
mask = mask & episode_mask
|
|
107
|
+
|
|
108
|
+
return masked_mean(losses, mask)
|
|
109
|
+
|
|
69
110
|
@save_load()
|
|
70
111
|
class MetaController(Module):
|
|
71
112
|
def __init__(
|
|
@@ -107,7 +148,6 @@ class MetaController(Module):
|
|
|
107
148
|
|
|
108
149
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
109
150
|
|
|
110
|
-
|
|
111
151
|
self.dim_latent = dim_latent
|
|
112
152
|
self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
|
|
113
153
|
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
@@ -147,6 +187,23 @@ class MetaController(Module):
|
|
|
147
187
|
*self.action_proposer_mean_log_var.parameters()
|
|
148
188
|
]
|
|
149
189
|
|
|
190
|
+
def get_action_dist_for_internal_rl(
|
|
191
|
+
self,
|
|
192
|
+
residual_stream
|
|
193
|
+
):
|
|
194
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
195
|
+
|
|
196
|
+
proposed_action_hidden, _ = self.action_proposer(meta_embed)
|
|
197
|
+
|
|
198
|
+
return self.action_proposer_mean_log_var(proposed_action_hidden)
|
|
199
|
+
|
|
200
|
+
def log_prob(
|
|
201
|
+
self,
|
|
202
|
+
action_dist,
|
|
203
|
+
sampled_latent_action
|
|
204
|
+
):
|
|
205
|
+
return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
|
|
206
|
+
|
|
150
207
|
def forward(
|
|
151
208
|
self,
|
|
152
209
|
residual_stream,
|
|
@@ -276,6 +333,12 @@ class MetaController(Module):
|
|
|
276
333
|
|
|
277
334
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
278
335
|
|
|
336
|
+
Hiddens = namedtuple('Hiddens', (
|
|
337
|
+
'lower_body',
|
|
338
|
+
'meta_controller',
|
|
339
|
+
'upper_body'
|
|
340
|
+
))
|
|
341
|
+
|
|
279
342
|
TransformerOutput = namedtuple('TransformerOutput', (
|
|
280
343
|
'residual_stream_latent',
|
|
281
344
|
'prev_hiddens'
|
|
@@ -417,9 +480,8 @@ class Transformer(Module):
|
|
|
417
480
|
# maybe return behavior cloning loss
|
|
418
481
|
|
|
419
482
|
if behavioral_cloning:
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
loss_mask = lens_to_mask(episode_lens, state.shape[1])
|
|
483
|
+
|
|
484
|
+
loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
|
|
423
485
|
|
|
424
486
|
state_dist_params = self.state_readout(attended)
|
|
425
487
|
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
|
|
@@ -441,4 +503,4 @@ class Transformer(Module):
|
|
|
441
503
|
if return_one:
|
|
442
504
|
return dist_params
|
|
443
505
|
|
|
444
|
-
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
506
|
+
return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
|
|
|
23
23
|
|
|
24
24
|
from assoc_scan import AssocScan
|
|
25
25
|
|
|
26
|
-
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
|
|
26
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
|
|
27
27
|
from torch_einops_utils.save_load import save_load
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
@@ -50,6 +50,9 @@ def default(*args):
|
|
|
50
50
|
def straight_through(src, tgt):
|
|
51
51
|
return tgt + src - src.detach()
|
|
52
52
|
|
|
53
|
+
def log(t, eps = 1e-20):
|
|
54
|
+
return t.clamp_min(eps).log()
|
|
55
|
+
|
|
53
56
|
# meta controller
|
|
54
57
|
|
|
55
58
|
@save_load()
|
|
@@ -71,6 +74,9 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
71
74
|
kl_loss_threshold = 0.
|
|
72
75
|
):
|
|
73
76
|
super().__init__()
|
|
77
|
+
|
|
78
|
+
assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
|
|
79
|
+
|
|
74
80
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
75
81
|
|
|
76
82
|
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
@@ -137,6 +143,23 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
137
143
|
*self.proposer_to_binary_logits.parameters()
|
|
138
144
|
]
|
|
139
145
|
|
|
146
|
+
def log_prob(
|
|
147
|
+
self,
|
|
148
|
+
action_dist,
|
|
149
|
+
sampled_latent_action
|
|
150
|
+
):
|
|
151
|
+
action_prob = action_dist.sigmoid()
|
|
152
|
+
probs = stack((action_prob, 1. - action_prob), dim = -1)
|
|
153
|
+
log_probs = log(probs)
|
|
154
|
+
|
|
155
|
+
indices = sampled_latent_action.argmax(dim = -1)
|
|
156
|
+
codes = self.binary_mapper.codes[indices].long()
|
|
157
|
+
|
|
158
|
+
codes = rearrange(codes, '... -> ... 1')
|
|
159
|
+
action_log_probs = log_probs.gather(-1, codes)
|
|
160
|
+
|
|
161
|
+
return rearrange(action_log_probs, '... 1 -> ...')
|
|
162
|
+
|
|
140
163
|
def forward(
|
|
141
164
|
self,
|
|
142
165
|
residual_stream,
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
from torch.nn import Module, ModuleList
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
|
|
9
|
+
from metacontroller.metacontroller import Transformer
|
|
10
|
+
|
|
11
|
+
from torch_einops_utils import pack_with_inverse
|
|
12
|
+
|
|
13
|
+
# resnet components
|
|
14
|
+
|
|
15
|
+
def exists(v):
|
|
16
|
+
return v is not None
|
|
17
|
+
|
|
18
|
+
class BasicBlock(Module):
|
|
19
|
+
expansion = 1
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dim,
|
|
24
|
+
dim_out,
|
|
25
|
+
stride = 1,
|
|
26
|
+
downsample: Module | None = None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
|
|
30
|
+
self.bn1 = nn.BatchNorm2d(dim_out)
|
|
31
|
+
self.relu = nn.ReLU(inplace = True)
|
|
32
|
+
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
|
|
33
|
+
self.bn2 = nn.BatchNorm2d(dim_out)
|
|
34
|
+
self.downsample = downsample
|
|
35
|
+
|
|
36
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
37
|
+
identity = x
|
|
38
|
+
|
|
39
|
+
out = self.conv1(x)
|
|
40
|
+
out = self.bn1(out)
|
|
41
|
+
out = self.relu(out)
|
|
42
|
+
|
|
43
|
+
out = self.conv2(out)
|
|
44
|
+
out = self.bn2(out)
|
|
45
|
+
|
|
46
|
+
if exists(self.downsample):
|
|
47
|
+
identity = self.downsample(x)
|
|
48
|
+
|
|
49
|
+
out += identity
|
|
50
|
+
return self.relu(out)
|
|
51
|
+
|
|
52
|
+
class Bottleneck(Module):
|
|
53
|
+
expansion = 4
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
dim,
|
|
58
|
+
dim_out,
|
|
59
|
+
stride = 1,
|
|
60
|
+
downsample: Module | None = None
|
|
61
|
+
):
|
|
62
|
+
super().__init__()
|
|
63
|
+
width = dim_out # simple resnet shortcut
|
|
64
|
+
self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
|
|
65
|
+
self.bn1 = nn.BatchNorm2d(width)
|
|
66
|
+
self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
|
|
67
|
+
self.bn2 = nn.BatchNorm2d(width)
|
|
68
|
+
self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
|
|
69
|
+
self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
|
|
70
|
+
self.relu = nn.ReLU(inplace = True)
|
|
71
|
+
self.downsample = downsample
|
|
72
|
+
|
|
73
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
74
|
+
identity = x
|
|
75
|
+
|
|
76
|
+
out = self.conv1(x)
|
|
77
|
+
out = self.bn1(out)
|
|
78
|
+
out = self.relu(out)
|
|
79
|
+
|
|
80
|
+
out = self.conv2(out)
|
|
81
|
+
out = self.bn2(out)
|
|
82
|
+
out = self.relu(out)
|
|
83
|
+
|
|
84
|
+
out = self.conv3(out)
|
|
85
|
+
out = self.bn3(out)
|
|
86
|
+
|
|
87
|
+
if exists(self.downsample):
|
|
88
|
+
identity = self.downsample(x)
|
|
89
|
+
|
|
90
|
+
out += identity
|
|
91
|
+
return self.relu(out)
|
|
92
|
+
|
|
93
|
+
class ResNet(Module):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
block: type[BasicBlock | Bottleneck],
|
|
97
|
+
layers: list[int],
|
|
98
|
+
num_classes = 1000,
|
|
99
|
+
channels = 3
|
|
100
|
+
):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.inplanes = 64
|
|
103
|
+
|
|
104
|
+
self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
|
|
105
|
+
self.bn1 = nn.BatchNorm2d(64)
|
|
106
|
+
self.relu = nn.ReLU(inplace = True)
|
|
107
|
+
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
|
|
108
|
+
|
|
109
|
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
110
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
|
|
111
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
|
|
112
|
+
self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
|
|
113
|
+
|
|
114
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
115
|
+
self.flatten = Rearrange('b c 1 1 -> b c')
|
|
116
|
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
117
|
+
|
|
118
|
+
def _make_layer(
|
|
119
|
+
self,
|
|
120
|
+
block: type[BasicBlock | Bottleneck],
|
|
121
|
+
planes: int,
|
|
122
|
+
blocks: int,
|
|
123
|
+
stride: int = 1
|
|
124
|
+
) -> nn.Sequential:
|
|
125
|
+
downsample = None
|
|
126
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
127
|
+
downsample = nn.Sequential(
|
|
128
|
+
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
|
|
129
|
+
nn.BatchNorm2d(planes * block.expansion),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
layers = []
|
|
133
|
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
134
|
+
self.inplanes = planes * block.expansion
|
|
135
|
+
for _ in range(1, blocks):
|
|
136
|
+
layers.append(block(self.inplanes, planes))
|
|
137
|
+
|
|
138
|
+
return nn.Sequential(*layers)
|
|
139
|
+
|
|
140
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
141
|
+
x = self.conv1(x)
|
|
142
|
+
x = self.bn1(x)
|
|
143
|
+
x = self.relu(x)
|
|
144
|
+
x = self.maxpool(x)
|
|
145
|
+
|
|
146
|
+
x = self.layer1(x)
|
|
147
|
+
x = self.layer2(x)
|
|
148
|
+
x = self.layer3(x)
|
|
149
|
+
x = self.layer4(x)
|
|
150
|
+
|
|
151
|
+
x = self.avgpool(x)
|
|
152
|
+
x = self.flatten(x)
|
|
153
|
+
x = self.fc(x)
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
# resnet factory
|
|
157
|
+
|
|
158
|
+
def resnet18(num_classes: any = 1000):
|
|
159
|
+
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
|
|
160
|
+
|
|
161
|
+
def resnet34(num_classes: any = 1000):
|
|
162
|
+
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
|
|
163
|
+
|
|
164
|
+
def resnet50(num_classes: any = 1000):
|
|
165
|
+
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
|
|
166
|
+
|
|
167
|
+
# transformer with resnet
|
|
168
|
+
|
|
169
|
+
class TransformerWithResnet(Transformer):
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
*,
|
|
173
|
+
resnet_type = 'resnet18',
|
|
174
|
+
**kwargs
|
|
175
|
+
):
|
|
176
|
+
super().__init__(**kwargs)
|
|
177
|
+
resnet_klass = resnet18
|
|
178
|
+
if resnet_type == 'resnet34':
|
|
179
|
+
resnet_klass = resnet34
|
|
180
|
+
elif resnet_type == 'resnet50':
|
|
181
|
+
resnet_klass = resnet50
|
|
182
|
+
|
|
183
|
+
self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
|
|
184
|
+
self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
|
|
185
|
+
|
|
186
|
+
def visual_encode(self, x: Tensor) -> Tensor:
|
|
187
|
+
if x.shape[-1] == 3:
|
|
188
|
+
x = rearrange(x, '... h w c -> ... c h w')
|
|
189
|
+
|
|
190
|
+
x, inverse = pack_with_inverse(x, '* c h w')
|
|
191
|
+
|
|
192
|
+
h = self.visual_encoder(x)
|
|
193
|
+
|
|
194
|
+
return inverse(h, '* d')
|
{metacontroller_pytorch-0.0.31.dist-info → metacontroller_pytorch-0.0.33.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -40,7 +40,7 @@ Requires-Dist: einops>=0.8.1
|
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
42
|
Requires-Dist: memmap-replay-buffer>=0.0.23
|
|
43
|
-
Requires-Dist: torch-einops-utils>=0.0.
|
|
43
|
+
Requires-Dist: torch-einops-utils>=0.0.19
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
45
|
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
|
46
46
|
Requires-Dist: x-evolution>=0.1.23
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=B9XHYgVBrcJkEWhUORz--D5AHjcLnvLRVY9SRqVbhdw,16222
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=7vGtenScxvDQhkeYUmNnTTbVTJAtIFUVqEoWVGZP2Is,8936
|
|
4
|
+
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
+
metacontroller_pytorch-0.0.33.dist-info/METADATA,sha256=kWHDAnQeEueYWRzqPi5ouXjtSDIG2bnDqOfhznSUOoM,4747
|
|
6
|
+
metacontroller_pytorch-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.33.dist-info/RECORD,,
|
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
from typing import Any, List, Type, Union, Optional
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
from torch import nn
|
|
6
|
-
from einops import rearrange
|
|
7
|
-
from metacontroller.metacontroller import Transformer
|
|
8
|
-
|
|
9
|
-
class TransformerWithResnetEncoder(Transformer):
|
|
10
|
-
def __init__(self, **kwargs):
|
|
11
|
-
super().__init__(**kwargs)
|
|
12
|
-
self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
|
|
13
|
-
self.visual_encoder = resnet18(out_dim=self.resnet_dim)
|
|
14
|
-
|
|
15
|
-
def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
16
|
-
b, t = x.shape[:2]
|
|
17
|
-
x = rearrange(x, 'b t h w c -> (b t) c h w')
|
|
18
|
-
h = self.visual_encoder(x)
|
|
19
|
-
h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
|
|
20
|
-
return h
|
|
21
|
-
|
|
22
|
-
# resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
|
|
23
|
-
|
|
24
|
-
class _BasicBlock(nn.Module):
|
|
25
|
-
expansion: int = 1
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
in_channels: int,
|
|
30
|
-
out_channels: int,
|
|
31
|
-
stride: int,
|
|
32
|
-
downsample: Optional[nn.Module] = None,
|
|
33
|
-
groups: int = 1,
|
|
34
|
-
base_channels: int = 64,
|
|
35
|
-
) -> None:
|
|
36
|
-
super(_BasicBlock, self).__init__()
|
|
37
|
-
self.stride = stride
|
|
38
|
-
self.downsample = downsample
|
|
39
|
-
self.groups = groups
|
|
40
|
-
self.base_channels = base_channels
|
|
41
|
-
|
|
42
|
-
self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
|
|
43
|
-
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
44
|
-
self.relu = nn.ReLU(True)
|
|
45
|
-
self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
|
|
46
|
-
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
47
|
-
|
|
48
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
49
|
-
identity = x
|
|
50
|
-
|
|
51
|
-
out = self.conv1(x)
|
|
52
|
-
out = self.bn1(out)
|
|
53
|
-
out = self.relu(out)
|
|
54
|
-
|
|
55
|
-
out = self.conv2(out)
|
|
56
|
-
out = self.bn2(out)
|
|
57
|
-
|
|
58
|
-
if self.downsample is not None:
|
|
59
|
-
identity = self.downsample(x)
|
|
60
|
-
|
|
61
|
-
out = torch.add(out, identity)
|
|
62
|
-
out = self.relu(out)
|
|
63
|
-
|
|
64
|
-
return out
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class _Bottleneck(nn.Module):
|
|
68
|
-
expansion: int = 4
|
|
69
|
-
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
in_channels: int,
|
|
73
|
-
out_channels: int,
|
|
74
|
-
stride: int,
|
|
75
|
-
downsample: Optional[nn.Module] = None,
|
|
76
|
-
groups: int = 1,
|
|
77
|
-
base_channels: int = 64,
|
|
78
|
-
) -> None:
|
|
79
|
-
super(_Bottleneck, self).__init__()
|
|
80
|
-
self.stride = stride
|
|
81
|
-
self.downsample = downsample
|
|
82
|
-
self.groups = groups
|
|
83
|
-
self.base_channels = base_channels
|
|
84
|
-
|
|
85
|
-
channels = int(out_channels * (base_channels / 64.0)) * groups
|
|
86
|
-
|
|
87
|
-
self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
|
|
88
|
-
self.bn1 = nn.BatchNorm2d(channels)
|
|
89
|
-
self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
|
|
90
|
-
self.bn2 = nn.BatchNorm2d(channels)
|
|
91
|
-
self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
|
|
92
|
-
self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
|
|
93
|
-
self.relu = nn.ReLU(True)
|
|
94
|
-
|
|
95
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
96
|
-
identity = x
|
|
97
|
-
|
|
98
|
-
out = self.conv1(x)
|
|
99
|
-
out = self.bn1(out)
|
|
100
|
-
out = self.relu(out)
|
|
101
|
-
|
|
102
|
-
out = self.conv2(out)
|
|
103
|
-
out = self.bn2(out)
|
|
104
|
-
out = self.relu(out)
|
|
105
|
-
|
|
106
|
-
out = self.conv3(out)
|
|
107
|
-
out = self.bn3(out)
|
|
108
|
-
|
|
109
|
-
if self.downsample is not None:
|
|
110
|
-
identity = self.downsample(x)
|
|
111
|
-
|
|
112
|
-
out = torch.add(out, identity)
|
|
113
|
-
out = self.relu(out)
|
|
114
|
-
|
|
115
|
-
return out
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class ResNet(nn.Module):
|
|
119
|
-
|
|
120
|
-
def __init__(
|
|
121
|
-
self,
|
|
122
|
-
arch_cfg: List[int],
|
|
123
|
-
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
124
|
-
groups: int = 1,
|
|
125
|
-
channels_per_group: int = 64,
|
|
126
|
-
out_dim: int = 1000,
|
|
127
|
-
) -> None:
|
|
128
|
-
super(ResNet, self).__init__()
|
|
129
|
-
self.in_channels = 64
|
|
130
|
-
self.dilation = 1
|
|
131
|
-
self.groups = groups
|
|
132
|
-
self.base_channels = channels_per_group
|
|
133
|
-
|
|
134
|
-
self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
|
|
135
|
-
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
|
136
|
-
self.relu = nn.ReLU(True)
|
|
137
|
-
self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
|
|
138
|
-
|
|
139
|
-
self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
|
|
140
|
-
self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
|
|
141
|
-
self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
|
|
142
|
-
self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
|
|
143
|
-
|
|
144
|
-
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
145
|
-
|
|
146
|
-
self.fc = nn.Linear(512 * block.expansion, out_dim)
|
|
147
|
-
|
|
148
|
-
# Initialize neural network weights
|
|
149
|
-
self._initialize_weights()
|
|
150
|
-
|
|
151
|
-
def _make_layer(
|
|
152
|
-
self,
|
|
153
|
-
repeat_times: int,
|
|
154
|
-
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
155
|
-
channels: int,
|
|
156
|
-
stride: int = 1,
|
|
157
|
-
) -> nn.Sequential:
|
|
158
|
-
downsample = None
|
|
159
|
-
|
|
160
|
-
if stride != 1 or self.in_channels != channels * block.expansion:
|
|
161
|
-
downsample = nn.Sequential(
|
|
162
|
-
nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
|
|
163
|
-
nn.BatchNorm2d(channels * block.expansion),
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
layers = [
|
|
167
|
-
block(
|
|
168
|
-
self.in_channels,
|
|
169
|
-
channels,
|
|
170
|
-
stride,
|
|
171
|
-
downsample,
|
|
172
|
-
self.groups,
|
|
173
|
-
self.base_channels
|
|
174
|
-
)
|
|
175
|
-
]
|
|
176
|
-
self.in_channels = channels * block.expansion
|
|
177
|
-
for _ in range(1, repeat_times):
|
|
178
|
-
layers.append(
|
|
179
|
-
block(
|
|
180
|
-
self.in_channels,
|
|
181
|
-
channels,
|
|
182
|
-
1,
|
|
183
|
-
None,
|
|
184
|
-
self.groups,
|
|
185
|
-
self.base_channels,
|
|
186
|
-
)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
return nn.Sequential(*layers)
|
|
190
|
-
|
|
191
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
192
|
-
out = self._forward_impl(x)
|
|
193
|
-
|
|
194
|
-
return out
|
|
195
|
-
|
|
196
|
-
# Support torch.script function
|
|
197
|
-
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
198
|
-
out = self.conv1(x)
|
|
199
|
-
out = self.bn1(out)
|
|
200
|
-
out = self.relu(out)
|
|
201
|
-
out = self.maxpool(out)
|
|
202
|
-
|
|
203
|
-
out = self.layer1(out)
|
|
204
|
-
out = self.layer2(out)
|
|
205
|
-
out = self.layer3(out)
|
|
206
|
-
out = self.layer4(out)
|
|
207
|
-
|
|
208
|
-
out = self.avgpool(out)
|
|
209
|
-
out = torch.flatten(out, 1)
|
|
210
|
-
out = self.fc(out)
|
|
211
|
-
|
|
212
|
-
return out
|
|
213
|
-
|
|
214
|
-
def _initialize_weights(self) -> None:
|
|
215
|
-
for module in self.modules():
|
|
216
|
-
if isinstance(module, nn.Conv2d):
|
|
217
|
-
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
218
|
-
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
219
|
-
nn.init.constant_(module.weight, 1)
|
|
220
|
-
nn.init.constant_(module.bias, 0)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def resnet18(**kwargs: Any) -> ResNet:
|
|
224
|
-
model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
|
|
225
|
-
|
|
226
|
-
return model
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def resnet34(**kwargs: Any) -> ResNet:
|
|
230
|
-
model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
|
|
231
|
-
|
|
232
|
-
return model
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def resnet50(**kwargs: Any) -> ResNet:
|
|
236
|
-
model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
|
|
237
|
-
|
|
238
|
-
return model
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def resnet101(**kwargs: Any) -> ResNet:
|
|
242
|
-
model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
|
|
243
|
-
|
|
244
|
-
return model
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def resnet152(**kwargs: Any) -> ResNet:
|
|
248
|
-
model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
|
|
249
|
-
|
|
250
|
-
return model
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=lxWgeWFcXxSDm-ygd14DjyEOYIJIALcuLkoRAfEzNtc,14719
|
|
3
|
-
metacontroller/metacontroller_with_binary_mapper.py,sha256=BrsQdkhlOyR2O5xAXTLC4p-uKOAbW7wET-lVU0qktws,8242
|
|
4
|
-
metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
|
|
5
|
-
metacontroller_pytorch-0.0.31.dist-info/METADATA,sha256=mtOtYymI01jBMO7pyaAIJ166B5Mk3khH8CUwUMNLTKw,4747
|
|
6
|
-
metacontroller_pytorch-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
metacontroller_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
metacontroller_pytorch-0.0.31.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.31.dist-info → metacontroller_pytorch-0.0.33.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|