doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +170 -9
  3. doctra/cli/utils.py +2 -3
  4. doctra/engines/image_restoration/__init__.py +10 -0
  5. doctra/engines/image_restoration/docres_engine.py +561 -0
  6. doctra/engines/vlm/outlines_types.py +13 -9
  7. doctra/engines/vlm/service.py +4 -2
  8. doctra/exporters/excel_writer.py +89 -0
  9. doctra/parsers/enhanced_pdf_parser.py +374 -0
  10. doctra/parsers/structured_pdf_parser.py +6 -0
  11. doctra/parsers/table_chart_extractor.py +6 -0
  12. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  13. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  14. doctra/third_party/docres/data/MBD/infer.py +151 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  25. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  26. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  27. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  28. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  29. doctra/third_party/docres/inference.py +370 -0
  30. doctra/third_party/docres/models/restormer_arch.py +308 -0
  31. doctra/third_party/docres/utils.py +464 -0
  32. doctra/ui/app.py +8 -14
  33. doctra/utils/structured_utils.py +5 -2
  34. doctra/version.py +1 -1
  35. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
  36. doctra-0.4.1.dist-info/RECORD +67 -0
  37. doctra-0.3.3.dist-info/RECORD +0 -44
  38. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
  39. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
  40. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
1
+ import torch.nn as nn
2
+ import math
3
+ import torch.utils.model_zoo as model_zoo
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+
6
+ webroot = 'http://dl.yf.io/drn/'
7
+
8
+ model_urls = {
9
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10
+ 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
11
+ 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
12
+ 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
13
+ 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
14
+ 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
15
+ 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
16
+ 'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=padding, bias=False, dilation=dilation)
23
+
24
+
25
+ class BasicBlock(nn.Module):
26
+ expansion = 1
27
+
28
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
29
+ dilation=(1, 1), residual=True, BatchNorm=None):
30
+ super(BasicBlock, self).__init__()
31
+ self.conv1 = conv3x3(inplanes, planes, stride,
32
+ padding=dilation[0], dilation=dilation[0])
33
+ self.bn1 = BatchNorm(planes)
34
+ self.relu = nn.ReLU(inplace=True)
35
+ self.conv2 = conv3x3(planes, planes,
36
+ padding=dilation[1], dilation=dilation[1])
37
+ self.bn2 = BatchNorm(planes)
38
+ self.downsample = downsample
39
+ self.stride = stride
40
+ self.residual = residual
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+ if self.residual:
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
65
+ dilation=(1, 1), residual=True, BatchNorm=None):
66
+ super(Bottleneck, self).__init__()
67
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68
+ self.bn1 = BatchNorm(planes)
69
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70
+ padding=dilation[1], bias=False,
71
+ dilation=dilation[1])
72
+ self.bn2 = BatchNorm(planes)
73
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
74
+ self.bn3 = BatchNorm(planes * 4)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ residual = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv3(out)
91
+ out = self.bn3(out)
92
+
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+
102
+ class DRN(nn.Module):
103
+
104
+ def __init__(self, block, layers, arch='D',
105
+ channels=(16, 32, 64, 128, 256, 512, 512, 512),
106
+ BatchNorm=None):
107
+ super(DRN, self).__init__()
108
+ self.inplanes = channels[0]
109
+ self.out_dim = channels[-1]
110
+ self.arch = arch
111
+
112
+ if arch == 'C':
113
+ self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
114
+ padding=3, bias=False)
115
+ self.bn1 = BatchNorm(channels[0])
116
+ self.relu = nn.ReLU(inplace=True)
117
+
118
+ self.layer1 = self._make_layer(
119
+ BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
120
+ self.layer2 = self._make_layer(
121
+ BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
122
+
123
+ elif arch == 'D':
124
+ self.layer0 = nn.Sequential(
125
+ nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
126
+ bias=False),
127
+ BatchNorm(channels[0]),
128
+ nn.ReLU(inplace=True)
129
+ )
130
+
131
+ self.layer1 = self._make_conv_layers(
132
+ channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
133
+ self.layer2 = self._make_conv_layers(
134
+ channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
135
+
136
+ self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
137
+ self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
138
+ self.layer5 = self._make_layer(block, channels[4], layers[4],
139
+ dilation=2, new_level=False, BatchNorm=BatchNorm)
140
+ self.layer6 = None if layers[5] == 0 else \
141
+ self._make_layer(block, channels[5], layers[5], dilation=4,
142
+ new_level=False, BatchNorm=BatchNorm)
143
+
144
+ if arch == 'C':
145
+ self.layer7 = None if layers[6] == 0 else \
146
+ self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
147
+ new_level=False, residual=False, BatchNorm=BatchNorm)
148
+ self.layer8 = None if layers[7] == 0 else \
149
+ self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
150
+ new_level=False, residual=False, BatchNorm=BatchNorm)
151
+ elif arch == 'D':
152
+ self.layer7 = None if layers[6] == 0 else \
153
+ self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
154
+ self.layer8 = None if layers[7] == 0 else \
155
+ self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
156
+
157
+ self._init_weight()
158
+
159
+ def _init_weight(self):
160
+ for m in self.modules():
161
+ if isinstance(m, nn.Conv2d):
162
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
163
+ m.weight.data.normal_(0, math.sqrt(2. / n))
164
+ elif isinstance(m, SynchronizedBatchNorm2d):
165
+ m.weight.data.fill_(1)
166
+ m.bias.data.zero_()
167
+ elif isinstance(m, nn.BatchNorm2d):
168
+ m.weight.data.fill_(1)
169
+ m.bias.data.zero_()
170
+
171
+
172
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
173
+ new_level=True, residual=True, BatchNorm=None):
174
+ assert dilation == 1 or dilation % 2 == 0
175
+ downsample = None
176
+ if stride != 1 or self.inplanes != planes * block.expansion:
177
+ downsample = nn.Sequential(
178
+ nn.Conv2d(self.inplanes, planes * block.expansion,
179
+ kernel_size=1, stride=stride, bias=False),
180
+ BatchNorm(planes * block.expansion),
181
+ )
182
+
183
+ layers = list()
184
+ layers.append(block(
185
+ self.inplanes, planes, stride, downsample,
186
+ dilation=(1, 1) if dilation == 1 else (
187
+ dilation // 2 if new_level else dilation, dilation),
188
+ residual=residual, BatchNorm=BatchNorm))
189
+ self.inplanes = planes * block.expansion
190
+ for i in range(1, blocks):
191
+ layers.append(block(self.inplanes, planes, residual=residual,
192
+ dilation=(dilation, dilation), BatchNorm=BatchNorm))
193
+
194
+ return nn.Sequential(*layers)
195
+
196
+ def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
197
+ modules = []
198
+ for i in range(convs):
199
+ modules.extend([
200
+ nn.Conv2d(self.inplanes, channels, kernel_size=3,
201
+ stride=stride if i == 0 else 1,
202
+ padding=dilation, bias=False, dilation=dilation),
203
+ BatchNorm(channels),
204
+ nn.ReLU(inplace=True)])
205
+ self.inplanes = channels
206
+ return nn.Sequential(*modules)
207
+
208
+ def forward(self, x):
209
+ if self.arch == 'C':
210
+ x = self.conv1(x)
211
+ x = self.bn1(x)
212
+ x = self.relu(x)
213
+ elif self.arch == 'D':
214
+ x = self.layer0(x)
215
+
216
+ x = self.layer1(x)
217
+ x = self.layer2(x)
218
+
219
+ x = self.layer3(x)
220
+ low_level_feat = x
221
+
222
+ x = self.layer4(x)
223
+ x = self.layer5(x)
224
+
225
+ if self.layer6 is not None:
226
+ x = self.layer6(x)
227
+
228
+ if self.layer7 is not None:
229
+ x = self.layer7(x)
230
+
231
+ if self.layer8 is not None:
232
+ x = self.layer8(x)
233
+
234
+ return x, low_level_feat
235
+
236
+
237
+ class DRN_A(nn.Module):
238
+
239
+ def __init__(self, block, layers, BatchNorm=None):
240
+ self.inplanes = 64
241
+ super(DRN_A, self).__init__()
242
+ self.out_dim = 512 * block.expansion
243
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
244
+ bias=False)
245
+ self.bn1 = BatchNorm(64)
246
+ self.relu = nn.ReLU(inplace=True)
247
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
248
+ self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
249
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
250
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
251
+ dilation=2, BatchNorm=BatchNorm)
252
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
253
+ dilation=4, BatchNorm=BatchNorm)
254
+
255
+ self._init_weight()
256
+
257
+ def _init_weight(self):
258
+ for m in self.modules():
259
+ if isinstance(m, nn.Conv2d):
260
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
261
+ m.weight.data.normal_(0, math.sqrt(2. / n))
262
+ elif isinstance(m, SynchronizedBatchNorm2d):
263
+ m.weight.data.fill_(1)
264
+ m.bias.data.zero_()
265
+ elif isinstance(m, nn.BatchNorm2d):
266
+ m.weight.data.fill_(1)
267
+ m.bias.data.zero_()
268
+
269
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
270
+ downsample = None
271
+ if stride != 1 or self.inplanes != planes * block.expansion:
272
+ downsample = nn.Sequential(
273
+ nn.Conv2d(self.inplanes, planes * block.expansion,
274
+ kernel_size=1, stride=stride, bias=False),
275
+ BatchNorm(planes * block.expansion),
276
+ )
277
+
278
+ layers = []
279
+ layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
280
+ self.inplanes = planes * block.expansion
281
+ for i in range(1, blocks):
282
+ layers.append(block(self.inplanes, planes,
283
+ dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
284
+
285
+ return nn.Sequential(*layers)
286
+
287
+ def forward(self, x):
288
+ x = self.conv1(x)
289
+ x = self.bn1(x)
290
+ x = self.relu(x)
291
+ x = self.maxpool(x)
292
+
293
+ x = self.layer1(x)
294
+ x = self.layer2(x)
295
+ x = self.layer3(x)
296
+ x = self.layer4(x)
297
+
298
+ return x
299
+
300
+ def drn_a_50(BatchNorm, pretrained=True):
301
+ model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
302
+ if pretrained:
303
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
304
+ return model
305
+
306
+
307
+ def drn_c_26(BatchNorm, pretrained=True):
308
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
309
+ if pretrained:
310
+ pretrained = model_zoo.load_url(model_urls['drn-c-26'])
311
+ del pretrained['fc.weight']
312
+ del pretrained['fc.bias']
313
+ model.load_state_dict(pretrained)
314
+ return model
315
+
316
+
317
+ def drn_c_42(BatchNorm, pretrained=True):
318
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
319
+ if pretrained:
320
+ pretrained = model_zoo.load_url(model_urls['drn-c-42'])
321
+ del pretrained['fc.weight']
322
+ del pretrained['fc.bias']
323
+ model.load_state_dict(pretrained)
324
+ return model
325
+
326
+
327
+ def drn_c_58(BatchNorm, pretrained=True):
328
+ model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
329
+ if pretrained:
330
+ pretrained = model_zoo.load_url(model_urls['drn-c-58'])
331
+ del pretrained['fc.weight']
332
+ del pretrained['fc.bias']
333
+ model.load_state_dict(pretrained)
334
+ return model
335
+
336
+
337
+ def drn_d_22(BatchNorm, pretrained=True):
338
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
339
+ if pretrained:
340
+ pretrained = model_zoo.load_url(model_urls['drn-d-22'])
341
+ del pretrained['fc.weight']
342
+ del pretrained['fc.bias']
343
+ model.load_state_dict(pretrained)
344
+ return model
345
+
346
+
347
+ def drn_d_24(BatchNorm, pretrained=True):
348
+ model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
349
+ if pretrained:
350
+ pretrained = model_zoo.load_url(model_urls['drn-d-24'])
351
+ del pretrained['fc.weight']
352
+ del pretrained['fc.bias']
353
+ model.load_state_dict(pretrained)
354
+ return model
355
+
356
+
357
+ def drn_d_38(BatchNorm, pretrained=True):
358
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
359
+ if pretrained:
360
+ pretrained = model_zoo.load_url(model_urls['drn-d-38'])
361
+ del pretrained['fc.weight']
362
+ del pretrained['fc.bias']
363
+ model.load_state_dict(pretrained)
364
+ return model
365
+
366
+
367
+ def drn_d_40(BatchNorm, pretrained=True):
368
+ model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
369
+ if pretrained:
370
+ pretrained = model_zoo.load_url(model_urls['drn-d-40'])
371
+ del pretrained['fc.weight']
372
+ del pretrained['fc.bias']
373
+ model.load_state_dict(pretrained)
374
+ return model
375
+
376
+
377
+ def drn_d_54(BatchNorm, pretrained=True):
378
+ model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
379
+ if pretrained:
380
+ pretrained = model_zoo.load_url(model_urls['drn-d-54'])
381
+ del pretrained['fc.weight']
382
+ del pretrained['fc.bias']
383
+ model.load_state_dict(pretrained)
384
+ return model
385
+
386
+
387
+ def drn_d_105(BatchNorm, pretrained=True):
388
+ model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
389
+ if pretrained:
390
+ pretrained = model_zoo.load_url(model_urls['drn-d-105'])
391
+ del pretrained['fc.weight']
392
+ del pretrained['fc.bias']
393
+ model.load_state_dict(pretrained)
394
+ return model
395
+
396
+ if __name__ == "__main__":
397
+ import torch
398
+ model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
399
+ input = torch.rand(1, 3, 512, 512)
400
+ output, low_level_feat = model(input)
401
+ print(output.size())
402
+ print(low_level_feat.size())
@@ -0,0 +1,151 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import math
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+ import torch.utils.model_zoo as model_zoo
7
+
8
+ def conv_bn(inp, oup, stride, BatchNorm):
9
+ return nn.Sequential(
10
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11
+ BatchNorm(oup),
12
+ nn.ReLU6(inplace=True)
13
+ )
14
+
15
+
16
+ def fixed_padding(inputs, kernel_size, dilation):
17
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
18
+ pad_total = kernel_size_effective - 1
19
+ pad_beg = pad_total // 2
20
+ pad_end = pad_total - pad_beg
21
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
22
+ return padded_inputs
23
+
24
+
25
+ class InvertedResidual(nn.Module):
26
+ def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
27
+ super(InvertedResidual, self).__init__()
28
+ self.stride = stride
29
+ assert stride in [1, 2]
30
+
31
+ hidden_dim = round(inp * expand_ratio)
32
+ self.use_res_connect = self.stride == 1 and inp == oup
33
+ self.kernel_size = 3
34
+ self.dilation = dilation
35
+
36
+ if expand_ratio == 1:
37
+ self.conv = nn.Sequential(
38
+ # dw
39
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
40
+ BatchNorm(hidden_dim),
41
+ nn.ReLU6(inplace=True),
42
+ # pw-linear
43
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
44
+ BatchNorm(oup),
45
+ )
46
+ else:
47
+ self.conv = nn.Sequential(
48
+ # pw
49
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
50
+ BatchNorm(hidden_dim),
51
+ nn.ReLU6(inplace=True),
52
+ # dw
53
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
54
+ BatchNorm(hidden_dim),
55
+ nn.ReLU6(inplace=True),
56
+ # pw-linear
57
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
58
+ BatchNorm(oup),
59
+ )
60
+
61
+ def forward(self, x):
62
+ x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
63
+ if self.use_res_connect:
64
+ x = x + self.conv(x_pad)
65
+ else:
66
+ x = self.conv(x_pad)
67
+ return x
68
+
69
+
70
+ class MobileNetV2(nn.Module):
71
+ def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
72
+ super(MobileNetV2, self).__init__()
73
+ block = InvertedResidual
74
+ input_channel = 32
75
+ current_stride = 1
76
+ rate = 1
77
+ interverted_residual_setting = [
78
+ # t, c, n, s
79
+ [1, 16, 1, 1],
80
+ [6, 24, 2, 2],
81
+ [6, 32, 3, 2],
82
+ [6, 64, 4, 2],
83
+ [6, 96, 3, 1],
84
+ [6, 160, 3, 2],
85
+ [6, 320, 1, 1],
86
+ ]
87
+
88
+ # building first layer
89
+ input_channel = int(input_channel * width_mult)
90
+ self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
91
+ current_stride *= 2
92
+ # building inverted residual blocks
93
+ for t, c, n, s in interverted_residual_setting:
94
+ if current_stride == output_stride:
95
+ stride = 1
96
+ dilation = rate
97
+ rate *= s
98
+ else:
99
+ stride = s
100
+ dilation = 1
101
+ current_stride *= s
102
+ output_channel = int(c * width_mult)
103
+ for i in range(n):
104
+ if i == 0:
105
+ self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
106
+ else:
107
+ self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
108
+ input_channel = output_channel
109
+ self.features = nn.Sequential(*self.features)
110
+ self._initialize_weights()
111
+
112
+ if pretrained:
113
+ self._load_pretrained_model()
114
+
115
+ self.low_level_features = self.features[0:4]
116
+ self.high_level_features = self.features[4:]
117
+
118
+ def forward(self, x):
119
+ low_level_feat = self.low_level_features(x)
120
+ x = self.high_level_features(low_level_feat)
121
+ return x, low_level_feat
122
+
123
+ def _load_pretrained_model(self):
124
+ pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
125
+ model_dict = {}
126
+ state_dict = self.state_dict()
127
+ for k, v in pretrain_dict.items():
128
+ if k in state_dict:
129
+ model_dict[k] = v
130
+ state_dict.update(model_dict)
131
+ self.load_state_dict(state_dict)
132
+
133
+ def _initialize_weights(self):
134
+ for m in self.modules():
135
+ if isinstance(m, nn.Conv2d):
136
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
138
+ torch.nn.init.kaiming_normal_(m.weight)
139
+ elif isinstance(m, SynchronizedBatchNorm2d):
140
+ m.weight.data.fill_(1)
141
+ m.bias.data.zero_()
142
+ elif isinstance(m, nn.BatchNorm2d):
143
+ m.weight.data.fill_(1)
144
+ m.bias.data.zero_()
145
+
146
+ if __name__ == "__main__":
147
+ input = torch.rand(1, 3, 512, 512)
148
+ model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
149
+ output, low_level_feat = model(input)
150
+ print(output.size())
151
+ print(low_level_feat.size())