doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +170 -9
- doctra/cli/utils.py +2 -3
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +561 -0
- doctra/engines/vlm/outlines_types.py +13 -9
- doctra/engines/vlm/service.py +4 -2
- doctra/exporters/excel_writer.py +89 -0
- doctra/parsers/enhanced_pdf_parser.py +374 -0
- doctra/parsers/structured_pdf_parser.py +6 -0
- doctra/parsers/table_chart_extractor.py +6 -0
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +8 -14
- doctra/utils/structured_utils.py +5 -2
- doctra/version.py +1 -1
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
- doctra-0.4.1.dist-info/RECORD +67 -0
- doctra-0.3.3.dist-info/RECORD +0 -44
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
|
|
1
|
+
import math
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.utils.model_zoo as model_zoo
|
4
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5
|
+
|
6
|
+
class Bottleneck(nn.Module):
|
7
|
+
expansion = 4
|
8
|
+
|
9
|
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
|
10
|
+
super(Bottleneck, self).__init__()
|
11
|
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
12
|
+
self.bn1 = BatchNorm(planes)
|
13
|
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
14
|
+
dilation=dilation, padding=dilation, bias=False)
|
15
|
+
self.bn2 = BatchNorm(planes)
|
16
|
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
17
|
+
self.bn3 = BatchNorm(planes * 4)
|
18
|
+
self.relu = nn.ReLU(inplace=True)
|
19
|
+
self.downsample = downsample
|
20
|
+
self.stride = stride
|
21
|
+
self.dilation = dilation
|
22
|
+
|
23
|
+
def forward(self, x):
|
24
|
+
residual = x
|
25
|
+
|
26
|
+
out = self.conv1(x)
|
27
|
+
out = self.bn1(out)
|
28
|
+
out = self.relu(out)
|
29
|
+
|
30
|
+
out = self.conv2(out)
|
31
|
+
out = self.bn2(out)
|
32
|
+
out = self.relu(out)
|
33
|
+
|
34
|
+
out = self.conv3(out)
|
35
|
+
out = self.bn3(out)
|
36
|
+
|
37
|
+
if self.downsample is not None:
|
38
|
+
residual = self.downsample(x)
|
39
|
+
|
40
|
+
out += residual
|
41
|
+
out = self.relu(out)
|
42
|
+
|
43
|
+
return out
|
44
|
+
|
45
|
+
class ResNet(nn.Module):
|
46
|
+
|
47
|
+
def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
|
48
|
+
self.inplanes = 64
|
49
|
+
super(ResNet, self).__init__()
|
50
|
+
blocks = [1, 2, 4]
|
51
|
+
if output_stride == 16:
|
52
|
+
strides = [1, 2, 2, 1]
|
53
|
+
dilations = [1, 1, 1, 2]
|
54
|
+
elif output_stride == 8:
|
55
|
+
strides = [1, 2, 1, 1]
|
56
|
+
dilations = [1, 1, 2, 4]
|
57
|
+
else:
|
58
|
+
raise NotImplementedError
|
59
|
+
|
60
|
+
# Modules
|
61
|
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62
|
+
bias=False)
|
63
|
+
self.bn1 = BatchNorm(64)
|
64
|
+
self.relu = nn.ReLU(inplace=True)
|
65
|
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
66
|
+
|
67
|
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
|
68
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
|
69
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
|
70
|
+
self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
71
|
+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
72
|
+
self._init_weight()
|
73
|
+
|
74
|
+
# if pretrained:
|
75
|
+
# self._load_pretrained_model()
|
76
|
+
|
77
|
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
78
|
+
downsample = None
|
79
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
80
|
+
downsample = nn.Sequential(
|
81
|
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
82
|
+
kernel_size=1, stride=stride, bias=False),
|
83
|
+
BatchNorm(planes * block.expansion),
|
84
|
+
)
|
85
|
+
|
86
|
+
layers = []
|
87
|
+
layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
|
88
|
+
self.inplanes = planes * block.expansion
|
89
|
+
for i in range(1, blocks):
|
90
|
+
layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
|
91
|
+
|
92
|
+
return nn.Sequential(*layers)
|
93
|
+
|
94
|
+
def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
95
|
+
downsample = None
|
96
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
97
|
+
downsample = nn.Sequential(
|
98
|
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
99
|
+
kernel_size=1, stride=stride, bias=False),
|
100
|
+
BatchNorm(planes * block.expansion),
|
101
|
+
)
|
102
|
+
|
103
|
+
layers = []
|
104
|
+
layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
|
105
|
+
downsample=downsample, BatchNorm=BatchNorm))
|
106
|
+
self.inplanes = planes * block.expansion
|
107
|
+
for i in range(1, len(blocks)):
|
108
|
+
layers.append(block(self.inplanes, planes, stride=1,
|
109
|
+
dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
|
110
|
+
|
111
|
+
return nn.Sequential(*layers)
|
112
|
+
|
113
|
+
def forward(self, input):
|
114
|
+
x = self.conv1(input)
|
115
|
+
x = self.bn1(x)
|
116
|
+
x = self.relu(x)
|
117
|
+
x = self.maxpool(x)
|
118
|
+
|
119
|
+
x = self.layer1(x)
|
120
|
+
low_level_feat = x
|
121
|
+
x = self.layer2(x)
|
122
|
+
x = self.layer3(x)
|
123
|
+
x = self.layer4(x)
|
124
|
+
return x, low_level_feat
|
125
|
+
|
126
|
+
def _init_weight(self):
|
127
|
+
for m in self.modules():
|
128
|
+
if isinstance(m, nn.Conv2d):
|
129
|
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
130
|
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
131
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
132
|
+
m.weight.data.fill_(1)
|
133
|
+
m.bias.data.zero_()
|
134
|
+
elif isinstance(m, nn.BatchNorm2d):
|
135
|
+
m.weight.data.fill_(1)
|
136
|
+
m.bias.data.zero_()
|
137
|
+
|
138
|
+
def _load_pretrained_model(self):
|
139
|
+
|
140
|
+
import urllib.request
|
141
|
+
import ssl
|
142
|
+
ssl._create_default_https_context = ssl._create_unverified_context
|
143
|
+
response = urllib.request.urlopen('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
|
144
|
+
|
145
|
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
|
146
|
+
model_dict = {}
|
147
|
+
state_dict = self.state_dict()
|
148
|
+
for k, v in pretrain_dict.items():
|
149
|
+
if k in state_dict:
|
150
|
+
# if 'conv1' in k:
|
151
|
+
# continue
|
152
|
+
model_dict[k] = v
|
153
|
+
state_dict.update(model_dict)
|
154
|
+
self.load_state_dict(state_dict)
|
155
|
+
|
156
|
+
def ResNet101(output_stride, BatchNorm, pretrained=True):
|
157
|
+
"""Constructs a ResNet-101 model.
|
158
|
+
Args:
|
159
|
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
160
|
+
"""
|
161
|
+
model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
|
162
|
+
return model
|
163
|
+
|
164
|
+
if __name__ == "__main__":
|
165
|
+
import torch
|
166
|
+
model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
|
167
|
+
input = torch.rand(1, 3, 512, 512)
|
168
|
+
output, low_level_feat = model(input)
|
169
|
+
print(output.size())
|
170
|
+
print(low_level_feat.size())
|
@@ -0,0 +1,288 @@
|
|
1
|
+
import math
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
import torch.utils.model_zoo as model_zoo
|
6
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
7
|
+
|
8
|
+
def fixed_padding(inputs, kernel_size, dilation):
|
9
|
+
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
10
|
+
pad_total = kernel_size_effective - 1
|
11
|
+
pad_beg = pad_total // 2
|
12
|
+
pad_end = pad_total - pad_beg
|
13
|
+
padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
|
14
|
+
return padded_inputs
|
15
|
+
|
16
|
+
|
17
|
+
class SeparableConv2d(nn.Module):
|
18
|
+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
|
19
|
+
super(SeparableConv2d, self).__init__()
|
20
|
+
|
21
|
+
self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
|
22
|
+
groups=inplanes, bias=bias)
|
23
|
+
self.bn = BatchNorm(inplanes)
|
24
|
+
self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
|
25
|
+
|
26
|
+
def forward(self, x):
|
27
|
+
x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
|
28
|
+
x = self.conv1(x)
|
29
|
+
x = self.bn(x)
|
30
|
+
x = self.pointwise(x)
|
31
|
+
return x
|
32
|
+
|
33
|
+
|
34
|
+
class Block(nn.Module):
|
35
|
+
def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
|
36
|
+
start_with_relu=True, grow_first=True, is_last=False):
|
37
|
+
super(Block, self).__init__()
|
38
|
+
|
39
|
+
if planes != inplanes or stride != 1:
|
40
|
+
self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
|
41
|
+
self.skipbn = BatchNorm(planes)
|
42
|
+
else:
|
43
|
+
self.skip = None
|
44
|
+
|
45
|
+
self.relu = nn.ReLU(inplace=True)
|
46
|
+
rep = []
|
47
|
+
|
48
|
+
filters = inplanes
|
49
|
+
if grow_first:
|
50
|
+
rep.append(self.relu)
|
51
|
+
rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
|
52
|
+
rep.append(BatchNorm(planes))
|
53
|
+
filters = planes
|
54
|
+
|
55
|
+
for i in range(reps - 1):
|
56
|
+
rep.append(self.relu)
|
57
|
+
rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
|
58
|
+
rep.append(BatchNorm(filters))
|
59
|
+
|
60
|
+
if not grow_first:
|
61
|
+
rep.append(self.relu)
|
62
|
+
rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
|
63
|
+
rep.append(BatchNorm(planes))
|
64
|
+
|
65
|
+
if stride != 1:
|
66
|
+
rep.append(self.relu)
|
67
|
+
rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
|
68
|
+
rep.append(BatchNorm(planes))
|
69
|
+
|
70
|
+
if stride == 1 and is_last:
|
71
|
+
rep.append(self.relu)
|
72
|
+
rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
|
73
|
+
rep.append(BatchNorm(planes))
|
74
|
+
|
75
|
+
if not start_with_relu:
|
76
|
+
rep = rep[1:]
|
77
|
+
|
78
|
+
self.rep = nn.Sequential(*rep)
|
79
|
+
|
80
|
+
def forward(self, inp):
|
81
|
+
x = self.rep(inp)
|
82
|
+
|
83
|
+
if self.skip is not None:
|
84
|
+
skip = self.skip(inp)
|
85
|
+
skip = self.skipbn(skip)
|
86
|
+
else:
|
87
|
+
skip = inp
|
88
|
+
|
89
|
+
x = x + skip
|
90
|
+
|
91
|
+
return x
|
92
|
+
|
93
|
+
|
94
|
+
class AlignedXception(nn.Module):
|
95
|
+
"""
|
96
|
+
Modified Alighed Xception
|
97
|
+
"""
|
98
|
+
def __init__(self, output_stride, BatchNorm,
|
99
|
+
pretrained=True):
|
100
|
+
super(AlignedXception, self).__init__()
|
101
|
+
|
102
|
+
if output_stride == 16:
|
103
|
+
entry_block3_stride = 2
|
104
|
+
middle_block_dilation = 1
|
105
|
+
exit_block_dilations = (1, 2)
|
106
|
+
elif output_stride == 8:
|
107
|
+
entry_block3_stride = 1
|
108
|
+
middle_block_dilation = 2
|
109
|
+
exit_block_dilations = (2, 4)
|
110
|
+
else:
|
111
|
+
raise NotImplementedError
|
112
|
+
|
113
|
+
|
114
|
+
# Entry flow
|
115
|
+
self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
|
116
|
+
self.bn1 = BatchNorm(32)
|
117
|
+
self.relu = nn.ReLU(inplace=True)
|
118
|
+
|
119
|
+
self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
|
120
|
+
self.bn2 = BatchNorm(64)
|
121
|
+
|
122
|
+
self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
|
123
|
+
self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
|
124
|
+
grow_first=True)
|
125
|
+
self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
|
126
|
+
start_with_relu=True, grow_first=True, is_last=True)
|
127
|
+
|
128
|
+
# Middle flow
|
129
|
+
self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
130
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
131
|
+
self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
132
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
133
|
+
self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
134
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
135
|
+
self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
136
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
137
|
+
self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
138
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
139
|
+
self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
140
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
141
|
+
self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
142
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
143
|
+
self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
144
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
145
|
+
self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
146
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
147
|
+
self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
148
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
149
|
+
self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
150
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
151
|
+
self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
152
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
153
|
+
self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
154
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
155
|
+
self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
156
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
157
|
+
self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
158
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
159
|
+
self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
|
160
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
|
161
|
+
|
162
|
+
# Exit flow
|
163
|
+
self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
|
164
|
+
BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
|
165
|
+
|
166
|
+
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
|
167
|
+
self.bn3 = BatchNorm(1536)
|
168
|
+
|
169
|
+
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
|
170
|
+
self.bn4 = BatchNorm(1536)
|
171
|
+
|
172
|
+
self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
|
173
|
+
self.bn5 = BatchNorm(2048)
|
174
|
+
|
175
|
+
# Init weights
|
176
|
+
self._init_weight()
|
177
|
+
|
178
|
+
# Load pretrained model
|
179
|
+
if pretrained:
|
180
|
+
self._load_pretrained_model()
|
181
|
+
|
182
|
+
def forward(self, x):
|
183
|
+
# Entry flow
|
184
|
+
x = self.conv1(x)
|
185
|
+
x = self.bn1(x)
|
186
|
+
x = self.relu(x)
|
187
|
+
|
188
|
+
x = self.conv2(x)
|
189
|
+
x = self.bn2(x)
|
190
|
+
x = self.relu(x)
|
191
|
+
|
192
|
+
x = self.block1(x)
|
193
|
+
# add relu here
|
194
|
+
x = self.relu(x)
|
195
|
+
low_level_feat = x
|
196
|
+
x = self.block2(x)
|
197
|
+
x = self.block3(x)
|
198
|
+
|
199
|
+
# Middle flow
|
200
|
+
x = self.block4(x)
|
201
|
+
x = self.block5(x)
|
202
|
+
x = self.block6(x)
|
203
|
+
x = self.block7(x)
|
204
|
+
x = self.block8(x)
|
205
|
+
x = self.block9(x)
|
206
|
+
x = self.block10(x)
|
207
|
+
x = self.block11(x)
|
208
|
+
x = self.block12(x)
|
209
|
+
x = self.block13(x)
|
210
|
+
x = self.block14(x)
|
211
|
+
x = self.block15(x)
|
212
|
+
x = self.block16(x)
|
213
|
+
x = self.block17(x)
|
214
|
+
x = self.block18(x)
|
215
|
+
x = self.block19(x)
|
216
|
+
|
217
|
+
# Exit flow
|
218
|
+
x = self.block20(x)
|
219
|
+
x = self.relu(x)
|
220
|
+
x = self.conv3(x)
|
221
|
+
x = self.bn3(x)
|
222
|
+
x = self.relu(x)
|
223
|
+
|
224
|
+
x = self.conv4(x)
|
225
|
+
x = self.bn4(x)
|
226
|
+
x = self.relu(x)
|
227
|
+
|
228
|
+
x = self.conv5(x)
|
229
|
+
x = self.bn5(x)
|
230
|
+
x = self.relu(x)
|
231
|
+
|
232
|
+
return x, low_level_feat
|
233
|
+
|
234
|
+
def _init_weight(self):
|
235
|
+
for m in self.modules():
|
236
|
+
if isinstance(m, nn.Conv2d):
|
237
|
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
238
|
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
239
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
240
|
+
m.weight.data.fill_(1)
|
241
|
+
m.bias.data.zero_()
|
242
|
+
elif isinstance(m, nn.BatchNorm2d):
|
243
|
+
m.weight.data.fill_(1)
|
244
|
+
m.bias.data.zero_()
|
245
|
+
|
246
|
+
|
247
|
+
def _load_pretrained_model(self):
|
248
|
+
pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
|
249
|
+
model_dict = {}
|
250
|
+
state_dict = self.state_dict()
|
251
|
+
|
252
|
+
for k, v in pretrain_dict.items():
|
253
|
+
if k in state_dict:
|
254
|
+
if 'pointwise' in k:
|
255
|
+
v = v.unsqueeze(-1).unsqueeze(-1)
|
256
|
+
if k.startswith('block11'):
|
257
|
+
model_dict[k] = v
|
258
|
+
model_dict[k.replace('block11', 'block12')] = v
|
259
|
+
model_dict[k.replace('block11', 'block13')] = v
|
260
|
+
model_dict[k.replace('block11', 'block14')] = v
|
261
|
+
model_dict[k.replace('block11', 'block15')] = v
|
262
|
+
model_dict[k.replace('block11', 'block16')] = v
|
263
|
+
model_dict[k.replace('block11', 'block17')] = v
|
264
|
+
model_dict[k.replace('block11', 'block18')] = v
|
265
|
+
model_dict[k.replace('block11', 'block19')] = v
|
266
|
+
elif k.startswith('block12'):
|
267
|
+
model_dict[k.replace('block12', 'block20')] = v
|
268
|
+
elif k.startswith('bn3'):
|
269
|
+
model_dict[k] = v
|
270
|
+
model_dict[k.replace('bn3', 'bn4')] = v
|
271
|
+
elif k.startswith('conv4'):
|
272
|
+
model_dict[k.replace('conv4', 'conv5')] = v
|
273
|
+
elif k.startswith('bn4'):
|
274
|
+
model_dict[k.replace('bn4', 'bn5')] = v
|
275
|
+
else:
|
276
|
+
model_dict[k] = v
|
277
|
+
state_dict.update(model_dict)
|
278
|
+
self.load_state_dict(state_dict)
|
279
|
+
|
280
|
+
|
281
|
+
|
282
|
+
if __name__ == "__main__":
|
283
|
+
import torch
|
284
|
+
model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
|
285
|
+
input = torch.rand(1, 3, 512, 512)
|
286
|
+
output, low_level_feat = model(input)
|
287
|
+
print(output.size())
|
288
|
+
print(low_level_feat.size())
|
@@ -0,0 +1,59 @@
|
|
1
|
+
import math
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6
|
+
|
7
|
+
class Decoder(nn.Module):
|
8
|
+
def __init__(self, num_classes, backbone, BatchNorm):
|
9
|
+
super(Decoder, self).__init__()
|
10
|
+
if backbone == 'resnet' or backbone == 'drn':
|
11
|
+
low_level_inplanes = 256
|
12
|
+
elif backbone == 'xception':
|
13
|
+
low_level_inplanes = 128
|
14
|
+
elif backbone == 'mobilenet':
|
15
|
+
low_level_inplanes = 24
|
16
|
+
else:
|
17
|
+
raise NotImplementedError
|
18
|
+
|
19
|
+
self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
|
20
|
+
self.bn1 = BatchNorm(48)
|
21
|
+
self.relu = nn.ReLU()
|
22
|
+
self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
23
|
+
BatchNorm(256),
|
24
|
+
nn.ReLU(),
|
25
|
+
nn.Dropout(0.5),
|
26
|
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
27
|
+
BatchNorm(256),
|
28
|
+
nn.ReLU(),
|
29
|
+
nn.Dropout(0.1),
|
30
|
+
nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
|
31
|
+
nn.Sigmoid()
|
32
|
+
)
|
33
|
+
self._init_weight()
|
34
|
+
|
35
|
+
|
36
|
+
def forward(self, x, low_level_feat):
|
37
|
+
low_level_feat = self.conv1(low_level_feat)
|
38
|
+
low_level_feat = self.bn1(low_level_feat)
|
39
|
+
low_level_feat = self.relu(low_level_feat)
|
40
|
+
|
41
|
+
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
|
42
|
+
x = torch.cat((x, low_level_feat), dim=1)
|
43
|
+
x = self.last_conv(x)
|
44
|
+
|
45
|
+
return x
|
46
|
+
|
47
|
+
def _init_weight(self):
|
48
|
+
for m in self.modules():
|
49
|
+
if isinstance(m, nn.Conv2d):
|
50
|
+
torch.nn.init.kaiming_normal_(m.weight)
|
51
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
52
|
+
m.weight.data.fill_(1)
|
53
|
+
m.bias.data.zero_()
|
54
|
+
elif isinstance(m, nn.BatchNorm2d):
|
55
|
+
m.weight.data.fill_(1)
|
56
|
+
m.bias.data.zero_()
|
57
|
+
|
58
|
+
def build_decoder(num_classes, backbone, BatchNorm):
|
59
|
+
return Decoder(num_classes, backbone, BatchNorm)
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5
|
+
from model.deep_lab_model.aspp import build_aspp
|
6
|
+
from model.deep_lab_model.decoder import build_decoder
|
7
|
+
from model.deep_lab_model.backbone import build_backbone
|
8
|
+
|
9
|
+
class DeepLab(nn.Module):
|
10
|
+
def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
|
11
|
+
sync_bn=True, freeze_bn=False):
|
12
|
+
super(DeepLab, self).__init__()
|
13
|
+
if backbone == 'drn':
|
14
|
+
output_stride = 8
|
15
|
+
|
16
|
+
if sync_bn == True:
|
17
|
+
BatchNorm = SynchronizedBatchNorm2d
|
18
|
+
else:
|
19
|
+
BatchNorm = nn.BatchNorm2d
|
20
|
+
|
21
|
+
self.backbone = build_backbone(backbone, output_stride, BatchNorm)
|
22
|
+
self.aspp = build_aspp(backbone, output_stride, BatchNorm)
|
23
|
+
self.decoder = build_decoder(num_classes, backbone, BatchNorm)
|
24
|
+
|
25
|
+
self.freeze_bn = freeze_bn
|
26
|
+
|
27
|
+
def forward(self, input):
|
28
|
+
x, low_level_feat = self.backbone(input)
|
29
|
+
x = self.aspp(x)
|
30
|
+
x = self.decoder(x, low_level_feat)
|
31
|
+
x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
|
32
|
+
|
33
|
+
return x
|
34
|
+
|
35
|
+
def freeze_bn(self):
|
36
|
+
for m in self.modules():
|
37
|
+
if isinstance(m, SynchronizedBatchNorm2d):
|
38
|
+
m.eval()
|
39
|
+
elif isinstance(m, nn.BatchNorm2d):
|
40
|
+
m.eval()
|
41
|
+
|
42
|
+
def get_1x_lr_params(self):
|
43
|
+
modules = [self.backbone]
|
44
|
+
for i in range(len(modules)):
|
45
|
+
for m in modules[i].named_modules():
|
46
|
+
if self.freeze_bn:
|
47
|
+
if isinstance(m[1], nn.Conv2d):
|
48
|
+
for p in m[1].parameters():
|
49
|
+
if p.requires_grad:
|
50
|
+
yield p
|
51
|
+
else:
|
52
|
+
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
|
53
|
+
or isinstance(m[1], nn.BatchNorm2d):
|
54
|
+
for p in m[1].parameters():
|
55
|
+
if p.requires_grad:
|
56
|
+
yield p
|
57
|
+
|
58
|
+
def get_10x_lr_params(self):
|
59
|
+
modules = [self.aspp, self.decoder]
|
60
|
+
for i in range(len(modules)):
|
61
|
+
for m in modules[i].named_modules():
|
62
|
+
if self.freeze_bn:
|
63
|
+
if isinstance(m[1], nn.Conv2d):
|
64
|
+
for p in m[1].parameters():
|
65
|
+
if p.requires_grad:
|
66
|
+
yield p
|
67
|
+
else:
|
68
|
+
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
|
69
|
+
or isinstance(m[1], nn.BatchNorm2d):
|
70
|
+
for p in m[1].parameters():
|
71
|
+
if p.requires_grad:
|
72
|
+
yield p
|
73
|
+
|
74
|
+
if __name__ == "__main__":
|
75
|
+
model = DeepLab(backbone='mobilenet', output_stride=16)
|
76
|
+
model.eval()
|
77
|
+
input = torch.rand(1, 3, 513, 513)
|
78
|
+
output = model(input)
|
79
|
+
print(output.size())
|
80
|
+
|
81
|
+
|
@@ -0,0 +1,12 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# File : __init__.py
|
3
|
+
# Author : Jiayuan Mao
|
4
|
+
# Email : maojiayuan@gmail.com
|
5
|
+
# Date : 27/01/2018
|
6
|
+
#
|
7
|
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8
|
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9
|
+
# Distributed under MIT License.
|
10
|
+
|
11
|
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12
|
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|