doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +170 -9
  3. doctra/cli/utils.py +2 -3
  4. doctra/engines/image_restoration/__init__.py +10 -0
  5. doctra/engines/image_restoration/docres_engine.py +561 -0
  6. doctra/engines/vlm/outlines_types.py +13 -9
  7. doctra/engines/vlm/service.py +4 -2
  8. doctra/exporters/excel_writer.py +89 -0
  9. doctra/parsers/enhanced_pdf_parser.py +374 -0
  10. doctra/parsers/structured_pdf_parser.py +6 -0
  11. doctra/parsers/table_chart_extractor.py +6 -0
  12. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  13. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  14. doctra/third_party/docres/data/MBD/infer.py +151 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  25. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  26. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  27. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  28. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  29. doctra/third_party/docres/inference.py +370 -0
  30. doctra/third_party/docres/models/restormer_arch.py +308 -0
  31. doctra/third_party/docres/utils.py +464 -0
  32. doctra/ui/app.py +8 -14
  33. doctra/utils/structured_utils.py +5 -2
  34. doctra/version.py +1 -1
  35. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
  36. doctra-0.4.1.dist-info/RECORD +67 -0
  37. doctra-0.3.3.dist-info/RECORD +0 -44
  38. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
  39. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
  40. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
1
+ import cv2
2
+ import numpy as np
3
+ import copy
4
+ import torch
5
+ import torch
6
+ import itertools
7
+ import torch.nn as nn
8
+ from torch.autograd import Function, Variable
9
+
10
+ def reorder(myPoints):
11
+ myPoints = myPoints.reshape((4, 2))
12
+ myPointsNew = np.zeros((4, 1, 2), dtype=np.int32)
13
+ add = myPoints.sum(1)
14
+ myPointsNew[0] = myPoints[np.argmin(add)]
15
+ myPointsNew[3] =myPoints[np.argmax(add)]
16
+ diff = np.diff(myPoints, axis=1)
17
+ myPointsNew[1] =myPoints[np.argmin(diff)]
18
+ myPointsNew[2] = myPoints[np.argmax(diff)]
19
+ return myPointsNew
20
+
21
+
22
+ def findMiddle(corners,mask,points=[0.25,0.5,0.75]):
23
+ num_middle_points = len(points)
24
+ top = [np.array([])]*num_middle_points
25
+ bottom = [np.array([])]*num_middle_points
26
+ left = [np.array([])]*num_middle_points
27
+ right = [np.array([])]*num_middle_points
28
+
29
+ center_top = []
30
+ center_bottom = []
31
+ center_left = []
32
+ center_right = []
33
+
34
+ center = (int((corners[0][0][1]+corners[3][0][1])/2),int((corners[0][0][0]+corners[3][0][0])/2))
35
+ for ratio in points:
36
+
37
+ center_top.append( (center[0],int(corners[0][0][0]*(1-ratio)+corners[1][0][0]*ratio)) )
38
+
39
+ center_bottom.append( (center[0],int(corners[2][0][0]*(1-ratio)+corners[3][0][0]*ratio)) )
40
+
41
+ center_left.append( (int(corners[0][0][1]*(1-ratio)+corners[2][0][1]*ratio),center[1]) )
42
+
43
+ center_right.append( (int(corners[1][0][1]*(1-ratio)+corners[3][0][1]*ratio),center[1]) )
44
+
45
+ for i in range(0,center[0],1):
46
+ for j in range(num_middle_points):
47
+ if top[j].size==0:
48
+ if mask[i,center_top[j][1]]==255:
49
+ top[j] = np.asarray([center_top[j][1],i])
50
+ top[j] = top[j].reshape(1,2)
51
+
52
+ for i in range(mask.shape[0]-1,center[0],-1):
53
+ for j in range(num_middle_points):
54
+ if bottom[j].size==0:
55
+ if mask[i,center_bottom[j][1]]==255:
56
+ bottom[j] = np.asarray([center_bottom[j][1],i])
57
+ bottom[j] = bottom[j].reshape(1,2)
58
+
59
+ for i in range(mask.shape[1]-1,center[1],-1):
60
+ for j in range(num_middle_points):
61
+ if right[j].size==0:
62
+ if mask[center_right[j][0],i]==255:
63
+ right[j] = np.asarray([i,center_right[j][0]])
64
+ right[j] = right[j].reshape(1,2)
65
+
66
+ for i in range(0,center[1]):
67
+ for j in range(num_middle_points):
68
+ if left[j].size==0:
69
+ if mask[center_left[j][0],i]==255:
70
+ left[j] = np.asarray([i,center_left[j][0]])
71
+ left[j] = left[j].reshape(1,2)
72
+
73
+ return np.asarray(top+bottom+left+right)
74
+
75
+ def DP_algorithmv1(contours):
76
+ biggest = np.array([])
77
+ max_area = 0
78
+ step = 0.001
79
+ count = 0
80
+ # while biggest.size==0:
81
+ while True:
82
+ for i in contours:
83
+ # print(i.shape)
84
+ area = cv2.contourArea(i)
85
+ # print(area,cv2.arcLength(i, True))
86
+ if area > cv2.arcLength(i, True)*10:
87
+ peri = cv2.arcLength(i, True)
88
+ approx = cv2.approxPolyDP(i, (0.01+step*count) * peri, True)
89
+ if area > max_area and len(approx) == 4:
90
+ max_area = area
91
+ biggest_contours = i
92
+ biggest = approx
93
+ break
94
+ if abs(max_area - cv2.contourArea(biggest))/max_area > 0.3:
95
+ biggest = np.array([])
96
+ count += 1
97
+ if count > 200:
98
+ break
99
+ temp = biggest[0]
100
+ return biggest,max_area, biggest_contours
101
+
102
+ def DP_algorithm(contours):
103
+ biggest = np.array([])
104
+ max_area = 0
105
+ step = 0.001
106
+ count = 0
107
+
108
+ ### largest contours
109
+ for i in contours:
110
+ area = cv2.contourArea(i)
111
+ if area > max_area:
112
+ max_area = area
113
+ biggest_contours = i
114
+ peri = cv2.arcLength(biggest_contours, True)
115
+
116
+ ### find four corners
117
+ while True:
118
+ approx = cv2.approxPolyDP(biggest_contours, (0.01+step*count) * peri, True)
119
+ if len(approx) == 4:
120
+ biggest = approx
121
+ break
122
+ # if abs(max_area - cv2.contourArea(biggest))/max_area > 0.2:
123
+ # if abs(max_area - cv2.contourArea(biggest))/max_area > 0.4:
124
+ # biggest = np.array([])
125
+ count += 1
126
+ if count > 200:
127
+ break
128
+ return biggest,max_area, biggest_contours
129
+
130
+ def drawRectangle(img,biggest,color,thickness):
131
+ cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
132
+ cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
133
+ cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
134
+ cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
135
+ return img
136
+
137
+ def minAreaRect(contours,img):
138
+ # biggest = np.array([])
139
+ max_area = 0
140
+ for i in contours:
141
+ area = cv2.contourArea(i)
142
+ if area > max_area:
143
+ peri = cv2.arcLength(i, True)
144
+ rect = cv2.minAreaRect(i)
145
+ points = cv2.boxPoints(rect)
146
+ max_area = area
147
+ return points
148
+
149
+ def cropRectangle(img,biggest):
150
+ # print(biggest)
151
+ w = np.abs(biggest[0][0][0] - biggest[1][0][0])
152
+ h = np.abs(biggest[0][0][1] - biggest[2][0][1])
153
+ new_img = np.zeros((w,h,img.shape[-1]),dtype=np.uint8)
154
+ new_img = img[biggest[0][0][1]:biggest[0][0][1]+h,biggest[0][0][0]:biggest[0][0][0]+w]
155
+ return new_img
156
+
157
+ def cvimg2torch(img,min=0,max=1):
158
+ '''
159
+ input:
160
+ im -> ndarray uint8 HxWxC
161
+ return
162
+ tensor -> torch.tensor BxCxHxW
163
+ '''
164
+ if len(img.shape)==2:
165
+ img = np.expand_dims(img,axis=-1)
166
+ img = img.astype(float) / 255.0
167
+ img = img.transpose(2, 0, 1) # NHWC -> NCHW
168
+ img = np.expand_dims(img, 0)
169
+ img = torch.from_numpy(img).float()
170
+ return img
171
+
172
+ def torch2cvimg(tensor,min=0,max=1):
173
+ '''
174
+ input:
175
+ tensor -> torch.tensor BxCxHxW C can be 1,3
176
+ return
177
+ im -> ndarray uint8 HxWxC
178
+ '''
179
+ im_list = []
180
+ for i in range(tensor.shape[0]):
181
+ im = tensor.detach().cpu().data.numpy()[i]
182
+ im = im.transpose(1,2,0)
183
+ im = np.clip(im,min,max)
184
+ im = ((im-min)/(max-min)*255).astype(np.uint8)
185
+ im_list.append(im)
186
+ return im_list
187
+
188
+
189
+
190
+ class TPSGridGen(nn.Module):
191
+ def __init__(self, target_height, target_width, target_control_points):
192
+ '''
193
+ target_control_points -> torch.tensor num_pointx2 -1~1
194
+ source_control_points -> torch.tensor batch_size x num_point x 2 -1~1
195
+ return:
196
+ grid -> batch_size x hw x 2 -1~1
197
+ '''
198
+ super(TPSGridGen, self).__init__()
199
+ assert target_control_points.ndimension() == 2
200
+ assert target_control_points.size(1) == 2
201
+ N = target_control_points.size(0)
202
+ self.num_points = N
203
+ target_control_points = target_control_points.float()
204
+
205
+ # create padded kernel matrix
206
+ forward_kernel = torch.zeros(N + 3, N + 3)
207
+ target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
208
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
209
+ forward_kernel[:N, -3].fill_(1)
210
+ forward_kernel[-3, :N].fill_(1)
211
+ forward_kernel[:N, -2:].copy_(target_control_points)
212
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
213
+ # compute inverse matrix
214
+ inverse_kernel = torch.inverse(forward_kernel)
215
+
216
+ # create target cordinate matrix
217
+ HW = target_height * target_width
218
+ target_coordinate = list(itertools.product(range(target_height), range(target_width)))
219
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
220
+ Y, X = target_coordinate.split(1, dim = 1)
221
+ Y = Y * 2 / (target_height - 1) - 1
222
+ X = X * 2 / (target_width - 1) - 1
223
+ target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
224
+ target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate.to(target_control_points.device), target_control_points)
225
+ target_coordinate_repr = torch.cat([
226
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
227
+ ], dim = 1)
228
+
229
+ # register precomputed matrices
230
+ self.register_buffer('inverse_kernel', inverse_kernel)
231
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
232
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
233
+
234
+ def forward(self, source_control_points):
235
+ assert source_control_points.ndimension() == 3
236
+ assert source_control_points.size(1) == self.num_points
237
+ assert source_control_points.size(2) == 2
238
+ batch_size = source_control_points.size(0)
239
+
240
+ Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
241
+ mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
242
+ source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
243
+ return source_coordinate
244
+ # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
245
+ def compute_partial_repr(self, input_points, control_points):
246
+ N = input_points.size(0)
247
+ M = control_points.size(0)
248
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
249
+ # original implementation, very slow
250
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
251
+ pairwise_diff_square = pairwise_diff * pairwise_diff
252
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
253
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
254
+ # fix numerical error for 0 * log(0), substitute all nan with 0
255
+ mask = repr_matrix != repr_matrix
256
+ repr_matrix.masked_fill_(mask, 0)
257
+ return repr_matrix
258
+
259
+
260
+
261
+
262
+
263
+ ### deside wheather further process
264
+ # point_area = cv2.contourArea(np.concatenate((biggest_angle[0].reshape(1,1,2),middle[0:3],biggest_angle[1].reshape(1,1,2),middle[9:12],biggest_angle[3].reshape(1,1,2),middle[3:6][::-1],biggest_angle[2].reshape(1,1,2),middle[6:9][::-1]),axis=0))
265
+ #### 最小外接矩形
266
+ # rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
267
+ # box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
268
+ # box = np.int0(box)
269
+ # box = box.reshape((4,1,2))
270
+ # minrect_area = cv2.contourArea(box)
271
+ # print(abs(minrect_area-point_area)/point_area)
272
+ #### 四个角点 IOU
273
+ # biggest_box = np.concatenate((biggest_angle[0,:,:].reshape(1,1,2),biggest_angle[2,:,:].reshape(1,1,2),biggest_angle[3,:,:].reshape(1,1,2),biggest_angle[1,:,:].reshape(1,1,2)),axis=0)
274
+ # biggest_mask = np.zeros_like(mask)
275
+ # # corner_area = cv2.contourArea(biggest_box)
276
+ # cv2.drawContours(biggest_mask,[biggest_box], -1, color=255, thickness=-1)
277
+
278
+ # smooth = 1e-5
279
+ # biggest_mask_ = biggest_mask > 50
280
+ # mask_ = mask > 50
281
+ # intersection = (biggest_mask_ & mask_).sum()
282
+ # union = (biggest_mask_ | mask_).sum()
283
+ # iou = (intersection + smooth) / (union + smooth)
284
+ # if iou > 0.975:
285
+ # skip = True
286
+ # else:
287
+ # skip = False
288
+ # print(iou)
289
+ # cv2.imshow('mask',cv2.resize(mask,(512,512)))
290
+ # cv2.imshow('biggest_mask',cv2.resize(biggest_mask,(512,512)))
291
+ # cv2.waitKey(0)
@@ -0,0 +1,151 @@
1
+ import torch
2
+ import argparse
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import glob
6
+ import cv2
7
+ from tqdm import tqdm
8
+
9
+ import time
10
+ import os
11
+ from model.deep_lab_model.deeplab import *
12
+ from MBD import mask_base_dewarper
13
+ import time
14
+
15
+ from utils import cvimg2torch,torch2cvimg
16
+
17
+
18
+
19
+ def net1_net2_infer(model,img_paths,args):
20
+
21
+ ### validate on the real datasets
22
+ seg_model=model
23
+ seg_model.eval()
24
+ for img_path in tqdm(img_paths):
25
+ if os.path.exists(img_path.replace('_origin','_capture')):
26
+ continue
27
+ t1 = time.time()
28
+ ### segmentation mask predict
29
+ img_org = cv2.imread(img_path)
30
+ h_org,w_org = img_org.shape[:2]
31
+ img = cv2.resize(img_org,(448, 448))
32
+ img = cv2.GaussianBlur(img,(15,15),0,0)
33
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
34
+ img = cvimg2torch(img)
35
+
36
+ with torch.no_grad():
37
+ pred = seg_model(img.cuda())
38
+ mask_pred = pred[:,0,:,:].unsqueeze(1)
39
+ mask_pred = F.interpolate(mask_pred,(h_org,w_org))
40
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
41
+ mask_pred = (mask_pred*255).astype(np.uint8)
42
+ kernel = np.ones((3,3))
43
+ mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
44
+ mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
45
+ mask_pred[mask_pred>100] = 255
46
+ mask_pred[mask_pred<100] = 0
47
+ ### tps transform base on the mask
48
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
49
+ try:
50
+ dewarp, grid = mask_base_dewarper(img_org,mask_pred)
51
+ except:
52
+ print('fail')
53
+ grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
54
+ grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
55
+ dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
56
+ grid = grid[0].numpy()
57
+ # cv2.imshow('in',cv2.resize(img_org,(512,512)))
58
+ # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
59
+ # cv2.waitKey(0)
60
+ cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
61
+ cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
62
+
63
+ grid0 = cv2.resize(grid[:,:,0],(128,128))
64
+ grid1 = cv2.resize(grid[:,:,1],(128,128))
65
+ grid = np.stack((grid0,grid1),axis=-1)
66
+ np.save(img_path.replace('_origin','_grid1'),grid)
67
+
68
+
69
+ def net1_net2_infer_single_im(img,model_path):
70
+ seg_model = DeepLab(num_classes=1,
71
+ backbone='resnet',
72
+ output_stride=16,
73
+ sync_bn=None,
74
+ freeze_bn=False)
75
+ seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
76
+ seg_model.cuda()
77
+ checkpoint = torch.load(model_path)
78
+ seg_model.load_state_dict(checkpoint['model_state'])
79
+ ### validate on the real datasets
80
+ seg_model.eval()
81
+ ### segmentation mask predict
82
+ img_org = img
83
+ h_org,w_org = img_org.shape[:2]
84
+ img = cv2.resize(img_org,(448, 448))
85
+ img = cv2.GaussianBlur(img,(15,15),0,0)
86
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
87
+ img = cvimg2torch(img)
88
+
89
+ with torch.no_grad():
90
+ # from torchtoolbox.tools import summary
91
+ # print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
92
+
93
+ pred = seg_model(img.cuda())
94
+ mask_pred = pred[:,0,:,:].unsqueeze(1)
95
+ mask_pred = F.interpolate(mask_pred,(h_org,w_org))
96
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
97
+ mask_pred = (mask_pred*255).astype(np.uint8)
98
+ kernel = np.ones((3,3))
99
+ mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
100
+ mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
101
+ mask_pred[mask_pred>100] = 255
102
+ mask_pred[mask_pred<100] = 0
103
+ ### tps transform base on the mask
104
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
105
+ # try:
106
+ # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
107
+ # except:
108
+ # print('fail')
109
+ # grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
110
+ # grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
111
+ # dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
112
+ # grid = grid[0].numpy()
113
+ # cv2.imshow('in',cv2.resize(img_org,(512,512)))
114
+ # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
115
+ # cv2.waitKey(0)
116
+ # cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
117
+ # cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
118
+
119
+ # grid0 = cv2.resize(grid[:,:,0],(128,128))
120
+ # grid1 = cv2.resize(grid[:,:,1],(128,128))
121
+ # grid = np.stack((grid0,grid1),axis=-1)
122
+ # np.save(img_path.replace('_origin','_grid1'),grid)
123
+ return mask_pred
124
+
125
+
126
+
127
+ if __name__ == '__main__':
128
+ parser = argparse.ArgumentParser(description='Hyperparams')
129
+ parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
130
+ parser.add_argument('--img_rows', nargs='?', type=int, default=448,
131
+ help='Height of the input image')
132
+ parser.add_argument('--img_cols', nargs='?', type=int, default=448,
133
+ help='Width of the input image')
134
+ parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
135
+ help='Path to previous saved model to restart from')
136
+ args = parser.parse_args()
137
+
138
+ seg_model = DeepLab(num_classes=1,
139
+ backbone='resnet',
140
+ output_stride=16,
141
+ sync_bn=None,
142
+ freeze_bn=False)
143
+ seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
144
+ seg_model.cuda()
145
+ checkpoint = torch.load(args.seg_model_path)
146
+ seg_model.load_state_dict(checkpoint['model_state'])
147
+
148
+ im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
149
+
150
+ net1_net2_infer(seg_model,im_paths,args)
151
+
@@ -0,0 +1,95 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6
+
7
+ class _ASPPModule(nn.Module):
8
+ def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9
+ super(_ASPPModule, self).__init__()
10
+ self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11
+ stride=1, padding=padding, dilation=dilation, bias=False)
12
+ self.bn = BatchNorm(planes)
13
+ self.relu = nn.ReLU()
14
+
15
+ self._init_weight()
16
+
17
+ def forward(self, x):
18
+ x = self.atrous_conv(x)
19
+ x = self.bn(x)
20
+
21
+ return self.relu(x)
22
+
23
+ def _init_weight(self):
24
+ for m in self.modules():
25
+ if isinstance(m, nn.Conv2d):
26
+ torch.nn.init.kaiming_normal_(m.weight)
27
+ elif isinstance(m, SynchronizedBatchNorm2d):
28
+ m.weight.data.fill_(1)
29
+ m.bias.data.zero_()
30
+ elif isinstance(m, nn.BatchNorm2d):
31
+ m.weight.data.fill_(1)
32
+ m.bias.data.zero_()
33
+
34
+ class ASPP(nn.Module):
35
+ def __init__(self, backbone, output_stride, BatchNorm):
36
+ super(ASPP, self).__init__()
37
+ if backbone == 'drn':
38
+ inplanes = 512
39
+ elif backbone == 'mobilenet':
40
+ inplanes = 320
41
+ else:
42
+ inplanes = 2048
43
+ if output_stride == 16:
44
+ dilations = [1, 6, 12, 18]
45
+ elif output_stride == 8:
46
+ dilations = [1, 12, 24, 36]
47
+ else:
48
+ raise NotImplementedError
49
+
50
+ self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51
+ self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52
+ self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53
+ self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54
+
55
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56
+ nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57
+ BatchNorm(256),
58
+ nn.ReLU())
59
+ self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60
+ self.bn1 = BatchNorm(256)
61
+ self.relu = nn.ReLU()
62
+ self.dropout = nn.Dropout(0.5)
63
+ self._init_weight()
64
+
65
+ def forward(self, x):
66
+ x1 = self.aspp1(x)
67
+ x2 = self.aspp2(x)
68
+ x3 = self.aspp3(x)
69
+ x4 = self.aspp4(x)
70
+ x5 = self.global_avg_pool(x)
71
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73
+
74
+ x = self.conv1(x)
75
+ x = self.bn1(x)
76
+ x = self.relu(x)
77
+
78
+ return self.dropout(x)
79
+
80
+ def _init_weight(self):
81
+ for m in self.modules():
82
+ if isinstance(m, nn.Conv2d):
83
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
85
+ torch.nn.init.kaiming_normal_(m.weight)
86
+ elif isinstance(m, SynchronizedBatchNorm2d):
87
+ m.weight.data.fill_(1)
88
+ m.bias.data.zero_()
89
+ elif isinstance(m, nn.BatchNorm2d):
90
+ m.weight.data.fill_(1)
91
+ m.bias.data.zero_()
92
+
93
+
94
+ def build_aspp(backbone, output_stride, BatchNorm):
95
+ return ASPP(backbone, output_stride, BatchNorm)
@@ -0,0 +1,13 @@
1
+ from model.deep_lab_model.backbone import resnet, xception, drn, mobilenet
2
+
3
+ def build_backbone(backbone, output_stride, BatchNorm):
4
+ if backbone == 'resnet':
5
+ return resnet.ResNet101(output_stride, BatchNorm)
6
+ elif backbone == 'xception':
7
+ return xception.AlignedXception(output_stride, BatchNorm)
8
+ elif backbone == 'drn':
9
+ return drn.drn_d_54(BatchNorm)
10
+ elif backbone == 'mobilenet':
11
+ return mobilenet.MobileNetV2(output_stride, BatchNorm)
12
+ else:
13
+ raise NotImplementedError