metacontroller-pytorch 0.0.32__tar.gz → 0.0.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (18) hide show
  1. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/PKG-INFO +2 -2
  2. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller.py +54 -4
  3. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller_with_binary_mapper.py +4 -1
  4. metacontroller_pytorch-0.0.33/metacontroller/transformer_with_resnet.py +194 -0
  5. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/pyproject.toml +2 -2
  6. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/tests/test_metacontroller.py +80 -14
  7. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/train_behavior_clone_babyai.py +2 -2
  8. metacontroller_pytorch-0.0.32/metacontroller/metacontroller_with_resnet.py +0 -250
  9. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.github/workflows/python-publish.yml +0 -0
  10. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.github/workflows/test.yml +0 -0
  11. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/.gitignore +0 -0
  12. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/LICENSE +0 -0
  13. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/README.md +0 -0
  14. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/fig1.png +0 -0
  15. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/gather_babyai_trajs.py +0 -0
  16. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/metacontroller/__init__.py +0 -0
  17. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/test_babyai_e2e.sh +0 -0
  18. {metacontroller_pytorch-0.0.32 → metacontroller_pytorch-0.0.33}/train_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.32
3
+ Version: 0.0.33
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -40,7 +40,7 @@ Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
42
  Requires-Dist: memmap-replay-buffer>=0.0.23
43
- Requires-Dist: torch-einops-utils>=0.0.16
43
+ Requires-Dist: torch-einops-utils>=0.0.19
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: vector-quantize-pytorch>=1.27.20
46
46
  Requires-Dist: x-evolution>=0.1.23
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -66,6 +66,47 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ def z_score(t, eps = 1e-8):
70
+ return (t - t.mean()) / (t.std() + eps)
71
+
72
+ def policy_loss(
73
+ meta_controller,
74
+ state,
75
+ old_log_probs,
76
+ actions,
77
+ advantages,
78
+ mask,
79
+ episode_lens = None,
80
+ eps_clip = 0.2
81
+ ):
82
+ # get new log probs
83
+
84
+ action_dist = meta_controller.get_action_dist_for_internal_rl(state)
85
+ new_log_probs = meta_controller.log_prob(action_dist, actions)
86
+
87
+ # calculate ratio
88
+
89
+ ratio = (new_log_probs - old_log_probs).exp()
90
+
91
+ # align ratio and advantages
92
+
93
+ ratio, advantages = align_dims_left((ratio, advantages))
94
+
95
+ # ppo surrogate loss
96
+
97
+ surr1 = ratio * advantages
98
+ surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
99
+
100
+ losses = -torch.min(surr1, surr2)
101
+
102
+ # masking
103
+
104
+ if exists(episode_lens):
105
+ mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
106
+ mask = mask & episode_mask
107
+
108
+ return masked_mean(losses, mask)
109
+
69
110
  @save_load()
70
111
  class MetaController(Module):
71
112
  def __init__(
@@ -146,6 +187,16 @@ class MetaController(Module):
146
187
  *self.action_proposer_mean_log_var.parameters()
147
188
  ]
148
189
 
190
+ def get_action_dist_for_internal_rl(
191
+ self,
192
+ residual_stream
193
+ ):
194
+ meta_embed = self.model_to_meta(residual_stream)
195
+
196
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
197
+
198
+ return self.action_proposer_mean_log_var(proposed_action_hidden)
199
+
149
200
  def log_prob(
150
201
  self,
151
202
  action_dist,
@@ -429,9 +480,8 @@ class Transformer(Module):
429
480
  # maybe return behavior cloning loss
430
481
 
431
482
  if behavioral_cloning:
432
- loss_mask = None
433
- if exists(episode_lens):
434
- loss_mask = lens_to_mask(episode_lens, state.shape[1])
483
+
484
+ loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
435
485
 
436
486
  state_dist_params = self.state_readout(attended)
437
487
  state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
23
23
 
24
24
  from assoc_scan import AssocScan
25
25
 
26
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
27
27
  from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
@@ -74,6 +74,9 @@ class MetaControllerWithBinaryMapper(Module):
74
74
  kl_loss_threshold = 0.
75
75
  ):
76
76
  super().__init__()
77
+
78
+ assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
+
77
80
  dim_meta = default(dim_meta_controller, dim_model)
78
81
 
79
82
  self.model_to_meta = Linear(dim_model, dim_meta)
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from metacontroller.metacontroller import Transformer
10
+
11
+ from torch_einops_utils import pack_with_inverse
12
+
13
+ # resnet components
14
+
15
+ def exists(v):
16
+ return v is not None
17
+
18
+ class BasicBlock(Module):
19
+ expansion = 1
20
+
21
+ def __init__(
22
+ self,
23
+ dim,
24
+ dim_out,
25
+ stride = 1,
26
+ downsample: Module | None = None
27
+ ):
28
+ super().__init__()
29
+ self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
30
+ self.bn1 = nn.BatchNorm2d(dim_out)
31
+ self.relu = nn.ReLU(inplace = True)
32
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
33
+ self.bn2 = nn.BatchNorm2d(dim_out)
34
+ self.downsample = downsample
35
+
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ identity = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if exists(self.downsample):
47
+ identity = self.downsample(x)
48
+
49
+ out += identity
50
+ return self.relu(out)
51
+
52
+ class Bottleneck(Module):
53
+ expansion = 4
54
+
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ dim_out,
59
+ stride = 1,
60
+ downsample: Module | None = None
61
+ ):
62
+ super().__init__()
63
+ width = dim_out # simple resnet shortcut
64
+ self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
65
+ self.bn1 = nn.BatchNorm2d(width)
66
+ self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
67
+ self.bn2 = nn.BatchNorm2d(width)
68
+ self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
69
+ self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
70
+ self.relu = nn.ReLU(inplace = True)
71
+ self.downsample = downsample
72
+
73
+ def forward(self, x: Tensor) -> Tensor:
74
+ identity = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if exists(self.downsample):
88
+ identity = self.downsample(x)
89
+
90
+ out += identity
91
+ return self.relu(out)
92
+
93
+ class ResNet(Module):
94
+ def __init__(
95
+ self,
96
+ block: type[BasicBlock | Bottleneck],
97
+ layers: list[int],
98
+ num_classes = 1000,
99
+ channels = 3
100
+ ):
101
+ super().__init__()
102
+ self.inplanes = 64
103
+
104
+ self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
105
+ self.bn1 = nn.BatchNorm2d(64)
106
+ self.relu = nn.ReLU(inplace = True)
107
+ self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
108
+
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
113
+
114
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
115
+ self.flatten = Rearrange('b c 1 1 -> b c')
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ def _make_layer(
119
+ self,
120
+ block: type[BasicBlock | Bottleneck],
121
+ planes: int,
122
+ blocks: int,
123
+ stride: int = 1
124
+ ) -> nn.Sequential:
125
+ downsample = None
126
+ if stride != 1 or self.inplanes != planes * block.expansion:
127
+ downsample = nn.Sequential(
128
+ nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
129
+ nn.BatchNorm2d(planes * block.expansion),
130
+ )
131
+
132
+ layers = []
133
+ layers.append(block(self.inplanes, planes, stride, downsample))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(block(self.inplanes, planes))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x: Tensor) -> Tensor:
141
+ x = self.conv1(x)
142
+ x = self.bn1(x)
143
+ x = self.relu(x)
144
+ x = self.maxpool(x)
145
+
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x = self.layer3(x)
149
+ x = self.layer4(x)
150
+
151
+ x = self.avgpool(x)
152
+ x = self.flatten(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+ # resnet factory
157
+
158
+ def resnet18(num_classes: any = 1000):
159
+ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
160
+
161
+ def resnet34(num_classes: any = 1000):
162
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
163
+
164
+ def resnet50(num_classes: any = 1000):
165
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
166
+
167
+ # transformer with resnet
168
+
169
+ class TransformerWithResnet(Transformer):
170
+ def __init__(
171
+ self,
172
+ *,
173
+ resnet_type = 'resnet18',
174
+ **kwargs
175
+ ):
176
+ super().__init__(**kwargs)
177
+ resnet_klass = resnet18
178
+ if resnet_type == 'resnet34':
179
+ resnet_klass = resnet34
180
+ elif resnet_type == 'resnet50':
181
+ resnet_klass = resnet50
182
+
183
+ self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
184
+ self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
185
+
186
+ def visual_encode(self, x: Tensor) -> Tensor:
187
+ if x.shape[-1] == 3:
188
+ x = rearrange(x, '... h w c -> ... c h w')
189
+
190
+ x, inverse = pack_with_inverse(x, '* c h w')
191
+
192
+ h = self.visual_encoder(x)
193
+
194
+ return inverse(h, '* d')
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.32"
3
+ version = "0.0.33"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -31,7 +31,7 @@ dependencies = [
31
31
  "loguru",
32
32
  "memmap-replay-buffer>=0.0.23",
33
33
  "torch>=2.5",
34
- "torch-einops-utils>=0.0.16",
34
+ "torch-einops-utils>=0.0.19",
35
35
  "vector-quantize-pytorch>=1.27.20",
36
36
  "x-evolution>=0.1.23",
37
37
  "x-mlps-pytorch",
@@ -4,19 +4,30 @@ param = pytest.mark.parametrize
4
4
  from pathlib import Path
5
5
 
6
6
  import torch
7
- from metacontroller.metacontroller import Transformer, MetaController
7
+ from torch import cat
8
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
8
9
  from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
9
10
 
10
11
  from einops import rearrange
11
12
 
12
- @param('use_binary_mapper_variant', (False, True))
13
+ # functions
14
+
15
+ def exists(v):
16
+ return v is not None
17
+
18
+ # test
19
+
20
+ @param('use_binary_mapper_variant, switch_per_latent_dim', [
21
+ (False, False),
22
+ (False, True),
23
+ (True, False)
24
+ ])
13
25
  @param('action_discrete', (False, True))
14
- @param('switch_per_latent_dim', (False, True))
15
26
  @param('variable_length', (False, True))
16
27
  def test_metacontroller(
17
28
  use_binary_mapper_variant,
18
- action_discrete,
19
29
  switch_per_latent_dim,
30
+ action_discrete,
20
31
  variable_length
21
32
  ):
22
33
 
@@ -69,22 +80,77 @@ def test_metacontroller(
69
80
 
70
81
  # internal rl - done iteratively
71
82
 
72
- cache = None
73
- past_action_id = None
83
+ # simulate grpo
84
+
85
+ all_episodes = []
86
+ all_rewards = []
87
+
88
+ for _ in range(3): # group of 3
89
+ subset_state = state[:1]
90
+
91
+ cache = None
92
+ past_action_id = None
93
+
94
+ states = []
95
+ log_probs = []
96
+ switch_betas = []
97
+ latent_actions = []
98
+
99
+ for one_state in subset_state.unbind(dim = 1):
100
+ one_state = rearrange(one_state, 'b d -> b 1 d')
101
+
102
+ logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
74
103
 
75
- for one_state in state.unbind(dim = 1):
76
- one_state = rearrange(one_state, 'b d -> b 1 d')
104
+ past_action_id = model.action_readout.sample(logits)
77
105
 
78
- logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
106
+ # get log prob from meta controller latent actions
79
107
 
80
- assert logits.shape == (2, 1, *assert_shape)
81
- past_action_id = model.action_readout.sample(logits)
108
+ meta_output = cache.prev_hiddens.meta_controller
82
109
 
83
- # get log prob from meta controller latent actions
110
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
84
111
 
85
- meta_controller_hidden = cache.prev_hiddens.meta_controller
112
+ states.append(meta_output.input_residual_stream)
113
+ log_probs.append(old_log_probs)
114
+ switch_betas.append(meta_output.switch_beta)
115
+ latent_actions.append(meta_output.actions)
116
+
117
+ # accumulate across time for the episode data
118
+
119
+ all_episodes.append(dict(
120
+ states = cat(states, dim = 1),
121
+ log_probs = cat(log_probs, dim = 1),
122
+ switch_betas = cat(switch_betas, dim = 1),
123
+ latent_actions = cat(latent_actions, dim = 1)
124
+ ))
125
+
126
+ all_rewards.append(torch.randn(1))
127
+
128
+ # calculate advantages using z-score
129
+
130
+ rewards = cat(all_rewards)
131
+ advantages = z_score(rewards)
132
+
133
+ assert advantages.shape == (3,)
134
+
135
+ # simulate a policy loss update over the entire group
136
+
137
+ group_states = cat([e['states'] for e in all_episodes], dim = 0)
138
+ group_log_probs = cat([e['log_probs'] for e in all_episodes], dim = 0)
139
+ group_latent_actions = cat([e['latent_actions'] for e in all_episodes], dim = 0)
140
+ group_switch_betas = cat([e['switch_betas'] for e in all_episodes], dim = 0)
141
+
142
+ if not use_binary_mapper_variant:
143
+ loss = policy_loss(
144
+ meta_controller,
145
+ group_states,
146
+ group_log_probs,
147
+ group_latent_actions,
148
+ advantages,
149
+ group_switch_betas == 1.,
150
+ episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
151
+ )
86
152
 
87
- old_log_probs = meta_controller.log_prob(meta_controller_hidden.action_dist, meta_controller_hidden.actions)
153
+ loss.backward()
88
154
 
89
155
  # evolutionary strategies over grpo
90
156
 
@@ -26,7 +26,7 @@ from memmap_replay_buffer import ReplayBuffer
26
26
  from einops import rearrange
27
27
 
28
28
  from metacontroller.metacontroller import Transformer
29
- from metacontroller.metacontroller_with_resnet import TransformerWithResnetEncoder
29
+ from metacontroller.transformer_with_resnet import TransformerWithResnet
30
30
 
31
31
  import minigrid
32
32
  import gymnasium as gym
@@ -95,7 +95,7 @@ def train(
95
95
 
96
96
  # transformer
97
97
 
98
- transformer_class = TransformerWithResnetEncoder if use_resnet else Transformer
98
+ transformer_class = TransformerWithResnet if use_resnet else Transformer
99
99
  model = transformer_class(
100
100
  dim = dim,
101
101
  state_embed_readout = dict(num_continuous = state_dim),
@@ -1,250 +0,0 @@
1
- from typing import Any, List, Type, Union, Optional
2
-
3
- import torch
4
- from torch import Tensor
5
- from torch import nn
6
- from einops import rearrange
7
- from metacontroller.metacontroller import Transformer
8
-
9
- class TransformerWithResnetEncoder(Transformer):
10
- def __init__(self, **kwargs):
11
- super().__init__(**kwargs)
12
- self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
13
- self.visual_encoder = resnet18(out_dim=self.resnet_dim)
14
-
15
- def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
16
- b, t = x.shape[:2]
17
- x = rearrange(x, 'b t h w c -> (b t) c h w')
18
- h = self.visual_encoder(x)
19
- h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
20
- return h
21
-
22
- # resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
23
-
24
- class _BasicBlock(nn.Module):
25
- expansion: int = 1
26
-
27
- def __init__(
28
- self,
29
- in_channels: int,
30
- out_channels: int,
31
- stride: int,
32
- downsample: Optional[nn.Module] = None,
33
- groups: int = 1,
34
- base_channels: int = 64,
35
- ) -> None:
36
- super(_BasicBlock, self).__init__()
37
- self.stride = stride
38
- self.downsample = downsample
39
- self.groups = groups
40
- self.base_channels = base_channels
41
-
42
- self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
43
- self.bn1 = nn.BatchNorm2d(out_channels)
44
- self.relu = nn.ReLU(True)
45
- self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
46
- self.bn2 = nn.BatchNorm2d(out_channels)
47
-
48
- def forward(self, x: Tensor) -> Tensor:
49
- identity = x
50
-
51
- out = self.conv1(x)
52
- out = self.bn1(out)
53
- out = self.relu(out)
54
-
55
- out = self.conv2(out)
56
- out = self.bn2(out)
57
-
58
- if self.downsample is not None:
59
- identity = self.downsample(x)
60
-
61
- out = torch.add(out, identity)
62
- out = self.relu(out)
63
-
64
- return out
65
-
66
-
67
- class _Bottleneck(nn.Module):
68
- expansion: int = 4
69
-
70
- def __init__(
71
- self,
72
- in_channels: int,
73
- out_channels: int,
74
- stride: int,
75
- downsample: Optional[nn.Module] = None,
76
- groups: int = 1,
77
- base_channels: int = 64,
78
- ) -> None:
79
- super(_Bottleneck, self).__init__()
80
- self.stride = stride
81
- self.downsample = downsample
82
- self.groups = groups
83
- self.base_channels = base_channels
84
-
85
- channels = int(out_channels * (base_channels / 64.0)) * groups
86
-
87
- self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
88
- self.bn1 = nn.BatchNorm2d(channels)
89
- self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
90
- self.bn2 = nn.BatchNorm2d(channels)
91
- self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
92
- self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
93
- self.relu = nn.ReLU(True)
94
-
95
- def forward(self, x: Tensor) -> Tensor:
96
- identity = x
97
-
98
- out = self.conv1(x)
99
- out = self.bn1(out)
100
- out = self.relu(out)
101
-
102
- out = self.conv2(out)
103
- out = self.bn2(out)
104
- out = self.relu(out)
105
-
106
- out = self.conv3(out)
107
- out = self.bn3(out)
108
-
109
- if self.downsample is not None:
110
- identity = self.downsample(x)
111
-
112
- out = torch.add(out, identity)
113
- out = self.relu(out)
114
-
115
- return out
116
-
117
-
118
- class ResNet(nn.Module):
119
-
120
- def __init__(
121
- self,
122
- arch_cfg: List[int],
123
- block: Type[Union[_BasicBlock, _Bottleneck]],
124
- groups: int = 1,
125
- channels_per_group: int = 64,
126
- out_dim: int = 1000,
127
- ) -> None:
128
- super(ResNet, self).__init__()
129
- self.in_channels = 64
130
- self.dilation = 1
131
- self.groups = groups
132
- self.base_channels = channels_per_group
133
-
134
- self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
135
- self.bn1 = nn.BatchNorm2d(self.in_channels)
136
- self.relu = nn.ReLU(True)
137
- self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
138
-
139
- self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
140
- self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
141
- self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
142
- self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
143
-
144
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145
-
146
- self.fc = nn.Linear(512 * block.expansion, out_dim)
147
-
148
- # Initialize neural network weights
149
- self._initialize_weights()
150
-
151
- def _make_layer(
152
- self,
153
- repeat_times: int,
154
- block: Type[Union[_BasicBlock, _Bottleneck]],
155
- channels: int,
156
- stride: int = 1,
157
- ) -> nn.Sequential:
158
- downsample = None
159
-
160
- if stride != 1 or self.in_channels != channels * block.expansion:
161
- downsample = nn.Sequential(
162
- nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
163
- nn.BatchNorm2d(channels * block.expansion),
164
- )
165
-
166
- layers = [
167
- block(
168
- self.in_channels,
169
- channels,
170
- stride,
171
- downsample,
172
- self.groups,
173
- self.base_channels
174
- )
175
- ]
176
- self.in_channels = channels * block.expansion
177
- for _ in range(1, repeat_times):
178
- layers.append(
179
- block(
180
- self.in_channels,
181
- channels,
182
- 1,
183
- None,
184
- self.groups,
185
- self.base_channels,
186
- )
187
- )
188
-
189
- return nn.Sequential(*layers)
190
-
191
- def forward(self, x: Tensor) -> Tensor:
192
- out = self._forward_impl(x)
193
-
194
- return out
195
-
196
- # Support torch.script function
197
- def _forward_impl(self, x: Tensor) -> Tensor:
198
- out = self.conv1(x)
199
- out = self.bn1(out)
200
- out = self.relu(out)
201
- out = self.maxpool(out)
202
-
203
- out = self.layer1(out)
204
- out = self.layer2(out)
205
- out = self.layer3(out)
206
- out = self.layer4(out)
207
-
208
- out = self.avgpool(out)
209
- out = torch.flatten(out, 1)
210
- out = self.fc(out)
211
-
212
- return out
213
-
214
- def _initialize_weights(self) -> None:
215
- for module in self.modules():
216
- if isinstance(module, nn.Conv2d):
217
- nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
218
- elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
219
- nn.init.constant_(module.weight, 1)
220
- nn.init.constant_(module.bias, 0)
221
-
222
-
223
- def resnet18(**kwargs: Any) -> ResNet:
224
- model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
225
-
226
- return model
227
-
228
-
229
- def resnet34(**kwargs: Any) -> ResNet:
230
- model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
231
-
232
- return model
233
-
234
-
235
- def resnet50(**kwargs: Any) -> ResNet:
236
- model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
237
-
238
- return model
239
-
240
-
241
- def resnet101(**kwargs: Any) -> ResNet:
242
- model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
243
-
244
- return model
245
-
246
-
247
- def resnet152(**kwargs: Any) -> ResNet:
248
- model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
249
-
250
- return model