doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +168 -0
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +566 -0
- doctra/engines/vlm/service.py +0 -12
- doctra/parsers/enhanced_pdf_parser.py +370 -0
- doctra/parsers/structured_pdf_parser.py +11 -60
- doctra/parsers/table_chart_extractor.py +8 -44
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +5 -32
- doctra/utils/progress.py +13 -98
- doctra/utils/structured_utils.py +45 -49
- doctra/version.py +1 -1
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
- doctra-0.4.0.dist-info/RECORD +67 -0
- doctra-0.3.2.dist-info/RECORD +0 -44
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
|
|
1
|
+
import torch
|
2
|
+
import argparse
|
3
|
+
import numpy as np
|
4
|
+
import torch.nn.functional as F
|
5
|
+
import glob
|
6
|
+
import cv2
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
import time
|
10
|
+
import os
|
11
|
+
from model.deep_lab_model.deeplab import *
|
12
|
+
from MBD import mask_base_dewarper
|
13
|
+
import time
|
14
|
+
|
15
|
+
from utils import cvimg2torch,torch2cvimg
|
16
|
+
|
17
|
+
|
18
|
+
|
19
|
+
def net1_net2_infer(model,img_paths,args):
|
20
|
+
|
21
|
+
### validate on the real datasets
|
22
|
+
seg_model=model
|
23
|
+
seg_model.eval()
|
24
|
+
for img_path in tqdm(img_paths):
|
25
|
+
if os.path.exists(img_path.replace('_origin','_capture')):
|
26
|
+
continue
|
27
|
+
t1 = time.time()
|
28
|
+
### segmentation mask predict
|
29
|
+
img_org = cv2.imread(img_path)
|
30
|
+
h_org,w_org = img_org.shape[:2]
|
31
|
+
img = cv2.resize(img_org,(448, 448))
|
32
|
+
img = cv2.GaussianBlur(img,(15,15),0,0)
|
33
|
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
34
|
+
img = cvimg2torch(img)
|
35
|
+
|
36
|
+
with torch.no_grad():
|
37
|
+
pred = seg_model(img.cuda())
|
38
|
+
mask_pred = pred[:,0,:,:].unsqueeze(1)
|
39
|
+
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
|
40
|
+
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
|
41
|
+
mask_pred = (mask_pred*255).astype(np.uint8)
|
42
|
+
kernel = np.ones((3,3))
|
43
|
+
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
|
44
|
+
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
|
45
|
+
mask_pred[mask_pred>100] = 255
|
46
|
+
mask_pred[mask_pred<100] = 0
|
47
|
+
### tps transform base on the mask
|
48
|
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
49
|
+
try:
|
50
|
+
dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
51
|
+
except:
|
52
|
+
print('fail')
|
53
|
+
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
|
54
|
+
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
|
55
|
+
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
|
56
|
+
grid = grid[0].numpy()
|
57
|
+
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
|
58
|
+
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
|
59
|
+
# cv2.waitKey(0)
|
60
|
+
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
|
61
|
+
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
|
62
|
+
|
63
|
+
grid0 = cv2.resize(grid[:,:,0],(128,128))
|
64
|
+
grid1 = cv2.resize(grid[:,:,1],(128,128))
|
65
|
+
grid = np.stack((grid0,grid1),axis=-1)
|
66
|
+
np.save(img_path.replace('_origin','_grid1'),grid)
|
67
|
+
|
68
|
+
|
69
|
+
def net1_net2_infer_single_im(img,model_path):
|
70
|
+
seg_model = DeepLab(num_classes=1,
|
71
|
+
backbone='resnet',
|
72
|
+
output_stride=16,
|
73
|
+
sync_bn=None,
|
74
|
+
freeze_bn=False)
|
75
|
+
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
|
76
|
+
seg_model.cuda()
|
77
|
+
checkpoint = torch.load(model_path)
|
78
|
+
seg_model.load_state_dict(checkpoint['model_state'])
|
79
|
+
### validate on the real datasets
|
80
|
+
seg_model.eval()
|
81
|
+
### segmentation mask predict
|
82
|
+
img_org = img
|
83
|
+
h_org,w_org = img_org.shape[:2]
|
84
|
+
img = cv2.resize(img_org,(448, 448))
|
85
|
+
img = cv2.GaussianBlur(img,(15,15),0,0)
|
86
|
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
87
|
+
img = cvimg2torch(img)
|
88
|
+
|
89
|
+
with torch.no_grad():
|
90
|
+
# from torchtoolbox.tools import summary
|
91
|
+
# print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
|
92
|
+
|
93
|
+
pred = seg_model(img.cuda())
|
94
|
+
mask_pred = pred[:,0,:,:].unsqueeze(1)
|
95
|
+
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
|
96
|
+
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
|
97
|
+
mask_pred = (mask_pred*255).astype(np.uint8)
|
98
|
+
kernel = np.ones((3,3))
|
99
|
+
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
|
100
|
+
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
|
101
|
+
mask_pred[mask_pred>100] = 255
|
102
|
+
mask_pred[mask_pred<100] = 0
|
103
|
+
### tps transform base on the mask
|
104
|
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
105
|
+
# try:
|
106
|
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
107
|
+
# except:
|
108
|
+
# print('fail')
|
109
|
+
# grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
|
110
|
+
# grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
|
111
|
+
# dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
|
112
|
+
# grid = grid[0].numpy()
|
113
|
+
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
|
114
|
+
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
|
115
|
+
# cv2.waitKey(0)
|
116
|
+
# cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
|
117
|
+
# cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
|
118
|
+
|
119
|
+
# grid0 = cv2.resize(grid[:,:,0],(128,128))
|
120
|
+
# grid1 = cv2.resize(grid[:,:,1],(128,128))
|
121
|
+
# grid = np.stack((grid0,grid1),axis=-1)
|
122
|
+
# np.save(img_path.replace('_origin','_grid1'),grid)
|
123
|
+
return mask_pred
|
124
|
+
|
125
|
+
|
126
|
+
|
127
|
+
if __name__ == '__main__':
|
128
|
+
parser = argparse.ArgumentParser(description='Hyperparams')
|
129
|
+
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
|
130
|
+
parser.add_argument('--img_rows', nargs='?', type=int, default=448,
|
131
|
+
help='Height of the input image')
|
132
|
+
parser.add_argument('--img_cols', nargs='?', type=int, default=448,
|
133
|
+
help='Width of the input image')
|
134
|
+
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
|
135
|
+
help='Path to previous saved model to restart from')
|
136
|
+
args = parser.parse_args()
|
137
|
+
|
138
|
+
seg_model = DeepLab(num_classes=1,
|
139
|
+
backbone='resnet',
|
140
|
+
output_stride=16,
|
141
|
+
sync_bn=None,
|
142
|
+
freeze_bn=False)
|
143
|
+
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
|
144
|
+
seg_model.cuda()
|
145
|
+
checkpoint = torch.load(args.seg_model_path)
|
146
|
+
seg_model.load_state_dict(checkpoint['model_state'])
|
147
|
+
|
148
|
+
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
|
149
|
+
|
150
|
+
net1_net2_infer(seg_model,im_paths,args)
|
151
|
+
|
@@ -0,0 +1,95 @@
|
|
1
|
+
import math
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6
|
+
|
7
|
+
class _ASPPModule(nn.Module):
|
8
|
+
def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
|
9
|
+
super(_ASPPModule, self).__init__()
|
10
|
+
self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
|
11
|
+
stride=1, padding=padding, dilation=dilation, bias=False)
|
12
|
+
self.bn = BatchNorm(planes)
|
13
|
+
self.relu = nn.ReLU()
|
14
|
+
|
15
|
+
self._init_weight()
|
16
|
+
|
17
|
+
def forward(self, x):
|
18
|
+
x = self.atrous_conv(x)
|
19
|
+
x = self.bn(x)
|
20
|
+
|
21
|
+
return self.relu(x)
|
22
|
+
|
23
|
+
def _init_weight(self):
|
24
|
+
for m in self.modules():
|
25
|
+
if isinstance(m, nn.Conv2d):
|
26
|
+
torch.nn.init.kaiming_normal_(m.weight)
|
27
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
28
|
+
m.weight.data.fill_(1)
|
29
|
+
m.bias.data.zero_()
|
30
|
+
elif isinstance(m, nn.BatchNorm2d):
|
31
|
+
m.weight.data.fill_(1)
|
32
|
+
m.bias.data.zero_()
|
33
|
+
|
34
|
+
class ASPP(nn.Module):
|
35
|
+
def __init__(self, backbone, output_stride, BatchNorm):
|
36
|
+
super(ASPP, self).__init__()
|
37
|
+
if backbone == 'drn':
|
38
|
+
inplanes = 512
|
39
|
+
elif backbone == 'mobilenet':
|
40
|
+
inplanes = 320
|
41
|
+
else:
|
42
|
+
inplanes = 2048
|
43
|
+
if output_stride == 16:
|
44
|
+
dilations = [1, 6, 12, 18]
|
45
|
+
elif output_stride == 8:
|
46
|
+
dilations = [1, 12, 24, 36]
|
47
|
+
else:
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
|
51
|
+
self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
|
52
|
+
self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
|
53
|
+
self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
54
|
+
|
55
|
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
56
|
+
nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
|
57
|
+
BatchNorm(256),
|
58
|
+
nn.ReLU())
|
59
|
+
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
|
60
|
+
self.bn1 = BatchNorm(256)
|
61
|
+
self.relu = nn.ReLU()
|
62
|
+
self.dropout = nn.Dropout(0.5)
|
63
|
+
self._init_weight()
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
x1 = self.aspp1(x)
|
67
|
+
x2 = self.aspp2(x)
|
68
|
+
x3 = self.aspp3(x)
|
69
|
+
x4 = self.aspp4(x)
|
70
|
+
x5 = self.global_avg_pool(x)
|
71
|
+
x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
72
|
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
73
|
+
|
74
|
+
x = self.conv1(x)
|
75
|
+
x = self.bn1(x)
|
76
|
+
x = self.relu(x)
|
77
|
+
|
78
|
+
return self.dropout(x)
|
79
|
+
|
80
|
+
def _init_weight(self):
|
81
|
+
for m in self.modules():
|
82
|
+
if isinstance(m, nn.Conv2d):
|
83
|
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
84
|
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
85
|
+
torch.nn.init.kaiming_normal_(m.weight)
|
86
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
87
|
+
m.weight.data.fill_(1)
|
88
|
+
m.bias.data.zero_()
|
89
|
+
elif isinstance(m, nn.BatchNorm2d):
|
90
|
+
m.weight.data.fill_(1)
|
91
|
+
m.bias.data.zero_()
|
92
|
+
|
93
|
+
|
94
|
+
def build_aspp(backbone, output_stride, BatchNorm):
|
95
|
+
return ASPP(backbone, output_stride, BatchNorm)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from model.deep_lab_model.backbone import resnet, xception, drn, mobilenet
|
2
|
+
|
3
|
+
def build_backbone(backbone, output_stride, BatchNorm):
|
4
|
+
if backbone == 'resnet':
|
5
|
+
return resnet.ResNet101(output_stride, BatchNorm)
|
6
|
+
elif backbone == 'xception':
|
7
|
+
return xception.AlignedXception(output_stride, BatchNorm)
|
8
|
+
elif backbone == 'drn':
|
9
|
+
return drn.drn_d_54(BatchNorm)
|
10
|
+
elif backbone == 'mobilenet':
|
11
|
+
return mobilenet.MobileNetV2(output_stride, BatchNorm)
|
12
|
+
else:
|
13
|
+
raise NotImplementedError
|
@@ -0,0 +1,402 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
import math
|
3
|
+
import torch.utils.model_zoo as model_zoo
|
4
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
5
|
+
|
6
|
+
webroot = 'http://dl.yf.io/drn/'
|
7
|
+
|
8
|
+
model_urls = {
|
9
|
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
10
|
+
'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
|
11
|
+
'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
|
12
|
+
'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
|
13
|
+
'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
|
14
|
+
'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
|
15
|
+
'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
|
16
|
+
'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
|
17
|
+
}
|
18
|
+
|
19
|
+
|
20
|
+
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
|
21
|
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
22
|
+
padding=padding, bias=False, dilation=dilation)
|
23
|
+
|
24
|
+
|
25
|
+
class BasicBlock(nn.Module):
|
26
|
+
expansion = 1
|
27
|
+
|
28
|
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
29
|
+
dilation=(1, 1), residual=True, BatchNorm=None):
|
30
|
+
super(BasicBlock, self).__init__()
|
31
|
+
self.conv1 = conv3x3(inplanes, planes, stride,
|
32
|
+
padding=dilation[0], dilation=dilation[0])
|
33
|
+
self.bn1 = BatchNorm(planes)
|
34
|
+
self.relu = nn.ReLU(inplace=True)
|
35
|
+
self.conv2 = conv3x3(planes, planes,
|
36
|
+
padding=dilation[1], dilation=dilation[1])
|
37
|
+
self.bn2 = BatchNorm(planes)
|
38
|
+
self.downsample = downsample
|
39
|
+
self.stride = stride
|
40
|
+
self.residual = residual
|
41
|
+
|
42
|
+
def forward(self, x):
|
43
|
+
residual = x
|
44
|
+
|
45
|
+
out = self.conv1(x)
|
46
|
+
out = self.bn1(out)
|
47
|
+
out = self.relu(out)
|
48
|
+
|
49
|
+
out = self.conv2(out)
|
50
|
+
out = self.bn2(out)
|
51
|
+
|
52
|
+
if self.downsample is not None:
|
53
|
+
residual = self.downsample(x)
|
54
|
+
if self.residual:
|
55
|
+
out += residual
|
56
|
+
out = self.relu(out)
|
57
|
+
|
58
|
+
return out
|
59
|
+
|
60
|
+
|
61
|
+
class Bottleneck(nn.Module):
|
62
|
+
expansion = 4
|
63
|
+
|
64
|
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
65
|
+
dilation=(1, 1), residual=True, BatchNorm=None):
|
66
|
+
super(Bottleneck, self).__init__()
|
67
|
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
68
|
+
self.bn1 = BatchNorm(planes)
|
69
|
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
70
|
+
padding=dilation[1], bias=False,
|
71
|
+
dilation=dilation[1])
|
72
|
+
self.bn2 = BatchNorm(planes)
|
73
|
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
74
|
+
self.bn3 = BatchNorm(planes * 4)
|
75
|
+
self.relu = nn.ReLU(inplace=True)
|
76
|
+
self.downsample = downsample
|
77
|
+
self.stride = stride
|
78
|
+
|
79
|
+
def forward(self, x):
|
80
|
+
residual = x
|
81
|
+
|
82
|
+
out = self.conv1(x)
|
83
|
+
out = self.bn1(out)
|
84
|
+
out = self.relu(out)
|
85
|
+
|
86
|
+
out = self.conv2(out)
|
87
|
+
out = self.bn2(out)
|
88
|
+
out = self.relu(out)
|
89
|
+
|
90
|
+
out = self.conv3(out)
|
91
|
+
out = self.bn3(out)
|
92
|
+
|
93
|
+
if self.downsample is not None:
|
94
|
+
residual = self.downsample(x)
|
95
|
+
|
96
|
+
out += residual
|
97
|
+
out = self.relu(out)
|
98
|
+
|
99
|
+
return out
|
100
|
+
|
101
|
+
|
102
|
+
class DRN(nn.Module):
|
103
|
+
|
104
|
+
def __init__(self, block, layers, arch='D',
|
105
|
+
channels=(16, 32, 64, 128, 256, 512, 512, 512),
|
106
|
+
BatchNorm=None):
|
107
|
+
super(DRN, self).__init__()
|
108
|
+
self.inplanes = channels[0]
|
109
|
+
self.out_dim = channels[-1]
|
110
|
+
self.arch = arch
|
111
|
+
|
112
|
+
if arch == 'C':
|
113
|
+
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
|
114
|
+
padding=3, bias=False)
|
115
|
+
self.bn1 = BatchNorm(channels[0])
|
116
|
+
self.relu = nn.ReLU(inplace=True)
|
117
|
+
|
118
|
+
self.layer1 = self._make_layer(
|
119
|
+
BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
|
120
|
+
self.layer2 = self._make_layer(
|
121
|
+
BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
|
122
|
+
|
123
|
+
elif arch == 'D':
|
124
|
+
self.layer0 = nn.Sequential(
|
125
|
+
nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
|
126
|
+
bias=False),
|
127
|
+
BatchNorm(channels[0]),
|
128
|
+
nn.ReLU(inplace=True)
|
129
|
+
)
|
130
|
+
|
131
|
+
self.layer1 = self._make_conv_layers(
|
132
|
+
channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
|
133
|
+
self.layer2 = self._make_conv_layers(
|
134
|
+
channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
|
135
|
+
|
136
|
+
self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
|
137
|
+
self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
|
138
|
+
self.layer5 = self._make_layer(block, channels[4], layers[4],
|
139
|
+
dilation=2, new_level=False, BatchNorm=BatchNorm)
|
140
|
+
self.layer6 = None if layers[5] == 0 else \
|
141
|
+
self._make_layer(block, channels[5], layers[5], dilation=4,
|
142
|
+
new_level=False, BatchNorm=BatchNorm)
|
143
|
+
|
144
|
+
if arch == 'C':
|
145
|
+
self.layer7 = None if layers[6] == 0 else \
|
146
|
+
self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
|
147
|
+
new_level=False, residual=False, BatchNorm=BatchNorm)
|
148
|
+
self.layer8 = None if layers[7] == 0 else \
|
149
|
+
self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
|
150
|
+
new_level=False, residual=False, BatchNorm=BatchNorm)
|
151
|
+
elif arch == 'D':
|
152
|
+
self.layer7 = None if layers[6] == 0 else \
|
153
|
+
self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
|
154
|
+
self.layer8 = None if layers[7] == 0 else \
|
155
|
+
self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
|
156
|
+
|
157
|
+
self._init_weight()
|
158
|
+
|
159
|
+
def _init_weight(self):
|
160
|
+
for m in self.modules():
|
161
|
+
if isinstance(m, nn.Conv2d):
|
162
|
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
163
|
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
164
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
165
|
+
m.weight.data.fill_(1)
|
166
|
+
m.bias.data.zero_()
|
167
|
+
elif isinstance(m, nn.BatchNorm2d):
|
168
|
+
m.weight.data.fill_(1)
|
169
|
+
m.bias.data.zero_()
|
170
|
+
|
171
|
+
|
172
|
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
173
|
+
new_level=True, residual=True, BatchNorm=None):
|
174
|
+
assert dilation == 1 or dilation % 2 == 0
|
175
|
+
downsample = None
|
176
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
177
|
+
downsample = nn.Sequential(
|
178
|
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
179
|
+
kernel_size=1, stride=stride, bias=False),
|
180
|
+
BatchNorm(planes * block.expansion),
|
181
|
+
)
|
182
|
+
|
183
|
+
layers = list()
|
184
|
+
layers.append(block(
|
185
|
+
self.inplanes, planes, stride, downsample,
|
186
|
+
dilation=(1, 1) if dilation == 1 else (
|
187
|
+
dilation // 2 if new_level else dilation, dilation),
|
188
|
+
residual=residual, BatchNorm=BatchNorm))
|
189
|
+
self.inplanes = planes * block.expansion
|
190
|
+
for i in range(1, blocks):
|
191
|
+
layers.append(block(self.inplanes, planes, residual=residual,
|
192
|
+
dilation=(dilation, dilation), BatchNorm=BatchNorm))
|
193
|
+
|
194
|
+
return nn.Sequential(*layers)
|
195
|
+
|
196
|
+
def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
|
197
|
+
modules = []
|
198
|
+
for i in range(convs):
|
199
|
+
modules.extend([
|
200
|
+
nn.Conv2d(self.inplanes, channels, kernel_size=3,
|
201
|
+
stride=stride if i == 0 else 1,
|
202
|
+
padding=dilation, bias=False, dilation=dilation),
|
203
|
+
BatchNorm(channels),
|
204
|
+
nn.ReLU(inplace=True)])
|
205
|
+
self.inplanes = channels
|
206
|
+
return nn.Sequential(*modules)
|
207
|
+
|
208
|
+
def forward(self, x):
|
209
|
+
if self.arch == 'C':
|
210
|
+
x = self.conv1(x)
|
211
|
+
x = self.bn1(x)
|
212
|
+
x = self.relu(x)
|
213
|
+
elif self.arch == 'D':
|
214
|
+
x = self.layer0(x)
|
215
|
+
|
216
|
+
x = self.layer1(x)
|
217
|
+
x = self.layer2(x)
|
218
|
+
|
219
|
+
x = self.layer3(x)
|
220
|
+
low_level_feat = x
|
221
|
+
|
222
|
+
x = self.layer4(x)
|
223
|
+
x = self.layer5(x)
|
224
|
+
|
225
|
+
if self.layer6 is not None:
|
226
|
+
x = self.layer6(x)
|
227
|
+
|
228
|
+
if self.layer7 is not None:
|
229
|
+
x = self.layer7(x)
|
230
|
+
|
231
|
+
if self.layer8 is not None:
|
232
|
+
x = self.layer8(x)
|
233
|
+
|
234
|
+
return x, low_level_feat
|
235
|
+
|
236
|
+
|
237
|
+
class DRN_A(nn.Module):
|
238
|
+
|
239
|
+
def __init__(self, block, layers, BatchNorm=None):
|
240
|
+
self.inplanes = 64
|
241
|
+
super(DRN_A, self).__init__()
|
242
|
+
self.out_dim = 512 * block.expansion
|
243
|
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
244
|
+
bias=False)
|
245
|
+
self.bn1 = BatchNorm(64)
|
246
|
+
self.relu = nn.ReLU(inplace=True)
|
247
|
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
248
|
+
self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
|
249
|
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
|
250
|
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
251
|
+
dilation=2, BatchNorm=BatchNorm)
|
252
|
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
253
|
+
dilation=4, BatchNorm=BatchNorm)
|
254
|
+
|
255
|
+
self._init_weight()
|
256
|
+
|
257
|
+
def _init_weight(self):
|
258
|
+
for m in self.modules():
|
259
|
+
if isinstance(m, nn.Conv2d):
|
260
|
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
261
|
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
262
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
263
|
+
m.weight.data.fill_(1)
|
264
|
+
m.bias.data.zero_()
|
265
|
+
elif isinstance(m, nn.BatchNorm2d):
|
266
|
+
m.weight.data.fill_(1)
|
267
|
+
m.bias.data.zero_()
|
268
|
+
|
269
|
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
|
270
|
+
downsample = None
|
271
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
272
|
+
downsample = nn.Sequential(
|
273
|
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
274
|
+
kernel_size=1, stride=stride, bias=False),
|
275
|
+
BatchNorm(planes * block.expansion),
|
276
|
+
)
|
277
|
+
|
278
|
+
layers = []
|
279
|
+
layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
|
280
|
+
self.inplanes = planes * block.expansion
|
281
|
+
for i in range(1, blocks):
|
282
|
+
layers.append(block(self.inplanes, planes,
|
283
|
+
dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
|
284
|
+
|
285
|
+
return nn.Sequential(*layers)
|
286
|
+
|
287
|
+
def forward(self, x):
|
288
|
+
x = self.conv1(x)
|
289
|
+
x = self.bn1(x)
|
290
|
+
x = self.relu(x)
|
291
|
+
x = self.maxpool(x)
|
292
|
+
|
293
|
+
x = self.layer1(x)
|
294
|
+
x = self.layer2(x)
|
295
|
+
x = self.layer3(x)
|
296
|
+
x = self.layer4(x)
|
297
|
+
|
298
|
+
return x
|
299
|
+
|
300
|
+
def drn_a_50(BatchNorm, pretrained=True):
|
301
|
+
model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
|
302
|
+
if pretrained:
|
303
|
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
304
|
+
return model
|
305
|
+
|
306
|
+
|
307
|
+
def drn_c_26(BatchNorm, pretrained=True):
|
308
|
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
|
309
|
+
if pretrained:
|
310
|
+
pretrained = model_zoo.load_url(model_urls['drn-c-26'])
|
311
|
+
del pretrained['fc.weight']
|
312
|
+
del pretrained['fc.bias']
|
313
|
+
model.load_state_dict(pretrained)
|
314
|
+
return model
|
315
|
+
|
316
|
+
|
317
|
+
def drn_c_42(BatchNorm, pretrained=True):
|
318
|
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
|
319
|
+
if pretrained:
|
320
|
+
pretrained = model_zoo.load_url(model_urls['drn-c-42'])
|
321
|
+
del pretrained['fc.weight']
|
322
|
+
del pretrained['fc.bias']
|
323
|
+
model.load_state_dict(pretrained)
|
324
|
+
return model
|
325
|
+
|
326
|
+
|
327
|
+
def drn_c_58(BatchNorm, pretrained=True):
|
328
|
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
|
329
|
+
if pretrained:
|
330
|
+
pretrained = model_zoo.load_url(model_urls['drn-c-58'])
|
331
|
+
del pretrained['fc.weight']
|
332
|
+
del pretrained['fc.bias']
|
333
|
+
model.load_state_dict(pretrained)
|
334
|
+
return model
|
335
|
+
|
336
|
+
|
337
|
+
def drn_d_22(BatchNorm, pretrained=True):
|
338
|
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
|
339
|
+
if pretrained:
|
340
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-22'])
|
341
|
+
del pretrained['fc.weight']
|
342
|
+
del pretrained['fc.bias']
|
343
|
+
model.load_state_dict(pretrained)
|
344
|
+
return model
|
345
|
+
|
346
|
+
|
347
|
+
def drn_d_24(BatchNorm, pretrained=True):
|
348
|
+
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
|
349
|
+
if pretrained:
|
350
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-24'])
|
351
|
+
del pretrained['fc.weight']
|
352
|
+
del pretrained['fc.bias']
|
353
|
+
model.load_state_dict(pretrained)
|
354
|
+
return model
|
355
|
+
|
356
|
+
|
357
|
+
def drn_d_38(BatchNorm, pretrained=True):
|
358
|
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
359
|
+
if pretrained:
|
360
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-38'])
|
361
|
+
del pretrained['fc.weight']
|
362
|
+
del pretrained['fc.bias']
|
363
|
+
model.load_state_dict(pretrained)
|
364
|
+
return model
|
365
|
+
|
366
|
+
|
367
|
+
def drn_d_40(BatchNorm, pretrained=True):
|
368
|
+
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
|
369
|
+
if pretrained:
|
370
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-40'])
|
371
|
+
del pretrained['fc.weight']
|
372
|
+
del pretrained['fc.bias']
|
373
|
+
model.load_state_dict(pretrained)
|
374
|
+
return model
|
375
|
+
|
376
|
+
|
377
|
+
def drn_d_54(BatchNorm, pretrained=True):
|
378
|
+
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
379
|
+
if pretrained:
|
380
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-54'])
|
381
|
+
del pretrained['fc.weight']
|
382
|
+
del pretrained['fc.bias']
|
383
|
+
model.load_state_dict(pretrained)
|
384
|
+
return model
|
385
|
+
|
386
|
+
|
387
|
+
def drn_d_105(BatchNorm, pretrained=True):
|
388
|
+
model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
|
389
|
+
if pretrained:
|
390
|
+
pretrained = model_zoo.load_url(model_urls['drn-d-105'])
|
391
|
+
del pretrained['fc.weight']
|
392
|
+
del pretrained['fc.bias']
|
393
|
+
model.load_state_dict(pretrained)
|
394
|
+
return model
|
395
|
+
|
396
|
+
if __name__ == "__main__":
|
397
|
+
import torch
|
398
|
+
model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
|
399
|
+
input = torch.rand(1, 3, 512, 512)
|
400
|
+
output, low_level_feat = model(input)
|
401
|
+
print(output.size())
|
402
|
+
print(low_level_feat.size())
|