doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +170 -9
- doctra/cli/utils.py +2 -3
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +561 -0
- doctra/engines/vlm/outlines_types.py +13 -9
- doctra/engines/vlm/service.py +4 -2
- doctra/exporters/excel_writer.py +89 -0
- doctra/parsers/enhanced_pdf_parser.py +374 -0
- doctra/parsers/structured_pdf_parser.py +6 -0
- doctra/parsers/table_chart_extractor.py +6 -0
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +8 -14
- doctra/utils/structured_utils.py +5 -2
- doctra/version.py +1 -1
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
- doctra-0.4.1.dist-info/RECORD +67 -0
- doctra-0.3.3.dist-info/RECORD +0 -44
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
|
|
1
|
+
import os
|
2
|
+
import cv2
|
3
|
+
import glob
|
4
|
+
from pathlib import Path
|
5
|
+
import utils
|
6
|
+
import argparse
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from utils import convert_state_dict
|
12
|
+
from models import restormer_arch
|
13
|
+
from data.preprocess.crop_merge_image import stride_integral
|
14
|
+
|
15
|
+
os.sys.path.append('./data/MBD/')
|
16
|
+
from data.MBD.infer import net1_net2_infer_single_im
|
17
|
+
|
18
|
+
|
19
|
+
def dewarp_prompt(img):
|
20
|
+
mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
|
21
|
+
base_coord = utils.getBasecoord(256,256)/256
|
22
|
+
img[mask==0]=0
|
23
|
+
mask = cv2.resize(mask,(256,256))/255
|
24
|
+
return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
|
25
|
+
|
26
|
+
def deshadow_prompt(img):
|
27
|
+
h,w = img.shape[:2]
|
28
|
+
# img = cv2.resize(img,(128,128))
|
29
|
+
img = cv2.resize(img,(1024,1024))
|
30
|
+
rgb_planes = cv2.split(img)
|
31
|
+
result_planes = []
|
32
|
+
result_norm_planes = []
|
33
|
+
bg_imgs = []
|
34
|
+
for plane in rgb_planes:
|
35
|
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
36
|
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
37
|
+
bg_imgs.append(bg_img)
|
38
|
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
39
|
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
40
|
+
result_planes.append(diff_img)
|
41
|
+
result_norm_planes.append(norm_img)
|
42
|
+
bg_imgs = cv2.merge(bg_imgs)
|
43
|
+
bg_imgs = cv2.resize(bg_imgs,(w,h))
|
44
|
+
# result = cv2.merge(result_planes)
|
45
|
+
result_norm = cv2.merge(result_norm_planes)
|
46
|
+
result_norm[result_norm==0]=1
|
47
|
+
shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
|
48
|
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
49
|
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
|
50
|
+
shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
|
51
|
+
# return shadow_map
|
52
|
+
return bg_imgs
|
53
|
+
|
54
|
+
def deblur_prompt(img):
|
55
|
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
56
|
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
57
|
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
58
|
+
absY = cv2.convertScaleAbs(y)
|
59
|
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
60
|
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
61
|
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
|
62
|
+
return high_frequency
|
63
|
+
|
64
|
+
def appearance_prompt(img):
|
65
|
+
h,w = img.shape[:2]
|
66
|
+
# img = cv2.resize(img,(128,128))
|
67
|
+
img = cv2.resize(img,(1024,1024))
|
68
|
+
rgb_planes = cv2.split(img)
|
69
|
+
result_planes = []
|
70
|
+
result_norm_planes = []
|
71
|
+
for plane in rgb_planes:
|
72
|
+
dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
|
73
|
+
bg_img = cv2.medianBlur(dilated_img, 21)
|
74
|
+
diff_img = 255 - cv2.absdiff(plane, bg_img)
|
75
|
+
norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
|
76
|
+
result_planes.append(diff_img)
|
77
|
+
result_norm_planes.append(norm_img)
|
78
|
+
result_norm = cv2.merge(result_norm_planes)
|
79
|
+
result_norm = cv2.resize(result_norm,(w,h))
|
80
|
+
return result_norm
|
81
|
+
|
82
|
+
def binarization_promptv2(img):
|
83
|
+
result,thresh = utils.SauvolaModBinarization(img)
|
84
|
+
thresh = thresh.astype(np.uint8)
|
85
|
+
result[result>155]=255
|
86
|
+
result[result<=155]=0
|
87
|
+
|
88
|
+
x = cv2.Sobel(img,cv2.CV_16S,1,0)
|
89
|
+
y = cv2.Sobel(img,cv2.CV_16S,0,1)
|
90
|
+
absX = cv2.convertScaleAbs(x) # 转回uint8
|
91
|
+
absY = cv2.convertScaleAbs(y)
|
92
|
+
high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
|
93
|
+
high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
|
94
|
+
return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
|
95
|
+
|
96
|
+
def dewarping(model,im_path):
|
97
|
+
INPUT_SIZE=256
|
98
|
+
im_org = cv2.imread(im_path)
|
99
|
+
im_masked, prompt_org = dewarp_prompt(im_org.copy())
|
100
|
+
|
101
|
+
h,w = im_masked.shape[:2]
|
102
|
+
im_masked = im_masked.copy()
|
103
|
+
im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
|
104
|
+
im_masked = im_masked / 255.0
|
105
|
+
im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
|
106
|
+
im_masked = im_masked.float().to(DEVICE)
|
107
|
+
|
108
|
+
prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
|
109
|
+
prompt = prompt.float().to(DEVICE)
|
110
|
+
|
111
|
+
in_im = torch.cat((im_masked,prompt),dim=1)
|
112
|
+
|
113
|
+
# inference
|
114
|
+
base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
|
115
|
+
model = model.float()
|
116
|
+
with torch.no_grad():
|
117
|
+
pred = model(in_im)
|
118
|
+
pred = pred[0][:2].permute(1,2,0).cpu().numpy()
|
119
|
+
pred = pred+base_coord
|
120
|
+
## smooth
|
121
|
+
for i in range(15):
|
122
|
+
pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
|
123
|
+
pred = cv2.resize(pred,(w,h))*(w,h)
|
124
|
+
pred = pred.astype(np.float32)
|
125
|
+
out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
|
126
|
+
|
127
|
+
prompt_org = (prompt_org*255).astype(np.uint8)
|
128
|
+
prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
|
129
|
+
|
130
|
+
return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
|
131
|
+
|
132
|
+
def appearance(model,im_path):
|
133
|
+
MAX_SIZE=1600
|
134
|
+
# obtain im and prompt
|
135
|
+
im_org = cv2.imread(im_path)
|
136
|
+
h,w = im_org.shape[:2]
|
137
|
+
prompt = appearance_prompt(im_org)
|
138
|
+
in_im = np.concatenate((im_org,prompt),-1)
|
139
|
+
|
140
|
+
# constrain the max resolution
|
141
|
+
if max(w,h) < MAX_SIZE:
|
142
|
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
143
|
+
else:
|
144
|
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
145
|
+
|
146
|
+
# normalize
|
147
|
+
in_im = in_im / 255.0
|
148
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
149
|
+
|
150
|
+
# inference
|
151
|
+
in_im = in_im.half().to(DEVICE)
|
152
|
+
model = model.half()
|
153
|
+
with torch.no_grad():
|
154
|
+
pred = model(in_im)
|
155
|
+
pred = torch.clamp(pred,0,1)
|
156
|
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
157
|
+
pred = (pred*255).astype(np.uint8)
|
158
|
+
|
159
|
+
if max(w,h) < MAX_SIZE:
|
160
|
+
out_im = pred[padding_h:,padding_w:]
|
161
|
+
else:
|
162
|
+
pred[pred==0] = 1
|
163
|
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
164
|
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
165
|
+
shadow_map[shadow_map==0]=0.00001
|
166
|
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
167
|
+
|
168
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
169
|
+
|
170
|
+
|
171
|
+
def deshadowing(model,im_path):
|
172
|
+
MAX_SIZE=1600
|
173
|
+
# obtain im and prompt
|
174
|
+
im_org = cv2.imread(im_path)
|
175
|
+
h,w = im_org.shape[:2]
|
176
|
+
prompt = deshadow_prompt(im_org)
|
177
|
+
in_im = np.concatenate((im_org,prompt),-1)
|
178
|
+
|
179
|
+
# constrain the max resolution
|
180
|
+
if max(w,h) < MAX_SIZE:
|
181
|
+
in_im,padding_h,padding_w = stride_integral(in_im,8)
|
182
|
+
else:
|
183
|
+
in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
|
184
|
+
|
185
|
+
# normalize
|
186
|
+
in_im = in_im / 255.0
|
187
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
188
|
+
|
189
|
+
# inference
|
190
|
+
in_im = in_im.half().to(DEVICE)
|
191
|
+
model = model.half()
|
192
|
+
with torch.no_grad():
|
193
|
+
pred = model(in_im)
|
194
|
+
pred = torch.clamp(pred,0,1)
|
195
|
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
196
|
+
pred = (pred*255).astype(np.uint8)
|
197
|
+
|
198
|
+
if max(w,h) < MAX_SIZE:
|
199
|
+
out_im = pred[padding_h:,padding_w:]
|
200
|
+
else:
|
201
|
+
pred[pred==0]=1
|
202
|
+
shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
|
203
|
+
shadow_map = cv2.resize(shadow_map,(w,h))
|
204
|
+
shadow_map[shadow_map==0]=0.00001
|
205
|
+
out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
|
206
|
+
|
207
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
208
|
+
|
209
|
+
|
210
|
+
def deblurring(model,im_path):
|
211
|
+
# setup image
|
212
|
+
im_org = cv2.imread(im_path)
|
213
|
+
in_im,padding_h,padding_w = stride_integral(im_org,8)
|
214
|
+
prompt = deblur_prompt(in_im)
|
215
|
+
in_im = np.concatenate((in_im,prompt),-1)
|
216
|
+
in_im = in_im / 255.0
|
217
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
218
|
+
in_im = in_im.half().to(DEVICE)
|
219
|
+
# inference
|
220
|
+
model.to(DEVICE)
|
221
|
+
model.eval()
|
222
|
+
model = model.half()
|
223
|
+
with torch.no_grad():
|
224
|
+
pred = model(in_im)
|
225
|
+
pred = torch.clamp(pred,0,1)
|
226
|
+
pred = pred[0].permute(1,2,0).cpu().numpy()
|
227
|
+
pred = (pred*255).astype(np.uint8)
|
228
|
+
out_im = pred[padding_h:,padding_w:]
|
229
|
+
|
230
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
231
|
+
|
232
|
+
|
233
|
+
|
234
|
+
def binarization(model,im_path):
|
235
|
+
im_org = cv2.imread(im_path)
|
236
|
+
im,padding_h,padding_w = stride_integral(im_org,8)
|
237
|
+
prompt = binarization_promptv2(im)
|
238
|
+
h,w = im.shape[:2]
|
239
|
+
in_im = np.concatenate((im,prompt),-1)
|
240
|
+
|
241
|
+
in_im = in_im / 255.0
|
242
|
+
in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
|
243
|
+
in_im = in_im.to(DEVICE)
|
244
|
+
model = model.half()
|
245
|
+
in_im = in_im.half()
|
246
|
+
with torch.no_grad():
|
247
|
+
pred = model(in_im)
|
248
|
+
pred = pred[:,:2,:,:]
|
249
|
+
pred = torch.max(torch.softmax(pred,1),1)[1]
|
250
|
+
pred = pred[0].cpu().numpy()
|
251
|
+
pred = (pred*255).astype(np.uint8)
|
252
|
+
pred = cv2.resize(pred,(w,h))
|
253
|
+
out_im = pred[padding_h:,padding_w:]
|
254
|
+
|
255
|
+
return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
|
256
|
+
|
257
|
+
def get_args():
|
258
|
+
parser = argparse.ArgumentParser(description='Params')
|
259
|
+
parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
|
260
|
+
parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/',
|
261
|
+
help='Path of input document image')
|
262
|
+
parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/',
|
263
|
+
help='Folder of the output images')
|
264
|
+
parser.add_argument('--task', nargs='?', type=str, default='dewarping',
|
265
|
+
help='task that need to be executed')
|
266
|
+
parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0,
|
267
|
+
help='Width of the input image')
|
268
|
+
args = parser.parse_args()
|
269
|
+
possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
|
270
|
+
assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
|
271
|
+
return args
|
272
|
+
|
273
|
+
def model_init(args):
|
274
|
+
# prepare model
|
275
|
+
model = restormer_arch.Restormer(
|
276
|
+
inp_channels=6,
|
277
|
+
out_channels=3,
|
278
|
+
dim = 48,
|
279
|
+
num_blocks = [2,3,3,4],
|
280
|
+
num_refinement_blocks = 4,
|
281
|
+
heads = [1,2,4,8],
|
282
|
+
ffn_expansion_factor = 2.66,
|
283
|
+
bias = False,
|
284
|
+
LayerNorm_type = 'WithBias',
|
285
|
+
dual_pixel_task = True
|
286
|
+
)
|
287
|
+
|
288
|
+
if DEVICE.type == 'cpu':
|
289
|
+
state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
|
290
|
+
else:
|
291
|
+
state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
|
292
|
+
model.load_state_dict(state)
|
293
|
+
|
294
|
+
model.eval()
|
295
|
+
model = model.to(DEVICE)
|
296
|
+
return model
|
297
|
+
|
298
|
+
def inference_one_im(model,im_path,task):
|
299
|
+
if task=='dewarping':
|
300
|
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
301
|
+
elif task=='deshadowing':
|
302
|
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
|
303
|
+
elif task=='appearance':
|
304
|
+
prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
|
305
|
+
elif task=='deblurring':
|
306
|
+
prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
|
307
|
+
elif task=='binarization':
|
308
|
+
prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
|
309
|
+
elif task=='end2end':
|
310
|
+
prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
|
311
|
+
cv2.imwrite('restorted/step1.jpg',restorted)
|
312
|
+
prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg')
|
313
|
+
cv2.imwrite('restorted/step2.jpg',restorted)
|
314
|
+
prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
|
315
|
+
# os.remove('restorted/step1.jpg')
|
316
|
+
# os.remove('restorted/step2.jpg')
|
317
|
+
|
318
|
+
return prompt1,prompt2,prompt3,restorted
|
319
|
+
|
320
|
+
|
321
|
+
def save_results(
|
322
|
+
img_path: str,
|
323
|
+
out_folder: str,
|
324
|
+
task: str,
|
325
|
+
save_dtsprompt: bool,
|
326
|
+
):
|
327
|
+
im_name = os.path.split(img_path)[-1]
|
328
|
+
im_format = '.'+im_name.split('.')[-1]
|
329
|
+
save_path = os.path.join(out_folder, im_name.replace(im_format, '_' + task + im_format))
|
330
|
+
cv2.imwrite(save_path, restorted)
|
331
|
+
if save_dtsprompt:
|
332
|
+
cv2.imwrite(save_path.replace(im_format, '_prompt1' + im_format), prompt1)
|
333
|
+
cv2.imwrite(save_path.replace(im_format, '_prompt2' + im_format), prompt2)
|
334
|
+
cv2.imwrite(save_path.replace(im_format, '_prompt3' + im_format), prompt3)
|
335
|
+
|
336
|
+
|
337
|
+
if __name__ == '__main__':
|
338
|
+
|
339
|
+
## model init
|
340
|
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
341
|
+
args = get_args()
|
342
|
+
model = model_init(args)
|
343
|
+
|
344
|
+
img_source = args.im_path
|
345
|
+
|
346
|
+
if Path(img_source).is_dir():
|
347
|
+
img_paths = glob.glob(os.path.join(img_source, '*'))
|
348
|
+
for img_path in img_paths:
|
349
|
+
## inference
|
350
|
+
prompt1,prompt2,prompt3,restorted = inference_one_im(model,img_path,args.task)
|
351
|
+
|
352
|
+
## results saving
|
353
|
+
save_results(
|
354
|
+
img_path=img_path,
|
355
|
+
out_folder=args.out_folder,
|
356
|
+
task=args.task,
|
357
|
+
save_dtsprompt=args.save_dtsprompt,
|
358
|
+
)
|
359
|
+
|
360
|
+
else:
|
361
|
+
## inference
|
362
|
+
prompt1,prompt2,prompt3,restorted = inference_one_im(model,img_source,args.task)
|
363
|
+
|
364
|
+
## results saving
|
365
|
+
save_results(
|
366
|
+
img_path=img_source,
|
367
|
+
out_folder=args.out_folder,
|
368
|
+
task=args.task,
|
369
|
+
save_dtsprompt=args.save_dtsprompt,
|
370
|
+
)
|
@@ -0,0 +1,308 @@
|
|
1
|
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
2
|
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
3
|
+
## https://arxiv.org/abs/2111.09881
|
4
|
+
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from pdb import set_trace as stx
|
10
|
+
import numbers
|
11
|
+
|
12
|
+
from einops import rearrange
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
##########################################################################
|
17
|
+
## Layer Norm
|
18
|
+
|
19
|
+
def to_3d(x):
|
20
|
+
return rearrange(x, 'b c h w -> b (h w) c')
|
21
|
+
|
22
|
+
def to_4d(x,h,w):
|
23
|
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
24
|
+
|
25
|
+
class BiasFree_LayerNorm(nn.Module):
|
26
|
+
def __init__(self, normalized_shape):
|
27
|
+
super(BiasFree_LayerNorm, self).__init__()
|
28
|
+
if isinstance(normalized_shape, numbers.Integral):
|
29
|
+
normalized_shape = (normalized_shape,)
|
30
|
+
normalized_shape = torch.Size(normalized_shape)
|
31
|
+
|
32
|
+
assert len(normalized_shape) == 1
|
33
|
+
|
34
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
35
|
+
self.normalized_shape = normalized_shape
|
36
|
+
|
37
|
+
def forward(self, x):
|
38
|
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
39
|
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
40
|
+
|
41
|
+
class WithBias_LayerNorm(nn.Module):
|
42
|
+
def __init__(self, normalized_shape):
|
43
|
+
super(WithBias_LayerNorm, self).__init__()
|
44
|
+
if isinstance(normalized_shape, numbers.Integral):
|
45
|
+
normalized_shape = (normalized_shape,)
|
46
|
+
normalized_shape = torch.Size(normalized_shape)
|
47
|
+
|
48
|
+
assert len(normalized_shape) == 1
|
49
|
+
|
50
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
51
|
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
52
|
+
self.normalized_shape = normalized_shape
|
53
|
+
|
54
|
+
def forward(self, x):
|
55
|
+
mu = x.mean(-1, keepdim=True)
|
56
|
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
57
|
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
58
|
+
|
59
|
+
|
60
|
+
class LayerNorm(nn.Module):
|
61
|
+
def __init__(self, dim, LayerNorm_type):
|
62
|
+
super(LayerNorm, self).__init__()
|
63
|
+
if LayerNorm_type =='BiasFree':
|
64
|
+
self.body = BiasFree_LayerNorm(dim)
|
65
|
+
else:
|
66
|
+
self.body = WithBias_LayerNorm(dim)
|
67
|
+
|
68
|
+
def forward(self, x):
|
69
|
+
h, w = x.shape[-2:]
|
70
|
+
return to_4d(self.body(to_3d(x)), h, w)
|
71
|
+
|
72
|
+
|
73
|
+
|
74
|
+
##########################################################################
|
75
|
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
76
|
+
class FeedForward(nn.Module):
|
77
|
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
78
|
+
super(FeedForward, self).__init__()
|
79
|
+
|
80
|
+
hidden_features = int(dim*ffn_expansion_factor)
|
81
|
+
|
82
|
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
83
|
+
|
84
|
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
85
|
+
|
86
|
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
87
|
+
|
88
|
+
def forward(self, x):
|
89
|
+
x = self.project_in(x)
|
90
|
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
91
|
+
x = F.gelu(x1) * x2
|
92
|
+
x = self.project_out(x)
|
93
|
+
return x
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
##########################################################################
|
98
|
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
99
|
+
class Attention(nn.Module):
|
100
|
+
def __init__(self, dim, num_heads, bias):
|
101
|
+
super(Attention, self).__init__()
|
102
|
+
self.num_heads = num_heads
|
103
|
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
104
|
+
|
105
|
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
106
|
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
107
|
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
108
|
+
|
109
|
+
|
110
|
+
|
111
|
+
def forward(self, x):
|
112
|
+
b,c,h,w = x.shape
|
113
|
+
|
114
|
+
qkv = self.qkv_dwconv(self.qkv(x))
|
115
|
+
q,k,v = qkv.chunk(3, dim=1)
|
116
|
+
|
117
|
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
118
|
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
119
|
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
120
|
+
|
121
|
+
q = torch.nn.functional.normalize(q, dim=-1)
|
122
|
+
k = torch.nn.functional.normalize(k, dim=-1)
|
123
|
+
|
124
|
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
125
|
+
attn = attn.softmax(dim=-1)
|
126
|
+
|
127
|
+
out = (attn @ v)
|
128
|
+
|
129
|
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
130
|
+
|
131
|
+
out = self.project_out(out)
|
132
|
+
return out
|
133
|
+
|
134
|
+
|
135
|
+
|
136
|
+
##########################################################################
|
137
|
+
class TransformerBlock(nn.Module):
|
138
|
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
139
|
+
super(TransformerBlock, self).__init__()
|
140
|
+
|
141
|
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
142
|
+
self.attn = Attention(dim, num_heads, bias)
|
143
|
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
144
|
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
145
|
+
|
146
|
+
def forward(self, x):
|
147
|
+
x = x + self.attn(self.norm1(x))
|
148
|
+
x = x + self.ffn(self.norm2(x))
|
149
|
+
|
150
|
+
return x
|
151
|
+
|
152
|
+
|
153
|
+
|
154
|
+
##########################################################################
|
155
|
+
## Overlapped image patch embedding with 3x3 Conv
|
156
|
+
class OverlapPatchEmbed(nn.Module):
|
157
|
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
158
|
+
super(OverlapPatchEmbed, self).__init__()
|
159
|
+
|
160
|
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
161
|
+
|
162
|
+
def forward(self, x):
|
163
|
+
x = self.proj(x)
|
164
|
+
|
165
|
+
return x
|
166
|
+
|
167
|
+
|
168
|
+
|
169
|
+
##########################################################################
|
170
|
+
## Resizing modules
|
171
|
+
class Downsample(nn.Module):
|
172
|
+
def __init__(self, n_feat):
|
173
|
+
super(Downsample, self).__init__()
|
174
|
+
|
175
|
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
176
|
+
nn.PixelUnshuffle(2))
|
177
|
+
|
178
|
+
def forward(self, x):
|
179
|
+
return self.body(x)
|
180
|
+
|
181
|
+
class Upsample(nn.Module):
|
182
|
+
def __init__(self, n_feat):
|
183
|
+
super(Upsample, self).__init__()
|
184
|
+
|
185
|
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
186
|
+
nn.PixelShuffle(2))
|
187
|
+
|
188
|
+
def forward(self, x):
|
189
|
+
return self.body(x)
|
190
|
+
|
191
|
+
##########################################################################
|
192
|
+
##---------- Restormer -----------------------
|
193
|
+
class Restormer(nn.Module):
|
194
|
+
def __init__(self,
|
195
|
+
inp_channels=3,
|
196
|
+
out_channels=3,
|
197
|
+
dim = 48,
|
198
|
+
num_blocks = [4,6,6,8],
|
199
|
+
num_refinement_blocks = 4,
|
200
|
+
heads = [1,2,4,8],
|
201
|
+
ffn_expansion_factor = 2.66,
|
202
|
+
bias = False,
|
203
|
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
204
|
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
205
|
+
):
|
206
|
+
|
207
|
+
super(Restormer, self).__init__()
|
208
|
+
|
209
|
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
210
|
+
|
211
|
+
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
212
|
+
|
213
|
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
214
|
+
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
215
|
+
|
216
|
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
217
|
+
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
218
|
+
|
219
|
+
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
|
220
|
+
self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
|
221
|
+
|
222
|
+
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
|
223
|
+
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
|
224
|
+
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
225
|
+
|
226
|
+
|
227
|
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
228
|
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
229
|
+
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
230
|
+
|
231
|
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
232
|
+
|
233
|
+
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
234
|
+
|
235
|
+
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
236
|
+
|
237
|
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
238
|
+
self.dual_pixel_task = dual_pixel_task
|
239
|
+
if self.dual_pixel_task:
|
240
|
+
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
|
241
|
+
###########################
|
242
|
+
|
243
|
+
|
244
|
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
245
|
+
|
246
|
+
def forward(self, inp_img,task=''):
|
247
|
+
|
248
|
+
inp_enc_level1 = self.patch_embed(inp_img)
|
249
|
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
250
|
+
|
251
|
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
252
|
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
253
|
+
|
254
|
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
255
|
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
256
|
+
|
257
|
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
258
|
+
latent = self.latent(inp_enc_level4)
|
259
|
+
|
260
|
+
|
261
|
+
inp_dec_level3 = self.up4_3(latent)
|
262
|
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
263
|
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
264
|
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
265
|
+
|
266
|
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
267
|
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
268
|
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
269
|
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
270
|
+
|
271
|
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
272
|
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
273
|
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
274
|
+
|
275
|
+
out_dec_level1 = self.refinement(out_dec_level1)
|
276
|
+
|
277
|
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
278
|
+
out_dec_level1 = self.output(out_dec_level1)
|
279
|
+
|
280
|
+
return out_dec_level1
|
281
|
+
|
282
|
+
|
283
|
+
|
284
|
+
if __name__ == '__main__':
|
285
|
+
from torchtoolbox.tools import summary
|
286
|
+
model = Restormer(
|
287
|
+
inp_channels=6,
|
288
|
+
out_channels=3,
|
289
|
+
dim = 48,
|
290
|
+
# num_blocks = [4,6,6,8],
|
291
|
+
num_blocks = [2,3,3,4],
|
292
|
+
num_refinement_blocks = 4,
|
293
|
+
heads = [1,2,4,8],
|
294
|
+
ffn_expansion_factor = 2.66,
|
295
|
+
bias = False,
|
296
|
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
297
|
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
298
|
+
)
|
299
|
+
# model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
|
300
|
+
print(summary(model,torch.rand((1, 6, 256, 256))))
|
301
|
+
|
302
|
+
from thop import profile
|
303
|
+
input = torch.rand((1, 6, 256, 256))
|
304
|
+
gflops,params = profile(model,inputs=(input,))
|
305
|
+
gflops = gflops*2 / 10**9
|
306
|
+
params = params / 10**6
|
307
|
+
print(gflops,'==============')
|
308
|
+
print(params,'==============')
|