metacontroller-pytorch 0.0.32__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +54 -4
- metacontroller/metacontroller_with_binary_mapper.py +4 -1
- metacontroller/transformer_with_resnet.py +194 -0
- {metacontroller_pytorch-0.0.32.dist-info → metacontroller_pytorch-0.0.33.dist-info}/METADATA +2 -2
- metacontroller_pytorch-0.0.33.dist-info/RECORD +8 -0
- metacontroller/metacontroller_with_resnet.py +0 -250
- metacontroller_pytorch-0.0.32.dist-info/RECORD +0 -8
- {metacontroller_pytorch-0.0.32.dist-info → metacontroller_pytorch-0.0.33.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.32.dist-info → metacontroller_pytorch-0.0.33.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
|
|
|
26
26
|
|
|
27
27
|
from assoc_scan import AssocScan
|
|
28
28
|
|
|
29
|
-
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
29
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
|
|
30
30
|
from torch_einops_utils.save_load import save_load
|
|
31
31
|
|
|
32
32
|
# constants
|
|
@@ -66,6 +66,47 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
66
66
|
'switch_loss'
|
|
67
67
|
))
|
|
68
68
|
|
|
69
|
+
def z_score(t, eps = 1e-8):
|
|
70
|
+
return (t - t.mean()) / (t.std() + eps)
|
|
71
|
+
|
|
72
|
+
def policy_loss(
|
|
73
|
+
meta_controller,
|
|
74
|
+
state,
|
|
75
|
+
old_log_probs,
|
|
76
|
+
actions,
|
|
77
|
+
advantages,
|
|
78
|
+
mask,
|
|
79
|
+
episode_lens = None,
|
|
80
|
+
eps_clip = 0.2
|
|
81
|
+
):
|
|
82
|
+
# get new log probs
|
|
83
|
+
|
|
84
|
+
action_dist = meta_controller.get_action_dist_for_internal_rl(state)
|
|
85
|
+
new_log_probs = meta_controller.log_prob(action_dist, actions)
|
|
86
|
+
|
|
87
|
+
# calculate ratio
|
|
88
|
+
|
|
89
|
+
ratio = (new_log_probs - old_log_probs).exp()
|
|
90
|
+
|
|
91
|
+
# align ratio and advantages
|
|
92
|
+
|
|
93
|
+
ratio, advantages = align_dims_left((ratio, advantages))
|
|
94
|
+
|
|
95
|
+
# ppo surrogate loss
|
|
96
|
+
|
|
97
|
+
surr1 = ratio * advantages
|
|
98
|
+
surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
|
|
99
|
+
|
|
100
|
+
losses = -torch.min(surr1, surr2)
|
|
101
|
+
|
|
102
|
+
# masking
|
|
103
|
+
|
|
104
|
+
if exists(episode_lens):
|
|
105
|
+
mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
|
|
106
|
+
mask = mask & episode_mask
|
|
107
|
+
|
|
108
|
+
return masked_mean(losses, mask)
|
|
109
|
+
|
|
69
110
|
@save_load()
|
|
70
111
|
class MetaController(Module):
|
|
71
112
|
def __init__(
|
|
@@ -146,6 +187,16 @@ class MetaController(Module):
|
|
|
146
187
|
*self.action_proposer_mean_log_var.parameters()
|
|
147
188
|
]
|
|
148
189
|
|
|
190
|
+
def get_action_dist_for_internal_rl(
|
|
191
|
+
self,
|
|
192
|
+
residual_stream
|
|
193
|
+
):
|
|
194
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
195
|
+
|
|
196
|
+
proposed_action_hidden, _ = self.action_proposer(meta_embed)
|
|
197
|
+
|
|
198
|
+
return self.action_proposer_mean_log_var(proposed_action_hidden)
|
|
199
|
+
|
|
149
200
|
def log_prob(
|
|
150
201
|
self,
|
|
151
202
|
action_dist,
|
|
@@ -429,9 +480,8 @@ class Transformer(Module):
|
|
|
429
480
|
# maybe return behavior cloning loss
|
|
430
481
|
|
|
431
482
|
if behavioral_cloning:
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
loss_mask = lens_to_mask(episode_lens, state.shape[1])
|
|
483
|
+
|
|
484
|
+
loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
|
|
435
485
|
|
|
436
486
|
state_dist_params = self.state_readout(attended)
|
|
437
487
|
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
|
|
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
|
|
|
23
23
|
|
|
24
24
|
from assoc_scan import AssocScan
|
|
25
25
|
|
|
26
|
-
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
|
|
26
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
|
|
27
27
|
from torch_einops_utils.save_load import save_load
|
|
28
28
|
|
|
29
29
|
from vector_quantize_pytorch import BinaryMapper
|
|
@@ -74,6 +74,9 @@ class MetaControllerWithBinaryMapper(Module):
|
|
|
74
74
|
kl_loss_threshold = 0.
|
|
75
75
|
):
|
|
76
76
|
super().__init__()
|
|
77
|
+
|
|
78
|
+
assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
|
|
79
|
+
|
|
77
80
|
dim_meta = default(dim_meta_controller, dim_model)
|
|
78
81
|
|
|
79
82
|
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn, Tensor
|
|
5
|
+
from torch.nn import Module, ModuleList
|
|
6
|
+
from einops import rearrange
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
|
|
9
|
+
from metacontroller.metacontroller import Transformer
|
|
10
|
+
|
|
11
|
+
from torch_einops_utils import pack_with_inverse
|
|
12
|
+
|
|
13
|
+
# resnet components
|
|
14
|
+
|
|
15
|
+
def exists(v):
|
|
16
|
+
return v is not None
|
|
17
|
+
|
|
18
|
+
class BasicBlock(Module):
|
|
19
|
+
expansion = 1
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
dim,
|
|
24
|
+
dim_out,
|
|
25
|
+
stride = 1,
|
|
26
|
+
downsample: Module | None = None
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
|
|
30
|
+
self.bn1 = nn.BatchNorm2d(dim_out)
|
|
31
|
+
self.relu = nn.ReLU(inplace = True)
|
|
32
|
+
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
|
|
33
|
+
self.bn2 = nn.BatchNorm2d(dim_out)
|
|
34
|
+
self.downsample = downsample
|
|
35
|
+
|
|
36
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
37
|
+
identity = x
|
|
38
|
+
|
|
39
|
+
out = self.conv1(x)
|
|
40
|
+
out = self.bn1(out)
|
|
41
|
+
out = self.relu(out)
|
|
42
|
+
|
|
43
|
+
out = self.conv2(out)
|
|
44
|
+
out = self.bn2(out)
|
|
45
|
+
|
|
46
|
+
if exists(self.downsample):
|
|
47
|
+
identity = self.downsample(x)
|
|
48
|
+
|
|
49
|
+
out += identity
|
|
50
|
+
return self.relu(out)
|
|
51
|
+
|
|
52
|
+
class Bottleneck(Module):
|
|
53
|
+
expansion = 4
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
dim,
|
|
58
|
+
dim_out,
|
|
59
|
+
stride = 1,
|
|
60
|
+
downsample: Module | None = None
|
|
61
|
+
):
|
|
62
|
+
super().__init__()
|
|
63
|
+
width = dim_out # simple resnet shortcut
|
|
64
|
+
self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
|
|
65
|
+
self.bn1 = nn.BatchNorm2d(width)
|
|
66
|
+
self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
|
|
67
|
+
self.bn2 = nn.BatchNorm2d(width)
|
|
68
|
+
self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
|
|
69
|
+
self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
|
|
70
|
+
self.relu = nn.ReLU(inplace = True)
|
|
71
|
+
self.downsample = downsample
|
|
72
|
+
|
|
73
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
74
|
+
identity = x
|
|
75
|
+
|
|
76
|
+
out = self.conv1(x)
|
|
77
|
+
out = self.bn1(out)
|
|
78
|
+
out = self.relu(out)
|
|
79
|
+
|
|
80
|
+
out = self.conv2(out)
|
|
81
|
+
out = self.bn2(out)
|
|
82
|
+
out = self.relu(out)
|
|
83
|
+
|
|
84
|
+
out = self.conv3(out)
|
|
85
|
+
out = self.bn3(out)
|
|
86
|
+
|
|
87
|
+
if exists(self.downsample):
|
|
88
|
+
identity = self.downsample(x)
|
|
89
|
+
|
|
90
|
+
out += identity
|
|
91
|
+
return self.relu(out)
|
|
92
|
+
|
|
93
|
+
class ResNet(Module):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
block: type[BasicBlock | Bottleneck],
|
|
97
|
+
layers: list[int],
|
|
98
|
+
num_classes = 1000,
|
|
99
|
+
channels = 3
|
|
100
|
+
):
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.inplanes = 64
|
|
103
|
+
|
|
104
|
+
self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
|
|
105
|
+
self.bn1 = nn.BatchNorm2d(64)
|
|
106
|
+
self.relu = nn.ReLU(inplace = True)
|
|
107
|
+
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
|
|
108
|
+
|
|
109
|
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
110
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
|
|
111
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
|
|
112
|
+
self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
|
|
113
|
+
|
|
114
|
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
115
|
+
self.flatten = Rearrange('b c 1 1 -> b c')
|
|
116
|
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
117
|
+
|
|
118
|
+
def _make_layer(
|
|
119
|
+
self,
|
|
120
|
+
block: type[BasicBlock | Bottleneck],
|
|
121
|
+
planes: int,
|
|
122
|
+
blocks: int,
|
|
123
|
+
stride: int = 1
|
|
124
|
+
) -> nn.Sequential:
|
|
125
|
+
downsample = None
|
|
126
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
127
|
+
downsample = nn.Sequential(
|
|
128
|
+
nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
|
|
129
|
+
nn.BatchNorm2d(planes * block.expansion),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
layers = []
|
|
133
|
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
134
|
+
self.inplanes = planes * block.expansion
|
|
135
|
+
for _ in range(1, blocks):
|
|
136
|
+
layers.append(block(self.inplanes, planes))
|
|
137
|
+
|
|
138
|
+
return nn.Sequential(*layers)
|
|
139
|
+
|
|
140
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
141
|
+
x = self.conv1(x)
|
|
142
|
+
x = self.bn1(x)
|
|
143
|
+
x = self.relu(x)
|
|
144
|
+
x = self.maxpool(x)
|
|
145
|
+
|
|
146
|
+
x = self.layer1(x)
|
|
147
|
+
x = self.layer2(x)
|
|
148
|
+
x = self.layer3(x)
|
|
149
|
+
x = self.layer4(x)
|
|
150
|
+
|
|
151
|
+
x = self.avgpool(x)
|
|
152
|
+
x = self.flatten(x)
|
|
153
|
+
x = self.fc(x)
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
# resnet factory
|
|
157
|
+
|
|
158
|
+
def resnet18(num_classes: any = 1000):
|
|
159
|
+
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
|
|
160
|
+
|
|
161
|
+
def resnet34(num_classes: any = 1000):
|
|
162
|
+
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
|
|
163
|
+
|
|
164
|
+
def resnet50(num_classes: any = 1000):
|
|
165
|
+
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
|
|
166
|
+
|
|
167
|
+
# transformer with resnet
|
|
168
|
+
|
|
169
|
+
class TransformerWithResnet(Transformer):
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
*,
|
|
173
|
+
resnet_type = 'resnet18',
|
|
174
|
+
**kwargs
|
|
175
|
+
):
|
|
176
|
+
super().__init__(**kwargs)
|
|
177
|
+
resnet_klass = resnet18
|
|
178
|
+
if resnet_type == 'resnet34':
|
|
179
|
+
resnet_klass = resnet34
|
|
180
|
+
elif resnet_type == 'resnet50':
|
|
181
|
+
resnet_klass = resnet50
|
|
182
|
+
|
|
183
|
+
self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
|
|
184
|
+
self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
|
|
185
|
+
|
|
186
|
+
def visual_encode(self, x: Tensor) -> Tensor:
|
|
187
|
+
if x.shape[-1] == 3:
|
|
188
|
+
x = rearrange(x, '... h w c -> ... c h w')
|
|
189
|
+
|
|
190
|
+
x, inverse = pack_with_inverse(x, '* c h w')
|
|
191
|
+
|
|
192
|
+
h = self.visual_encoder(x)
|
|
193
|
+
|
|
194
|
+
return inverse(h, '* d')
|
{metacontroller_pytorch-0.0.32.dist-info → metacontroller_pytorch-0.0.33.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -40,7 +40,7 @@ Requires-Dist: einops>=0.8.1
|
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
41
|
Requires-Dist: loguru
|
|
42
42
|
Requires-Dist: memmap-replay-buffer>=0.0.23
|
|
43
|
-
Requires-Dist: torch-einops-utils>=0.0.
|
|
43
|
+
Requires-Dist: torch-einops-utils>=0.0.19
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
45
|
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
|
46
46
|
Requires-Dist: x-evolution>=0.1.23
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=B9XHYgVBrcJkEWhUORz--D5AHjcLnvLRVY9SRqVbhdw,16222
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=7vGtenScxvDQhkeYUmNnTTbVTJAtIFUVqEoWVGZP2Is,8936
|
|
4
|
+
metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
|
|
5
|
+
metacontroller_pytorch-0.0.33.dist-info/METADATA,sha256=kWHDAnQeEueYWRzqPi5ouXjtSDIG2bnDqOfhznSUOoM,4747
|
|
6
|
+
metacontroller_pytorch-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
metacontroller_pytorch-0.0.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
metacontroller_pytorch-0.0.33.dist-info/RECORD,,
|
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
from typing import Any, List, Type, Union, Optional
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
from torch import nn
|
|
6
|
-
from einops import rearrange
|
|
7
|
-
from metacontroller.metacontroller import Transformer
|
|
8
|
-
|
|
9
|
-
class TransformerWithResnetEncoder(Transformer):
|
|
10
|
-
def __init__(self, **kwargs):
|
|
11
|
-
super().__init__(**kwargs)
|
|
12
|
-
self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
|
|
13
|
-
self.visual_encoder = resnet18(out_dim=self.resnet_dim)
|
|
14
|
-
|
|
15
|
-
def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
16
|
-
b, t = x.shape[:2]
|
|
17
|
-
x = rearrange(x, 'b t h w c -> (b t) c h w')
|
|
18
|
-
h = self.visual_encoder(x)
|
|
19
|
-
h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
|
|
20
|
-
return h
|
|
21
|
-
|
|
22
|
-
# resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
|
|
23
|
-
|
|
24
|
-
class _BasicBlock(nn.Module):
|
|
25
|
-
expansion: int = 1
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
in_channels: int,
|
|
30
|
-
out_channels: int,
|
|
31
|
-
stride: int,
|
|
32
|
-
downsample: Optional[nn.Module] = None,
|
|
33
|
-
groups: int = 1,
|
|
34
|
-
base_channels: int = 64,
|
|
35
|
-
) -> None:
|
|
36
|
-
super(_BasicBlock, self).__init__()
|
|
37
|
-
self.stride = stride
|
|
38
|
-
self.downsample = downsample
|
|
39
|
-
self.groups = groups
|
|
40
|
-
self.base_channels = base_channels
|
|
41
|
-
|
|
42
|
-
self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
|
|
43
|
-
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
44
|
-
self.relu = nn.ReLU(True)
|
|
45
|
-
self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
|
|
46
|
-
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
47
|
-
|
|
48
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
49
|
-
identity = x
|
|
50
|
-
|
|
51
|
-
out = self.conv1(x)
|
|
52
|
-
out = self.bn1(out)
|
|
53
|
-
out = self.relu(out)
|
|
54
|
-
|
|
55
|
-
out = self.conv2(out)
|
|
56
|
-
out = self.bn2(out)
|
|
57
|
-
|
|
58
|
-
if self.downsample is not None:
|
|
59
|
-
identity = self.downsample(x)
|
|
60
|
-
|
|
61
|
-
out = torch.add(out, identity)
|
|
62
|
-
out = self.relu(out)
|
|
63
|
-
|
|
64
|
-
return out
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class _Bottleneck(nn.Module):
|
|
68
|
-
expansion: int = 4
|
|
69
|
-
|
|
70
|
-
def __init__(
|
|
71
|
-
self,
|
|
72
|
-
in_channels: int,
|
|
73
|
-
out_channels: int,
|
|
74
|
-
stride: int,
|
|
75
|
-
downsample: Optional[nn.Module] = None,
|
|
76
|
-
groups: int = 1,
|
|
77
|
-
base_channels: int = 64,
|
|
78
|
-
) -> None:
|
|
79
|
-
super(_Bottleneck, self).__init__()
|
|
80
|
-
self.stride = stride
|
|
81
|
-
self.downsample = downsample
|
|
82
|
-
self.groups = groups
|
|
83
|
-
self.base_channels = base_channels
|
|
84
|
-
|
|
85
|
-
channels = int(out_channels * (base_channels / 64.0)) * groups
|
|
86
|
-
|
|
87
|
-
self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
|
|
88
|
-
self.bn1 = nn.BatchNorm2d(channels)
|
|
89
|
-
self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
|
|
90
|
-
self.bn2 = nn.BatchNorm2d(channels)
|
|
91
|
-
self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
|
|
92
|
-
self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
|
|
93
|
-
self.relu = nn.ReLU(True)
|
|
94
|
-
|
|
95
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
96
|
-
identity = x
|
|
97
|
-
|
|
98
|
-
out = self.conv1(x)
|
|
99
|
-
out = self.bn1(out)
|
|
100
|
-
out = self.relu(out)
|
|
101
|
-
|
|
102
|
-
out = self.conv2(out)
|
|
103
|
-
out = self.bn2(out)
|
|
104
|
-
out = self.relu(out)
|
|
105
|
-
|
|
106
|
-
out = self.conv3(out)
|
|
107
|
-
out = self.bn3(out)
|
|
108
|
-
|
|
109
|
-
if self.downsample is not None:
|
|
110
|
-
identity = self.downsample(x)
|
|
111
|
-
|
|
112
|
-
out = torch.add(out, identity)
|
|
113
|
-
out = self.relu(out)
|
|
114
|
-
|
|
115
|
-
return out
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class ResNet(nn.Module):
|
|
119
|
-
|
|
120
|
-
def __init__(
|
|
121
|
-
self,
|
|
122
|
-
arch_cfg: List[int],
|
|
123
|
-
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
124
|
-
groups: int = 1,
|
|
125
|
-
channels_per_group: int = 64,
|
|
126
|
-
out_dim: int = 1000,
|
|
127
|
-
) -> None:
|
|
128
|
-
super(ResNet, self).__init__()
|
|
129
|
-
self.in_channels = 64
|
|
130
|
-
self.dilation = 1
|
|
131
|
-
self.groups = groups
|
|
132
|
-
self.base_channels = channels_per_group
|
|
133
|
-
|
|
134
|
-
self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
|
|
135
|
-
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
|
136
|
-
self.relu = nn.ReLU(True)
|
|
137
|
-
self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
|
|
138
|
-
|
|
139
|
-
self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
|
|
140
|
-
self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
|
|
141
|
-
self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
|
|
142
|
-
self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
|
|
143
|
-
|
|
144
|
-
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
145
|
-
|
|
146
|
-
self.fc = nn.Linear(512 * block.expansion, out_dim)
|
|
147
|
-
|
|
148
|
-
# Initialize neural network weights
|
|
149
|
-
self._initialize_weights()
|
|
150
|
-
|
|
151
|
-
def _make_layer(
|
|
152
|
-
self,
|
|
153
|
-
repeat_times: int,
|
|
154
|
-
block: Type[Union[_BasicBlock, _Bottleneck]],
|
|
155
|
-
channels: int,
|
|
156
|
-
stride: int = 1,
|
|
157
|
-
) -> nn.Sequential:
|
|
158
|
-
downsample = None
|
|
159
|
-
|
|
160
|
-
if stride != 1 or self.in_channels != channels * block.expansion:
|
|
161
|
-
downsample = nn.Sequential(
|
|
162
|
-
nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
|
|
163
|
-
nn.BatchNorm2d(channels * block.expansion),
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
layers = [
|
|
167
|
-
block(
|
|
168
|
-
self.in_channels,
|
|
169
|
-
channels,
|
|
170
|
-
stride,
|
|
171
|
-
downsample,
|
|
172
|
-
self.groups,
|
|
173
|
-
self.base_channels
|
|
174
|
-
)
|
|
175
|
-
]
|
|
176
|
-
self.in_channels = channels * block.expansion
|
|
177
|
-
for _ in range(1, repeat_times):
|
|
178
|
-
layers.append(
|
|
179
|
-
block(
|
|
180
|
-
self.in_channels,
|
|
181
|
-
channels,
|
|
182
|
-
1,
|
|
183
|
-
None,
|
|
184
|
-
self.groups,
|
|
185
|
-
self.base_channels,
|
|
186
|
-
)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
return nn.Sequential(*layers)
|
|
190
|
-
|
|
191
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
192
|
-
out = self._forward_impl(x)
|
|
193
|
-
|
|
194
|
-
return out
|
|
195
|
-
|
|
196
|
-
# Support torch.script function
|
|
197
|
-
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
198
|
-
out = self.conv1(x)
|
|
199
|
-
out = self.bn1(out)
|
|
200
|
-
out = self.relu(out)
|
|
201
|
-
out = self.maxpool(out)
|
|
202
|
-
|
|
203
|
-
out = self.layer1(out)
|
|
204
|
-
out = self.layer2(out)
|
|
205
|
-
out = self.layer3(out)
|
|
206
|
-
out = self.layer4(out)
|
|
207
|
-
|
|
208
|
-
out = self.avgpool(out)
|
|
209
|
-
out = torch.flatten(out, 1)
|
|
210
|
-
out = self.fc(out)
|
|
211
|
-
|
|
212
|
-
return out
|
|
213
|
-
|
|
214
|
-
def _initialize_weights(self) -> None:
|
|
215
|
-
for module in self.modules():
|
|
216
|
-
if isinstance(module, nn.Conv2d):
|
|
217
|
-
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
218
|
-
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
219
|
-
nn.init.constant_(module.weight, 1)
|
|
220
|
-
nn.init.constant_(module.bias, 0)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
def resnet18(**kwargs: Any) -> ResNet:
|
|
224
|
-
model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
|
|
225
|
-
|
|
226
|
-
return model
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def resnet34(**kwargs: Any) -> ResNet:
|
|
230
|
-
model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
|
|
231
|
-
|
|
232
|
-
return model
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
def resnet50(**kwargs: Any) -> ResNet:
|
|
236
|
-
model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
|
|
237
|
-
|
|
238
|
-
return model
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def resnet101(**kwargs: Any) -> ResNet:
|
|
242
|
-
model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
|
|
243
|
-
|
|
244
|
-
return model
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
def resnet152(**kwargs: Any) -> ResNet:
|
|
248
|
-
model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
|
|
249
|
-
|
|
250
|
-
return model
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=somE9gX36c1d9hF2n8Qn4foRY8krHGodvrvulhkIGE8,15006
|
|
3
|
-
metacontroller/metacontroller_with_binary_mapper.py,sha256=CTGK8ruQ3TkioVUwFTHdrbfzubaeuhSdXHfHtaDcwMY,8813
|
|
4
|
-
metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
|
|
5
|
-
metacontroller_pytorch-0.0.32.dist-info/METADATA,sha256=hr08iXm6Mb-rnDu2xPrr9YQ6cwTtX1F79MfBYt54Y94,4747
|
|
6
|
-
metacontroller_pytorch-0.0.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
-
metacontroller_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
metacontroller_pytorch-0.0.32.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.32.dist-info → metacontroller_pytorch-0.0.33.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|