doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +168 -0
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +566 -0
- doctra/engines/vlm/service.py +0 -12
- doctra/parsers/enhanced_pdf_parser.py +370 -0
- doctra/parsers/structured_pdf_parser.py +11 -60
- doctra/parsers/table_chart_extractor.py +8 -44
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +5 -32
- doctra/utils/progress.py +13 -98
- doctra/utils/structured_utils.py +45 -49
- doctra/version.py +1 -1
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
- doctra-0.4.0.dist-info/RECORD +67 -0
- doctra-0.3.2.dist-info/RECORD +0 -44
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import cv2
|
4
|
+
import numpy as np
|
5
|
+
# SIZE =256
|
6
|
+
# BATCH_SIZE = 32
|
7
|
+
# STRIDES = 256
|
8
|
+
|
9
|
+
def split_img(img, size_x, size_y, strides):
|
10
|
+
max_y, max_x = img.shape[:2]
|
11
|
+
border_y = 0
|
12
|
+
if max_y % size_y != 0:
|
13
|
+
border_y = size_y - (max_y % size_y)
|
14
|
+
img = cv2.copyMakeBorder(img,border_y,0,0,0,cv2.BORDER_REPLICATE)
|
15
|
+
# img = cv2.copyMakeBorder(img, border_y, 0, 0, 0, cv2.BORDER_CONSTANT, value=[255,255,255])
|
16
|
+
border_x = 0
|
17
|
+
if max_x % size_x != 0:
|
18
|
+
border_x = size_x - (max_x % size_x)
|
19
|
+
# img = cv2.copyMakeBorder(img, 0, 0, border_x, 0, cv2.BORDER_CONSTANT, value=[255,255,255])
|
20
|
+
img = cv2.copyMakeBorder(img,0,0,border_x,0,cv2.BORDER_REPLICATE)
|
21
|
+
# h,w
|
22
|
+
max_y, max_x = img.shape[:2]
|
23
|
+
parts = []
|
24
|
+
curr_y = 0
|
25
|
+
x = 0
|
26
|
+
y = 0
|
27
|
+
# TODO: rewrite with generators.
|
28
|
+
while (curr_y + size_y) <= max_y:
|
29
|
+
curr_x = 0
|
30
|
+
while (curr_x + size_x) <= max_x:
|
31
|
+
parts.append(img[curr_y:curr_y + size_y, curr_x:curr_x + size_x])
|
32
|
+
curr_x += strides
|
33
|
+
y += 1
|
34
|
+
curr_y += strides
|
35
|
+
# parts is a list
|
36
|
+
# (windows_number_x*windows_number_y,SIZE,SIZE,3)
|
37
|
+
# print(max_y,max_x)
|
38
|
+
# print(y,x)
|
39
|
+
# print(np.array(parts).shape)
|
40
|
+
return parts, border_x, border_y, max_x, max_y
|
41
|
+
|
42
|
+
|
43
|
+
def combine_imgs(border_x,border_y,imgs, max_y, max_x,size_x, size_y, strides):
|
44
|
+
|
45
|
+
# weighted_img
|
46
|
+
|
47
|
+
index = int(size_x / strides)
|
48
|
+
weight_img = np.ones(shape=(max_y,max_x))
|
49
|
+
weight_img[0:strides] = index
|
50
|
+
weight_img[-strides:] = index
|
51
|
+
weight_img[:,0:strides]=index
|
52
|
+
weight_img[:,-strides:]=index
|
53
|
+
|
54
|
+
# 边上
|
55
|
+
i = 0
|
56
|
+
for j in range(1,index+1):
|
57
|
+
# 左上
|
58
|
+
weight_img[0:strides,i:i+strides] = np.ones(shape=(strides,strides))*j
|
59
|
+
weight_img[i:i+strides,0:strides] = np.ones(shape=(strides,strides))*j
|
60
|
+
# 右上
|
61
|
+
weight_img[i:i+strides,-strides:] = np.ones(shape=(strides,strides))*j
|
62
|
+
if i == 0:
|
63
|
+
weight_img[0:strides,-strides:] = np.ones(shape=(strides,strides))*j
|
64
|
+
else:
|
65
|
+
weight_img[0:strides,-strides-i:-i] = np.ones(shape=(strides,strides))*j
|
66
|
+
# 左下
|
67
|
+
weight_img[-strides:,i:i+strides] = np.ones(shape=(strides,strides))*j
|
68
|
+
if i == 0:
|
69
|
+
weight_img[-strides:,0:strides] = np.ones(shape=(strides,strides))*j
|
70
|
+
else:
|
71
|
+
weight_img[-strides-i:-i:,0:strides] = np.ones(shape=(strides,strides))*j
|
72
|
+
# 右下
|
73
|
+
if i == 0:
|
74
|
+
weight_img[-strides:,-strides:] = np.ones(shape=(strides,strides))*j
|
75
|
+
else:
|
76
|
+
weight_img[-strides-i:-i,-strides:] = np.ones(shape=(strides,strides))*j
|
77
|
+
weight_img[-strides:,-strides-i:-i] = np.ones(shape=(strides,strides))*j
|
78
|
+
|
79
|
+
|
80
|
+
i += strides
|
81
|
+
|
82
|
+
for i in range(strides,max_y-strides,strides):
|
83
|
+
for j in range(strides,max_x-strides,strides):
|
84
|
+
weight_img[i:i+strides,j:j+strides] = np.ones(shape=(strides,strides))*weight_img[i][0]*weight_img[0][j]
|
85
|
+
|
86
|
+
|
87
|
+
if len(imgs[0].shape)==2:
|
88
|
+
new_img = np.zeros(shape=(max_y,max_x))
|
89
|
+
weight_img = (1 / weight_img)
|
90
|
+
else:
|
91
|
+
new_img = np.zeros(shape=(max_y,max_x,imgs[0].shape[-1]))
|
92
|
+
weight_img = (1 / weight_img).reshape((max_y,max_x,1))
|
93
|
+
weight_img = np.tile(weight_img,(1,1,imgs[0].shape[-1]))
|
94
|
+
|
95
|
+
curr_y = 0
|
96
|
+
x = 0
|
97
|
+
y = 0
|
98
|
+
i = 0
|
99
|
+
# TODO: rewrite with generators.
|
100
|
+
while (curr_y + size_y) <= max_y:
|
101
|
+
curr_x = 0
|
102
|
+
while (curr_x + size_x) <= max_x:
|
103
|
+
new_img[curr_y:curr_y + size_y, curr_x:curr_x + size_x] += weight_img[curr_y:curr_y + size_y, curr_x:curr_x + size_x]*imgs[i]
|
104
|
+
i += 1
|
105
|
+
curr_x += strides
|
106
|
+
y += 1
|
107
|
+
curr_y += strides
|
108
|
+
|
109
|
+
|
110
|
+
new_img = new_img[border_y:, border_x:]
|
111
|
+
# print(border_y,border_x)
|
112
|
+
|
113
|
+
return new_img
|
114
|
+
|
115
|
+
|
116
|
+
def stride_integral(img,stride=32):
|
117
|
+
h,w = img.shape[:2]
|
118
|
+
|
119
|
+
if (h%stride)!=0:
|
120
|
+
padding_h = stride - (h%stride)
|
121
|
+
img = cv2.copyMakeBorder(img,padding_h,0,0,0,borderType=cv2.BORDER_REPLICATE)
|
122
|
+
else:
|
123
|
+
padding_h = 0
|
124
|
+
|
125
|
+
if (w%stride)!=0:
|
126
|
+
padding_w = stride - (w%stride)
|
127
|
+
img = cv2.copyMakeBorder(img,0,0,padding_w,0,borderType=cv2.BORDER_REPLICATE)
|
128
|
+
else:
|
129
|
+
padding_w = 0
|
130
|
+
|
131
|
+
return img,padding_h,padding_w
|
132
|
+
|
133
|
+
|
134
|
+
def mkdir_s(path: str):
|
135
|
+
"""Create directory in specified path, if not exists."""
|
136
|
+
if not os.path.exists(path):
|
137
|
+
os.makedirs(path)
|
138
|
+
|
139
|
+
|
140
|
+
if __name__ =='__main__':
|
141
|
+
parts, border_x, border_y, max_x, max_y = split_img(im,512,512,strides=512)
|
142
|
+
result = combine_imgs(border_x,border_y,parts, max_y, max_x,512, 512, 512)
|
@@ -0,0 +1,370 @@
|
|
1
|
+
import os
|
2
|
+
import cv2
|
3
|
+
import glob
|
4
|
+
from pathlib import Path
|
5
|
+
import utils
|
6
|
+
import argparse
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from utils import convert_state_dict
|
12
|
+
from models import restormer_arch
|
13
|
+
from data.preprocess.crop_merge_image import stride_integral
|
14
|
+
|
15
|
+
os.sys.path.append('./data/MBD/')
|
16
|
+
from data.MBD.infer import net1_net2_infer_single_im
|
17
|
+
|
18
|
+
|
19
|
+
def dewarp_prompt(img):
|
20
|
+
mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
|
21
|
+
base_coord = utils.getBasecoord(256,256)/256
|
22
|
+
img[mask==0]=0
|
23
|
+
mask = cv2.resize(mask,(256,256))/255
|
24
|
+
return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
|
25
|
+
|
26
|
+
def deshadow_prompt(img):
|
27
|
+
h,w = img.shape[:2]
|
28
|
+
# img = cv2.resize(img,(128,128))
|
29
|
+
img = cv2.resize(img,(1024,1024))
|
30
|
+
rgb_planes = cv2.split(img)
|
31
|
+
result_planes = []
|
32
|
+
result_norm_planes = []
|
33
|
+
bg_imgs = []
|
34
|
+
for plane in rgb_planes:
|
35
|
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
36
|
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
37
|
+
bg_imgs.append(bg_img)
|
38
|
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
39
|
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
40
|
+
result_planes.append(diff_img)
|
41
|
+
result_norm_planes.append(norm_img)
|
42
|
+
bg_imgs = cv2.merge(bg_imgs)
|
43
|
+
bg_imgs = cv2.resize(bg_imgs,(w,h))
|
44
|
+
# result = cv2.merge(result_planes)
|
45
|
+
result_norm = cv2.merge(result_norm_planes)
|
46
|
+
result_norm[result_norm==0]=1
|
47
|
+
shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
|
48
|
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
49
|
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
|
50
|
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
|
51
|
+
# return shadow_map
|
52
|
+
return bg_imgs
|
53
|
+
|
54
|
+
def deblur_prompt(img):
|
55
|
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
56
|
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
57
|
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
58
|
+
absY = cv2.convertScaleAbs(y)
|
59
|
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
60
|
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
61
|
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
|
62
|
+
return high_frequency
|
63
|
+
|
64
|
+
def appearance_prompt(img):
|
65
|
+
h,w = img.shape[:2]
|
66
|
+
# img = cv2.resize(img,(128,128))
|
67
|
+
img = cv2.resize(img,(1024,1024))
|
68
|
+
rgb_planes = cv2.split(img)
|
69
|
+
result_planes = []
|
70
|
+
result_norm_planes = []
|
71
|
+
for plane in rgb_planes:
|
72
|
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
73
|
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
74
|
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
75
|
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
76
|
+
result_planes.append(diff_img)
|
77
|
+
result_norm_planes.append(norm_img)
|
78
|
+
result_norm = cv2.merge(result_norm_planes)
|
79
|
+
result_norm = cv2.resize(result_norm,(w,h))
|
80
|
+
return result_norm
|
81
|
+
|
82
|
+
def binarization_promptv2(img):
|
83
|
+
result,thresh = utils.SauvolaModBinarization(img)
|
84
|
+
thresh = thresh.astype(np.uint8)
|
85
|
+
result[result>155]=255
|
86
|
+
result[result<=155]=0
|
87
|
+
|
88
|
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
89
|
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
90
|
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
91
|
+
absY = cv2.convertScaleAbs(y)
|
92
|
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
93
|
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
94
|
+
return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
|
95
|
+
|
96
|
+
def dewarping(model,im_path):
|
97
|
+
INPUT_SIZE=256
|
98
|
+
im_org = cv2.imread(im_path)
|
99
|
+
im_masked, prompt_org = dewarp_prompt(im_org.copy())
|
100
|
+
|
101
|
+
h,w = im_masked.shape[:2]
|
102
|
+
im_masked = im_masked.copy()
|
103
|
+
im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
|
104
|
+
im_masked = im_masked / 255.0
|
105
|
+
im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
|
106
|
+
im_masked = im_masked.float().to(DEVICE)
|
107
|
+
|
108
|
+
prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
|
109
|
+
prompt = prompt.float().to(DEVICE)
|
110
|
+
|
111
|
+
in_im = torch.cat((im_masked,prompt),dim=1)
|
112
|
+
|
113
|
+
# inference
|
114
|
+
base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
|
115
|
+
model = model.float()
|
116
|
+
with torch.no_grad():
|
117
|
+
pred = model(in_im)
|
118
|
+
pred = pred[0][:2].permute(1,2,0).cpu().numpy()
|
119
|
+
pred = pred+base_coord
|
120
|
+
## smooth
|
121
|
+
for i in range(15):
|
122
|
+
pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
|
123
|
+
pred = cv2.resize(pred,(w,h))*(w,h)
|
124
|
+
pred = pred.astype(np.float32)
|
125
|
+
out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
|
126
|
+
|
127
|
+
prompt_org = (prompt_org*255).astype(np.uint8)
|
128
|
+
prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
|
129
|
+
|
130
|
+
return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
|
131
|
+
|
132
|
+
def appearance(model,im_path):
|
133
|
+
MAX_SIZE=1600
|
134
|
+
# obtain im and prompt
|
135
|
+
im_org = cv2.imread(im_path)
|
136
|
+
h,w = im_org.shape[:2]
|
137
|
+
prompt = appearance_prompt(im_org)
|
138
|
+
in_im = np.concatenate((im_org,prompt),-1)
|
139
|
+
|
140
|
+
# constrain the max resolution
|
141
|
+
if max(w,h) < MAX_SIZE:
|
142
|
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
143
|
+
else:
|
144
|
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
145
|
+
|
146
|
+
# normalize
|
147
|
+
in_im = in_im / 255.0
|
148
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
149
|
+
|
150
|
+
# inference
|
151
|
+
in_im = in_im.half().to(DEVICE)
|
152
|
+
model = model.half()
|
153
|
+
with torch.no_grad():
|
154
|
+
pred = model(in_im)
|
155
|
+
pred = torch.clamp(pred,0,1)
|
156
|
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
157
|
+
pred = (pred*255).astype(np.uint8)
|
158
|
+
|
159
|
+
if max(w,h) < MAX_SIZE:
|
160
|
+
out_im = pred[padding_h:,padding_w:]
|
161
|
+
else:
|
162
|
+
pred[pred==0] = 1
|
163
|
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
164
|
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
165
|
+
shadow_map[shadow_map==0]=0.00001
|
166
|
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
167
|
+
|
168
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
169
|
+
|
170
|
+
|
171
|
+
def deshadowing(model,im_path):
|
172
|
+
MAX_SIZE=1600
|
173
|
+
# obtain im and prompt
|
174
|
+
im_org = cv2.imread(im_path)
|
175
|
+
h,w = im_org.shape[:2]
|
176
|
+
prompt = deshadow_prompt(im_org)
|
177
|
+
in_im = np.concatenate((im_org,prompt),-1)
|
178
|
+
|
179
|
+
# constrain the max resolution
|
180
|
+
if max(w,h) < MAX_SIZE:
|
181
|
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
182
|
+
else:
|
183
|
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
184
|
+
|
185
|
+
# normalize
|
186
|
+
in_im = in_im / 255.0
|
187
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
188
|
+
|
189
|
+
# inference
|
190
|
+
in_im = in_im.half().to(DEVICE)
|
191
|
+
model = model.half()
|
192
|
+
with torch.no_grad():
|
193
|
+
pred = model(in_im)
|
194
|
+
pred = torch.clamp(pred,0,1)
|
195
|
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
196
|
+
pred = (pred*255).astype(np.uint8)
|
197
|
+
|
198
|
+
if max(w,h) < MAX_SIZE:
|
199
|
+
out_im = pred[padding_h:,padding_w:]
|
200
|
+
else:
|
201
|
+
pred[pred==0]=1
|
202
|
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
203
|
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
204
|
+
shadow_map[shadow_map==0]=0.00001
|
205
|
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
206
|
+
|
207
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
208
|
+
|
209
|
+
|
210
|
+
def deblurring(model,im_path):
|
211
|
+
# setup image
|
212
|
+
im_org = cv2.imread(im_path)
|
213
|
+
in_im,padding_h,padding_w = stride_integral(im_org,8)
|
214
|
+
prompt = deblur_prompt(in_im)
|
215
|
+
in_im = np.concatenate((in_im,prompt),-1)
|
216
|
+
in_im = in_im / 255.0
|
217
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
218
|
+
in_im = in_im.half().to(DEVICE)
|
219
|
+
# inference
|
220
|
+
model.to(DEVICE)
|
221
|
+
model.eval()
|
222
|
+
model = model.half()
|
223
|
+
with torch.no_grad():
|
224
|
+
pred = model(in_im)
|
225
|
+
pred = torch.clamp(pred,0,1)
|
226
|
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
227
|
+
pred = (pred*255).astype(np.uint8)
|
228
|
+
out_im = pred[padding_h:,padding_w:]
|
229
|
+
|
230
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
231
|
+
|
232
|
+
|
233
|
+
|
234
|
+
def binarization(model,im_path):
|
235
|
+
im_org = cv2.imread(im_path)
|
236
|
+
im,padding_h,padding_w = stride_integral(im_org,8)
|
237
|
+
prompt = binarization_promptv2(im)
|
238
|
+
h,w = im.shape[:2]
|
239
|
+
in_im = np.concatenate((im,prompt),-1)
|
240
|
+
|
241
|
+
in_im = in_im / 255.0
|
242
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
243
|
+
in_im = in_im.to(DEVICE)
|
244
|
+
model = model.half()
|
245
|
+
in_im = in_im.half()
|
246
|
+
with torch.no_grad():
|
247
|
+
pred = model(in_im)
|
248
|
+
pred = pred[:,:2,:,:]
|
249
|
+
pred = torch.max(torch.softmax(pred,1),1)[1]
|
250
|
+
pred = pred[0].cpu().numpy()
|
251
|
+
pred = (pred*255).astype(np.uint8)
|
252
|
+
pred = cv2.resize(pred,(w,h))
|
253
|
+
out_im = pred[padding_h:,padding_w:]
|
254
|
+
|
255
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
256
|
+
|
257
|
+
def get_args():
|
258
|
+
parser = argparse.ArgumentParser(description='Params')
|
259
|
+
parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
|
260
|
+
parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/',
|
261
|
+
help='Path of input document image')
|
262
|
+
parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/',
|
263
|
+
help='Folder of the output images')
|
264
|
+
parser.add_argument('--task', nargs='?', type=str, default='dewarping',
|
265
|
+
help='task that need to be executed')
|
266
|
+
parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0,
|
267
|
+
help='Width of the input image')
|
268
|
+
args = parser.parse_args()
|
269
|
+
possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
|
270
|
+
assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
|
271
|
+
return args
|
272
|
+
|
273
|
+
def model_init(args):
|
274
|
+
# prepare model
|
275
|
+
model = restormer_arch.Restormer(
|
276
|
+
inp_channels=6,
|
277
|
+
out_channels=3,
|
278
|
+
dim = 48,
|
279
|
+
num_blocks = [2,3,3,4],
|
280
|
+
num_refinement_blocks = 4,
|
281
|
+
heads = [1,2,4,8],
|
282
|
+
ffn_expansion_factor = 2.66,
|
283
|
+
bias = False,
|
284
|
+
LayerNorm_type = 'WithBias',
|
285
|
+
dual_pixel_task = True
|
286
|
+
)
|
287
|
+
|
288
|
+
if DEVICE.type == 'cpu':
|
289
|
+
state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
|
290
|
+
else:
|
291
|
+
state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
|
292
|
+
model.load_state_dict(state)
|
293
|
+
|
294
|
+
model.eval()
|
295
|
+
model = model.to(DEVICE)
|
296
|
+
return model
|
297
|
+
|
298
|
+
def inference_one_im(model,im_path,task):
|
299
|
+
if task=='dewarping':
|
300
|
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
301
|
+
elif task=='deshadowing':
|
302
|
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
|
303
|
+
elif task=='appearance':
|
304
|
+
prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
|
305
|
+
elif task=='deblurring':
|
306
|
+
prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
|
307
|
+
elif task=='binarization':
|
308
|
+
prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
|
309
|
+
elif task=='end2end':
|
310
|
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
311
|
+
cv2.imwrite('restorted/step1.jpg',restorted)
|
312
|
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg')
|
313
|
+
cv2.imwrite('restorted/step2.jpg',restorted)
|
314
|
+
prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
|
315
|
+
# os.remove('restorted/step1.jpg')
|
316
|
+
# os.remove('restorted/step2.jpg')
|
317
|
+
|
318
|
+
return prompt1,prompt2,prompt3,restorted
|
319
|
+
|
320
|
+
|
321
|
+
def save_results(
|
322
|
+
img_path: str,
|
323
|
+
out_folder: str,
|
324
|
+
task: str,
|
325
|
+
save_dtsprompt: bool,
|
326
|
+
):
|
327
|
+
im_name = os.path.split(img_path)[-1]
|
328
|
+
im_format = '.'+im_name.split('.')[-1]
|
329
|
+
save_path = os.path.join(out_folder, im_name.replace(im_format, '_' + task + im_format))
|
330
|
+
cv2.imwrite(save_path, restorted)
|
331
|
+
if save_dtsprompt:
|
332
|
+
cv2.imwrite(save_path.replace(im_format, '_prompt1' + im_format), prompt1)
|
333
|
+
cv2.imwrite(save_path.replace(im_format, '_prompt2' + im_format), prompt2)
|
334
|
+
cv2.imwrite(save_path.replace(im_format, '_prompt3' + im_format), prompt3)
|
335
|
+
|
336
|
+
|
337
|
+
if __name__ == '__main__':
|
338
|
+
|
339
|
+
## model init
|
340
|
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
341
|
+
args = get_args()
|
342
|
+
model = model_init(args)
|
343
|
+
|
344
|
+
img_source = args.im_path
|
345
|
+
|
346
|
+
if Path(img_source).is_dir():
|
347
|
+
img_paths = glob.glob(os.path.join(img_source, '*'))
|
348
|
+
for img_path in img_paths:
|
349
|
+
## inference
|
350
|
+
prompt1,prompt2,prompt3,restorted = inference_one_im(model,img_path,args.task)
|
351
|
+
|
352
|
+
## results saving
|
353
|
+
save_results(
|
354
|
+
img_path=img_path,
|
355
|
+
out_folder=args.out_folder,
|
356
|
+
task=args.task,
|
357
|
+
save_dtsprompt=args.save_dtsprompt,
|
358
|
+
)
|
359
|
+
|
360
|
+
else:
|
361
|
+
## inference
|
362
|
+
prompt1,prompt2,prompt3,restorted = inference_one_im(model,img_source,args.task)
|
363
|
+
|
364
|
+
## results saving
|
365
|
+
save_results(
|
366
|
+
img_path=img_source,
|
367
|
+
out_folder=args.out_folder,
|
368
|
+
task=args.task,
|
369
|
+
save_dtsprompt=args.save_dtsprompt,
|
370
|
+
)
|