doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +168 -0
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +566 -0
- doctra/engines/vlm/service.py +0 -12
- doctra/parsers/enhanced_pdf_parser.py +370 -0
- doctra/parsers/structured_pdf_parser.py +11 -60
- doctra/parsers/table_chart_extractor.py +8 -44
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +5 -32
- doctra/utils/progress.py +13 -98
- doctra/utils/structured_utils.py +45 -49
- doctra/version.py +1 -1
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
- doctra-0.4.0.dist-info/RECORD +67 -0
- doctra-0.3.2.dist-info/RECORD +0 -44
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,308 @@
|
|
1
|
+
## Restormer: Efficient Transformer for High-Resolution Image Restoration
|
2
|
+
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang
|
3
|
+
## https://arxiv.org/abs/2111.09881
|
4
|
+
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from pdb import set_trace as stx
|
10
|
+
import numbers
|
11
|
+
|
12
|
+
from einops import rearrange
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
##########################################################################
|
17
|
+
## Layer Norm
|
18
|
+
|
19
|
+
def to_3d(x):
|
20
|
+
return rearrange(x, 'b c h w -> b (h w) c')
|
21
|
+
|
22
|
+
def to_4d(x,h,w):
|
23
|
+
return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
|
24
|
+
|
25
|
+
class BiasFree_LayerNorm(nn.Module):
|
26
|
+
def __init__(self, normalized_shape):
|
27
|
+
super(BiasFree_LayerNorm, self).__init__()
|
28
|
+
if isinstance(normalized_shape, numbers.Integral):
|
29
|
+
normalized_shape = (normalized_shape,)
|
30
|
+
normalized_shape = torch.Size(normalized_shape)
|
31
|
+
|
32
|
+
assert len(normalized_shape) == 1
|
33
|
+
|
34
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
35
|
+
self.normalized_shape = normalized_shape
|
36
|
+
|
37
|
+
def forward(self, x):
|
38
|
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
39
|
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
40
|
+
|
41
|
+
class WithBias_LayerNorm(nn.Module):
|
42
|
+
def __init__(self, normalized_shape):
|
43
|
+
super(WithBias_LayerNorm, self).__init__()
|
44
|
+
if isinstance(normalized_shape, numbers.Integral):
|
45
|
+
normalized_shape = (normalized_shape,)
|
46
|
+
normalized_shape = torch.Size(normalized_shape)
|
47
|
+
|
48
|
+
assert len(normalized_shape) == 1
|
49
|
+
|
50
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
51
|
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
52
|
+
self.normalized_shape = normalized_shape
|
53
|
+
|
54
|
+
def forward(self, x):
|
55
|
+
mu = x.mean(-1, keepdim=True)
|
56
|
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
57
|
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
58
|
+
|
59
|
+
|
60
|
+
class LayerNorm(nn.Module):
|
61
|
+
def __init__(self, dim, LayerNorm_type):
|
62
|
+
super(LayerNorm, self).__init__()
|
63
|
+
if LayerNorm_type =='BiasFree':
|
64
|
+
self.body = BiasFree_LayerNorm(dim)
|
65
|
+
else:
|
66
|
+
self.body = WithBias_LayerNorm(dim)
|
67
|
+
|
68
|
+
def forward(self, x):
|
69
|
+
h, w = x.shape[-2:]
|
70
|
+
return to_4d(self.body(to_3d(x)), h, w)
|
71
|
+
|
72
|
+
|
73
|
+
|
74
|
+
##########################################################################
|
75
|
+
## Gated-Dconv Feed-Forward Network (GDFN)
|
76
|
+
class FeedForward(nn.Module):
|
77
|
+
def __init__(self, dim, ffn_expansion_factor, bias):
|
78
|
+
super(FeedForward, self).__init__()
|
79
|
+
|
80
|
+
hidden_features = int(dim*ffn_expansion_factor)
|
81
|
+
|
82
|
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
83
|
+
|
84
|
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
85
|
+
|
86
|
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
87
|
+
|
88
|
+
def forward(self, x):
|
89
|
+
x = self.project_in(x)
|
90
|
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
91
|
+
x = F.gelu(x1) * x2
|
92
|
+
x = self.project_out(x)
|
93
|
+
return x
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
##########################################################################
|
98
|
+
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
99
|
+
class Attention(nn.Module):
|
100
|
+
def __init__(self, dim, num_heads, bias):
|
101
|
+
super(Attention, self).__init__()
|
102
|
+
self.num_heads = num_heads
|
103
|
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
104
|
+
|
105
|
+
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
106
|
+
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
107
|
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
108
|
+
|
109
|
+
|
110
|
+
|
111
|
+
def forward(self, x):
|
112
|
+
b,c,h,w = x.shape
|
113
|
+
|
114
|
+
qkv = self.qkv_dwconv(self.qkv(x))
|
115
|
+
q,k,v = qkv.chunk(3, dim=1)
|
116
|
+
|
117
|
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
118
|
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
119
|
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
120
|
+
|
121
|
+
q = torch.nn.functional.normalize(q, dim=-1)
|
122
|
+
k = torch.nn.functional.normalize(k, dim=-1)
|
123
|
+
|
124
|
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
125
|
+
attn = attn.softmax(dim=-1)
|
126
|
+
|
127
|
+
out = (attn @ v)
|
128
|
+
|
129
|
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
130
|
+
|
131
|
+
out = self.project_out(out)
|
132
|
+
return out
|
133
|
+
|
134
|
+
|
135
|
+
|
136
|
+
##########################################################################
|
137
|
+
class TransformerBlock(nn.Module):
|
138
|
+
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
139
|
+
super(TransformerBlock, self).__init__()
|
140
|
+
|
141
|
+
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
142
|
+
self.attn = Attention(dim, num_heads, bias)
|
143
|
+
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
144
|
+
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
145
|
+
|
146
|
+
def forward(self, x):
|
147
|
+
x = x + self.attn(self.norm1(x))
|
148
|
+
x = x + self.ffn(self.norm2(x))
|
149
|
+
|
150
|
+
return x
|
151
|
+
|
152
|
+
|
153
|
+
|
154
|
+
##########################################################################
|
155
|
+
## Overlapped image patch embedding with 3x3 Conv
|
156
|
+
class OverlapPatchEmbed(nn.Module):
|
157
|
+
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
158
|
+
super(OverlapPatchEmbed, self).__init__()
|
159
|
+
|
160
|
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
|
161
|
+
|
162
|
+
def forward(self, x):
|
163
|
+
x = self.proj(x)
|
164
|
+
|
165
|
+
return x
|
166
|
+
|
167
|
+
|
168
|
+
|
169
|
+
##########################################################################
|
170
|
+
## Resizing modules
|
171
|
+
class Downsample(nn.Module):
|
172
|
+
def __init__(self, n_feat):
|
173
|
+
super(Downsample, self).__init__()
|
174
|
+
|
175
|
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
|
176
|
+
nn.PixelUnshuffle(2))
|
177
|
+
|
178
|
+
def forward(self, x):
|
179
|
+
return self.body(x)
|
180
|
+
|
181
|
+
class Upsample(nn.Module):
|
182
|
+
def __init__(self, n_feat):
|
183
|
+
super(Upsample, self).__init__()
|
184
|
+
|
185
|
+
self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
|
186
|
+
nn.PixelShuffle(2))
|
187
|
+
|
188
|
+
def forward(self, x):
|
189
|
+
return self.body(x)
|
190
|
+
|
191
|
+
##########################################################################
|
192
|
+
##---------- Restormer -----------------------
|
193
|
+
class Restormer(nn.Module):
|
194
|
+
def __init__(self,
|
195
|
+
inp_channels=3,
|
196
|
+
out_channels=3,
|
197
|
+
dim = 48,
|
198
|
+
num_blocks = [4,6,6,8],
|
199
|
+
num_refinement_blocks = 4,
|
200
|
+
heads = [1,2,4,8],
|
201
|
+
ffn_expansion_factor = 2.66,
|
202
|
+
bias = False,
|
203
|
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
204
|
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
205
|
+
):
|
206
|
+
|
207
|
+
super(Restormer, self).__init__()
|
208
|
+
|
209
|
+
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
210
|
+
|
211
|
+
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
212
|
+
|
213
|
+
self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
|
214
|
+
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
215
|
+
|
216
|
+
self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
|
217
|
+
self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
218
|
+
|
219
|
+
self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
|
220
|
+
self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
|
221
|
+
|
222
|
+
self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
|
223
|
+
self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
|
224
|
+
self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
|
225
|
+
|
226
|
+
|
227
|
+
self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
|
228
|
+
self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
|
229
|
+
self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
230
|
+
|
231
|
+
self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
|
232
|
+
|
233
|
+
self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
234
|
+
|
235
|
+
self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
|
236
|
+
|
237
|
+
#### For Dual-Pixel Defocus Deblurring Task ####
|
238
|
+
self.dual_pixel_task = dual_pixel_task
|
239
|
+
if self.dual_pixel_task:
|
240
|
+
self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
|
241
|
+
###########################
|
242
|
+
|
243
|
+
|
244
|
+
self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
245
|
+
|
246
|
+
def forward(self, inp_img,task=''):
|
247
|
+
|
248
|
+
inp_enc_level1 = self.patch_embed(inp_img)
|
249
|
+
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
250
|
+
|
251
|
+
inp_enc_level2 = self.down1_2(out_enc_level1)
|
252
|
+
out_enc_level2 = self.encoder_level2(inp_enc_level2)
|
253
|
+
|
254
|
+
inp_enc_level3 = self.down2_3(out_enc_level2)
|
255
|
+
out_enc_level3 = self.encoder_level3(inp_enc_level3)
|
256
|
+
|
257
|
+
inp_enc_level4 = self.down3_4(out_enc_level3)
|
258
|
+
latent = self.latent(inp_enc_level4)
|
259
|
+
|
260
|
+
|
261
|
+
inp_dec_level3 = self.up4_3(latent)
|
262
|
+
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
263
|
+
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
264
|
+
out_dec_level3 = self.decoder_level3(inp_dec_level3)
|
265
|
+
|
266
|
+
inp_dec_level2 = self.up3_2(out_dec_level3)
|
267
|
+
inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
|
268
|
+
inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
|
269
|
+
out_dec_level2 = self.decoder_level2(inp_dec_level2)
|
270
|
+
|
271
|
+
inp_dec_level1 = self.up2_1(out_dec_level2)
|
272
|
+
inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
|
273
|
+
out_dec_level1 = self.decoder_level1(inp_dec_level1)
|
274
|
+
|
275
|
+
out_dec_level1 = self.refinement(out_dec_level1)
|
276
|
+
|
277
|
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
278
|
+
out_dec_level1 = self.output(out_dec_level1)
|
279
|
+
|
280
|
+
return out_dec_level1
|
281
|
+
|
282
|
+
|
283
|
+
|
284
|
+
if __name__ == '__main__':
|
285
|
+
from torchtoolbox.tools import summary
|
286
|
+
model = Restormer(
|
287
|
+
inp_channels=6,
|
288
|
+
out_channels=3,
|
289
|
+
dim = 48,
|
290
|
+
# num_blocks = [4,6,6,8],
|
291
|
+
num_blocks = [2,3,3,4],
|
292
|
+
num_refinement_blocks = 4,
|
293
|
+
heads = [1,2,4,8],
|
294
|
+
ffn_expansion_factor = 2.66,
|
295
|
+
bias = False,
|
296
|
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
297
|
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
298
|
+
)
|
299
|
+
# model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
|
300
|
+
print(summary(model,torch.rand((1, 6, 256, 256))))
|
301
|
+
|
302
|
+
from thop import profile
|
303
|
+
input = torch.rand((1, 6, 256, 256))
|
304
|
+
gflops,params = profile(model,inputs=(input,))
|
305
|
+
gflops = gflops*2 / 10**9
|
306
|
+
params = params / 10**6
|
307
|
+
print(gflops,'==============')
|
308
|
+
print(params,'==============')
|