doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +170 -9
- doctra/cli/utils.py +2 -3
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +561 -0
- doctra/engines/vlm/outlines_types.py +13 -9
- doctra/engines/vlm/service.py +4 -2
- doctra/exporters/excel_writer.py +89 -0
- doctra/parsers/enhanced_pdf_parser.py +374 -0
- doctra/parsers/structured_pdf_parser.py +6 -0
- doctra/parsers/table_chart_extractor.py +6 -0
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +8 -14
- doctra/utils/structured_utils.py +5 -2
- doctra/version.py +1 -1
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
- doctra-0.4.1.dist-info/RECORD +67 -0
- doctra-0.3.3.dist-info/RECORD +0 -44
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
import math
|
3
|
+
import torch.utils.model_zoo as model_zoo
|
4
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5
|
+
|
6
|
+
webroot = 'http://dl.yf.io/drn/'
|
7
|
+
|
8
|
+
model_urls = {
|
9
|
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
10
|
+
'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
|
11
|
+
'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
|
12
|
+
'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
|
13
|
+
'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
|
14
|
+
'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
|
15
|
+
'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
|
16
|
+
'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
|
17
|
+
}
|
18
|
+
|
19
|
+
|
20
|
+
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
|
21
|
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
22
|
+
padding=padding, bias=False, dilation=dilation)
|
23
|
+
|
24
|
+
|
25
|
+
class BasicBlock(nn.Module):
|
26
|
+
expansion = 1
|
27
|
+
|
28
|
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
29
|
+
dilation=(1, 1), residual=True, BatchNorm=None):
|
30
|
+
super(BasicBlock, self).__init__()
|
31
|
+
self.conv1 = conv3x3(inplanes, planes, stride,
|
32
|
+
padding=dilation[0], dilation=dilation[0])
|
33
|
+
self.bn1 = BatchNorm(planes)
|
34
|
+
self.relu = nn.ReLU(inplace=True)
|
35
|
+
self.conv2 = conv3x3(planes, planes,
|
36
|
+
padding=dilation[1], dilation=dilation[1])
|
37
|
+
self.bn2 = BatchNorm(planes)
|
38
|
+
self.downsample = downsample
|
39
|
+
self.stride = stride
|
40
|
+
self.residual = residual
|
41
|
+
|
42
|
+
def forward(self, x):
|
43
|
+
residual = x
|
44
|
+
|
45
|
+
out = self.conv1(x)
|
46
|
+
out = self.bn1(out)
|
47
|
+
out = self.relu(out)
|
48
|
+
|
49
|
+
out = self.conv2(out)
|
50
|
+
out = self.bn2(out)
|
51
|
+
|
52
|
+
if self.downsample is not None:
|
53
|
+
residual = self.downsample(x)
|
54
|
+
if self.residual:
|
55
|
+
out += residual
|
56
|
+
out = self.relu(out)
|
57
|
+
|
58
|
+
return out
|
59
|
+
|
60
|
+
|
61
|
+
class Bottleneck(nn.Module):
|
62
|
+
expansion = 4
|
63
|
+
|
64
|
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
65
|
+
dilation=(1, 1), residual=True, BatchNorm=None):
|
66
|
+
super(Bottleneck, self).__init__()
|
67
|
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
68
|
+
self.bn1 = BatchNorm(planes)
|
69
|
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
70
|
+
padding=dilation[1], bias=False,
|
71
|
+
dilation=dilation[1])
|
72
|
+
self.bn2 = BatchNorm(planes)
|
73
|
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
74
|
+
self.bn3 = BatchNorm(planes * 4)
|
75
|
+
self.relu = nn.ReLU(inplace=True)
|
76
|
+
self.downsample = downsample
|
77
|
+
self.stride = stride
|
78
|
+
|
79
|
+
def forward(self, x):
|
80
|
+
residual = x
|
81
|
+
|
82
|
+
out = self.conv1(x)
|
83
|
+
out = self.bn1(out)
|
84
|
+
out = self.relu(out)
|
85
|
+
|
86
|
+
out = self.conv2(out)
|
87
|
+
out = self.bn2(out)
|
88
|
+
out = self.relu(out)
|
89
|
+
|
90
|
+
out = self.conv3(out)
|
91
|
+
out = self.bn3(out)
|
92
|
+
|
93
|
+
if self.downsample is not None:
|
94
|
+
residual = self.downsample(x)
|
95
|
+
|
96
|
+
out += residual
|
97
|
+
out = self.relu(out)
|
98
|
+
|
99
|
+
return out
|
100
|
+
|
101
|
+
|
102
|
+
class DRN(nn.Module):
|
103
|
+
|
104
|
+
def __init__(self, block, layers, arch='D',
|
105
|
+
channels=(16, 32, 64, 128, 256, 512, 512, 512),
|
106
|
+
BatchNorm=None):
|
107
|
+
super(DRN, self).__init__()
|
108
|
+
self.inplanes = channels[0]
|
109
|
+
self.out_dim = channels[-1]
|
110
|
+
self.arch = arch
|
111
|
+
|
112
|
+
if arch == 'C':
|
113
|
+
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
|
114
|
+
padding=3, bias=False)
|
115
|
+
self.bn1 = BatchNorm(channels[0])
|
116
|
+
self.relu = nn.ReLU(inplace=True)
|
117
|
+
|
118
|
+
self.layer1 = self._make_layer(
|
119
|
+
BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
|
120
|
+
self.layer2 = self._make_layer(
|
121
|
+
BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
|
122
|
+
|
123
|
+
elif arch == 'D':
|
124
|
+
self.layer0 = nn.Sequential(
|
125
|
+
nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
|
126
|
+
bias=False),
|
127
|
+
BatchNorm(channels[0]),
|
128
|
+
nn.ReLU(inplace=True)
|
129
|
+
)
|
130
|
+
|
131
|
+
self.layer1 = self._make_conv_layers(
|
132
|
+
channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
|
133
|
+
self.layer2 = self._make_conv_layers(
|
134
|
+
channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
|
135
|
+
|
136
|
+
self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
|
137
|
+
self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
|
138
|
+
self.layer5 = self._make_layer(block, channels[4], layers[4],
|
139
|
+
dilation=2, new_level=False, BatchNorm=BatchNorm)
|
140
|
+
self.layer6 = None if layers[5] == 0 else \
|
141
|
+
self._make_layer(block, channels[5], layers[5], dilation=4,
|
142
|
+
new_level=False, BatchNorm=BatchNorm)
|
143
|
+
|
144
|
+
if arch == 'C':
|
145
|
+
self.layer7 = None if layers[6] == 0 else \
|
146
|
+
self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
|
147
|
+
new_level=False, residual=False, BatchNorm=BatchNorm)
|
148
|
+
self.layer8 = None if layers[7] == 0 else \
|
149
|
+
self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
|
150
|
+
new_level=False, residual=False, BatchNorm=BatchNorm)
|
151
|
+
elif arch == 'D':
|
152
|
+
self.layer7 = None if layers[6] == 0 else \
|
153
|
+
self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
|
154
|
+
self.layer8 = None if layers[7] == 0 else \
|
155
|
+
self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
|
156
|
+
|
157
|
+
self._init_weight()
|
158
|
+
|
159
|
+
def _init_weight(self):
|
160
|
+
for m in self.modules():
|
161
|
+
if isinstance(m, nn.Conv2d):
|
162
|
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
163
|
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
164
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
165
|
+
m.weight.data.fill_(1)
|
166
|
+
m.bias.data.zero_()
|
167
|
+
elif isinstance(m, nn.BatchNorm2d):
|
168
|
+
m.weight.data.fill_(1)
|
169
|
+
m.bias.data.zero_()
|
170
|
+
|
171
|
+
|
172
|
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
173
|
+
new_level=True, residual=True, BatchNorm=None):
|
174
|
+
assert dilation == 1 or dilation % 2 == 0
|
175
|
+
downsample = None
|
176
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
177
|
+
downsample = nn.Sequential(
|
178
|
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
179
|
+
kernel_size=1, stride=stride, bias=False),
|
180
|
+
BatchNorm(planes * block.expansion),
|
181
|
+
)
|
182
|
+
|
183
|
+
layers = list()
|
184
|
+
layers.append(block(
|
185
|
+
self.inplanes, planes, stride, downsample,
|
186
|
+
dilation=(1, 1) if dilation == 1 else (
|
187
|
+
dilation // 2 if new_level else dilation, dilation),
|
188
|
+
residual=residual, BatchNorm=BatchNorm))
|
189
|
+
self.inplanes = planes * block.expansion
|
190
|
+
for i in range(1, blocks):
|
191
|
+
layers.append(block(self.inplanes, planes, residual=residual,
|
192
|
+
dilation=(dilation, dilation), BatchNorm=BatchNorm))
|
193
|
+
|
194
|
+
return nn.Sequential(*layers)
|
195
|
+
|
196
|
+
def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
|
197
|
+
modules = []
|
198
|
+
for i in range(convs):
|
199
|
+
modules.extend([
|
200
|
+
nn.Conv2d(self.inplanes, channels, kernel_size=3,
|
201
|
+
stride=stride if i == 0 else 1,
|
202
|
+
padding=dilation, bias=False, dilation=dilation),
|
203
|
+
BatchNorm(channels),
|
204
|
+
nn.ReLU(inplace=True)])
|
205
|
+
self.inplanes = channels
|
206
|
+
return nn.Sequential(*modules)
|
207
|
+
|
208
|
+
def forward(self, x):
|
209
|
+
if self.arch == 'C':
|
210
|
+
x = self.conv1(x)
|
211
|
+
x = self.bn1(x)
|
212
|
+
x = self.relu(x)
|
213
|
+
elif self.arch == 'D':
|
214
|
+
x = self.layer0(x)
|
215
|
+
|
216
|
+
x = self.layer1(x)
|
217
|
+
x = self.layer2(x)
|
218
|
+
|
219
|
+
x = self.layer3(x)
|
220
|
+
low_level_feat = x
|
221
|
+
|
222
|
+
x = self.layer4(x)
|
223
|
+
x = self.layer5(x)
|
224
|
+
|
225
|
+
if self.layer6 is not None:
|
226
|
+
x = self.layer6(x)
|
227
|
+
|
228
|
+
if self.layer7 is not None:
|
229
|
+
x = self.layer7(x)
|
230
|
+
|
231
|
+
if self.layer8 is not None:
|
232
|
+
x = self.layer8(x)
|
233
|
+
|
234
|
+
return x, low_level_feat
|
235
|
+
|
236
|
+
|
237
|
+
class DRN_A(nn.Module):
|
238
|
+
|
239
|
+
def __init__(self, block, layers, BatchNorm=None):
|
240
|
+
self.inplanes = 64
|
241
|
+
super(DRN_A, self).__init__()
|
242
|
+
self.out_dim = 512 * block.expansion
|
243
|
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
244
|
+
bias=False)
|
245
|
+
self.bn1 = BatchNorm(64)
|
246
|
+
self.relu = nn.ReLU(inplace=True)
|
247
|
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
248
|
+
self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
|
249
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
|
250
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
251
|
+
dilation=2, BatchNorm=BatchNorm)
|
252
|
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
253
|
+
dilation=4, BatchNorm=BatchNorm)
|
254
|
+
|
255
|
+
self._init_weight()
|
256
|
+
|
257
|
+
def _init_weight(self):
|
258
|
+
for m in self.modules():
|
259
|
+
if isinstance(m, nn.Conv2d):
|
260
|
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
261
|
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
262
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
263
|
+
m.weight.data.fill_(1)
|
264
|
+
m.bias.data.zero_()
|
265
|
+
elif isinstance(m, nn.BatchNorm2d):
|
266
|
+
m.weight.data.fill_(1)
|
267
|
+
m.bias.data.zero_()
|
268
|
+
|
269
|
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
270
|
+
downsample = None
|
271
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
272
|
+
downsample = nn.Sequential(
|
273
|
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
274
|
+
kernel_size=1, stride=stride, bias=False),
|
275
|
+
BatchNorm(planes * block.expansion),
|
276
|
+
)
|
277
|
+
|
278
|
+
layers = []
|
279
|
+
layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
|
280
|
+
self.inplanes = planes * block.expansion
|
281
|
+
for i in range(1, blocks):
|
282
|
+
layers.append(block(self.inplanes, planes,
|
283
|
+
dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
|
284
|
+
|
285
|
+
return nn.Sequential(*layers)
|
286
|
+
|
287
|
+
def forward(self, x):
|
288
|
+
x = self.conv1(x)
|
289
|
+
x = self.bn1(x)
|
290
|
+
x = self.relu(x)
|
291
|
+
x = self.maxpool(x)
|
292
|
+
|
293
|
+
x = self.layer1(x)
|
294
|
+
x = self.layer2(x)
|
295
|
+
x = self.layer3(x)
|
296
|
+
x = self.layer4(x)
|
297
|
+
|
298
|
+
return x
|
299
|
+
|
300
|
+
def drn_a_50(BatchNorm, pretrained=True):
|
301
|
+
model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
|
302
|
+
if pretrained:
|
303
|
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
304
|
+
return model
|
305
|
+
|
306
|
+
|
307
|
+
def drn_c_26(BatchNorm, pretrained=True):
|
308
|
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
|
309
|
+
if pretrained:
|
310
|
+
pretrained = model_zoo.load_url(model_urls['drn-c-26'])
|
311
|
+
del pretrained['fc.weight']
|
312
|
+
del pretrained['fc.bias']
|
313
|
+
model.load_state_dict(pretrained)
|
314
|
+
return model
|
315
|
+
|
316
|
+
|
317
|
+
def drn_c_42(BatchNorm, pretrained=True):
|
318
|
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
|
319
|
+
if pretrained:
|
320
|
+
pretrained = model_zoo.load_url(model_urls['drn-c-42'])
|
321
|
+
del pretrained['fc.weight']
|
322
|
+
del pretrained['fc.bias']
|
323
|
+
model.load_state_dict(pretrained)
|
324
|
+
return model
|
325
|
+
|
326
|
+
|
327
|
+
def drn_c_58(BatchNorm, pretrained=True):
|
328
|
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
|
329
|
+
if pretrained:
|
330
|
+
pretrained = model_zoo.load_url(model_urls['drn-c-58'])
|
331
|
+
del pretrained['fc.weight']
|
332
|
+
del pretrained['fc.bias']
|
333
|
+
model.load_state_dict(pretrained)
|
334
|
+
return model
|
335
|
+
|
336
|
+
|
337
|
+
def drn_d_22(BatchNorm, pretrained=True):
|
338
|
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
|
339
|
+
if pretrained:
|
340
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-22'])
|
341
|
+
del pretrained['fc.weight']
|
342
|
+
del pretrained['fc.bias']
|
343
|
+
model.load_state_dict(pretrained)
|
344
|
+
return model
|
345
|
+
|
346
|
+
|
347
|
+
def drn_d_24(BatchNorm, pretrained=True):
|
348
|
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
|
349
|
+
if pretrained:
|
350
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-24'])
|
351
|
+
del pretrained['fc.weight']
|
352
|
+
del pretrained['fc.bias']
|
353
|
+
model.load_state_dict(pretrained)
|
354
|
+
return model
|
355
|
+
|
356
|
+
|
357
|
+
def drn_d_38(BatchNorm, pretrained=True):
|
358
|
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
359
|
+
if pretrained:
|
360
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-38'])
|
361
|
+
del pretrained['fc.weight']
|
362
|
+
del pretrained['fc.bias']
|
363
|
+
model.load_state_dict(pretrained)
|
364
|
+
return model
|
365
|
+
|
366
|
+
|
367
|
+
def drn_d_40(BatchNorm, pretrained=True):
|
368
|
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
|
369
|
+
if pretrained:
|
370
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-40'])
|
371
|
+
del pretrained['fc.weight']
|
372
|
+
del pretrained['fc.bias']
|
373
|
+
model.load_state_dict(pretrained)
|
374
|
+
return model
|
375
|
+
|
376
|
+
|
377
|
+
def drn_d_54(BatchNorm, pretrained=True):
|
378
|
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
379
|
+
if pretrained:
|
380
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-54'])
|
381
|
+
del pretrained['fc.weight']
|
382
|
+
del pretrained['fc.bias']
|
383
|
+
model.load_state_dict(pretrained)
|
384
|
+
return model
|
385
|
+
|
386
|
+
|
387
|
+
def drn_d_105(BatchNorm, pretrained=True):
|
388
|
+
model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
389
|
+
if pretrained:
|
390
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-105'])
|
391
|
+
del pretrained['fc.weight']
|
392
|
+
del pretrained['fc.bias']
|
393
|
+
model.load_state_dict(pretrained)
|
394
|
+
return model
|
395
|
+
|
396
|
+
if __name__ == "__main__":
|
397
|
+
import torch
|
398
|
+
model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
|
399
|
+
input = torch.rand(1, 3, 512, 512)
|
400
|
+
output, low_level_feat = model(input)
|
401
|
+
print(output.size())
|
402
|
+
print(low_level_feat.size())
|
@@ -0,0 +1,151 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
import torch.nn as nn
|
4
|
+
import math
|
5
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6
|
+
import torch.utils.model_zoo as model_zoo
|
7
|
+
|
8
|
+
def conv_bn(inp, oup, stride, BatchNorm):
|
9
|
+
return nn.Sequential(
|
10
|
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
11
|
+
BatchNorm(oup),
|
12
|
+
nn.ReLU6(inplace=True)
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def fixed_padding(inputs, kernel_size, dilation):
|
17
|
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
18
|
+
pad_total = kernel_size_effective - 1
|
19
|
+
pad_beg = pad_total // 2
|
20
|
+
pad_end = pad_total - pad_beg
|
21
|
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
|
22
|
+
return padded_inputs
|
23
|
+
|
24
|
+
|
25
|
+
class InvertedResidual(nn.Module):
|
26
|
+
def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
|
27
|
+
super(InvertedResidual, self).__init__()
|
28
|
+
self.stride = stride
|
29
|
+
assert stride in [1, 2]
|
30
|
+
|
31
|
+
hidden_dim = round(inp * expand_ratio)
|
32
|
+
self.use_res_connect = self.stride == 1 and inp == oup
|
33
|
+
self.kernel_size = 3
|
34
|
+
self.dilation = dilation
|
35
|
+
|
36
|
+
if expand_ratio == 1:
|
37
|
+
self.conv = nn.Sequential(
|
38
|
+
# dw
|
39
|
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
|
40
|
+
BatchNorm(hidden_dim),
|
41
|
+
nn.ReLU6(inplace=True),
|
42
|
+
# pw-linear
|
43
|
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
|
44
|
+
BatchNorm(oup),
|
45
|
+
)
|
46
|
+
else:
|
47
|
+
self.conv = nn.Sequential(
|
48
|
+
# pw
|
49
|
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
|
50
|
+
BatchNorm(hidden_dim),
|
51
|
+
nn.ReLU6(inplace=True),
|
52
|
+
# dw
|
53
|
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
|
54
|
+
BatchNorm(hidden_dim),
|
55
|
+
nn.ReLU6(inplace=True),
|
56
|
+
# pw-linear
|
57
|
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
|
58
|
+
BatchNorm(oup),
|
59
|
+
)
|
60
|
+
|
61
|
+
def forward(self, x):
|
62
|
+
x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
|
63
|
+
if self.use_res_connect:
|
64
|
+
x = x + self.conv(x_pad)
|
65
|
+
else:
|
66
|
+
x = self.conv(x_pad)
|
67
|
+
return x
|
68
|
+
|
69
|
+
|
70
|
+
class MobileNetV2(nn.Module):
|
71
|
+
def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
|
72
|
+
super(MobileNetV2, self).__init__()
|
73
|
+
block = InvertedResidual
|
74
|
+
input_channel = 32
|
75
|
+
current_stride = 1
|
76
|
+
rate = 1
|
77
|
+
interverted_residual_setting = [
|
78
|
+
# t, c, n, s
|
79
|
+
[1, 16, 1, 1],
|
80
|
+
[6, 24, 2, 2],
|
81
|
+
[6, 32, 3, 2],
|
82
|
+
[6, 64, 4, 2],
|
83
|
+
[6, 96, 3, 1],
|
84
|
+
[6, 160, 3, 2],
|
85
|
+
[6, 320, 1, 1],
|
86
|
+
]
|
87
|
+
|
88
|
+
# building first layer
|
89
|
+
input_channel = int(input_channel * width_mult)
|
90
|
+
self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
|
91
|
+
current_stride *= 2
|
92
|
+
# building inverted residual blocks
|
93
|
+
for t, c, n, s in interverted_residual_setting:
|
94
|
+
if current_stride == output_stride:
|
95
|
+
stride = 1
|
96
|
+
dilation = rate
|
97
|
+
rate *= s
|
98
|
+
else:
|
99
|
+
stride = s
|
100
|
+
dilation = 1
|
101
|
+
current_stride *= s
|
102
|
+
output_channel = int(c * width_mult)
|
103
|
+
for i in range(n):
|
104
|
+
if i == 0:
|
105
|
+
self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
|
106
|
+
else:
|
107
|
+
self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
|
108
|
+
input_channel = output_channel
|
109
|
+
self.features = nn.Sequential(*self.features)
|
110
|
+
self._initialize_weights()
|
111
|
+
|
112
|
+
if pretrained:
|
113
|
+
self._load_pretrained_model()
|
114
|
+
|
115
|
+
self.low_level_features = self.features[0:4]
|
116
|
+
self.high_level_features = self.features[4:]
|
117
|
+
|
118
|
+
def forward(self, x):
|
119
|
+
low_level_feat = self.low_level_features(x)
|
120
|
+
x = self.high_level_features(low_level_feat)
|
121
|
+
return x, low_level_feat
|
122
|
+
|
123
|
+
def _load_pretrained_model(self):
|
124
|
+
pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
|
125
|
+
model_dict = {}
|
126
|
+
state_dict = self.state_dict()
|
127
|
+
for k, v in pretrain_dict.items():
|
128
|
+
if k in state_dict:
|
129
|
+
model_dict[k] = v
|
130
|
+
state_dict.update(model_dict)
|
131
|
+
self.load_state_dict(state_dict)
|
132
|
+
|
133
|
+
def _initialize_weights(self):
|
134
|
+
for m in self.modules():
|
135
|
+
if isinstance(m, nn.Conv2d):
|
136
|
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
137
|
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
138
|
+
torch.nn.init.kaiming_normal_(m.weight)
|
139
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
140
|
+
m.weight.data.fill_(1)
|
141
|
+
m.bias.data.zero_()
|
142
|
+
elif isinstance(m, nn.BatchNorm2d):
|
143
|
+
m.weight.data.fill_(1)
|
144
|
+
m.bias.data.zero_()
|
145
|
+
|
146
|
+
if __name__ == "__main__":
|
147
|
+
input = torch.rand(1, 3, 512, 512)
|
148
|
+
model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
|
149
|
+
output, low_level_feat = model(input)
|
150
|
+
print(output.size())
|
151
|
+
print(low_level_feat.size())
|