metacontroller-pytorch 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -66,6 +66,47 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ def z_score(t, eps = 1e-8):
70
+ return (t - t.mean()) / (t.std() + eps)
71
+
72
+ def policy_loss(
73
+ meta_controller,
74
+ state,
75
+ old_log_probs,
76
+ actions,
77
+ advantages,
78
+ mask,
79
+ episode_lens = None,
80
+ eps_clip = 0.2
81
+ ):
82
+ # get new log probs
83
+
84
+ action_dist = meta_controller.get_action_dist_for_internal_rl(state)
85
+ new_log_probs = meta_controller.log_prob(action_dist, actions)
86
+
87
+ # calculate ratio
88
+
89
+ ratio = (new_log_probs - old_log_probs).exp()
90
+
91
+ # align ratio and advantages
92
+
93
+ ratio, advantages = align_dims_left((ratio, advantages))
94
+
95
+ # ppo surrogate loss
96
+
97
+ surr1 = ratio * advantages
98
+ surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
99
+
100
+ losses = -torch.min(surr1, surr2)
101
+
102
+ # masking
103
+
104
+ if exists(episode_lens):
105
+ mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
106
+ mask = mask & episode_mask
107
+
108
+ return masked_mean(losses, mask)
109
+
69
110
  @save_load()
70
111
  class MetaController(Module):
71
112
  def __init__(
@@ -107,7 +148,6 @@ class MetaController(Module):
107
148
 
108
149
  self.switch_per_latent_dim = switch_per_latent_dim
109
150
 
110
-
111
151
  self.dim_latent = dim_latent
112
152
  self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
113
153
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
@@ -147,6 +187,23 @@ class MetaController(Module):
147
187
  *self.action_proposer_mean_log_var.parameters()
148
188
  ]
149
189
 
190
+ def get_action_dist_for_internal_rl(
191
+ self,
192
+ residual_stream
193
+ ):
194
+ meta_embed = self.model_to_meta(residual_stream)
195
+
196
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
197
+
198
+ return self.action_proposer_mean_log_var(proposed_action_hidden)
199
+
200
+ def log_prob(
201
+ self,
202
+ action_dist,
203
+ sampled_latent_action
204
+ ):
205
+ return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
206
+
150
207
  def forward(
151
208
  self,
152
209
  residual_stream,
@@ -276,6 +333,12 @@ class MetaController(Module):
276
333
 
277
334
  # main transformer, which is subsumed into the environment after behavioral cloning
278
335
 
336
+ Hiddens = namedtuple('Hiddens', (
337
+ 'lower_body',
338
+ 'meta_controller',
339
+ 'upper_body'
340
+ ))
341
+
279
342
  TransformerOutput = namedtuple('TransformerOutput', (
280
343
  'residual_stream_latent',
281
344
  'prev_hiddens'
@@ -417,9 +480,8 @@ class Transformer(Module):
417
480
  # maybe return behavior cloning loss
418
481
 
419
482
  if behavioral_cloning:
420
- loss_mask = None
421
- if exists(episode_lens):
422
- loss_mask = lens_to_mask(episode_lens, state.shape[1])
483
+
484
+ loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
423
485
 
424
486
  state_dist_params = self.state_readout(attended)
425
487
  state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
@@ -441,4 +503,4 @@ class Transformer(Module):
441
503
  if return_one:
442
504
  return dist_params
443
505
 
444
- return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
506
+ return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
23
23
 
24
24
  from assoc_scan import AssocScan
25
25
 
26
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
27
27
  from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
@@ -50,6 +50,9 @@ def default(*args):
50
50
  def straight_through(src, tgt):
51
51
  return tgt + src - src.detach()
52
52
 
53
+ def log(t, eps = 1e-20):
54
+ return t.clamp_min(eps).log()
55
+
53
56
  # meta controller
54
57
 
55
58
  @save_load()
@@ -71,6 +74,9 @@ class MetaControllerWithBinaryMapper(Module):
71
74
  kl_loss_threshold = 0.
72
75
  ):
73
76
  super().__init__()
77
+
78
+ assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
+
74
80
  dim_meta = default(dim_meta_controller, dim_model)
75
81
 
76
82
  self.model_to_meta = Linear(dim_model, dim_meta)
@@ -137,6 +143,23 @@ class MetaControllerWithBinaryMapper(Module):
137
143
  *self.proposer_to_binary_logits.parameters()
138
144
  ]
139
145
 
146
+ def log_prob(
147
+ self,
148
+ action_dist,
149
+ sampled_latent_action
150
+ ):
151
+ action_prob = action_dist.sigmoid()
152
+ probs = stack((action_prob, 1. - action_prob), dim = -1)
153
+ log_probs = log(probs)
154
+
155
+ indices = sampled_latent_action.argmax(dim = -1)
156
+ codes = self.binary_mapper.codes[indices].long()
157
+
158
+ codes = rearrange(codes, '... -> ... 1')
159
+ action_log_probs = log_probs.gather(-1, codes)
160
+
161
+ return rearrange(action_log_probs, '... 1 -> ...')
162
+
140
163
  def forward(
141
164
  self,
142
165
  residual_stream,
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from metacontroller.metacontroller import Transformer
10
+
11
+ from torch_einops_utils import pack_with_inverse
12
+
13
+ # resnet components
14
+
15
+ def exists(v):
16
+ return v is not None
17
+
18
+ class BasicBlock(Module):
19
+ expansion = 1
20
+
21
+ def __init__(
22
+ self,
23
+ dim,
24
+ dim_out,
25
+ stride = 1,
26
+ downsample: Module | None = None
27
+ ):
28
+ super().__init__()
29
+ self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
30
+ self.bn1 = nn.BatchNorm2d(dim_out)
31
+ self.relu = nn.ReLU(inplace = True)
32
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
33
+ self.bn2 = nn.BatchNorm2d(dim_out)
34
+ self.downsample = downsample
35
+
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ identity = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if exists(self.downsample):
47
+ identity = self.downsample(x)
48
+
49
+ out += identity
50
+ return self.relu(out)
51
+
52
+ class Bottleneck(Module):
53
+ expansion = 4
54
+
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ dim_out,
59
+ stride = 1,
60
+ downsample: Module | None = None
61
+ ):
62
+ super().__init__()
63
+ width = dim_out # simple resnet shortcut
64
+ self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
65
+ self.bn1 = nn.BatchNorm2d(width)
66
+ self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
67
+ self.bn2 = nn.BatchNorm2d(width)
68
+ self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
69
+ self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
70
+ self.relu = nn.ReLU(inplace = True)
71
+ self.downsample = downsample
72
+
73
+ def forward(self, x: Tensor) -> Tensor:
74
+ identity = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if exists(self.downsample):
88
+ identity = self.downsample(x)
89
+
90
+ out += identity
91
+ return self.relu(out)
92
+
93
+ class ResNet(Module):
94
+ def __init__(
95
+ self,
96
+ block: type[BasicBlock | Bottleneck],
97
+ layers: list[int],
98
+ num_classes = 1000,
99
+ channels = 3
100
+ ):
101
+ super().__init__()
102
+ self.inplanes = 64
103
+
104
+ self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
105
+ self.bn1 = nn.BatchNorm2d(64)
106
+ self.relu = nn.ReLU(inplace = True)
107
+ self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
108
+
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
113
+
114
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
115
+ self.flatten = Rearrange('b c 1 1 -> b c')
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ def _make_layer(
119
+ self,
120
+ block: type[BasicBlock | Bottleneck],
121
+ planes: int,
122
+ blocks: int,
123
+ stride: int = 1
124
+ ) -> nn.Sequential:
125
+ downsample = None
126
+ if stride != 1 or self.inplanes != planes * block.expansion:
127
+ downsample = nn.Sequential(
128
+ nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
129
+ nn.BatchNorm2d(planes * block.expansion),
130
+ )
131
+
132
+ layers = []
133
+ layers.append(block(self.inplanes, planes, stride, downsample))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(block(self.inplanes, planes))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x: Tensor) -> Tensor:
141
+ x = self.conv1(x)
142
+ x = self.bn1(x)
143
+ x = self.relu(x)
144
+ x = self.maxpool(x)
145
+
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x = self.layer3(x)
149
+ x = self.layer4(x)
150
+
151
+ x = self.avgpool(x)
152
+ x = self.flatten(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+ # resnet factory
157
+
158
+ def resnet18(num_classes: any = 1000):
159
+ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
160
+
161
+ def resnet34(num_classes: any = 1000):
162
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
163
+
164
+ def resnet50(num_classes: any = 1000):
165
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
166
+
167
+ # transformer with resnet
168
+
169
+ class TransformerWithResnet(Transformer):
170
+ def __init__(
171
+ self,
172
+ *,
173
+ resnet_type = 'resnet18',
174
+ **kwargs
175
+ ):
176
+ super().__init__(**kwargs)
177
+ resnet_klass = resnet18
178
+ if resnet_type == 'resnet34':
179
+ resnet_klass = resnet34
180
+ elif resnet_type == 'resnet50':
181
+ resnet_klass = resnet50
182
+
183
+ self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
184
+ self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
185
+
186
+ def visual_encode(self, x: Tensor) -> Tensor:
187
+ if x.shape[-1] == 3:
188
+ x = rearrange(x, '... h w c -> ... c h w')
189
+
190
+ x, inverse = pack_with_inverse(x, '* c h w')
191
+
192
+ h = self.visual_encoder(x)
193
+
194
+ return inverse(h, '* d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.31
3
+ Version: 0.0.33
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -40,7 +40,7 @@ Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
42
  Requires-Dist: memmap-replay-buffer>=0.0.23
43
- Requires-Dist: torch-einops-utils>=0.0.16
43
+ Requires-Dist: torch-einops-utils>=0.0.19
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: vector-quantize-pytorch>=1.27.20
46
46
  Requires-Dist: x-evolution>=0.1.23
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=B9XHYgVBrcJkEWhUORz--D5AHjcLnvLRVY9SRqVbhdw,16222
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=7vGtenScxvDQhkeYUmNnTTbVTJAtIFUVqEoWVGZP2Is,8936
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.33.dist-info/METADATA,sha256=kWHDAnQeEueYWRzqPi5ouXjtSDIG2bnDqOfhznSUOoM,4747
6
+ metacontroller_pytorch-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.33.dist-info/RECORD,,
@@ -1,250 +0,0 @@
1
- from typing import Any, List, Type, Union, Optional
2
-
3
- import torch
4
- from torch import Tensor
5
- from torch import nn
6
- from einops import rearrange
7
- from metacontroller.metacontroller import Transformer
8
-
9
- class TransformerWithResnetEncoder(Transformer):
10
- def __init__(self, **kwargs):
11
- super().__init__(**kwargs)
12
- self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
13
- self.visual_encoder = resnet18(out_dim=self.resnet_dim)
14
-
15
- def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
16
- b, t = x.shape[:2]
17
- x = rearrange(x, 'b t h w c -> (b t) c h w')
18
- h = self.visual_encoder(x)
19
- h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
20
- return h
21
-
22
- # resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
23
-
24
- class _BasicBlock(nn.Module):
25
- expansion: int = 1
26
-
27
- def __init__(
28
- self,
29
- in_channels: int,
30
- out_channels: int,
31
- stride: int,
32
- downsample: Optional[nn.Module] = None,
33
- groups: int = 1,
34
- base_channels: int = 64,
35
- ) -> None:
36
- super(_BasicBlock, self).__init__()
37
- self.stride = stride
38
- self.downsample = downsample
39
- self.groups = groups
40
- self.base_channels = base_channels
41
-
42
- self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
43
- self.bn1 = nn.BatchNorm2d(out_channels)
44
- self.relu = nn.ReLU(True)
45
- self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
46
- self.bn2 = nn.BatchNorm2d(out_channels)
47
-
48
- def forward(self, x: Tensor) -> Tensor:
49
- identity = x
50
-
51
- out = self.conv1(x)
52
- out = self.bn1(out)
53
- out = self.relu(out)
54
-
55
- out = self.conv2(out)
56
- out = self.bn2(out)
57
-
58
- if self.downsample is not None:
59
- identity = self.downsample(x)
60
-
61
- out = torch.add(out, identity)
62
- out = self.relu(out)
63
-
64
- return out
65
-
66
-
67
- class _Bottleneck(nn.Module):
68
- expansion: int = 4
69
-
70
- def __init__(
71
- self,
72
- in_channels: int,
73
- out_channels: int,
74
- stride: int,
75
- downsample: Optional[nn.Module] = None,
76
- groups: int = 1,
77
- base_channels: int = 64,
78
- ) -> None:
79
- super(_Bottleneck, self).__init__()
80
- self.stride = stride
81
- self.downsample = downsample
82
- self.groups = groups
83
- self.base_channels = base_channels
84
-
85
- channels = int(out_channels * (base_channels / 64.0)) * groups
86
-
87
- self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
88
- self.bn1 = nn.BatchNorm2d(channels)
89
- self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
90
- self.bn2 = nn.BatchNorm2d(channels)
91
- self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
92
- self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
93
- self.relu = nn.ReLU(True)
94
-
95
- def forward(self, x: Tensor) -> Tensor:
96
- identity = x
97
-
98
- out = self.conv1(x)
99
- out = self.bn1(out)
100
- out = self.relu(out)
101
-
102
- out = self.conv2(out)
103
- out = self.bn2(out)
104
- out = self.relu(out)
105
-
106
- out = self.conv3(out)
107
- out = self.bn3(out)
108
-
109
- if self.downsample is not None:
110
- identity = self.downsample(x)
111
-
112
- out = torch.add(out, identity)
113
- out = self.relu(out)
114
-
115
- return out
116
-
117
-
118
- class ResNet(nn.Module):
119
-
120
- def __init__(
121
- self,
122
- arch_cfg: List[int],
123
- block: Type[Union[_BasicBlock, _Bottleneck]],
124
- groups: int = 1,
125
- channels_per_group: int = 64,
126
- out_dim: int = 1000,
127
- ) -> None:
128
- super(ResNet, self).__init__()
129
- self.in_channels = 64
130
- self.dilation = 1
131
- self.groups = groups
132
- self.base_channels = channels_per_group
133
-
134
- self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
135
- self.bn1 = nn.BatchNorm2d(self.in_channels)
136
- self.relu = nn.ReLU(True)
137
- self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
138
-
139
- self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
140
- self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
141
- self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
142
- self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
143
-
144
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145
-
146
- self.fc = nn.Linear(512 * block.expansion, out_dim)
147
-
148
- # Initialize neural network weights
149
- self._initialize_weights()
150
-
151
- def _make_layer(
152
- self,
153
- repeat_times: int,
154
- block: Type[Union[_BasicBlock, _Bottleneck]],
155
- channels: int,
156
- stride: int = 1,
157
- ) -> nn.Sequential:
158
- downsample = None
159
-
160
- if stride != 1 or self.in_channels != channels * block.expansion:
161
- downsample = nn.Sequential(
162
- nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
163
- nn.BatchNorm2d(channels * block.expansion),
164
- )
165
-
166
- layers = [
167
- block(
168
- self.in_channels,
169
- channels,
170
- stride,
171
- downsample,
172
- self.groups,
173
- self.base_channels
174
- )
175
- ]
176
- self.in_channels = channels * block.expansion
177
- for _ in range(1, repeat_times):
178
- layers.append(
179
- block(
180
- self.in_channels,
181
- channels,
182
- 1,
183
- None,
184
- self.groups,
185
- self.base_channels,
186
- )
187
- )
188
-
189
- return nn.Sequential(*layers)
190
-
191
- def forward(self, x: Tensor) -> Tensor:
192
- out = self._forward_impl(x)
193
-
194
- return out
195
-
196
- # Support torch.script function
197
- def _forward_impl(self, x: Tensor) -> Tensor:
198
- out = self.conv1(x)
199
- out = self.bn1(out)
200
- out = self.relu(out)
201
- out = self.maxpool(out)
202
-
203
- out = self.layer1(out)
204
- out = self.layer2(out)
205
- out = self.layer3(out)
206
- out = self.layer4(out)
207
-
208
- out = self.avgpool(out)
209
- out = torch.flatten(out, 1)
210
- out = self.fc(out)
211
-
212
- return out
213
-
214
- def _initialize_weights(self) -> None:
215
- for module in self.modules():
216
- if isinstance(module, nn.Conv2d):
217
- nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
218
- elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
219
- nn.init.constant_(module.weight, 1)
220
- nn.init.constant_(module.bias, 0)
221
-
222
-
223
- def resnet18(**kwargs: Any) -> ResNet:
224
- model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
225
-
226
- return model
227
-
228
-
229
- def resnet34(**kwargs: Any) -> ResNet:
230
- model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
231
-
232
- return model
233
-
234
-
235
- def resnet50(**kwargs: Any) -> ResNet:
236
- model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
237
-
238
- return model
239
-
240
-
241
- def resnet101(**kwargs: Any) -> ResNet:
242
- model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
243
-
244
- return model
245
-
246
-
247
- def resnet152(**kwargs: Any) -> ResNet:
248
- model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
249
-
250
- return model
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=lxWgeWFcXxSDm-ygd14DjyEOYIJIALcuLkoRAfEzNtc,14719
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=BrsQdkhlOyR2O5xAXTLC4p-uKOAbW7wET-lVU0qktws,8242
4
- metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
- metacontroller_pytorch-0.0.31.dist-info/METADATA,sha256=mtOtYymI01jBMO7pyaAIJ166B5Mk3khH8CUwUMNLTKw,4747
6
- metacontroller_pytorch-0.0.31.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.31.dist-info/RECORD,,