doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +170 -9
  3. doctra/cli/utils.py +2 -3
  4. doctra/engines/image_restoration/__init__.py +10 -0
  5. doctra/engines/image_restoration/docres_engine.py +561 -0
  6. doctra/engines/vlm/outlines_types.py +13 -9
  7. doctra/engines/vlm/service.py +4 -2
  8. doctra/exporters/excel_writer.py +89 -0
  9. doctra/parsers/enhanced_pdf_parser.py +374 -0
  10. doctra/parsers/structured_pdf_parser.py +6 -0
  11. doctra/parsers/table_chart_extractor.py +6 -0
  12. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  13. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  14. doctra/third_party/docres/data/MBD/infer.py +151 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  25. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  26. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  27. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  28. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  29. doctra/third_party/docres/inference.py +370 -0
  30. doctra/third_party/docres/models/restormer_arch.py +308 -0
  31. doctra/third_party/docres/utils.py +464 -0
  32. doctra/ui/app.py +8 -14
  33. doctra/utils/structured_utils.py +5 -2
  34. doctra/version.py +1 -1
  35. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
  36. doctra-0.4.1.dist-info/RECORD +67 -0
  37. doctra-0.3.3.dist-info/RECORD +0 -44
  38. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
  39. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
  40. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ import math
2
+ import torch.nn as nn
3
+ import torch.utils.model_zoo as model_zoo
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+
6
+ class Bottleneck(nn.Module):
7
+ expansion = 4
8
+
9
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
10
+ super(Bottleneck, self).__init__()
11
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12
+ self.bn1 = BatchNorm(planes)
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14
+ dilation=dilation, padding=dilation, bias=False)
15
+ self.bn2 = BatchNorm(planes)
16
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17
+ self.bn3 = BatchNorm(planes * 4)
18
+ self.relu = nn.ReLU(inplace=True)
19
+ self.downsample = downsample
20
+ self.stride = stride
21
+ self.dilation = dilation
22
+
23
+ def forward(self, x):
24
+ residual = x
25
+
26
+ out = self.conv1(x)
27
+ out = self.bn1(out)
28
+ out = self.relu(out)
29
+
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+ out = self.relu(out)
33
+
34
+ out = self.conv3(out)
35
+ out = self.bn3(out)
36
+
37
+ if self.downsample is not None:
38
+ residual = self.downsample(x)
39
+
40
+ out += residual
41
+ out = self.relu(out)
42
+
43
+ return out
44
+
45
+ class ResNet(nn.Module):
46
+
47
+ def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
48
+ self.inplanes = 64
49
+ super(ResNet, self).__init__()
50
+ blocks = [1, 2, 4]
51
+ if output_stride == 16:
52
+ strides = [1, 2, 2, 1]
53
+ dilations = [1, 1, 1, 2]
54
+ elif output_stride == 8:
55
+ strides = [1, 2, 1, 1]
56
+ dilations = [1, 1, 2, 4]
57
+ else:
58
+ raise NotImplementedError
59
+
60
+ # Modules
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = BatchNorm(64)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66
+
67
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
68
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
69
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
70
+ self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
71
+ # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
72
+ self._init_weight()
73
+
74
+ # if pretrained:
75
+ # self._load_pretrained_model()
76
+
77
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
78
+ downsample = None
79
+ if stride != 1 or self.inplanes != planes * block.expansion:
80
+ downsample = nn.Sequential(
81
+ nn.Conv2d(self.inplanes, planes * block.expansion,
82
+ kernel_size=1, stride=stride, bias=False),
83
+ BatchNorm(planes * block.expansion),
84
+ )
85
+
86
+ layers = []
87
+ layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
88
+ self.inplanes = planes * block.expansion
89
+ for i in range(1, blocks):
90
+ layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
91
+
92
+ return nn.Sequential(*layers)
93
+
94
+ def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
95
+ downsample = None
96
+ if stride != 1 or self.inplanes != planes * block.expansion:
97
+ downsample = nn.Sequential(
98
+ nn.Conv2d(self.inplanes, planes * block.expansion,
99
+ kernel_size=1, stride=stride, bias=False),
100
+ BatchNorm(planes * block.expansion),
101
+ )
102
+
103
+ layers = []
104
+ layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
105
+ downsample=downsample, BatchNorm=BatchNorm))
106
+ self.inplanes = planes * block.expansion
107
+ for i in range(1, len(blocks)):
108
+ layers.append(block(self.inplanes, planes, stride=1,
109
+ dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
110
+
111
+ return nn.Sequential(*layers)
112
+
113
+ def forward(self, input):
114
+ x = self.conv1(input)
115
+ x = self.bn1(x)
116
+ x = self.relu(x)
117
+ x = self.maxpool(x)
118
+
119
+ x = self.layer1(x)
120
+ low_level_feat = x
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.layer4(x)
124
+ return x, low_level_feat
125
+
126
+ def _init_weight(self):
127
+ for m in self.modules():
128
+ if isinstance(m, nn.Conv2d):
129
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+ m.weight.data.normal_(0, math.sqrt(2. / n))
131
+ elif isinstance(m, SynchronizedBatchNorm2d):
132
+ m.weight.data.fill_(1)
133
+ m.bias.data.zero_()
134
+ elif isinstance(m, nn.BatchNorm2d):
135
+ m.weight.data.fill_(1)
136
+ m.bias.data.zero_()
137
+
138
+ def _load_pretrained_model(self):
139
+
140
+ import urllib.request
141
+ import ssl
142
+ ssl._create_default_https_context = ssl._create_unverified_context
143
+ response = urllib.request.urlopen('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
144
+
145
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
146
+ model_dict = {}
147
+ state_dict = self.state_dict()
148
+ for k, v in pretrain_dict.items():
149
+ if k in state_dict:
150
+ # if 'conv1' in k:
151
+ # continue
152
+ model_dict[k] = v
153
+ state_dict.update(model_dict)
154
+ self.load_state_dict(state_dict)
155
+
156
+ def ResNet101(output_stride, BatchNorm, pretrained=True):
157
+ """Constructs a ResNet-101 model.
158
+ Args:
159
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
160
+ """
161
+ model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
162
+ return model
163
+
164
+ if __name__ == "__main__":
165
+ import torch
166
+ model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
167
+ input = torch.rand(1, 3, 512, 512)
168
+ output, low_level_feat = model(input)
169
+ print(output.size())
170
+ print(low_level_feat.size())
@@ -0,0 +1,288 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7
+
8
+ def fixed_padding(inputs, kernel_size, dilation):
9
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
10
+ pad_total = kernel_size_effective - 1
11
+ pad_beg = pad_total // 2
12
+ pad_end = pad_total - pad_beg
13
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
14
+ return padded_inputs
15
+
16
+
17
+ class SeparableConv2d(nn.Module):
18
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
19
+ super(SeparableConv2d, self).__init__()
20
+
21
+ self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
22
+ groups=inplanes, bias=bias)
23
+ self.bn = BatchNorm(inplanes)
24
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
25
+
26
+ def forward(self, x):
27
+ x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
28
+ x = self.conv1(x)
29
+ x = self.bn(x)
30
+ x = self.pointwise(x)
31
+ return x
32
+
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
36
+ start_with_relu=True, grow_first=True, is_last=False):
37
+ super(Block, self).__init__()
38
+
39
+ if planes != inplanes or stride != 1:
40
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
41
+ self.skipbn = BatchNorm(planes)
42
+ else:
43
+ self.skip = None
44
+
45
+ self.relu = nn.ReLU(inplace=True)
46
+ rep = []
47
+
48
+ filters = inplanes
49
+ if grow_first:
50
+ rep.append(self.relu)
51
+ rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
52
+ rep.append(BatchNorm(planes))
53
+ filters = planes
54
+
55
+ for i in range(reps - 1):
56
+ rep.append(self.relu)
57
+ rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
58
+ rep.append(BatchNorm(filters))
59
+
60
+ if not grow_first:
61
+ rep.append(self.relu)
62
+ rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
63
+ rep.append(BatchNorm(planes))
64
+
65
+ if stride != 1:
66
+ rep.append(self.relu)
67
+ rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
68
+ rep.append(BatchNorm(planes))
69
+
70
+ if stride == 1 and is_last:
71
+ rep.append(self.relu)
72
+ rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
73
+ rep.append(BatchNorm(planes))
74
+
75
+ if not start_with_relu:
76
+ rep = rep[1:]
77
+
78
+ self.rep = nn.Sequential(*rep)
79
+
80
+ def forward(self, inp):
81
+ x = self.rep(inp)
82
+
83
+ if self.skip is not None:
84
+ skip = self.skip(inp)
85
+ skip = self.skipbn(skip)
86
+ else:
87
+ skip = inp
88
+
89
+ x = x + skip
90
+
91
+ return x
92
+
93
+
94
+ class AlignedXception(nn.Module):
95
+ """
96
+ Modified Alighed Xception
97
+ """
98
+ def __init__(self, output_stride, BatchNorm,
99
+ pretrained=True):
100
+ super(AlignedXception, self).__init__()
101
+
102
+ if output_stride == 16:
103
+ entry_block3_stride = 2
104
+ middle_block_dilation = 1
105
+ exit_block_dilations = (1, 2)
106
+ elif output_stride == 8:
107
+ entry_block3_stride = 1
108
+ middle_block_dilation = 2
109
+ exit_block_dilations = (2, 4)
110
+ else:
111
+ raise NotImplementedError
112
+
113
+
114
+ # Entry flow
115
+ self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
116
+ self.bn1 = BatchNorm(32)
117
+ self.relu = nn.ReLU(inplace=True)
118
+
119
+ self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
120
+ self.bn2 = BatchNorm(64)
121
+
122
+ self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
123
+ self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
124
+ grow_first=True)
125
+ self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
126
+ start_with_relu=True, grow_first=True, is_last=True)
127
+
128
+ # Middle flow
129
+ self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
130
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
131
+ self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
132
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
133
+ self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
134
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
135
+ self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
136
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
137
+ self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
138
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
139
+ self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
140
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
141
+ self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
142
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
143
+ self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
144
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
145
+ self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
146
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
147
+ self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
148
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
149
+ self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
150
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
151
+ self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
152
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
153
+ self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
154
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
155
+ self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
156
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
157
+ self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
158
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
159
+ self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
160
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
161
+
162
+ # Exit flow
163
+ self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
164
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
165
+
166
+ self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
167
+ self.bn3 = BatchNorm(1536)
168
+
169
+ self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
170
+ self.bn4 = BatchNorm(1536)
171
+
172
+ self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
173
+ self.bn5 = BatchNorm(2048)
174
+
175
+ # Init weights
176
+ self._init_weight()
177
+
178
+ # Load pretrained model
179
+ if pretrained:
180
+ self._load_pretrained_model()
181
+
182
+ def forward(self, x):
183
+ # Entry flow
184
+ x = self.conv1(x)
185
+ x = self.bn1(x)
186
+ x = self.relu(x)
187
+
188
+ x = self.conv2(x)
189
+ x = self.bn2(x)
190
+ x = self.relu(x)
191
+
192
+ x = self.block1(x)
193
+ # add relu here
194
+ x = self.relu(x)
195
+ low_level_feat = x
196
+ x = self.block2(x)
197
+ x = self.block3(x)
198
+
199
+ # Middle flow
200
+ x = self.block4(x)
201
+ x = self.block5(x)
202
+ x = self.block6(x)
203
+ x = self.block7(x)
204
+ x = self.block8(x)
205
+ x = self.block9(x)
206
+ x = self.block10(x)
207
+ x = self.block11(x)
208
+ x = self.block12(x)
209
+ x = self.block13(x)
210
+ x = self.block14(x)
211
+ x = self.block15(x)
212
+ x = self.block16(x)
213
+ x = self.block17(x)
214
+ x = self.block18(x)
215
+ x = self.block19(x)
216
+
217
+ # Exit flow
218
+ x = self.block20(x)
219
+ x = self.relu(x)
220
+ x = self.conv3(x)
221
+ x = self.bn3(x)
222
+ x = self.relu(x)
223
+
224
+ x = self.conv4(x)
225
+ x = self.bn4(x)
226
+ x = self.relu(x)
227
+
228
+ x = self.conv5(x)
229
+ x = self.bn5(x)
230
+ x = self.relu(x)
231
+
232
+ return x, low_level_feat
233
+
234
+ def _init_weight(self):
235
+ for m in self.modules():
236
+ if isinstance(m, nn.Conv2d):
237
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
238
+ m.weight.data.normal_(0, math.sqrt(2. / n))
239
+ elif isinstance(m, SynchronizedBatchNorm2d):
240
+ m.weight.data.fill_(1)
241
+ m.bias.data.zero_()
242
+ elif isinstance(m, nn.BatchNorm2d):
243
+ m.weight.data.fill_(1)
244
+ m.bias.data.zero_()
245
+
246
+
247
+ def _load_pretrained_model(self):
248
+ pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
249
+ model_dict = {}
250
+ state_dict = self.state_dict()
251
+
252
+ for k, v in pretrain_dict.items():
253
+ if k in state_dict:
254
+ if 'pointwise' in k:
255
+ v = v.unsqueeze(-1).unsqueeze(-1)
256
+ if k.startswith('block11'):
257
+ model_dict[k] = v
258
+ model_dict[k.replace('block11', 'block12')] = v
259
+ model_dict[k.replace('block11', 'block13')] = v
260
+ model_dict[k.replace('block11', 'block14')] = v
261
+ model_dict[k.replace('block11', 'block15')] = v
262
+ model_dict[k.replace('block11', 'block16')] = v
263
+ model_dict[k.replace('block11', 'block17')] = v
264
+ model_dict[k.replace('block11', 'block18')] = v
265
+ model_dict[k.replace('block11', 'block19')] = v
266
+ elif k.startswith('block12'):
267
+ model_dict[k.replace('block12', 'block20')] = v
268
+ elif k.startswith('bn3'):
269
+ model_dict[k] = v
270
+ model_dict[k.replace('bn3', 'bn4')] = v
271
+ elif k.startswith('conv4'):
272
+ model_dict[k.replace('conv4', 'conv5')] = v
273
+ elif k.startswith('bn4'):
274
+ model_dict[k.replace('bn4', 'bn5')] = v
275
+ else:
276
+ model_dict[k] = v
277
+ state_dict.update(model_dict)
278
+ self.load_state_dict(state_dict)
279
+
280
+
281
+
282
+ if __name__ == "__main__":
283
+ import torch
284
+ model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
285
+ input = torch.rand(1, 3, 512, 512)
286
+ output, low_level_feat = model(input)
287
+ print(output.size())
288
+ print(low_level_feat.size())
@@ -0,0 +1,59 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+
7
+ class Decoder(nn.Module):
8
+ def __init__(self, num_classes, backbone, BatchNorm):
9
+ super(Decoder, self).__init__()
10
+ if backbone == 'resnet' or backbone == 'drn':
11
+ low_level_inplanes = 256
12
+ elif backbone == 'xception':
13
+ low_level_inplanes = 128
14
+ elif backbone == 'mobilenet':
15
+ low_level_inplanes = 24
16
+ else:
17
+ raise NotImplementedError
18
+
19
+ self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20
+ self.bn1 = BatchNorm(48)
21
+ self.relu = nn.ReLU()
22
+ self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23
+ BatchNorm(256),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27
+ BatchNorm(256),
28
+ nn.ReLU(),
29
+ nn.Dropout(0.1),
30
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
31
+ nn.Sigmoid()
32
+ )
33
+ self._init_weight()
34
+
35
+
36
+ def forward(self, x, low_level_feat):
37
+ low_level_feat = self.conv1(low_level_feat)
38
+ low_level_feat = self.bn1(low_level_feat)
39
+ low_level_feat = self.relu(low_level_feat)
40
+
41
+ x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
42
+ x = torch.cat((x, low_level_feat), dim=1)
43
+ x = self.last_conv(x)
44
+
45
+ return x
46
+
47
+ def _init_weight(self):
48
+ for m in self.modules():
49
+ if isinstance(m, nn.Conv2d):
50
+ torch.nn.init.kaiming_normal_(m.weight)
51
+ elif isinstance(m, SynchronizedBatchNorm2d):
52
+ m.weight.data.fill_(1)
53
+ m.bias.data.zero_()
54
+ elif isinstance(m, nn.BatchNorm2d):
55
+ m.weight.data.fill_(1)
56
+ m.bias.data.zero_()
57
+
58
+ def build_decoder(num_classes, backbone, BatchNorm):
59
+ return Decoder(num_classes, backbone, BatchNorm)
@@ -0,0 +1,81 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+ from model.deep_lab_model.aspp import build_aspp
6
+ from model.deep_lab_model.decoder import build_decoder
7
+ from model.deep_lab_model.backbone import build_backbone
8
+
9
+ class DeepLab(nn.Module):
10
+ def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
11
+ sync_bn=True, freeze_bn=False):
12
+ super(DeepLab, self).__init__()
13
+ if backbone == 'drn':
14
+ output_stride = 8
15
+
16
+ if sync_bn == True:
17
+ BatchNorm = SynchronizedBatchNorm2d
18
+ else:
19
+ BatchNorm = nn.BatchNorm2d
20
+
21
+ self.backbone = build_backbone(backbone, output_stride, BatchNorm)
22
+ self.aspp = build_aspp(backbone, output_stride, BatchNorm)
23
+ self.decoder = build_decoder(num_classes, backbone, BatchNorm)
24
+
25
+ self.freeze_bn = freeze_bn
26
+
27
+ def forward(self, input):
28
+ x, low_level_feat = self.backbone(input)
29
+ x = self.aspp(x)
30
+ x = self.decoder(x, low_level_feat)
31
+ x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
32
+
33
+ return x
34
+
35
+ def freeze_bn(self):
36
+ for m in self.modules():
37
+ if isinstance(m, SynchronizedBatchNorm2d):
38
+ m.eval()
39
+ elif isinstance(m, nn.BatchNorm2d):
40
+ m.eval()
41
+
42
+ def get_1x_lr_params(self):
43
+ modules = [self.backbone]
44
+ for i in range(len(modules)):
45
+ for m in modules[i].named_modules():
46
+ if self.freeze_bn:
47
+ if isinstance(m[1], nn.Conv2d):
48
+ for p in m[1].parameters():
49
+ if p.requires_grad:
50
+ yield p
51
+ else:
52
+ if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
53
+ or isinstance(m[1], nn.BatchNorm2d):
54
+ for p in m[1].parameters():
55
+ if p.requires_grad:
56
+ yield p
57
+
58
+ def get_10x_lr_params(self):
59
+ modules = [self.aspp, self.decoder]
60
+ for i in range(len(modules)):
61
+ for m in modules[i].named_modules():
62
+ if self.freeze_bn:
63
+ if isinstance(m[1], nn.Conv2d):
64
+ for p in m[1].parameters():
65
+ if p.requires_grad:
66
+ yield p
67
+ else:
68
+ if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
69
+ or isinstance(m[1], nn.BatchNorm2d):
70
+ for p in m[1].parameters():
71
+ if p.requires_grad:
72
+ yield p
73
+
74
+ if __name__ == "__main__":
75
+ model = DeepLab(backbone='mobilenet', output_stride=16)
76
+ model.eval()
77
+ input = torch.rand(1, 3, 513, 513)
78
+ output = model(input)
79
+ print(output.size())
80
+
81
+
@@ -0,0 +1,12 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : maojiayuan@gmail.com
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback