metacontroller-pytorch 0.0.31__tar.gz → 0.0.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (18) hide show
  1. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/PKG-INFO +2 -2
  2. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller.py +68 -6
  3. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/metacontroller/metacontroller_with_binary_mapper.py +24 -1
  4. metacontroller_pytorch-0.0.33/metacontroller/transformer_with_resnet.py +194 -0
  5. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/pyproject.toml +2 -2
  6. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/tests/test_metacontroller.py +83 -11
  7. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/train_behavior_clone_babyai.py +2 -2
  8. metacontroller_pytorch-0.0.31/metacontroller/metacontroller_with_resnet.py +0 -250
  9. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/.github/workflows/python-publish.yml +0 -0
  10. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/.github/workflows/test.yml +0 -0
  11. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/.gitignore +0 -0
  12. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/LICENSE +0 -0
  13. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/README.md +0 -0
  14. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/fig1.png +0 -0
  15. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/gather_babyai_trajs.py +0 -0
  16. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/metacontroller/__init__.py +0 -0
  17. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/test_babyai_e2e.sh +0 -0
  18. {metacontroller_pytorch-0.0.31 → metacontroller_pytorch-0.0.33}/train_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.31
3
+ Version: 0.0.33
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -40,7 +40,7 @@ Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
42
  Requires-Dist: memmap-replay-buffer>=0.0.23
43
- Requires-Dist: torch-einops-utils>=0.0.16
43
+ Requires-Dist: torch-einops-utils>=0.0.19
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: vector-quantize-pytorch>=1.27.20
46
46
  Requires-Dist: x-evolution>=0.1.23
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -66,6 +66,47 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
66
66
  'switch_loss'
67
67
  ))
68
68
 
69
+ def z_score(t, eps = 1e-8):
70
+ return (t - t.mean()) / (t.std() + eps)
71
+
72
+ def policy_loss(
73
+ meta_controller,
74
+ state,
75
+ old_log_probs,
76
+ actions,
77
+ advantages,
78
+ mask,
79
+ episode_lens = None,
80
+ eps_clip = 0.2
81
+ ):
82
+ # get new log probs
83
+
84
+ action_dist = meta_controller.get_action_dist_for_internal_rl(state)
85
+ new_log_probs = meta_controller.log_prob(action_dist, actions)
86
+
87
+ # calculate ratio
88
+
89
+ ratio = (new_log_probs - old_log_probs).exp()
90
+
91
+ # align ratio and advantages
92
+
93
+ ratio, advantages = align_dims_left((ratio, advantages))
94
+
95
+ # ppo surrogate loss
96
+
97
+ surr1 = ratio * advantages
98
+ surr2 = ratio.clamp(1 - eps_clip, 1 + eps_clip) * advantages
99
+
100
+ losses = -torch.min(surr1, surr2)
101
+
102
+ # masking
103
+
104
+ if exists(episode_lens):
105
+ mask, episode_mask = align_dims_left((mask, lens_to_mask(episode_lens, losses.shape[1])))
106
+ mask = mask & episode_mask
107
+
108
+ return masked_mean(losses, mask)
109
+
69
110
  @save_load()
70
111
  class MetaController(Module):
71
112
  def __init__(
@@ -107,7 +148,6 @@ class MetaController(Module):
107
148
 
108
149
  self.switch_per_latent_dim = switch_per_latent_dim
109
150
 
110
-
111
151
  self.dim_latent = dim_latent
112
152
  self.switching_unit = GRU(dim_meta + dim_latent, dim_meta)
113
153
  self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
@@ -147,6 +187,23 @@ class MetaController(Module):
147
187
  *self.action_proposer_mean_log_var.parameters()
148
188
  ]
149
189
 
190
+ def get_action_dist_for_internal_rl(
191
+ self,
192
+ residual_stream
193
+ ):
194
+ meta_embed = self.model_to_meta(residual_stream)
195
+
196
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
197
+
198
+ return self.action_proposer_mean_log_var(proposed_action_hidden)
199
+
200
+ def log_prob(
201
+ self,
202
+ action_dist,
203
+ sampled_latent_action
204
+ ):
205
+ return self.action_proposer_mean_log_var.log_prob(action_dist, sampled_latent_action)
206
+
150
207
  def forward(
151
208
  self,
152
209
  residual_stream,
@@ -276,6 +333,12 @@ class MetaController(Module):
276
333
 
277
334
  # main transformer, which is subsumed into the environment after behavioral cloning
278
335
 
336
+ Hiddens = namedtuple('Hiddens', (
337
+ 'lower_body',
338
+ 'meta_controller',
339
+ 'upper_body'
340
+ ))
341
+
279
342
  TransformerOutput = namedtuple('TransformerOutput', (
280
343
  'residual_stream_latent',
281
344
  'prev_hiddens'
@@ -417,9 +480,8 @@ class Transformer(Module):
417
480
  # maybe return behavior cloning loss
418
481
 
419
482
  if behavioral_cloning:
420
- loss_mask = None
421
- if exists(episode_lens):
422
- loss_mask = lens_to_mask(episode_lens, state.shape[1])
483
+
484
+ loss_mask = maybe(lens_to_mask)(episode_lens, state.shape[1])
423
485
 
424
486
  state_dist_params = self.state_readout(attended)
425
487
  state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
@@ -441,4 +503,4 @@ class Transformer(Module):
441
503
  if return_one:
442
504
  return dist_params
443
505
 
444
- return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
506
+ return dist_params, TransformerOutput(residual_stream, Hiddens(next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
23
23
 
24
24
  from assoc_scan import AssocScan
25
25
 
26
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
27
27
  from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
@@ -50,6 +50,9 @@ def default(*args):
50
50
  def straight_through(src, tgt):
51
51
  return tgt + src - src.detach()
52
52
 
53
+ def log(t, eps = 1e-20):
54
+ return t.clamp_min(eps).log()
55
+
53
56
  # meta controller
54
57
 
55
58
  @save_load()
@@ -71,6 +74,9 @@ class MetaControllerWithBinaryMapper(Module):
71
74
  kl_loss_threshold = 0.
72
75
  ):
73
76
  super().__init__()
77
+
78
+ assert not switch_per_code, 'switch_per_code is not supported for binary mapper'
79
+
74
80
  dim_meta = default(dim_meta_controller, dim_model)
75
81
 
76
82
  self.model_to_meta = Linear(dim_model, dim_meta)
@@ -137,6 +143,23 @@ class MetaControllerWithBinaryMapper(Module):
137
143
  *self.proposer_to_binary_logits.parameters()
138
144
  ]
139
145
 
146
+ def log_prob(
147
+ self,
148
+ action_dist,
149
+ sampled_latent_action
150
+ ):
151
+ action_prob = action_dist.sigmoid()
152
+ probs = stack((action_prob, 1. - action_prob), dim = -1)
153
+ log_probs = log(probs)
154
+
155
+ indices = sampled_latent_action.argmax(dim = -1)
156
+ codes = self.binary_mapper.codes[indices].long()
157
+
158
+ codes = rearrange(codes, '... -> ... 1')
159
+ action_log_probs = log_probs.gather(-1, codes)
160
+
161
+ return rearrange(action_log_probs, '... 1 -> ...')
162
+
140
163
  def forward(
141
164
  self,
142
165
  residual_stream,
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from metacontroller.metacontroller import Transformer
10
+
11
+ from torch_einops_utils import pack_with_inverse
12
+
13
+ # resnet components
14
+
15
+ def exists(v):
16
+ return v is not None
17
+
18
+ class BasicBlock(Module):
19
+ expansion = 1
20
+
21
+ def __init__(
22
+ self,
23
+ dim,
24
+ dim_out,
25
+ stride = 1,
26
+ downsample: Module | None = None
27
+ ):
28
+ super().__init__()
29
+ self.conv1 = nn.Conv2d(dim, dim_out, 3, stride = stride, padding = 1, bias = False)
30
+ self.bn1 = nn.BatchNorm2d(dim_out)
31
+ self.relu = nn.ReLU(inplace = True)
32
+ self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding = 1, bias = False)
33
+ self.bn2 = nn.BatchNorm2d(dim_out)
34
+ self.downsample = downsample
35
+
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ identity = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if exists(self.downsample):
47
+ identity = self.downsample(x)
48
+
49
+ out += identity
50
+ return self.relu(out)
51
+
52
+ class Bottleneck(Module):
53
+ expansion = 4
54
+
55
+ def __init__(
56
+ self,
57
+ dim,
58
+ dim_out,
59
+ stride = 1,
60
+ downsample: Module | None = None
61
+ ):
62
+ super().__init__()
63
+ width = dim_out # simple resnet shortcut
64
+ self.conv1 = nn.Conv2d(dim, width, 1, bias = False)
65
+ self.bn1 = nn.BatchNorm2d(width)
66
+ self.conv2 = nn.Conv2d(width, width, 3, stride = stride, padding = 1, bias = False)
67
+ self.bn2 = nn.BatchNorm2d(width)
68
+ self.conv3 = nn.Conv2d(width, dim_out * self.expansion, 1, bias = False)
69
+ self.bn3 = nn.BatchNorm2d(dim_out * self.expansion)
70
+ self.relu = nn.ReLU(inplace = True)
71
+ self.downsample = downsample
72
+
73
+ def forward(self, x: Tensor) -> Tensor:
74
+ identity = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if exists(self.downsample):
88
+ identity = self.downsample(x)
89
+
90
+ out += identity
91
+ return self.relu(out)
92
+
93
+ class ResNet(Module):
94
+ def __init__(
95
+ self,
96
+ block: type[BasicBlock | Bottleneck],
97
+ layers: list[int],
98
+ num_classes = 1000,
99
+ channels = 3
100
+ ):
101
+ super().__init__()
102
+ self.inplanes = 64
103
+
104
+ self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
105
+ self.bn1 = nn.BatchNorm2d(64)
106
+ self.relu = nn.ReLU(inplace = True)
107
+ self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
108
+
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride = 2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride = 2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride = 2)
113
+
114
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
115
+ self.flatten = Rearrange('b c 1 1 -> b c')
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ def _make_layer(
119
+ self,
120
+ block: type[BasicBlock | Bottleneck],
121
+ planes: int,
122
+ blocks: int,
123
+ stride: int = 1
124
+ ) -> nn.Sequential:
125
+ downsample = None
126
+ if stride != 1 or self.inplanes != planes * block.expansion:
127
+ downsample = nn.Sequential(
128
+ nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride = stride, bias = False),
129
+ nn.BatchNorm2d(planes * block.expansion),
130
+ )
131
+
132
+ layers = []
133
+ layers.append(block(self.inplanes, planes, stride, downsample))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(block(self.inplanes, planes))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x: Tensor) -> Tensor:
141
+ x = self.conv1(x)
142
+ x = self.bn1(x)
143
+ x = self.relu(x)
144
+ x = self.maxpool(x)
145
+
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x = self.layer3(x)
149
+ x = self.layer4(x)
150
+
151
+ x = self.avgpool(x)
152
+ x = self.flatten(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+ # resnet factory
157
+
158
+ def resnet18(num_classes: any = 1000):
159
+ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
160
+
161
+ def resnet34(num_classes: any = 1000):
162
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
163
+
164
+ def resnet50(num_classes: any = 1000):
165
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
166
+
167
+ # transformer with resnet
168
+
169
+ class TransformerWithResnet(Transformer):
170
+ def __init__(
171
+ self,
172
+ *,
173
+ resnet_type = 'resnet18',
174
+ **kwargs
175
+ ):
176
+ super().__init__(**kwargs)
177
+ resnet_klass = resnet18
178
+ if resnet_type == 'resnet34':
179
+ resnet_klass = resnet34
180
+ elif resnet_type == 'resnet50':
181
+ resnet_klass = resnet50
182
+
183
+ self.resnet_dim = kwargs['state_embed_readout']['num_continuous']
184
+ self.visual_encoder = resnet_klass(num_classes = self.resnet_dim)
185
+
186
+ def visual_encode(self, x: Tensor) -> Tensor:
187
+ if x.shape[-1] == 3:
188
+ x = rearrange(x, '... h w c -> ... c h w')
189
+
190
+ x, inverse = pack_with_inverse(x, '* c h w')
191
+
192
+ h = self.visual_encoder(x)
193
+
194
+ return inverse(h, '* d')
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.31"
3
+ version = "0.0.33"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -31,7 +31,7 @@ dependencies = [
31
31
  "loguru",
32
32
  "memmap-replay-buffer>=0.0.23",
33
33
  "torch>=2.5",
34
- "torch-einops-utils>=0.0.16",
34
+ "torch-einops-utils>=0.0.19",
35
35
  "vector-quantize-pytorch>=1.27.20",
36
36
  "x-evolution>=0.1.23",
37
37
  "x-mlps-pytorch",
@@ -4,19 +4,30 @@ param = pytest.mark.parametrize
4
4
  from pathlib import Path
5
5
 
6
6
  import torch
7
- from metacontroller.metacontroller import Transformer, MetaController
7
+ from torch import cat
8
+ from metacontroller.metacontroller import Transformer, MetaController, policy_loss, z_score
8
9
  from metacontroller.metacontroller_with_binary_mapper import MetaControllerWithBinaryMapper
9
10
 
10
11
  from einops import rearrange
11
12
 
12
- @param('use_binary_mapper_variant', (False, True))
13
+ # functions
14
+
15
+ def exists(v):
16
+ return v is not None
17
+
18
+ # test
19
+
20
+ @param('use_binary_mapper_variant, switch_per_latent_dim', [
21
+ (False, False),
22
+ (False, True),
23
+ (True, False)
24
+ ])
13
25
  @param('action_discrete', (False, True))
14
- @param('switch_per_latent_dim', (False, True))
15
26
  @param('variable_length', (False, True))
16
27
  def test_metacontroller(
17
28
  use_binary_mapper_variant,
18
- action_discrete,
19
29
  switch_per_latent_dim,
30
+ action_discrete,
20
31
  variable_length
21
32
  ):
22
33
 
@@ -69,16 +80,77 @@ def test_metacontroller(
69
80
 
70
81
  # internal rl - done iteratively
71
82
 
72
- cache = None
73
- past_action_id = None
83
+ # simulate grpo
84
+
85
+ all_episodes = []
86
+ all_rewards = []
87
+
88
+ for _ in range(3): # group of 3
89
+ subset_state = state[:1]
90
+
91
+ cache = None
92
+ past_action_id = None
93
+
94
+ states = []
95
+ log_probs = []
96
+ switch_betas = []
97
+ latent_actions = []
98
+
99
+ for one_state in subset_state.unbind(dim = 1):
100
+ one_state = rearrange(one_state, 'b d -> b 1 d')
101
+
102
+ logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
103
+
104
+ past_action_id = model.action_readout.sample(logits)
105
+
106
+ # get log prob from meta controller latent actions
74
107
 
75
- for one_state in state.unbind(dim = 1):
76
- one_state = rearrange(one_state, 'b d -> b 1 d')
108
+ meta_output = cache.prev_hiddens.meta_controller
77
109
 
78
- logits, cache = model(one_state, past_action_id, meta_controller = meta_controller, return_cache = True)
110
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
111
+
112
+ states.append(meta_output.input_residual_stream)
113
+ log_probs.append(old_log_probs)
114
+ switch_betas.append(meta_output.switch_beta)
115
+ latent_actions.append(meta_output.actions)
116
+
117
+ # accumulate across time for the episode data
118
+
119
+ all_episodes.append(dict(
120
+ states = cat(states, dim = 1),
121
+ log_probs = cat(log_probs, dim = 1),
122
+ switch_betas = cat(switch_betas, dim = 1),
123
+ latent_actions = cat(latent_actions, dim = 1)
124
+ ))
125
+
126
+ all_rewards.append(torch.randn(1))
127
+
128
+ # calculate advantages using z-score
129
+
130
+ rewards = cat(all_rewards)
131
+ advantages = z_score(rewards)
132
+
133
+ assert advantages.shape == (3,)
134
+
135
+ # simulate a policy loss update over the entire group
136
+
137
+ group_states = cat([e['states'] for e in all_episodes], dim = 0)
138
+ group_log_probs = cat([e['log_probs'] for e in all_episodes], dim = 0)
139
+ group_latent_actions = cat([e['latent_actions'] for e in all_episodes], dim = 0)
140
+ group_switch_betas = cat([e['switch_betas'] for e in all_episodes], dim = 0)
141
+
142
+ if not use_binary_mapper_variant:
143
+ loss = policy_loss(
144
+ meta_controller,
145
+ group_states,
146
+ group_log_probs,
147
+ group_latent_actions,
148
+ advantages,
149
+ group_switch_betas == 1.,
150
+ episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
151
+ )
79
152
 
80
- assert logits.shape == (2, 1, *assert_shape)
81
- past_action_id = model.action_readout.sample(logits)
153
+ loss.backward()
82
154
 
83
155
  # evolutionary strategies over grpo
84
156
 
@@ -26,7 +26,7 @@ from memmap_replay_buffer import ReplayBuffer
26
26
  from einops import rearrange
27
27
 
28
28
  from metacontroller.metacontroller import Transformer
29
- from metacontroller.metacontroller_with_resnet import TransformerWithResnetEncoder
29
+ from metacontroller.transformer_with_resnet import TransformerWithResnet
30
30
 
31
31
  import minigrid
32
32
  import gymnasium as gym
@@ -95,7 +95,7 @@ def train(
95
95
 
96
96
  # transformer
97
97
 
98
- transformer_class = TransformerWithResnetEncoder if use_resnet else Transformer
98
+ transformer_class = TransformerWithResnet if use_resnet else Transformer
99
99
  model = transformer_class(
100
100
  dim = dim,
101
101
  state_embed_readout = dict(num_continuous = state_dim),
@@ -1,250 +0,0 @@
1
- from typing import Any, List, Type, Union, Optional
2
-
3
- import torch
4
- from torch import Tensor
5
- from torch import nn
6
- from einops import rearrange
7
- from metacontroller.metacontroller import Transformer
8
-
9
- class TransformerWithResnetEncoder(Transformer):
10
- def __init__(self, **kwargs):
11
- super().__init__(**kwargs)
12
- self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
13
- self.visual_encoder = resnet18(out_dim=self.resnet_dim)
14
-
15
- def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
16
- b, t = x.shape[:2]
17
- x = rearrange(x, 'b t h w c -> (b t) c h w')
18
- h = self.visual_encoder(x)
19
- h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
20
- return h
21
-
22
- # resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
23
-
24
- class _BasicBlock(nn.Module):
25
- expansion: int = 1
26
-
27
- def __init__(
28
- self,
29
- in_channels: int,
30
- out_channels: int,
31
- stride: int,
32
- downsample: Optional[nn.Module] = None,
33
- groups: int = 1,
34
- base_channels: int = 64,
35
- ) -> None:
36
- super(_BasicBlock, self).__init__()
37
- self.stride = stride
38
- self.downsample = downsample
39
- self.groups = groups
40
- self.base_channels = base_channels
41
-
42
- self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
43
- self.bn1 = nn.BatchNorm2d(out_channels)
44
- self.relu = nn.ReLU(True)
45
- self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
46
- self.bn2 = nn.BatchNorm2d(out_channels)
47
-
48
- def forward(self, x: Tensor) -> Tensor:
49
- identity = x
50
-
51
- out = self.conv1(x)
52
- out = self.bn1(out)
53
- out = self.relu(out)
54
-
55
- out = self.conv2(out)
56
- out = self.bn2(out)
57
-
58
- if self.downsample is not None:
59
- identity = self.downsample(x)
60
-
61
- out = torch.add(out, identity)
62
- out = self.relu(out)
63
-
64
- return out
65
-
66
-
67
- class _Bottleneck(nn.Module):
68
- expansion: int = 4
69
-
70
- def __init__(
71
- self,
72
- in_channels: int,
73
- out_channels: int,
74
- stride: int,
75
- downsample: Optional[nn.Module] = None,
76
- groups: int = 1,
77
- base_channels: int = 64,
78
- ) -> None:
79
- super(_Bottleneck, self).__init__()
80
- self.stride = stride
81
- self.downsample = downsample
82
- self.groups = groups
83
- self.base_channels = base_channels
84
-
85
- channels = int(out_channels * (base_channels / 64.0)) * groups
86
-
87
- self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
88
- self.bn1 = nn.BatchNorm2d(channels)
89
- self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
90
- self.bn2 = nn.BatchNorm2d(channels)
91
- self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
92
- self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
93
- self.relu = nn.ReLU(True)
94
-
95
- def forward(self, x: Tensor) -> Tensor:
96
- identity = x
97
-
98
- out = self.conv1(x)
99
- out = self.bn1(out)
100
- out = self.relu(out)
101
-
102
- out = self.conv2(out)
103
- out = self.bn2(out)
104
- out = self.relu(out)
105
-
106
- out = self.conv3(out)
107
- out = self.bn3(out)
108
-
109
- if self.downsample is not None:
110
- identity = self.downsample(x)
111
-
112
- out = torch.add(out, identity)
113
- out = self.relu(out)
114
-
115
- return out
116
-
117
-
118
- class ResNet(nn.Module):
119
-
120
- def __init__(
121
- self,
122
- arch_cfg: List[int],
123
- block: Type[Union[_BasicBlock, _Bottleneck]],
124
- groups: int = 1,
125
- channels_per_group: int = 64,
126
- out_dim: int = 1000,
127
- ) -> None:
128
- super(ResNet, self).__init__()
129
- self.in_channels = 64
130
- self.dilation = 1
131
- self.groups = groups
132
- self.base_channels = channels_per_group
133
-
134
- self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
135
- self.bn1 = nn.BatchNorm2d(self.in_channels)
136
- self.relu = nn.ReLU(True)
137
- self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
138
-
139
- self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
140
- self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
141
- self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
142
- self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
143
-
144
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145
-
146
- self.fc = nn.Linear(512 * block.expansion, out_dim)
147
-
148
- # Initialize neural network weights
149
- self._initialize_weights()
150
-
151
- def _make_layer(
152
- self,
153
- repeat_times: int,
154
- block: Type[Union[_BasicBlock, _Bottleneck]],
155
- channels: int,
156
- stride: int = 1,
157
- ) -> nn.Sequential:
158
- downsample = None
159
-
160
- if stride != 1 or self.in_channels != channels * block.expansion:
161
- downsample = nn.Sequential(
162
- nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
163
- nn.BatchNorm2d(channels * block.expansion),
164
- )
165
-
166
- layers = [
167
- block(
168
- self.in_channels,
169
- channels,
170
- stride,
171
- downsample,
172
- self.groups,
173
- self.base_channels
174
- )
175
- ]
176
- self.in_channels = channels * block.expansion
177
- for _ in range(1, repeat_times):
178
- layers.append(
179
- block(
180
- self.in_channels,
181
- channels,
182
- 1,
183
- None,
184
- self.groups,
185
- self.base_channels,
186
- )
187
- )
188
-
189
- return nn.Sequential(*layers)
190
-
191
- def forward(self, x: Tensor) -> Tensor:
192
- out = self._forward_impl(x)
193
-
194
- return out
195
-
196
- # Support torch.script function
197
- def _forward_impl(self, x: Tensor) -> Tensor:
198
- out = self.conv1(x)
199
- out = self.bn1(out)
200
- out = self.relu(out)
201
- out = self.maxpool(out)
202
-
203
- out = self.layer1(out)
204
- out = self.layer2(out)
205
- out = self.layer3(out)
206
- out = self.layer4(out)
207
-
208
- out = self.avgpool(out)
209
- out = torch.flatten(out, 1)
210
- out = self.fc(out)
211
-
212
- return out
213
-
214
- def _initialize_weights(self) -> None:
215
- for module in self.modules():
216
- if isinstance(module, nn.Conv2d):
217
- nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
218
- elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
219
- nn.init.constant_(module.weight, 1)
220
- nn.init.constant_(module.bias, 0)
221
-
222
-
223
- def resnet18(**kwargs: Any) -> ResNet:
224
- model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
225
-
226
- return model
227
-
228
-
229
- def resnet34(**kwargs: Any) -> ResNet:
230
- model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
231
-
232
- return model
233
-
234
-
235
- def resnet50(**kwargs: Any) -> ResNet:
236
- model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
237
-
238
- return model
239
-
240
-
241
- def resnet101(**kwargs: Any) -> ResNet:
242
- model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
243
-
244
- return model
245
-
246
-
247
- def resnet152(**kwargs: Any) -> ResNet:
248
- model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
249
-
250
- return model