doctra 0.3.3__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +170 -9
  3. doctra/cli/utils.py +2 -3
  4. doctra/engines/image_restoration/__init__.py +10 -0
  5. doctra/engines/image_restoration/docres_engine.py +561 -0
  6. doctra/engines/vlm/outlines_types.py +13 -9
  7. doctra/engines/vlm/service.py +4 -2
  8. doctra/exporters/excel_writer.py +89 -0
  9. doctra/parsers/enhanced_pdf_parser.py +374 -0
  10. doctra/parsers/structured_pdf_parser.py +6 -0
  11. doctra/parsers/table_chart_extractor.py +6 -0
  12. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  13. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  14. doctra/third_party/docres/data/MBD/infer.py +151 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  25. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  26. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  27. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  28. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  29. doctra/third_party/docres/inference.py +370 -0
  30. doctra/third_party/docres/models/restormer_arch.py +308 -0
  31. doctra/third_party/docres/utils.py +464 -0
  32. doctra/ui/app.py +8 -14
  33. doctra/utils/structured_utils.py +5 -2
  34. doctra/version.py +1 -1
  35. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/METADATA +1 -1
  36. doctra-0.4.1.dist-info/RECORD +67 -0
  37. doctra-0.3.3.dist-info/RECORD +0 -44
  38. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/WHEEL +0 -0
  39. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/licenses/LICENSE +0 -0
  40. {doctra-0.3.3.dist-info → doctra-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
1
+ import os
2
+ import cv2
3
+ import glob
4
+ from pathlib import Path
5
+ import utils
6
+ import argparse
7
+ import numpy as np
8
+
9
+ import torch
10
+
11
+ from utils import convert_state_dict
12
+ from models import restormer_arch
13
+ from data.preprocess.crop_merge_image import stride_integral
14
+
15
+ os.sys.path.append('./data/MBD/')
16
+ from data.MBD.infer import net1_net2_infer_single_im
17
+
18
+
19
+ def dewarp_prompt(img):
20
+ mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
21
+ base_coord = utils.getBasecoord(256,256)/256
22
+ img[mask==0]=0
23
+ mask = cv2.resize(mask,(256,256))/255
24
+ return img,np.concatenate((base_coord,np.expand_dims(mask,-1)),-1)
25
+
26
+ def deshadow_prompt(img):
27
+ h,w = img.shape[:2]
28
+ # img = cv2.resize(img,(128,128))
29
+ img = cv2.resize(img,(1024,1024))
30
+ rgb_planes = cv2.split(img)
31
+ result_planes = []
32
+ result_norm_planes = []
33
+ bg_imgs = []
34
+ for plane in rgb_planes:
35
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
36
+ bg_img = cv2.medianBlur(dilated_img, 21)
37
+ bg_imgs.append(bg_img)
38
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
39
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
40
+ result_planes.append(diff_img)
41
+ result_norm_planes.append(norm_img)
42
+ bg_imgs = cv2.merge(bg_imgs)
43
+ bg_imgs = cv2.resize(bg_imgs,(w,h))
44
+ # result = cv2.merge(result_planes)
45
+ result_norm = cv2.merge(result_norm_planes)
46
+ result_norm[result_norm==0]=1
47
+ shadow_map = np.clip(img.astype(float)/result_norm.astype(float)*255,0,255).astype(np.uint8)
48
+ shadow_map = cv2.resize(shadow_map,(w,h))
49
+ shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_BGR2GRAY)
50
+ shadow_map = cv2.cvtColor(shadow_map,cv2.COLOR_GRAY2BGR)
51
+ # return shadow_map
52
+ return bg_imgs
53
+
54
+ def deblur_prompt(img):
55
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
56
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
57
+ absX = cv2.convertScaleAbs(x) # 转回uint8
58
+ absY = cv2.convertScaleAbs(y)
59
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
60
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
61
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
62
+ return high_frequency
63
+
64
+ def appearance_prompt(img):
65
+ h,w = img.shape[:2]
66
+ # img = cv2.resize(img,(128,128))
67
+ img = cv2.resize(img,(1024,1024))
68
+ rgb_planes = cv2.split(img)
69
+ result_planes = []
70
+ result_norm_planes = []
71
+ for plane in rgb_planes:
72
+ dilated_img = cv2.dilate(plane, np.ones((7,7), np.uint8))
73
+ bg_img = cv2.medianBlur(dilated_img, 21)
74
+ diff_img = 255 - cv2.absdiff(plane, bg_img)
75
+ norm_img = cv2.normalize(diff_img,None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC1)
76
+ result_planes.append(diff_img)
77
+ result_norm_planes.append(norm_img)
78
+ result_norm = cv2.merge(result_norm_planes)
79
+ result_norm = cv2.resize(result_norm,(w,h))
80
+ return result_norm
81
+
82
+ def binarization_promptv2(img):
83
+ result,thresh = utils.SauvolaModBinarization(img)
84
+ thresh = thresh.astype(np.uint8)
85
+ result[result>155]=255
86
+ result[result<=155]=0
87
+
88
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
89
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
90
+ absX = cv2.convertScaleAbs(x) # 转回uint8
91
+ absY = cv2.convertScaleAbs(y)
92
+ high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
93
+ high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
94
+ return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
95
+
96
+ def dewarping(model,im_path):
97
+ INPUT_SIZE=256
98
+ im_org = cv2.imread(im_path)
99
+ im_masked, prompt_org = dewarp_prompt(im_org.copy())
100
+
101
+ h,w = im_masked.shape[:2]
102
+ im_masked = im_masked.copy()
103
+ im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
104
+ im_masked = im_masked / 255.0
105
+ im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
106
+ im_masked = im_masked.float().to(DEVICE)
107
+
108
+ prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
109
+ prompt = prompt.float().to(DEVICE)
110
+
111
+ in_im = torch.cat((im_masked,prompt),dim=1)
112
+
113
+ # inference
114
+ base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
115
+ model = model.float()
116
+ with torch.no_grad():
117
+ pred = model(in_im)
118
+ pred = pred[0][:2].permute(1,2,0).cpu().numpy()
119
+ pred = pred+base_coord
120
+ ## smooth
121
+ for i in range(15):
122
+ pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
123
+ pred = cv2.resize(pred,(w,h))*(w,h)
124
+ pred = pred.astype(np.float32)
125
+ out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
126
+
127
+ prompt_org = (prompt_org*255).astype(np.uint8)
128
+ prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
129
+
130
+ return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
131
+
132
+ def appearance(model,im_path):
133
+ MAX_SIZE=1600
134
+ # obtain im and prompt
135
+ im_org = cv2.imread(im_path)
136
+ h,w = im_org.shape[:2]
137
+ prompt = appearance_prompt(im_org)
138
+ in_im = np.concatenate((im_org,prompt),-1)
139
+
140
+ # constrain the max resolution
141
+ if max(w,h) < MAX_SIZE:
142
+ in_im,padding_h,padding_w = stride_integral(in_im,8)
143
+ else:
144
+ in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
145
+
146
+ # normalize
147
+ in_im = in_im / 255.0
148
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
149
+
150
+ # inference
151
+ in_im = in_im.half().to(DEVICE)
152
+ model = model.half()
153
+ with torch.no_grad():
154
+ pred = model(in_im)
155
+ pred = torch.clamp(pred,0,1)
156
+ pred = pred[0].permute(1,2,0).cpu().numpy()
157
+ pred = (pred*255).astype(np.uint8)
158
+
159
+ if max(w,h) < MAX_SIZE:
160
+ out_im = pred[padding_h:,padding_w:]
161
+ else:
162
+ pred[pred==0] = 1
163
+ shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
164
+ shadow_map = cv2.resize(shadow_map,(w,h))
165
+ shadow_map[shadow_map==0]=0.00001
166
+ out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
167
+
168
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
169
+
170
+
171
+ def deshadowing(model,im_path):
172
+ MAX_SIZE=1600
173
+ # obtain im and prompt
174
+ im_org = cv2.imread(im_path)
175
+ h,w = im_org.shape[:2]
176
+ prompt = deshadow_prompt(im_org)
177
+ in_im = np.concatenate((im_org,prompt),-1)
178
+
179
+ # constrain the max resolution
180
+ if max(w,h) < MAX_SIZE:
181
+ in_im,padding_h,padding_w = stride_integral(in_im,8)
182
+ else:
183
+ in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
184
+
185
+ # normalize
186
+ in_im = in_im / 255.0
187
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
188
+
189
+ # inference
190
+ in_im = in_im.half().to(DEVICE)
191
+ model = model.half()
192
+ with torch.no_grad():
193
+ pred = model(in_im)
194
+ pred = torch.clamp(pred,0,1)
195
+ pred = pred[0].permute(1,2,0).cpu().numpy()
196
+ pred = (pred*255).astype(np.uint8)
197
+
198
+ if max(w,h) < MAX_SIZE:
199
+ out_im = pred[padding_h:,padding_w:]
200
+ else:
201
+ pred[pred==0]=1
202
+ shadow_map = cv2.resize(im_org,(MAX_SIZE,MAX_SIZE)).astype(float)/pred.astype(float)
203
+ shadow_map = cv2.resize(shadow_map,(w,h))
204
+ shadow_map[shadow_map==0]=0.00001
205
+ out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
206
+
207
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
208
+
209
+
210
+ def deblurring(model,im_path):
211
+ # setup image
212
+ im_org = cv2.imread(im_path)
213
+ in_im,padding_h,padding_w = stride_integral(im_org,8)
214
+ prompt = deblur_prompt(in_im)
215
+ in_im = np.concatenate((in_im,prompt),-1)
216
+ in_im = in_im / 255.0
217
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
218
+ in_im = in_im.half().to(DEVICE)
219
+ # inference
220
+ model.to(DEVICE)
221
+ model.eval()
222
+ model = model.half()
223
+ with torch.no_grad():
224
+ pred = model(in_im)
225
+ pred = torch.clamp(pred,0,1)
226
+ pred = pred[0].permute(1,2,0).cpu().numpy()
227
+ pred = (pred*255).astype(np.uint8)
228
+ out_im = pred[padding_h:,padding_w:]
229
+
230
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
231
+
232
+
233
+
234
+ def binarization(model,im_path):
235
+ im_org = cv2.imread(im_path)
236
+ im,padding_h,padding_w = stride_integral(im_org,8)
237
+ prompt = binarization_promptv2(im)
238
+ h,w = im.shape[:2]
239
+ in_im = np.concatenate((im,prompt),-1)
240
+
241
+ in_im = in_im / 255.0
242
+ in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
243
+ in_im = in_im.to(DEVICE)
244
+ model = model.half()
245
+ in_im = in_im.half()
246
+ with torch.no_grad():
247
+ pred = model(in_im)
248
+ pred = pred[:,:2,:,:]
249
+ pred = torch.max(torch.softmax(pred,1),1)[1]
250
+ pred = pred[0].cpu().numpy()
251
+ pred = (pred*255).astype(np.uint8)
252
+ pred = cv2.resize(pred,(w,h))
253
+ out_im = pred[padding_h:,padding_w:]
254
+
255
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
256
+
257
+ def get_args():
258
+ parser = argparse.ArgumentParser(description='Params')
259
+ parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
260
+ parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/',
261
+ help='Path of input document image')
262
+ parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/',
263
+ help='Folder of the output images')
264
+ parser.add_argument('--task', nargs='?', type=str, default='dewarping',
265
+ help='task that need to be executed')
266
+ parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0,
267
+ help='Width of the input image')
268
+ args = parser.parse_args()
269
+ possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
270
+ assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
271
+ return args
272
+
273
+ def model_init(args):
274
+ # prepare model
275
+ model = restormer_arch.Restormer(
276
+ inp_channels=6,
277
+ out_channels=3,
278
+ dim = 48,
279
+ num_blocks = [2,3,3,4],
280
+ num_refinement_blocks = 4,
281
+ heads = [1,2,4,8],
282
+ ffn_expansion_factor = 2.66,
283
+ bias = False,
284
+ LayerNorm_type = 'WithBias',
285
+ dual_pixel_task = True
286
+ )
287
+
288
+ if DEVICE.type == 'cpu':
289
+ state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
290
+ else:
291
+ state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
292
+ model.load_state_dict(state)
293
+
294
+ model.eval()
295
+ model = model.to(DEVICE)
296
+ return model
297
+
298
+ def inference_one_im(model,im_path,task):
299
+ if task=='dewarping':
300
+ prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
301
+ elif task=='deshadowing':
302
+ prompt1,prompt2,prompt3,restorted = deshadowing(model,im_path)
303
+ elif task=='appearance':
304
+ prompt1,prompt2,prompt3,restorted = appearance(model,im_path)
305
+ elif task=='deblurring':
306
+ prompt1,prompt2,prompt3,restorted = deblurring(model,im_path)
307
+ elif task=='binarization':
308
+ prompt1,prompt2,prompt3,restorted = binarization(model,im_path)
309
+ elif task=='end2end':
310
+ prompt1,prompt2,prompt3,restorted = dewarping(model,im_path)
311
+ cv2.imwrite('restorted/step1.jpg',restorted)
312
+ prompt1,prompt2,prompt3,restorted = deshadowing(model,'restorted/step1.jpg')
313
+ cv2.imwrite('restorted/step2.jpg',restorted)
314
+ prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
315
+ # os.remove('restorted/step1.jpg')
316
+ # os.remove('restorted/step2.jpg')
317
+
318
+ return prompt1,prompt2,prompt3,restorted
319
+
320
+
321
+ def save_results(
322
+ img_path: str,
323
+ out_folder: str,
324
+ task: str,
325
+ save_dtsprompt: bool,
326
+ ):
327
+ im_name = os.path.split(img_path)[-1]
328
+ im_format = '.'+im_name.split('.')[-1]
329
+ save_path = os.path.join(out_folder, im_name.replace(im_format, '_' + task + im_format))
330
+ cv2.imwrite(save_path, restorted)
331
+ if save_dtsprompt:
332
+ cv2.imwrite(save_path.replace(im_format, '_prompt1' + im_format), prompt1)
333
+ cv2.imwrite(save_path.replace(im_format, '_prompt2' + im_format), prompt2)
334
+ cv2.imwrite(save_path.replace(im_format, '_prompt3' + im_format), prompt3)
335
+
336
+
337
+ if __name__ == '__main__':
338
+
339
+ ## model init
340
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
341
+ args = get_args()
342
+ model = model_init(args)
343
+
344
+ img_source = args.im_path
345
+
346
+ if Path(img_source).is_dir():
347
+ img_paths = glob.glob(os.path.join(img_source, '*'))
348
+ for img_path in img_paths:
349
+ ## inference
350
+ prompt1,prompt2,prompt3,restorted = inference_one_im(model,img_path,args.task)
351
+
352
+ ## results saving
353
+ save_results(
354
+ img_path=img_path,
355
+ out_folder=args.out_folder,
356
+ task=args.task,
357
+ save_dtsprompt=args.save_dtsprompt,
358
+ )
359
+
360
+ else:
361
+ ## inference
362
+ prompt1,prompt2,prompt3,restorted = inference_one_im(model,img_source,args.task)
363
+
364
+ ## results saving
365
+ save_results(
366
+ img_path=img_source,
367
+ out_folder=args.out_folder,
368
+ task=args.task,
369
+ save_dtsprompt=args.save_dtsprompt,
370
+ )
@@ -0,0 +1,308 @@
1
+ ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2
+ ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3
+ ## https://arxiv.org/abs/2111.09881
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from pdb import set_trace as stx
10
+ import numbers
11
+
12
+ from einops import rearrange
13
+
14
+
15
+
16
+ ##########################################################################
17
+ ## Layer Norm
18
+
19
+ def to_3d(x):
20
+ return rearrange(x, 'b c h w -> b (h w) c')
21
+
22
+ def to_4d(x,h,w):
23
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
24
+
25
+ class BiasFree_LayerNorm(nn.Module):
26
+ def __init__(self, normalized_shape):
27
+ super(BiasFree_LayerNorm, self).__init__()
28
+ if isinstance(normalized_shape, numbers.Integral):
29
+ normalized_shape = (normalized_shape,)
30
+ normalized_shape = torch.Size(normalized_shape)
31
+
32
+ assert len(normalized_shape) == 1
33
+
34
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
35
+ self.normalized_shape = normalized_shape
36
+
37
+ def forward(self, x):
38
+ sigma = x.var(-1, keepdim=True, unbiased=False)
39
+ return x / torch.sqrt(sigma+1e-5) * self.weight
40
+
41
+ class WithBias_LayerNorm(nn.Module):
42
+ def __init__(self, normalized_shape):
43
+ super(WithBias_LayerNorm, self).__init__()
44
+ if isinstance(normalized_shape, numbers.Integral):
45
+ normalized_shape = (normalized_shape,)
46
+ normalized_shape = torch.Size(normalized_shape)
47
+
48
+ assert len(normalized_shape) == 1
49
+
50
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
51
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
52
+ self.normalized_shape = normalized_shape
53
+
54
+ def forward(self, x):
55
+ mu = x.mean(-1, keepdim=True)
56
+ sigma = x.var(-1, keepdim=True, unbiased=False)
57
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
58
+
59
+
60
+ class LayerNorm(nn.Module):
61
+ def __init__(self, dim, LayerNorm_type):
62
+ super(LayerNorm, self).__init__()
63
+ if LayerNorm_type =='BiasFree':
64
+ self.body = BiasFree_LayerNorm(dim)
65
+ else:
66
+ self.body = WithBias_LayerNorm(dim)
67
+
68
+ def forward(self, x):
69
+ h, w = x.shape[-2:]
70
+ return to_4d(self.body(to_3d(x)), h, w)
71
+
72
+
73
+
74
+ ##########################################################################
75
+ ## Gated-Dconv Feed-Forward Network (GDFN)
76
+ class FeedForward(nn.Module):
77
+ def __init__(self, dim, ffn_expansion_factor, bias):
78
+ super(FeedForward, self).__init__()
79
+
80
+ hidden_features = int(dim*ffn_expansion_factor)
81
+
82
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
83
+
84
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
85
+
86
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
87
+
88
+ def forward(self, x):
89
+ x = self.project_in(x)
90
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
91
+ x = F.gelu(x1) * x2
92
+ x = self.project_out(x)
93
+ return x
94
+
95
+
96
+
97
+ ##########################################################################
98
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
99
+ class Attention(nn.Module):
100
+ def __init__(self, dim, num_heads, bias):
101
+ super(Attention, self).__init__()
102
+ self.num_heads = num_heads
103
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104
+
105
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
106
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
107
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
+
109
+
110
+
111
+ def forward(self, x):
112
+ b,c,h,w = x.shape
113
+
114
+ qkv = self.qkv_dwconv(self.qkv(x))
115
+ q,k,v = qkv.chunk(3, dim=1)
116
+
117
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
119
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
120
+
121
+ q = torch.nn.functional.normalize(q, dim=-1)
122
+ k = torch.nn.functional.normalize(k, dim=-1)
123
+
124
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
125
+ attn = attn.softmax(dim=-1)
126
+
127
+ out = (attn @ v)
128
+
129
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
130
+
131
+ out = self.project_out(out)
132
+ return out
133
+
134
+
135
+
136
+ ##########################################################################
137
+ class TransformerBlock(nn.Module):
138
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
139
+ super(TransformerBlock, self).__init__()
140
+
141
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
142
+ self.attn = Attention(dim, num_heads, bias)
143
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
144
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
145
+
146
+ def forward(self, x):
147
+ x = x + self.attn(self.norm1(x))
148
+ x = x + self.ffn(self.norm2(x))
149
+
150
+ return x
151
+
152
+
153
+
154
+ ##########################################################################
155
+ ## Overlapped image patch embedding with 3x3 Conv
156
+ class OverlapPatchEmbed(nn.Module):
157
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
158
+ super(OverlapPatchEmbed, self).__init__()
159
+
160
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
161
+
162
+ def forward(self, x):
163
+ x = self.proj(x)
164
+
165
+ return x
166
+
167
+
168
+
169
+ ##########################################################################
170
+ ## Resizing modules
171
+ class Downsample(nn.Module):
172
+ def __init__(self, n_feat):
173
+ super(Downsample, self).__init__()
174
+
175
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
176
+ nn.PixelUnshuffle(2))
177
+
178
+ def forward(self, x):
179
+ return self.body(x)
180
+
181
+ class Upsample(nn.Module):
182
+ def __init__(self, n_feat):
183
+ super(Upsample, self).__init__()
184
+
185
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
186
+ nn.PixelShuffle(2))
187
+
188
+ def forward(self, x):
189
+ return self.body(x)
190
+
191
+ ##########################################################################
192
+ ##---------- Restormer -----------------------
193
+ class Restormer(nn.Module):
194
+ def __init__(self,
195
+ inp_channels=3,
196
+ out_channels=3,
197
+ dim = 48,
198
+ num_blocks = [4,6,6,8],
199
+ num_refinement_blocks = 4,
200
+ heads = [1,2,4,8],
201
+ ffn_expansion_factor = 2.66,
202
+ bias = False,
203
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
204
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
205
+ ):
206
+
207
+ super(Restormer, self).__init__()
208
+
209
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
210
+
211
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
212
+
213
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
214
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
215
+
216
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
217
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
218
+
219
+ self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
220
+ self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
221
+
222
+ self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
223
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
224
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
225
+
226
+
227
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
228
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
229
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
230
+
231
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
232
+
233
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
234
+
235
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
236
+
237
+ #### For Dual-Pixel Defocus Deblurring Task ####
238
+ self.dual_pixel_task = dual_pixel_task
239
+ if self.dual_pixel_task:
240
+ self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
241
+ ###########################
242
+
243
+
244
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
245
+
246
+ def forward(self, inp_img,task=''):
247
+
248
+ inp_enc_level1 = self.patch_embed(inp_img)
249
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
250
+
251
+ inp_enc_level2 = self.down1_2(out_enc_level1)
252
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
253
+
254
+ inp_enc_level3 = self.down2_3(out_enc_level2)
255
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
256
+
257
+ inp_enc_level4 = self.down3_4(out_enc_level3)
258
+ latent = self.latent(inp_enc_level4)
259
+
260
+
261
+ inp_dec_level3 = self.up4_3(latent)
262
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
263
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
264
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
265
+
266
+ inp_dec_level2 = self.up3_2(out_dec_level3)
267
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
268
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
269
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
270
+
271
+ inp_dec_level1 = self.up2_1(out_dec_level2)
272
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
273
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
274
+
275
+ out_dec_level1 = self.refinement(out_dec_level1)
276
+
277
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
278
+ out_dec_level1 = self.output(out_dec_level1)
279
+
280
+ return out_dec_level1
281
+
282
+
283
+
284
+ if __name__ == '__main__':
285
+ from torchtoolbox.tools import summary
286
+ model = Restormer(
287
+ inp_channels=6,
288
+ out_channels=3,
289
+ dim = 48,
290
+ # num_blocks = [4,6,6,8],
291
+ num_blocks = [2,3,3,4],
292
+ num_refinement_blocks = 4,
293
+ heads = [1,2,4,8],
294
+ ffn_expansion_factor = 2.66,
295
+ bias = False,
296
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
297
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
298
+ )
299
+ # model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
300
+ print(summary(model,torch.rand((1, 6, 256, 256))))
301
+
302
+ from thop import profile
303
+ input = torch.rand((1, 6, 256, 256))
304
+ gflops,params = profile(model,inputs=(input,))
305
+ gflops = gflops*2 / 10**9
306
+ params = params / 10**6
307
+ print(gflops,'==============')
308
+ print(params,'==============')