doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +168 -0
  3. doctra/engines/image_restoration/__init__.py +10 -0
  4. doctra/engines/image_restoration/docres_engine.py +566 -0
  5. doctra/engines/vlm/service.py +0 -12
  6. doctra/parsers/enhanced_pdf_parser.py +370 -0
  7. doctra/parsers/structured_pdf_parser.py +11 -60
  8. doctra/parsers/table_chart_extractor.py +8 -44
  9. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  10. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  11. doctra/third_party/docres/data/MBD/infer.py +151 -0
  12. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  13. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  14. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  25. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  26. doctra/third_party/docres/inference.py +370 -0
  27. doctra/third_party/docres/models/restormer_arch.py +308 -0
  28. doctra/third_party/docres/utils.py +464 -0
  29. doctra/ui/app.py +5 -32
  30. doctra/utils/progress.py +13 -98
  31. doctra/utils/structured_utils.py +45 -49
  32. doctra/version.py +1 -1
  33. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
  34. doctra-0.4.0.dist-info/RECORD +67 -0
  35. doctra-0.3.2.dist-info/RECORD +0 -44
  36. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
  37. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
  38. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,308 @@
1
+ ## Restormer: Efficient Transformer for High-Resolution Image Restoration
2
+ ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
3
+ ## https://arxiv.org/abs/2111.09881
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from pdb import set_trace as stx
10
+ import numbers
11
+
12
+ from einops import rearrange
13
+
14
+
15
+
16
+ ##########################################################################
17
+ ## Layer Norm
18
+
19
+ def to_3d(x):
20
+ return rearrange(x, 'b c h w -> b (h w) c')
21
+
22
+ def to_4d(x,h,w):
23
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
24
+
25
+ class BiasFree_LayerNorm(nn.Module):
26
+ def __init__(self, normalized_shape):
27
+ super(BiasFree_LayerNorm, self).__init__()
28
+ if isinstance(normalized_shape, numbers.Integral):
29
+ normalized_shape = (normalized_shape,)
30
+ normalized_shape = torch.Size(normalized_shape)
31
+
32
+ assert len(normalized_shape) == 1
33
+
34
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
35
+ self.normalized_shape = normalized_shape
36
+
37
+ def forward(self, x):
38
+ sigma = x.var(-1, keepdim=True, unbiased=False)
39
+ return x / torch.sqrt(sigma+1e-5) * self.weight
40
+
41
+ class WithBias_LayerNorm(nn.Module):
42
+ def __init__(self, normalized_shape):
43
+ super(WithBias_LayerNorm, self).__init__()
44
+ if isinstance(normalized_shape, numbers.Integral):
45
+ normalized_shape = (normalized_shape,)
46
+ normalized_shape = torch.Size(normalized_shape)
47
+
48
+ assert len(normalized_shape) == 1
49
+
50
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
51
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
52
+ self.normalized_shape = normalized_shape
53
+
54
+ def forward(self, x):
55
+ mu = x.mean(-1, keepdim=True)
56
+ sigma = x.var(-1, keepdim=True, unbiased=False)
57
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
58
+
59
+
60
+ class LayerNorm(nn.Module):
61
+ def __init__(self, dim, LayerNorm_type):
62
+ super(LayerNorm, self).__init__()
63
+ if LayerNorm_type =='BiasFree':
64
+ self.body = BiasFree_LayerNorm(dim)
65
+ else:
66
+ self.body = WithBias_LayerNorm(dim)
67
+
68
+ def forward(self, x):
69
+ h, w = x.shape[-2:]
70
+ return to_4d(self.body(to_3d(x)), h, w)
71
+
72
+
73
+
74
+ ##########################################################################
75
+ ## Gated-Dconv Feed-Forward Network (GDFN)
76
+ class FeedForward(nn.Module):
77
+ def __init__(self, dim, ffn_expansion_factor, bias):
78
+ super(FeedForward, self).__init__()
79
+
80
+ hidden_features = int(dim*ffn_expansion_factor)
81
+
82
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
83
+
84
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
85
+
86
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
87
+
88
+ def forward(self, x):
89
+ x = self.project_in(x)
90
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
91
+ x = F.gelu(x1) * x2
92
+ x = self.project_out(x)
93
+ return x
94
+
95
+
96
+
97
+ ##########################################################################
98
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
99
+ class Attention(nn.Module):
100
+ def __init__(self, dim, num_heads, bias):
101
+ super(Attention, self).__init__()
102
+ self.num_heads = num_heads
103
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104
+
105
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
106
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
107
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
+
109
+
110
+
111
+ def forward(self, x):
112
+ b,c,h,w = x.shape
113
+
114
+ qkv = self.qkv_dwconv(self.qkv(x))
115
+ q,k,v = qkv.chunk(3, dim=1)
116
+
117
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
119
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
120
+
121
+ q = torch.nn.functional.normalize(q, dim=-1)
122
+ k = torch.nn.functional.normalize(k, dim=-1)
123
+
124
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
125
+ attn = attn.softmax(dim=-1)
126
+
127
+ out = (attn @ v)
128
+
129
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
130
+
131
+ out = self.project_out(out)
132
+ return out
133
+
134
+
135
+
136
+ ##########################################################################
137
+ class TransformerBlock(nn.Module):
138
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
139
+ super(TransformerBlock, self).__init__()
140
+
141
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
142
+ self.attn = Attention(dim, num_heads, bias)
143
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
144
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
145
+
146
+ def forward(self, x):
147
+ x = x + self.attn(self.norm1(x))
148
+ x = x + self.ffn(self.norm2(x))
149
+
150
+ return x
151
+
152
+
153
+
154
+ ##########################################################################
155
+ ## Overlapped image patch embedding with 3x3 Conv
156
+ class OverlapPatchEmbed(nn.Module):
157
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
158
+ super(OverlapPatchEmbed, self).__init__()
159
+
160
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
161
+
162
+ def forward(self, x):
163
+ x = self.proj(x)
164
+
165
+ return x
166
+
167
+
168
+
169
+ ##########################################################################
170
+ ## Resizing modules
171
+ class Downsample(nn.Module):
172
+ def __init__(self, n_feat):
173
+ super(Downsample, self).__init__()
174
+
175
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
176
+ nn.PixelUnshuffle(2))
177
+
178
+ def forward(self, x):
179
+ return self.body(x)
180
+
181
+ class Upsample(nn.Module):
182
+ def __init__(self, n_feat):
183
+ super(Upsample, self).__init__()
184
+
185
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
186
+ nn.PixelShuffle(2))
187
+
188
+ def forward(self, x):
189
+ return self.body(x)
190
+
191
+ ##########################################################################
192
+ ##---------- Restormer -----------------------
193
+ class Restormer(nn.Module):
194
+ def __init__(self,
195
+ inp_channels=3,
196
+ out_channels=3,
197
+ dim = 48,
198
+ num_blocks = [4,6,6,8],
199
+ num_refinement_blocks = 4,
200
+ heads = [1,2,4,8],
201
+ ffn_expansion_factor = 2.66,
202
+ bias = False,
203
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
204
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
205
+ ):
206
+
207
+ super(Restormer, self).__init__()
208
+
209
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
210
+
211
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
212
+
213
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
214
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
215
+
216
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
217
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
218
+
219
+ self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
220
+ self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
221
+
222
+ self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
223
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
224
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
225
+
226
+
227
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
228
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
229
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
230
+
231
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
232
+
233
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
234
+
235
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
236
+
237
+ #### For Dual-Pixel Defocus Deblurring Task ####
238
+ self.dual_pixel_task = dual_pixel_task
239
+ if self.dual_pixel_task:
240
+ self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
241
+ ###########################
242
+
243
+
244
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
245
+
246
+ def forward(self, inp_img,task=''):
247
+
248
+ inp_enc_level1 = self.patch_embed(inp_img)
249
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
250
+
251
+ inp_enc_level2 = self.down1_2(out_enc_level1)
252
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
253
+
254
+ inp_enc_level3 = self.down2_3(out_enc_level2)
255
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
256
+
257
+ inp_enc_level4 = self.down3_4(out_enc_level3)
258
+ latent = self.latent(inp_enc_level4)
259
+
260
+
261
+ inp_dec_level3 = self.up4_3(latent)
262
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
263
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
264
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
265
+
266
+ inp_dec_level2 = self.up3_2(out_dec_level3)
267
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
268
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
269
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
270
+
271
+ inp_dec_level1 = self.up2_1(out_dec_level2)
272
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
273
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
274
+
275
+ out_dec_level1 = self.refinement(out_dec_level1)
276
+
277
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
278
+ out_dec_level1 = self.output(out_dec_level1)
279
+
280
+ return out_dec_level1
281
+
282
+
283
+
284
+ if __name__ == '__main__':
285
+ from torchtoolbox.tools import summary
286
+ model = Restormer(
287
+ inp_channels=6,
288
+ out_channels=3,
289
+ dim = 48,
290
+ # num_blocks = [4,6,6,8],
291
+ num_blocks = [2,3,3,4],
292
+ num_refinement_blocks = 4,
293
+ heads = [1,2,4,8],
294
+ ffn_expansion_factor = 2.66,
295
+ bias = False,
296
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
297
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
298
+ )
299
+ # model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
300
+ print(summary(model,torch.rand((1, 6, 256, 256))))
301
+
302
+ from thop import profile
303
+ input = torch.rand((1, 6, 256, 256))
304
+ gflops,params = profile(model,inputs=(input,))
305
+ gflops = gflops*2 / 10**9
306
+ params = params / 10**6
307
+ print(gflops,'==============')
308
+ print(params,'==============')