doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +170 -9
- doctra/cli/utils.py +2 -3
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +561 -0
- doctra/engines/vlm/outlines_types.py +13 -9
- doctra/engines/vlm/service.py +4 -2
- doctra/exporters/excel_writer.py +89 -0
- doctra/parsers/enhanced_pdf_parser.py +374 -0
- doctra/parsers/structured_pdf_parser.py +6 -0
- doctra/parsers/table_chart_extractor.py +6 -0
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +8 -14
- doctra/utils/structured_utils.py +5 -2
- doctra/version.py +1 -1
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
- doctra-0.4.1.dist-info/RECORD +67 -0
- doctra-0.3.3.dist-info/RECORD +0 -44
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
|
|
1
|
+
import cv2
|
2
|
+
import numpy as np
|
3
|
+
import copy
|
4
|
+
import torch
|
5
|
+
import torch
|
6
|
+
import itertools
|
7
|
+
import torch.nn as nn
|
8
|
+
from torch.autograd import Function, Variable
|
9
|
+
|
10
|
+
def reorder(myPoints):
|
11
|
+
myPoints = myPoints.reshape((4, 2))
|
12
|
+
myPointsNew = np.zeros((4, 1, 2), dtype=np.int32)
|
13
|
+
add = myPoints.sum(1)
|
14
|
+
myPointsNew[0] = myPoints[np.argmin(add)]
|
15
|
+
myPointsNew[3] =myPoints[np.argmax(add)]
|
16
|
+
diff = np.diff(myPoints, axis=1)
|
17
|
+
myPointsNew[1] =myPoints[np.argmin(diff)]
|
18
|
+
myPointsNew[2] = myPoints[np.argmax(diff)]
|
19
|
+
return myPointsNew
|
20
|
+
|
21
|
+
|
22
|
+
def findMiddle(corners,mask,points=[0.25,0.5,0.75]):
|
23
|
+
num_middle_points = len(points)
|
24
|
+
top = [np.array([])]*num_middle_points
|
25
|
+
bottom = [np.array([])]*num_middle_points
|
26
|
+
left = [np.array([])]*num_middle_points
|
27
|
+
right = [np.array([])]*num_middle_points
|
28
|
+
|
29
|
+
center_top = []
|
30
|
+
center_bottom = []
|
31
|
+
center_left = []
|
32
|
+
center_right = []
|
33
|
+
|
34
|
+
center = (int((corners[0][0][1]+corners[3][0][1])/2),int((corners[0][0][0]+corners[3][0][0])/2))
|
35
|
+
for ratio in points:
|
36
|
+
|
37
|
+
center_top.append( (center[0],int(corners[0][0][0]*(1-ratio)+corners[1][0][0]*ratio)) )
|
38
|
+
|
39
|
+
center_bottom.append( (center[0],int(corners[2][0][0]*(1-ratio)+corners[3][0][0]*ratio)) )
|
40
|
+
|
41
|
+
center_left.append( (int(corners[0][0][1]*(1-ratio)+corners[2][0][1]*ratio),center[1]) )
|
42
|
+
|
43
|
+
center_right.append( (int(corners[1][0][1]*(1-ratio)+corners[3][0][1]*ratio),center[1]) )
|
44
|
+
|
45
|
+
for i in range(0,center[0],1):
|
46
|
+
for j in range(num_middle_points):
|
47
|
+
if top[j].size==0:
|
48
|
+
if mask[i,center_top[j][1]]==255:
|
49
|
+
top[j] = np.asarray([center_top[j][1],i])
|
50
|
+
top[j] = top[j].reshape(1,2)
|
51
|
+
|
52
|
+
for i in range(mask.shape[0]-1,center[0],-1):
|
53
|
+
for j in range(num_middle_points):
|
54
|
+
if bottom[j].size==0:
|
55
|
+
if mask[i,center_bottom[j][1]]==255:
|
56
|
+
bottom[j] = np.asarray([center_bottom[j][1],i])
|
57
|
+
bottom[j] = bottom[j].reshape(1,2)
|
58
|
+
|
59
|
+
for i in range(mask.shape[1]-1,center[1],-1):
|
60
|
+
for j in range(num_middle_points):
|
61
|
+
if right[j].size==0:
|
62
|
+
if mask[center_right[j][0],i]==255:
|
63
|
+
right[j] = np.asarray([i,center_right[j][0]])
|
64
|
+
right[j] = right[j].reshape(1,2)
|
65
|
+
|
66
|
+
for i in range(0,center[1]):
|
67
|
+
for j in range(num_middle_points):
|
68
|
+
if left[j].size==0:
|
69
|
+
if mask[center_left[j][0],i]==255:
|
70
|
+
left[j] = np.asarray([i,center_left[j][0]])
|
71
|
+
left[j] = left[j].reshape(1,2)
|
72
|
+
|
73
|
+
return np.asarray(top+bottom+left+right)
|
74
|
+
|
75
|
+
def DP_algorithmv1(contours):
|
76
|
+
biggest = np.array([])
|
77
|
+
max_area = 0
|
78
|
+
step = 0.001
|
79
|
+
count = 0
|
80
|
+
# while biggest.size==0:
|
81
|
+
while True:
|
82
|
+
for i in contours:
|
83
|
+
# print(i.shape)
|
84
|
+
area = cv2.contourArea(i)
|
85
|
+
# print(area,cv2.arcLength(i, True))
|
86
|
+
if area > cv2.arcLength(i, True)*10:
|
87
|
+
peri = cv2.arcLength(i, True)
|
88
|
+
approx = cv2.approxPolyDP(i, (0.01+step*count) * peri, True)
|
89
|
+
if area > max_area and len(approx) == 4:
|
90
|
+
max_area = area
|
91
|
+
biggest_contours = i
|
92
|
+
biggest = approx
|
93
|
+
break
|
94
|
+
if abs(max_area - cv2.contourArea(biggest))/max_area > 0.3:
|
95
|
+
biggest = np.array([])
|
96
|
+
count += 1
|
97
|
+
if count > 200:
|
98
|
+
break
|
99
|
+
temp = biggest[0]
|
100
|
+
return biggest,max_area, biggest_contours
|
101
|
+
|
102
|
+
def DP_algorithm(contours):
|
103
|
+
biggest = np.array([])
|
104
|
+
max_area = 0
|
105
|
+
step = 0.001
|
106
|
+
count = 0
|
107
|
+
|
108
|
+
### largest contours
|
109
|
+
for i in contours:
|
110
|
+
area = cv2.contourArea(i)
|
111
|
+
if area > max_area:
|
112
|
+
max_area = area
|
113
|
+
biggest_contours = i
|
114
|
+
peri = cv2.arcLength(biggest_contours, True)
|
115
|
+
|
116
|
+
### find four corners
|
117
|
+
while True:
|
118
|
+
approx = cv2.approxPolyDP(biggest_contours, (0.01+step*count) * peri, True)
|
119
|
+
if len(approx) == 4:
|
120
|
+
biggest = approx
|
121
|
+
break
|
122
|
+
# if abs(max_area - cv2.contourArea(biggest))/max_area > 0.2:
|
123
|
+
# if abs(max_area - cv2.contourArea(biggest))/max_area > 0.4:
|
124
|
+
# biggest = np.array([])
|
125
|
+
count += 1
|
126
|
+
if count > 200:
|
127
|
+
break
|
128
|
+
return biggest,max_area, biggest_contours
|
129
|
+
|
130
|
+
def drawRectangle(img,biggest,color,thickness):
|
131
|
+
cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
|
132
|
+
cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
|
133
|
+
cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
|
134
|
+
cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
|
135
|
+
return img
|
136
|
+
|
137
|
+
def minAreaRect(contours,img):
|
138
|
+
# biggest = np.array([])
|
139
|
+
max_area = 0
|
140
|
+
for i in contours:
|
141
|
+
area = cv2.contourArea(i)
|
142
|
+
if area > max_area:
|
143
|
+
peri = cv2.arcLength(i, True)
|
144
|
+
rect = cv2.minAreaRect(i)
|
145
|
+
points = cv2.boxPoints(rect)
|
146
|
+
max_area = area
|
147
|
+
return points
|
148
|
+
|
149
|
+
def cropRectangle(img,biggest):
|
150
|
+
# print(biggest)
|
151
|
+
w = np.abs(biggest[0][0][0] - biggest[1][0][0])
|
152
|
+
h = np.abs(biggest[0][0][1] - biggest[2][0][1])
|
153
|
+
new_img = np.zeros((w,h,img.shape[-1]),dtype=np.uint8)
|
154
|
+
new_img = img[biggest[0][0][1]:biggest[0][0][1]+h,biggest[0][0][0]:biggest[0][0][0]+w]
|
155
|
+
return new_img
|
156
|
+
|
157
|
+
def cvimg2torch(img,min=0,max=1):
|
158
|
+
'''
|
159
|
+
input:
|
160
|
+
im -> ndarray uint8 HxWxC
|
161
|
+
return
|
162
|
+
tensor -> torch.tensor BxCxHxW
|
163
|
+
'''
|
164
|
+
if len(img.shape)==2:
|
165
|
+
img = np.expand_dims(img,axis=-1)
|
166
|
+
img = img.astype(float) / 255.0
|
167
|
+
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
168
|
+
img = np.expand_dims(img, 0)
|
169
|
+
img = torch.from_numpy(img).float()
|
170
|
+
return img
|
171
|
+
|
172
|
+
def torch2cvimg(tensor,min=0,max=1):
|
173
|
+
'''
|
174
|
+
input:
|
175
|
+
tensor -> torch.tensor BxCxHxW C can be 1,3
|
176
|
+
return
|
177
|
+
im -> ndarray uint8 HxWxC
|
178
|
+
'''
|
179
|
+
im_list = []
|
180
|
+
for i in range(tensor.shape[0]):
|
181
|
+
im = tensor.detach().cpu().data.numpy()[i]
|
182
|
+
im = im.transpose(1,2,0)
|
183
|
+
im = np.clip(im,min,max)
|
184
|
+
im = ((im-min)/(max-min)*255).astype(np.uint8)
|
185
|
+
im_list.append(im)
|
186
|
+
return im_list
|
187
|
+
|
188
|
+
|
189
|
+
|
190
|
+
class TPSGridGen(nn.Module):
|
191
|
+
def __init__(self, target_height, target_width, target_control_points):
|
192
|
+
'''
|
193
|
+
target_control_points -> torch.tensor num_pointx2 -1~1
|
194
|
+
source_control_points -> torch.tensor batch_size x num_point x 2 -1~1
|
195
|
+
return:
|
196
|
+
grid -> batch_size x hw x 2 -1~1
|
197
|
+
'''
|
198
|
+
super(TPSGridGen, self).__init__()
|
199
|
+
assert target_control_points.ndimension() == 2
|
200
|
+
assert target_control_points.size(1) == 2
|
201
|
+
N = target_control_points.size(0)
|
202
|
+
self.num_points = N
|
203
|
+
target_control_points = target_control_points.float()
|
204
|
+
|
205
|
+
# create padded kernel matrix
|
206
|
+
forward_kernel = torch.zeros(N + 3, N + 3)
|
207
|
+
target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
|
208
|
+
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
209
|
+
forward_kernel[:N, -3].fill_(1)
|
210
|
+
forward_kernel[-3, :N].fill_(1)
|
211
|
+
forward_kernel[:N, -2:].copy_(target_control_points)
|
212
|
+
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
213
|
+
# compute inverse matrix
|
214
|
+
inverse_kernel = torch.inverse(forward_kernel)
|
215
|
+
|
216
|
+
# create target cordinate matrix
|
217
|
+
HW = target_height * target_width
|
218
|
+
target_coordinate = list(itertools.product(range(target_height), range(target_width)))
|
219
|
+
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
220
|
+
Y, X = target_coordinate.split(1, dim = 1)
|
221
|
+
Y = Y * 2 / (target_height - 1) - 1
|
222
|
+
X = X * 2 / (target_width - 1) - 1
|
223
|
+
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
|
224
|
+
target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate.to(target_control_points.device), target_control_points)
|
225
|
+
target_coordinate_repr = torch.cat([
|
226
|
+
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
227
|
+
], dim = 1)
|
228
|
+
|
229
|
+
# register precomputed matrices
|
230
|
+
self.register_buffer('inverse_kernel', inverse_kernel)
|
231
|
+
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
232
|
+
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
233
|
+
|
234
|
+
def forward(self, source_control_points):
|
235
|
+
assert source_control_points.ndimension() == 3
|
236
|
+
assert source_control_points.size(1) == self.num_points
|
237
|
+
assert source_control_points.size(2) == 2
|
238
|
+
batch_size = source_control_points.size(0)
|
239
|
+
|
240
|
+
Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
|
241
|
+
mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
|
242
|
+
source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
|
243
|
+
return source_coordinate
|
244
|
+
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
245
|
+
def compute_partial_repr(self, input_points, control_points):
|
246
|
+
N = input_points.size(0)
|
247
|
+
M = control_points.size(0)
|
248
|
+
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
249
|
+
# original implementation, very slow
|
250
|
+
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
251
|
+
pairwise_diff_square = pairwise_diff * pairwise_diff
|
252
|
+
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
|
253
|
+
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
254
|
+
# fix numerical error for 0 * log(0), substitute all nan with 0
|
255
|
+
mask = repr_matrix != repr_matrix
|
256
|
+
repr_matrix.masked_fill_(mask, 0)
|
257
|
+
return repr_matrix
|
258
|
+
|
259
|
+
|
260
|
+
|
261
|
+
|
262
|
+
|
263
|
+
### deside wheather further process
|
264
|
+
# point_area = cv2.contourArea(np.concatenate((biggest_angle[0].reshape(1,1,2),middle[0:3],biggest_angle[1].reshape(1,1,2),middle[9:12],biggest_angle[3].reshape(1,1,2),middle[3:6][::-1],biggest_angle[2].reshape(1,1,2),middle[6:9][::-1]),axis=0))
|
265
|
+
#### 最小外接矩形
|
266
|
+
# rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
|
267
|
+
# box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
|
268
|
+
# box = np.int0(box)
|
269
|
+
# box = box.reshape((4,1,2))
|
270
|
+
# minrect_area = cv2.contourArea(box)
|
271
|
+
# print(abs(minrect_area-point_area)/point_area)
|
272
|
+
#### 四个角点 IOU
|
273
|
+
# biggest_box = np.concatenate((biggest_angle[0,:,:].reshape(1,1,2),biggest_angle[2,:,:].reshape(1,1,2),biggest_angle[3,:,:].reshape(1,1,2),biggest_angle[1,:,:].reshape(1,1,2)),axis=0)
|
274
|
+
# biggest_mask = np.zeros_like(mask)
|
275
|
+
# # corner_area = cv2.contourArea(biggest_box)
|
276
|
+
# cv2.drawContours(biggest_mask,[biggest_box], -1, color=255, thickness=-1)
|
277
|
+
|
278
|
+
# smooth = 1e-5
|
279
|
+
# biggest_mask_ = biggest_mask > 50
|
280
|
+
# mask_ = mask > 50
|
281
|
+
# intersection = (biggest_mask_ & mask_).sum()
|
282
|
+
# union = (biggest_mask_ | mask_).sum()
|
283
|
+
# iou = (intersection + smooth) / (union + smooth)
|
284
|
+
# if iou > 0.975:
|
285
|
+
# skip = True
|
286
|
+
# else:
|
287
|
+
# skip = False
|
288
|
+
# print(iou)
|
289
|
+
# cv2.imshow('mask',cv2.resize(mask,(512,512)))
|
290
|
+
# cv2.imshow('biggest_mask',cv2.resize(biggest_mask,(512,512)))
|
291
|
+
# cv2.waitKey(0)
|
@@ -0,0 +1,151 @@
|
|
1
|
+
import torch
|
2
|
+
import argparse
|
3
|
+
import numpy as np
|
4
|
+
import torch.nn.functional as F
|
5
|
+
import glob
|
6
|
+
import cv2
|
7
|
+
from tqdm import tqdm
|
8
|
+
|
9
|
+
import time
|
10
|
+
import os
|
11
|
+
from model.deep_lab_model.deeplab import *
|
12
|
+
from MBD import mask_base_dewarper
|
13
|
+
import time
|
14
|
+
|
15
|
+
from utils import cvimg2torch,torch2cvimg
|
16
|
+
|
17
|
+
|
18
|
+
|
19
|
+
def net1_net2_infer(model,img_paths,args):
|
20
|
+
|
21
|
+
### validate on the real datasets
|
22
|
+
seg_model=model
|
23
|
+
seg_model.eval()
|
24
|
+
for img_path in tqdm(img_paths):
|
25
|
+
if os.path.exists(img_path.replace('_origin','_capture')):
|
26
|
+
continue
|
27
|
+
t1 = time.time()
|
28
|
+
### segmentation mask predict
|
29
|
+
img_org = cv2.imread(img_path)
|
30
|
+
h_org,w_org = img_org.shape[:2]
|
31
|
+
img = cv2.resize(img_org,(448, 448))
|
32
|
+
img = cv2.GaussianBlur(img,(15,15),0,0)
|
33
|
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
34
|
+
img = cvimg2torch(img)
|
35
|
+
|
36
|
+
with torch.no_grad():
|
37
|
+
pred = seg_model(img.cuda())
|
38
|
+
mask_pred = pred[:,0,:,:].unsqueeze(1)
|
39
|
+
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
|
40
|
+
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
|
41
|
+
mask_pred = (mask_pred*255).astype(np.uint8)
|
42
|
+
kernel = np.ones((3,3))
|
43
|
+
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
|
44
|
+
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
|
45
|
+
mask_pred[mask_pred>100] = 255
|
46
|
+
mask_pred[mask_pred<100] = 0
|
47
|
+
### tps transform base on the mask
|
48
|
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
49
|
+
try:
|
50
|
+
dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
51
|
+
except:
|
52
|
+
print('fail')
|
53
|
+
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
|
54
|
+
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
|
55
|
+
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
|
56
|
+
grid = grid[0].numpy()
|
57
|
+
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
|
58
|
+
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
|
59
|
+
# cv2.waitKey(0)
|
60
|
+
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
|
61
|
+
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
|
62
|
+
|
63
|
+
grid0 = cv2.resize(grid[:,:,0],(128,128))
|
64
|
+
grid1 = cv2.resize(grid[:,:,1],(128,128))
|
65
|
+
grid = np.stack((grid0,grid1),axis=-1)
|
66
|
+
np.save(img_path.replace('_origin','_grid1'),grid)
|
67
|
+
|
68
|
+
|
69
|
+
def net1_net2_infer_single_im(img,model_path):
|
70
|
+
seg_model = DeepLab(num_classes=1,
|
71
|
+
backbone='resnet',
|
72
|
+
output_stride=16,
|
73
|
+
sync_bn=None,
|
74
|
+
freeze_bn=False)
|
75
|
+
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
|
76
|
+
seg_model.cuda()
|
77
|
+
checkpoint = torch.load(model_path)
|
78
|
+
seg_model.load_state_dict(checkpoint['model_state'])
|
79
|
+
### validate on the real datasets
|
80
|
+
seg_model.eval()
|
81
|
+
### segmentation mask predict
|
82
|
+
img_org = img
|
83
|
+
h_org,w_org = img_org.shape[:2]
|
84
|
+
img = cv2.resize(img_org,(448, 448))
|
85
|
+
img = cv2.GaussianBlur(img,(15,15),0,0)
|
86
|
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
87
|
+
img = cvimg2torch(img)
|
88
|
+
|
89
|
+
with torch.no_grad():
|
90
|
+
# from torchtoolbox.tools import summary
|
91
|
+
# print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
|
92
|
+
|
93
|
+
pred = seg_model(img.cuda())
|
94
|
+
mask_pred = pred[:,0,:,:].unsqueeze(1)
|
95
|
+
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
|
96
|
+
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
|
97
|
+
mask_pred = (mask_pred*255).astype(np.uint8)
|
98
|
+
kernel = np.ones((3,3))
|
99
|
+
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
|
100
|
+
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
|
101
|
+
mask_pred[mask_pred>100] = 255
|
102
|
+
mask_pred[mask_pred<100] = 0
|
103
|
+
### tps transform base on the mask
|
104
|
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
105
|
+
# try:
|
106
|
+
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
|
107
|
+
# except:
|
108
|
+
# print('fail')
|
109
|
+
# grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
|
110
|
+
# grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
|
111
|
+
# dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
|
112
|
+
# grid = grid[0].numpy()
|
113
|
+
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
|
114
|
+
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
|
115
|
+
# cv2.waitKey(0)
|
116
|
+
# cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
|
117
|
+
# cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
|
118
|
+
|
119
|
+
# grid0 = cv2.resize(grid[:,:,0],(128,128))
|
120
|
+
# grid1 = cv2.resize(grid[:,:,1],(128,128))
|
121
|
+
# grid = np.stack((grid0,grid1),axis=-1)
|
122
|
+
# np.save(img_path.replace('_origin','_grid1'),grid)
|
123
|
+
return mask_pred
|
124
|
+
|
125
|
+
|
126
|
+
|
127
|
+
if __name__ == '__main__':
|
128
|
+
parser = argparse.ArgumentParser(description='Hyperparams')
|
129
|
+
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
|
130
|
+
parser.add_argument('--img_rows', nargs='?', type=int, default=448,
|
131
|
+
help='Height of the input image')
|
132
|
+
parser.add_argument('--img_cols', nargs='?', type=int, default=448,
|
133
|
+
help='Width of the input image')
|
134
|
+
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
|
135
|
+
help='Path to previous saved model to restart from')
|
136
|
+
args = parser.parse_args()
|
137
|
+
|
138
|
+
seg_model = DeepLab(num_classes=1,
|
139
|
+
backbone='resnet',
|
140
|
+
output_stride=16,
|
141
|
+
sync_bn=None,
|
142
|
+
freeze_bn=False)
|
143
|
+
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
|
144
|
+
seg_model.cuda()
|
145
|
+
checkpoint = torch.load(args.seg_model_path)
|
146
|
+
seg_model.load_state_dict(checkpoint['model_state'])
|
147
|
+
|
148
|
+
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
|
149
|
+
|
150
|
+
net1_net2_infer(seg_model,im_paths,args)
|
151
|
+
|
@@ -0,0 +1,95 @@
|
|
1
|
+
import math
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
6
|
+
|
7
|
+
class _ASPPModule(nn.Module):
|
8
|
+
def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
|
9
|
+
super(_ASPPModule, self).__init__()
|
10
|
+
self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
|
11
|
+
stride=1, padding=padding, dilation=dilation, bias=False)
|
12
|
+
self.bn = BatchNorm(planes)
|
13
|
+
self.relu = nn.ReLU()
|
14
|
+
|
15
|
+
self._init_weight()
|
16
|
+
|
17
|
+
def forward(self, x):
|
18
|
+
x = self.atrous_conv(x)
|
19
|
+
x = self.bn(x)
|
20
|
+
|
21
|
+
return self.relu(x)
|
22
|
+
|
23
|
+
def _init_weight(self):
|
24
|
+
for m in self.modules():
|
25
|
+
if isinstance(m, nn.Conv2d):
|
26
|
+
torch.nn.init.kaiming_normal_(m.weight)
|
27
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
28
|
+
m.weight.data.fill_(1)
|
29
|
+
m.bias.data.zero_()
|
30
|
+
elif isinstance(m, nn.BatchNorm2d):
|
31
|
+
m.weight.data.fill_(1)
|
32
|
+
m.bias.data.zero_()
|
33
|
+
|
34
|
+
class ASPP(nn.Module):
|
35
|
+
def __init__(self, backbone, output_stride, BatchNorm):
|
36
|
+
super(ASPP, self).__init__()
|
37
|
+
if backbone == 'drn':
|
38
|
+
inplanes = 512
|
39
|
+
elif backbone == 'mobilenet':
|
40
|
+
inplanes = 320
|
41
|
+
else:
|
42
|
+
inplanes = 2048
|
43
|
+
if output_stride == 16:
|
44
|
+
dilations = [1, 6, 12, 18]
|
45
|
+
elif output_stride == 8:
|
46
|
+
dilations = [1, 12, 24, 36]
|
47
|
+
else:
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
|
51
|
+
self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
|
52
|
+
self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
|
53
|
+
self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
|
54
|
+
|
55
|
+
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
56
|
+
nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
|
57
|
+
BatchNorm(256),
|
58
|
+
nn.ReLU())
|
59
|
+
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
|
60
|
+
self.bn1 = BatchNorm(256)
|
61
|
+
self.relu = nn.ReLU()
|
62
|
+
self.dropout = nn.Dropout(0.5)
|
63
|
+
self._init_weight()
|
64
|
+
|
65
|
+
def forward(self, x):
|
66
|
+
x1 = self.aspp1(x)
|
67
|
+
x2 = self.aspp2(x)
|
68
|
+
x3 = self.aspp3(x)
|
69
|
+
x4 = self.aspp4(x)
|
70
|
+
x5 = self.global_avg_pool(x)
|
71
|
+
x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
|
72
|
+
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
73
|
+
|
74
|
+
x = self.conv1(x)
|
75
|
+
x = self.bn1(x)
|
76
|
+
x = self.relu(x)
|
77
|
+
|
78
|
+
return self.dropout(x)
|
79
|
+
|
80
|
+
def _init_weight(self):
|
81
|
+
for m in self.modules():
|
82
|
+
if isinstance(m, nn.Conv2d):
|
83
|
+
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
84
|
+
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
85
|
+
torch.nn.init.kaiming_normal_(m.weight)
|
86
|
+
elif isinstance(m, SynchronizedBatchNorm2d):
|
87
|
+
m.weight.data.fill_(1)
|
88
|
+
m.bias.data.zero_()
|
89
|
+
elif isinstance(m, nn.BatchNorm2d):
|
90
|
+
m.weight.data.fill_(1)
|
91
|
+
m.bias.data.zero_()
|
92
|
+
|
93
|
+
|
94
|
+
def build_aspp(backbone, output_stride, BatchNorm):
|
95
|
+
return ASPP(backbone, output_stride, BatchNorm)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from model.deep_lab_model.backbone import resnet, xception, drn, mobilenet
|
2
|
+
|
3
|
+
def build_backbone(backbone, output_stride, BatchNorm):
|
4
|
+
if backbone == 'resnet':
|
5
|
+
return resnet.ResNet101(output_stride, BatchNorm)
|
6
|
+
elif backbone == 'xception':
|
7
|
+
return xception.AlignedXception(output_stride, BatchNorm)
|
8
|
+
elif backbone == 'drn':
|
9
|
+
return drn.drn_d_54(BatchNorm)
|
10
|
+
elif backbone == 'mobilenet':
|
11
|
+
return mobilenet.MobileNetV2(output_stride, BatchNorm)
|
12
|
+
else:
|
13
|
+
raise NotImplementedError
|