doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +168 -0
  3. doctra/engines/image_restoration/__init__.py +10 -0
  4. doctra/engines/image_restoration/docres_engine.py +566 -0
  5. doctra/engines/vlm/service.py +0 -12
  6. doctra/parsers/enhanced_pdf_parser.py +370 -0
  7. doctra/parsers/structured_pdf_parser.py +11 -60
  8. doctra/parsers/table_chart_extractor.py +8 -44
  9. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  10. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  11. doctra/third_party/docres/data/MBD/infer.py +151 -0
  12. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  13. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  14. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  25. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  26. doctra/third_party/docres/inference.py +370 -0
  27. doctra/third_party/docres/models/restormer_arch.py +308 -0
  28. doctra/third_party/docres/utils.py +464 -0
  29. doctra/ui/app.py +5 -32
  30. doctra/utils/progress.py +13 -98
  31. doctra/utils/structured_utils.py +45 -49
  32. doctra/version.py +1 -1
  33. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
  34. doctra-0.4.0.dist-info/RECORD +67 -0
  35. doctra-0.3.2.dist-info/RECORD +0 -44
  36. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
  37. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
  38. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import math
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+ import torch.utils.model_zoo as model_zoo
7
+
8
+ def conv_bn(inp, oup, stride, BatchNorm):
9
+ return nn.Sequential(
10
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11
+ BatchNorm(oup),
12
+ nn.ReLU6(inplace=True)
13
+ )
14
+
15
+
16
+ def fixed_padding(inputs, kernel_size, dilation):
17
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
18
+ pad_total = kernel_size_effective - 1
19
+ pad_beg = pad_total // 2
20
+ pad_end = pad_total - pad_beg
21
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
22
+ return padded_inputs
23
+
24
+
25
+ class InvertedResidual(nn.Module):
26
+ def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
27
+ super(InvertedResidual, self).__init__()
28
+ self.stride = stride
29
+ assert stride in [1, 2]
30
+
31
+ hidden_dim = round(inp * expand_ratio)
32
+ self.use_res_connect = self.stride == 1 and inp == oup
33
+ self.kernel_size = 3
34
+ self.dilation = dilation
35
+
36
+ if expand_ratio == 1:
37
+ self.conv = nn.Sequential(
38
+ # dw
39
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
40
+ BatchNorm(hidden_dim),
41
+ nn.ReLU6(inplace=True),
42
+ # pw-linear
43
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
44
+ BatchNorm(oup),
45
+ )
46
+ else:
47
+ self.conv = nn.Sequential(
48
+ # pw
49
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
50
+ BatchNorm(hidden_dim),
51
+ nn.ReLU6(inplace=True),
52
+ # dw
53
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
54
+ BatchNorm(hidden_dim),
55
+ nn.ReLU6(inplace=True),
56
+ # pw-linear
57
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
58
+ BatchNorm(oup),
59
+ )
60
+
61
+ def forward(self, x):
62
+ x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
63
+ if self.use_res_connect:
64
+ x = x + self.conv(x_pad)
65
+ else:
66
+ x = self.conv(x_pad)
67
+ return x
68
+
69
+
70
+ class MobileNetV2(nn.Module):
71
+ def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
72
+ super(MobileNetV2, self).__init__()
73
+ block = InvertedResidual
74
+ input_channel = 32
75
+ current_stride = 1
76
+ rate = 1
77
+ interverted_residual_setting = [
78
+ # t, c, n, s
79
+ [1, 16, 1, 1],
80
+ [6, 24, 2, 2],
81
+ [6, 32, 3, 2],
82
+ [6, 64, 4, 2],
83
+ [6, 96, 3, 1],
84
+ [6, 160, 3, 2],
85
+ [6, 320, 1, 1],
86
+ ]
87
+
88
+ # building first layer
89
+ input_channel = int(input_channel * width_mult)
90
+ self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
91
+ current_stride *= 2
92
+ # building inverted residual blocks
93
+ for t, c, n, s in interverted_residual_setting:
94
+ if current_stride == output_stride:
95
+ stride = 1
96
+ dilation = rate
97
+ rate *= s
98
+ else:
99
+ stride = s
100
+ dilation = 1
101
+ current_stride *= s
102
+ output_channel = int(c * width_mult)
103
+ for i in range(n):
104
+ if i == 0:
105
+ self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
106
+ else:
107
+ self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
108
+ input_channel = output_channel
109
+ self.features = nn.Sequential(*self.features)
110
+ self._initialize_weights()
111
+
112
+ if pretrained:
113
+ self._load_pretrained_model()
114
+
115
+ self.low_level_features = self.features[0:4]
116
+ self.high_level_features = self.features[4:]
117
+
118
+ def forward(self, x):
119
+ low_level_feat = self.low_level_features(x)
120
+ x = self.high_level_features(low_level_feat)
121
+ return x, low_level_feat
122
+
123
+ def _load_pretrained_model(self):
124
+ pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
125
+ model_dict = {}
126
+ state_dict = self.state_dict()
127
+ for k, v in pretrain_dict.items():
128
+ if k in state_dict:
129
+ model_dict[k] = v
130
+ state_dict.update(model_dict)
131
+ self.load_state_dict(state_dict)
132
+
133
+ def _initialize_weights(self):
134
+ for m in self.modules():
135
+ if isinstance(m, nn.Conv2d):
136
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
138
+ torch.nn.init.kaiming_normal_(m.weight)
139
+ elif isinstance(m, SynchronizedBatchNorm2d):
140
+ m.weight.data.fill_(1)
141
+ m.bias.data.zero_()
142
+ elif isinstance(m, nn.BatchNorm2d):
143
+ m.weight.data.fill_(1)
144
+ m.bias.data.zero_()
145
+
146
+ if __name__ == "__main__":
147
+ input = torch.rand(1, 3, 512, 512)
148
+ model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
149
+ output, low_level_feat = model(input)
150
+ print(output.size())
151
+ print(low_level_feat.size())
@@ -0,0 +1,170 @@
1
+ import math
2
+ import torch.nn as nn
3
+ import torch.utils.model_zoo as model_zoo
4
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5
+
6
+ class Bottleneck(nn.Module):
7
+ expansion = 4
8
+
9
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
10
+ super(Bottleneck, self).__init__()
11
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12
+ self.bn1 = BatchNorm(planes)
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14
+ dilation=dilation, padding=dilation, bias=False)
15
+ self.bn2 = BatchNorm(planes)
16
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17
+ self.bn3 = BatchNorm(planes * 4)
18
+ self.relu = nn.ReLU(inplace=True)
19
+ self.downsample = downsample
20
+ self.stride = stride
21
+ self.dilation = dilation
22
+
23
+ def forward(self, x):
24
+ residual = x
25
+
26
+ out = self.conv1(x)
27
+ out = self.bn1(out)
28
+ out = self.relu(out)
29
+
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+ out = self.relu(out)
33
+
34
+ out = self.conv3(out)
35
+ out = self.bn3(out)
36
+
37
+ if self.downsample is not None:
38
+ residual = self.downsample(x)
39
+
40
+ out += residual
41
+ out = self.relu(out)
42
+
43
+ return out
44
+
45
+ class ResNet(nn.Module):
46
+
47
+ def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
48
+ self.inplanes = 64
49
+ super(ResNet, self).__init__()
50
+ blocks = [1, 2, 4]
51
+ if output_stride == 16:
52
+ strides = [1, 2, 2, 1]
53
+ dilations = [1, 1, 1, 2]
54
+ elif output_stride == 8:
55
+ strides = [1, 2, 1, 1]
56
+ dilations = [1, 1, 2, 4]
57
+ else:
58
+ raise NotImplementedError
59
+
60
+ # Modules
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = BatchNorm(64)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66
+
67
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
68
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
69
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
70
+ self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
71
+ # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
72
+ self._init_weight()
73
+
74
+ # if pretrained:
75
+ # self._load_pretrained_model()
76
+
77
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
78
+ downsample = None
79
+ if stride != 1 or self.inplanes != planes * block.expansion:
80
+ downsample = nn.Sequential(
81
+ nn.Conv2d(self.inplanes, planes * block.expansion,
82
+ kernel_size=1, stride=stride, bias=False),
83
+ BatchNorm(planes * block.expansion),
84
+ )
85
+
86
+ layers = []
87
+ layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
88
+ self.inplanes = planes * block.expansion
89
+ for i in range(1, blocks):
90
+ layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
91
+
92
+ return nn.Sequential(*layers)
93
+
94
+ def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
95
+ downsample = None
96
+ if stride != 1 or self.inplanes != planes * block.expansion:
97
+ downsample = nn.Sequential(
98
+ nn.Conv2d(self.inplanes, planes * block.expansion,
99
+ kernel_size=1, stride=stride, bias=False),
100
+ BatchNorm(planes * block.expansion),
101
+ )
102
+
103
+ layers = []
104
+ layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
105
+ downsample=downsample, BatchNorm=BatchNorm))
106
+ self.inplanes = planes * block.expansion
107
+ for i in range(1, len(blocks)):
108
+ layers.append(block(self.inplanes, planes, stride=1,
109
+ dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
110
+
111
+ return nn.Sequential(*layers)
112
+
113
+ def forward(self, input):
114
+ x = self.conv1(input)
115
+ x = self.bn1(x)
116
+ x = self.relu(x)
117
+ x = self.maxpool(x)
118
+
119
+ x = self.layer1(x)
120
+ low_level_feat = x
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.layer4(x)
124
+ return x, low_level_feat
125
+
126
+ def _init_weight(self):
127
+ for m in self.modules():
128
+ if isinstance(m, nn.Conv2d):
129
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+ m.weight.data.normal_(0, math.sqrt(2. / n))
131
+ elif isinstance(m, SynchronizedBatchNorm2d):
132
+ m.weight.data.fill_(1)
133
+ m.bias.data.zero_()
134
+ elif isinstance(m, nn.BatchNorm2d):
135
+ m.weight.data.fill_(1)
136
+ m.bias.data.zero_()
137
+
138
+ def _load_pretrained_model(self):
139
+
140
+ import urllib.request
141
+ import ssl
142
+ ssl._create_default_https_context = ssl._create_unverified_context
143
+ response = urllib.request.urlopen('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
144
+
145
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
146
+ model_dict = {}
147
+ state_dict = self.state_dict()
148
+ for k, v in pretrain_dict.items():
149
+ if k in state_dict:
150
+ # if 'conv1' in k:
151
+ # continue
152
+ model_dict[k] = v
153
+ state_dict.update(model_dict)
154
+ self.load_state_dict(state_dict)
155
+
156
+ def ResNet101(output_stride, BatchNorm, pretrained=True):
157
+ """Constructs a ResNet-101 model.
158
+ Args:
159
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
160
+ """
161
+ model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
162
+ return model
163
+
164
+ if __name__ == "__main__":
165
+ import torch
166
+ model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
167
+ input = torch.rand(1, 3, 512, 512)
168
+ output, low_level_feat = model(input)
169
+ print(output.size())
170
+ print(low_level_feat.size())
@@ -0,0 +1,288 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7
+
8
+ def fixed_padding(inputs, kernel_size, dilation):
9
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
10
+ pad_total = kernel_size_effective - 1
11
+ pad_beg = pad_total // 2
12
+ pad_end = pad_total - pad_beg
13
+ padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
14
+ return padded_inputs
15
+
16
+
17
+ class SeparableConv2d(nn.Module):
18
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
19
+ super(SeparableConv2d, self).__init__()
20
+
21
+ self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
22
+ groups=inplanes, bias=bias)
23
+ self.bn = BatchNorm(inplanes)
24
+ self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
25
+
26
+ def forward(self, x):
27
+ x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
28
+ x = self.conv1(x)
29
+ x = self.bn(x)
30
+ x = self.pointwise(x)
31
+ return x
32
+
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
36
+ start_with_relu=True, grow_first=True, is_last=False):
37
+ super(Block, self).__init__()
38
+
39
+ if planes != inplanes or stride != 1:
40
+ self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
41
+ self.skipbn = BatchNorm(planes)
42
+ else:
43
+ self.skip = None
44
+
45
+ self.relu = nn.ReLU(inplace=True)
46
+ rep = []
47
+
48
+ filters = inplanes
49
+ if grow_first:
50
+ rep.append(self.relu)
51
+ rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
52
+ rep.append(BatchNorm(planes))
53
+ filters = planes
54
+
55
+ for i in range(reps - 1):
56
+ rep.append(self.relu)
57
+ rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
58
+ rep.append(BatchNorm(filters))
59
+
60
+ if not grow_first:
61
+ rep.append(self.relu)
62
+ rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
63
+ rep.append(BatchNorm(planes))
64
+
65
+ if stride != 1:
66
+ rep.append(self.relu)
67
+ rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
68
+ rep.append(BatchNorm(planes))
69
+
70
+ if stride == 1 and is_last:
71
+ rep.append(self.relu)
72
+ rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
73
+ rep.append(BatchNorm(planes))
74
+
75
+ if not start_with_relu:
76
+ rep = rep[1:]
77
+
78
+ self.rep = nn.Sequential(*rep)
79
+
80
+ def forward(self, inp):
81
+ x = self.rep(inp)
82
+
83
+ if self.skip is not None:
84
+ skip = self.skip(inp)
85
+ skip = self.skipbn(skip)
86
+ else:
87
+ skip = inp
88
+
89
+ x = x + skip
90
+
91
+ return x
92
+
93
+
94
+ class AlignedXception(nn.Module):
95
+ """
96
+ Modified Alighed Xception
97
+ """
98
+ def __init__(self, output_stride, BatchNorm,
99
+ pretrained=True):
100
+ super(AlignedXception, self).__init__()
101
+
102
+ if output_stride == 16:
103
+ entry_block3_stride = 2
104
+ middle_block_dilation = 1
105
+ exit_block_dilations = (1, 2)
106
+ elif output_stride == 8:
107
+ entry_block3_stride = 1
108
+ middle_block_dilation = 2
109
+ exit_block_dilations = (2, 4)
110
+ else:
111
+ raise NotImplementedError
112
+
113
+
114
+ # Entry flow
115
+ self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
116
+ self.bn1 = BatchNorm(32)
117
+ self.relu = nn.ReLU(inplace=True)
118
+
119
+ self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
120
+ self.bn2 = BatchNorm(64)
121
+
122
+ self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
123
+ self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
124
+ grow_first=True)
125
+ self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
126
+ start_with_relu=True, grow_first=True, is_last=True)
127
+
128
+ # Middle flow
129
+ self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
130
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
131
+ self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
132
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
133
+ self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
134
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
135
+ self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
136
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
137
+ self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
138
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
139
+ self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
140
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
141
+ self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
142
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
143
+ self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
144
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
145
+ self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
146
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
147
+ self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
148
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
149
+ self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
150
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
151
+ self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
152
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
153
+ self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
154
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
155
+ self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
156
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
157
+ self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
158
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
159
+ self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
160
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
161
+
162
+ # Exit flow
163
+ self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
164
+ BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
165
+
166
+ self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
167
+ self.bn3 = BatchNorm(1536)
168
+
169
+ self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
170
+ self.bn4 = BatchNorm(1536)
171
+
172
+ self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
173
+ self.bn5 = BatchNorm(2048)
174
+
175
+ # Init weights
176
+ self._init_weight()
177
+
178
+ # Load pretrained model
179
+ if pretrained:
180
+ self._load_pretrained_model()
181
+
182
+ def forward(self, x):
183
+ # Entry flow
184
+ x = self.conv1(x)
185
+ x = self.bn1(x)
186
+ x = self.relu(x)
187
+
188
+ x = self.conv2(x)
189
+ x = self.bn2(x)
190
+ x = self.relu(x)
191
+
192
+ x = self.block1(x)
193
+ # add relu here
194
+ x = self.relu(x)
195
+ low_level_feat = x
196
+ x = self.block2(x)
197
+ x = self.block3(x)
198
+
199
+ # Middle flow
200
+ x = self.block4(x)
201
+ x = self.block5(x)
202
+ x = self.block6(x)
203
+ x = self.block7(x)
204
+ x = self.block8(x)
205
+ x = self.block9(x)
206
+ x = self.block10(x)
207
+ x = self.block11(x)
208
+ x = self.block12(x)
209
+ x = self.block13(x)
210
+ x = self.block14(x)
211
+ x = self.block15(x)
212
+ x = self.block16(x)
213
+ x = self.block17(x)
214
+ x = self.block18(x)
215
+ x = self.block19(x)
216
+
217
+ # Exit flow
218
+ x = self.block20(x)
219
+ x = self.relu(x)
220
+ x = self.conv3(x)
221
+ x = self.bn3(x)
222
+ x = self.relu(x)
223
+
224
+ x = self.conv4(x)
225
+ x = self.bn4(x)
226
+ x = self.relu(x)
227
+
228
+ x = self.conv5(x)
229
+ x = self.bn5(x)
230
+ x = self.relu(x)
231
+
232
+ return x, low_level_feat
233
+
234
+ def _init_weight(self):
235
+ for m in self.modules():
236
+ if isinstance(m, nn.Conv2d):
237
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
238
+ m.weight.data.normal_(0, math.sqrt(2. / n))
239
+ elif isinstance(m, SynchronizedBatchNorm2d):
240
+ m.weight.data.fill_(1)
241
+ m.bias.data.zero_()
242
+ elif isinstance(m, nn.BatchNorm2d):
243
+ m.weight.data.fill_(1)
244
+ m.bias.data.zero_()
245
+
246
+
247
+ def _load_pretrained_model(self):
248
+ pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
249
+ model_dict = {}
250
+ state_dict = self.state_dict()
251
+
252
+ for k, v in pretrain_dict.items():
253
+ if k in state_dict:
254
+ if 'pointwise' in k:
255
+ v = v.unsqueeze(-1).unsqueeze(-1)
256
+ if k.startswith('block11'):
257
+ model_dict[k] = v
258
+ model_dict[k.replace('block11', 'block12')] = v
259
+ model_dict[k.replace('block11', 'block13')] = v
260
+ model_dict[k.replace('block11', 'block14')] = v
261
+ model_dict[k.replace('block11', 'block15')] = v
262
+ model_dict[k.replace('block11', 'block16')] = v
263
+ model_dict[k.replace('block11', 'block17')] = v
264
+ model_dict[k.replace('block11', 'block18')] = v
265
+ model_dict[k.replace('block11', 'block19')] = v
266
+ elif k.startswith('block12'):
267
+ model_dict[k.replace('block12', 'block20')] = v
268
+ elif k.startswith('bn3'):
269
+ model_dict[k] = v
270
+ model_dict[k.replace('bn3', 'bn4')] = v
271
+ elif k.startswith('conv4'):
272
+ model_dict[k.replace('conv4', 'conv5')] = v
273
+ elif k.startswith('bn4'):
274
+ model_dict[k.replace('bn4', 'bn5')] = v
275
+ else:
276
+ model_dict[k] = v
277
+ state_dict.update(model_dict)
278
+ self.load_state_dict(state_dict)
279
+
280
+
281
+
282
+ if __name__ == "__main__":
283
+ import torch
284
+ model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
285
+ input = torch.rand(1, 3, 512, 512)
286
+ output, low_level_feat = model(input)
287
+ print(output.size())
288
+ print(low_level_feat.size())
@@ -0,0 +1,59 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+
7
+ class Decoder(nn.Module):
8
+ def __init__(self, num_classes, backbone, BatchNorm):
9
+ super(Decoder, self).__init__()
10
+ if backbone == 'resnet' or backbone == 'drn':
11
+ low_level_inplanes = 256
12
+ elif backbone == 'xception':
13
+ low_level_inplanes = 128
14
+ elif backbone == 'mobilenet':
15
+ low_level_inplanes = 24
16
+ else:
17
+ raise NotImplementedError
18
+
19
+ self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20
+ self.bn1 = BatchNorm(48)
21
+ self.relu = nn.ReLU()
22
+ self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23
+ BatchNorm(256),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27
+ BatchNorm(256),
28
+ nn.ReLU(),
29
+ nn.Dropout(0.1),
30
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
31
+ nn.Sigmoid()
32
+ )
33
+ self._init_weight()
34
+
35
+
36
+ def forward(self, x, low_level_feat):
37
+ low_level_feat = self.conv1(low_level_feat)
38
+ low_level_feat = self.bn1(low_level_feat)
39
+ low_level_feat = self.relu(low_level_feat)
40
+
41
+ x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
42
+ x = torch.cat((x, low_level_feat), dim=1)
43
+ x = self.last_conv(x)
44
+
45
+ return x
46
+
47
+ def _init_weight(self):
48
+ for m in self.modules():
49
+ if isinstance(m, nn.Conv2d):
50
+ torch.nn.init.kaiming_normal_(m.weight)
51
+ elif isinstance(m, SynchronizedBatchNorm2d):
52
+ m.weight.data.fill_(1)
53
+ m.bias.data.zero_()
54
+ elif isinstance(m, nn.BatchNorm2d):
55
+ m.weight.data.fill_(1)
56
+ m.bias.data.zero_()
57
+
58
+ def build_decoder(num_classes, backbone, BatchNorm):
59
+ return Decoder(num_classes, backbone, BatchNorm)