doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctra/__init__.py +4 -0
- doctra/cli/main.py +168 -0
- doctra/engines/image_restoration/__init__.py +10 -0
- doctra/engines/image_restoration/docres_engine.py +566 -0
- doctra/engines/vlm/service.py +0 -12
- doctra/parsers/enhanced_pdf_parser.py +370 -0
- doctra/parsers/structured_pdf_parser.py +11 -60
- doctra/parsers/table_chart_extractor.py +8 -44
- doctra/third_party/docres/data/MBD/MBD.py +110 -0
- doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
- doctra/third_party/docres/data/MBD/infer.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
- doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
- doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
- doctra/third_party/docres/inference.py +370 -0
- doctra/third_party/docres/models/restormer_arch.py +308 -0
- doctra/third_party/docres/utils.py +464 -0
- doctra/ui/app.py +5 -32
- doctra/utils/progress.py +13 -98
- doctra/utils/structured_utils.py +45 -49
- doctra/version.py +1 -1
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
- doctra-0.4.0.dist-info/RECORD +67 -0
- doctra-0.3.2.dist-info/RECORD +0 -44
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -61,22 +61,17 @@ class ChartTablePDFParser:
|
|
61
61
|
):
|
62
62
|
"""
|
63
63
|
Initialize the ChartTablePDFParser with extraction configuration.
|
64
|
-
|
65
|
-
Sets up the layout detection engine and optionally the VLM service
|
66
|
-
for structured data extraction.
|
67
64
|
|
68
|
-
:param extract_charts: Whether to extract charts from the document
|
69
|
-
:param extract_tables: Whether to extract tables from the document
|
70
|
-
:param use_vlm: Whether to use VLM for structured data extraction
|
71
|
-
:param vlm_provider: VLM provider to use ("gemini", "openai", "anthropic", or "openrouter")
|
65
|
+
:param extract_charts: Whether to extract charts from the document (default: True)
|
66
|
+
:param extract_tables: Whether to extract tables from the document (default: True)
|
67
|
+
:param use_vlm: Whether to use VLM for structured data extraction (default: False)
|
68
|
+
:param vlm_provider: VLM provider to use ("gemini", "openai", "anthropic", or "openrouter", default: "gemini")
|
72
69
|
:param vlm_model: Model name to use (defaults to provider-specific defaults)
|
73
|
-
:param vlm_api_key: API key for VLM provider
|
74
|
-
:param layout_model_name: Layout detection model name
|
75
|
-
:param dpi: DPI for PDF rendering
|
76
|
-
:param min_score: Minimum confidence score for layout detection
|
77
|
-
:raises ValueError: If neither extract_charts nor extract_tables is True
|
70
|
+
:param vlm_api_key: API key for VLM provider (required if use_vlm is True)
|
71
|
+
:param layout_model_name: Layout detection model name (default: "PP-DocLayout_plus-L")
|
72
|
+
:param dpi: DPI for PDF rendering (default: 200)
|
73
|
+
:param min_score: Minimum confidence score for layout detection (default: 0.0)
|
78
74
|
"""
|
79
|
-
# Validation
|
80
75
|
if not extract_charts and not extract_tables:
|
81
76
|
raise ValueError("At least one of extract_charts or extract_tables must be True")
|
82
77
|
|
@@ -98,21 +93,15 @@ class ChartTablePDFParser:
|
|
98
93
|
def parse(self, pdf_path: str, output_base_dir: str = "outputs") -> None:
|
99
94
|
"""
|
100
95
|
Parse a PDF document and extract charts and/or tables.
|
101
|
-
|
102
|
-
Processes the PDF through layout detection, extracts the specified
|
103
|
-
element types, saves cropped images, and optionally converts them
|
104
|
-
to structured data using VLM.
|
105
96
|
|
106
97
|
:param pdf_path: Path to the input PDF file
|
107
98
|
:param output_base_dir: Base directory for output files (default: "outputs")
|
108
99
|
:return: None
|
109
100
|
"""
|
110
|
-
# Create output directory structure: outputs/<filename>/structured_parsing/
|
111
101
|
pdf_name = Path(pdf_path).stem
|
112
102
|
out_dir = os.path.join(output_base_dir, pdf_name, "structured_parsing")
|
113
103
|
os.makedirs(out_dir, exist_ok=True)
|
114
104
|
|
115
|
-
# Create subdirectories based on what we're extracting
|
116
105
|
charts_dir = None
|
117
106
|
tables_dir = None
|
118
107
|
|
@@ -129,24 +118,20 @@ class ChartTablePDFParser:
|
|
129
118
|
)
|
130
119
|
pil_pages = [im for (im, _, _) in render_pdf_to_images(pdf_path, dpi=self.dpi)]
|
131
120
|
|
132
|
-
# Determine which labels to extract
|
133
121
|
target_labels = []
|
134
122
|
if self.extract_charts:
|
135
123
|
target_labels.append("chart")
|
136
124
|
if self.extract_tables:
|
137
125
|
target_labels.append("table")
|
138
126
|
|
139
|
-
# Count items for progress bars
|
140
127
|
chart_count = sum(sum(1 for b in p.boxes if b.label == "chart") for p in pages) if self.extract_charts else 0
|
141
128
|
table_count = sum(sum(1 for b in p.boxes if b.label == "table") for p in pages) if self.extract_tables else 0
|
142
129
|
|
143
|
-
# Prepare output content
|
144
130
|
if self.use_vlm:
|
145
131
|
md_lines: List[str] = ["# Extracted Charts and Tables\n"]
|
146
132
|
structured_items: List[Dict[str, Any]] = []
|
147
133
|
vlm_items: List[Dict[str, Any]] = []
|
148
134
|
|
149
|
-
# Progress bar descriptions
|
150
135
|
charts_desc = "Charts (VLM → table)" if self.use_vlm else "Charts (cropped)"
|
151
136
|
tables_desc = "Tables (VLM → table)" if self.use_vlm else "Tables (cropped)"
|
152
137
|
|
@@ -154,11 +139,9 @@ class ChartTablePDFParser:
|
|
154
139
|
table_counter = 1
|
155
140
|
|
156
141
|
with ExitStack() as stack:
|
157
|
-
# Enhanced environment detection
|
158
142
|
is_notebook = "ipykernel" in sys.modules or "jupyter" in sys.modules
|
159
143
|
is_terminal = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
|
160
144
|
|
161
|
-
# Use appropriate progress bars based on environment
|
162
145
|
if is_notebook:
|
163
146
|
charts_bar = stack.enter_context(
|
164
147
|
create_notebook_friendly_bar(total=chart_count, desc=charts_desc)) if chart_count else None
|
@@ -174,23 +157,19 @@ class ChartTablePDFParser:
|
|
174
157
|
page_num = p.page_index
|
175
158
|
page_img: Image.Image = pil_pages[page_num - 1]
|
176
159
|
|
177
|
-
# Only process selected item types
|
178
160
|
target_items = [box for box in p.boxes if box.label in target_labels]
|
179
161
|
|
180
162
|
if target_items and self.use_vlm:
|
181
163
|
md_lines.append(f"\n## Page {page_num}\n")
|
182
164
|
|
183
165
|
for box in sorted(target_items, key=reading_order_key):
|
184
|
-
# Handle charts
|
185
166
|
if box.label == "chart" and self.extract_charts:
|
186
167
|
chart_filename = f"chart_{chart_counter:03d}.png"
|
187
168
|
chart_path = os.path.join(charts_dir, chart_filename)
|
188
169
|
|
189
|
-
# Save image
|
190
170
|
cropped_img = page_img.crop((box.x1, box.y1, box.x2, box.y2))
|
191
171
|
cropped_img.save(chart_path)
|
192
172
|
|
193
|
-
# Handle VLM processing if enabled
|
194
173
|
if self.use_vlm and self.vlm:
|
195
174
|
rel_path = os.path.join("charts", chart_filename)
|
196
175
|
wrote_table = False
|
@@ -227,16 +206,13 @@ class ChartTablePDFParser:
|
|
227
206
|
if charts_bar:
|
228
207
|
charts_bar.update(1)
|
229
208
|
|
230
|
-
# Handle tables
|
231
209
|
elif box.label == "table" and self.extract_tables:
|
232
210
|
table_filename = f"table_{table_counter:03d}.png"
|
233
211
|
table_path = os.path.join(tables_dir, table_filename)
|
234
212
|
|
235
|
-
# Save image
|
236
213
|
cropped_img = page_img.crop((box.x1, box.y1, box.x2, box.y2))
|
237
214
|
cropped_img.save(table_path)
|
238
215
|
|
239
|
-
# Handle VLM processing if enabled
|
240
216
|
if self.use_vlm and self.vlm:
|
241
217
|
rel_path = os.path.join("tables", table_filename)
|
242
218
|
wrote_table = False
|
@@ -273,19 +249,11 @@ class ChartTablePDFParser:
|
|
273
249
|
if tables_bar:
|
274
250
|
tables_bar.update(1)
|
275
251
|
|
276
|
-
# Write outputs only if VLM is used
|
277
|
-
md_path = None
|
278
252
|
excel_path = None
|
279
253
|
|
280
254
|
if self.use_vlm:
|
281
|
-
# Write markdown file
|
282
|
-
md_path = os.path.join(out_dir, "charts.md")
|
283
|
-
with open(md_path, 'w', encoding='utf-8') as f:
|
284
|
-
f.write('\n'.join(md_lines))
|
285
255
|
|
286
|
-
# Write Excel file if we have structured data
|
287
256
|
if structured_items:
|
288
|
-
# Determine Excel filename based on extraction target
|
289
257
|
if self.extract_charts and self.extract_tables:
|
290
258
|
excel_filename = "parsed_tables_charts.xlsx"
|
291
259
|
elif self.extract_charts:
|
@@ -299,23 +267,19 @@ class ChartTablePDFParser:
|
|
299
267
|
excel_path = os.path.join(out_dir, excel_filename)
|
300
268
|
write_structured_excel(excel_path, structured_items)
|
301
269
|
|
302
|
-
# Also create HTML version
|
303
270
|
html_filename = excel_filename.replace('.xlsx', '.html')
|
304
271
|
html_path = os.path.join(out_dir, html_filename)
|
305
272
|
write_structured_html(html_path, structured_items)
|
306
273
|
|
307
|
-
# Write VLM items mapping for UI linkage
|
308
274
|
if 'vlm_items' in locals() and vlm_items:
|
309
275
|
with open(os.path.join(out_dir, "vlm_items.json"), 'w', encoding='utf-8') as jf:
|
310
276
|
json.dump(vlm_items, jf, ensure_ascii=False, indent=2)
|
311
277
|
|
312
|
-
# Print results
|
313
278
|
extraction_types = []
|
314
279
|
if self.extract_charts:
|
315
280
|
extraction_types.append("charts")
|
316
281
|
if self.extract_tables:
|
317
282
|
extraction_types.append("tables")
|
318
283
|
|
319
|
-
# Print completion message with output directory
|
320
284
|
print(f"✅ Parsing completed successfully!")
|
321
285
|
print(f"📁 Output directory: {out_dir}")
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import cv2
|
2
|
+
import numpy as np
|
3
|
+
import MBD_utils
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as F
|
6
|
+
|
7
|
+
|
8
|
+
def mask_base_dewarper(image,mask):
|
9
|
+
'''
|
10
|
+
input:
|
11
|
+
image -> ndarray HxWx3 uint8
|
12
|
+
mask -> ndarray HxW uint8
|
13
|
+
return
|
14
|
+
dewarped -> ndarray HxWx3 uint8
|
15
|
+
grid (optional) -> ndarray HxWx2 -1~1
|
16
|
+
'''
|
17
|
+
|
18
|
+
## get contours
|
19
|
+
# _, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
|
20
|
+
contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
|
21
|
+
|
22
|
+
## get biggest contours and four corners based on Douglas-Peucker algorithm
|
23
|
+
four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
|
24
|
+
four_corners = MBD_utils.reorder(four_corners)
|
25
|
+
|
26
|
+
## reserve biggest contours and remove other noisy contours
|
27
|
+
new_mask = np.zeros_like(mask)
|
28
|
+
new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
|
29
|
+
|
30
|
+
## obtain middle points
|
31
|
+
# ratios = [0.25,0.5,0.75] # ratios = [0.125,0.25,0.375,0.5,0.625,0.75,0.875]
|
32
|
+
ratios = [0.25,0.5,0.75]
|
33
|
+
# ratios = [0.0625,0.125,0.1875,0.25,0.3125,0.375,0.4475,0.5,0.5625,0.625,0.06875,0.75,0.8125,0.875,0.9375]
|
34
|
+
middle = MBD_utils.findMiddle(corners=four_corners,mask=new_mask,points=ratios)
|
35
|
+
|
36
|
+
## all points
|
37
|
+
source_points = np.concatenate((four_corners,middle),axis=0) ## all_point = four_corners(topleft,topright,bottom)+top+bottom+left+right
|
38
|
+
|
39
|
+
## target points
|
40
|
+
h,w = image.shape[:2]
|
41
|
+
padding = 0
|
42
|
+
target_points = [[padding, padding],[w-padding, padding], [padding, h-padding],[w-padding, h-padding]]
|
43
|
+
for ratio in ratios:
|
44
|
+
target_points.append([int((w-2*padding)*ratio)+padding,padding])
|
45
|
+
for ratio in ratios:
|
46
|
+
target_points.append([int((w-2*padding)*ratio)+padding,h-padding])
|
47
|
+
for ratio in ratios:
|
48
|
+
target_points.append([padding,int((h-2*padding)*ratio)+padding])
|
49
|
+
for ratio in ratios:
|
50
|
+
target_points.append([w-padding,int((h-2*padding)*ratio)+padding])
|
51
|
+
|
52
|
+
## dewarp base on cv2
|
53
|
+
# pts1 = np.float32(source_points)
|
54
|
+
# pts2 = np.float32(target_points)
|
55
|
+
# tps = cv2.createThinPlateSplineShapeTransformer()
|
56
|
+
# matches = []
|
57
|
+
# N = pts1.shape[0]
|
58
|
+
# for i in range(0,N):
|
59
|
+
# matches.append(cv2.DMatch(i,i,0))
|
60
|
+
# pts1 = pts1.reshape(1,-1,2)
|
61
|
+
# pts2 = pts2.reshape(1,-1,2)
|
62
|
+
# tps.estimateTransformation(pts2,pts1,matches)
|
63
|
+
# dewarped = tps.warpImage(image)
|
64
|
+
|
65
|
+
## dewarp base on generated grid
|
66
|
+
source_points = source_points.reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
|
67
|
+
source_points = torch.from_numpy(source_points).float().cuda()
|
68
|
+
source_points = source_points.unsqueeze(0)
|
69
|
+
source_points = (source_points-0.5)*2
|
70
|
+
target_points = np.asarray(target_points).reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
|
71
|
+
target_points = torch.from_numpy(target_points).float()
|
72
|
+
target_points = (target_points-0.5)*2
|
73
|
+
|
74
|
+
model = MBD_utils.TPSGridGen(target_height=256,target_width=256,target_control_points=target_points)
|
75
|
+
model = model.cuda()
|
76
|
+
grid = model(source_points).view(-1,256,256,2).permute(0,3,1,2)
|
77
|
+
grid = F.interpolate(grid,(h,w),mode='bilinear').permute(0,2,3,1)
|
78
|
+
dewarped = MBD_utils.torch2cvimg(F.grid_sample(MBD_utils.cvimg2torch(image).cuda(),grid))[0]
|
79
|
+
return dewarped,grid[0].cpu().numpy()
|
80
|
+
|
81
|
+
def mask_base_cropper(image,mask):
|
82
|
+
'''
|
83
|
+
input:
|
84
|
+
image -> ndarray HxWx3 uint8
|
85
|
+
mask -> ndarray HxW uint8
|
86
|
+
return
|
87
|
+
dewarped -> ndarray HxWx3 uint8
|
88
|
+
grid (optional) -> ndarray HxWx2 -1~1
|
89
|
+
'''
|
90
|
+
|
91
|
+
## get contours
|
92
|
+
_, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
|
93
|
+
# contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
|
94
|
+
|
95
|
+
## get biggest contours and four corners based on Douglas-Peucker algorithm
|
96
|
+
four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
|
97
|
+
four_corners = MBD_utils.reorder(four_corners)
|
98
|
+
|
99
|
+
## reserve biggest contours and remove other noisy contours
|
100
|
+
new_mask = np.zeros_like(mask)
|
101
|
+
new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
|
102
|
+
|
103
|
+
## 最小外接矩形
|
104
|
+
rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
|
105
|
+
box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
|
106
|
+
box = np.int0(box)
|
107
|
+
box = box.reshape((4,1,2))
|
108
|
+
|
109
|
+
|
110
|
+
|
@@ -0,0 +1,291 @@
|
|
1
|
+
import cv2
|
2
|
+
import numpy as np
|
3
|
+
import copy
|
4
|
+
import torch
|
5
|
+
import torch
|
6
|
+
import itertools
|
7
|
+
import torch.nn as nn
|
8
|
+
from torch.autograd import Function, Variable
|
9
|
+
|
10
|
+
def reorder(myPoints):
|
11
|
+
myPoints = myPoints.reshape((4, 2))
|
12
|
+
myPointsNew = np.zeros((4, 1, 2), dtype=np.int32)
|
13
|
+
add = myPoints.sum(1)
|
14
|
+
myPointsNew[0] = myPoints[np.argmin(add)]
|
15
|
+
myPointsNew[3] =myPoints[np.argmax(add)]
|
16
|
+
diff = np.diff(myPoints, axis=1)
|
17
|
+
myPointsNew[1] =myPoints[np.argmin(diff)]
|
18
|
+
myPointsNew[2] = myPoints[np.argmax(diff)]
|
19
|
+
return myPointsNew
|
20
|
+
|
21
|
+
|
22
|
+
def findMiddle(corners,mask,points=[0.25,0.5,0.75]):
|
23
|
+
num_middle_points = len(points)
|
24
|
+
top = [np.array([])]*num_middle_points
|
25
|
+
bottom = [np.array([])]*num_middle_points
|
26
|
+
left = [np.array([])]*num_middle_points
|
27
|
+
right = [np.array([])]*num_middle_points
|
28
|
+
|
29
|
+
center_top = []
|
30
|
+
center_bottom = []
|
31
|
+
center_left = []
|
32
|
+
center_right = []
|
33
|
+
|
34
|
+
center = (int((corners[0][0][1]+corners[3][0][1])/2),int((corners[0][0][0]+corners[3][0][0])/2))
|
35
|
+
for ratio in points:
|
36
|
+
|
37
|
+
center_top.append( (center[0],int(corners[0][0][0]*(1-ratio)+corners[1][0][0]*ratio)) )
|
38
|
+
|
39
|
+
center_bottom.append( (center[0],int(corners[2][0][0]*(1-ratio)+corners[3][0][0]*ratio)) )
|
40
|
+
|
41
|
+
center_left.append( (int(corners[0][0][1]*(1-ratio)+corners[2][0][1]*ratio),center[1]) )
|
42
|
+
|
43
|
+
center_right.append( (int(corners[1][0][1]*(1-ratio)+corners[3][0][1]*ratio),center[1]) )
|
44
|
+
|
45
|
+
for i in range(0,center[0],1):
|
46
|
+
for j in range(num_middle_points):
|
47
|
+
if top[j].size==0:
|
48
|
+
if mask[i,center_top[j][1]]==255:
|
49
|
+
top[j] = np.asarray([center_top[j][1],i])
|
50
|
+
top[j] = top[j].reshape(1,2)
|
51
|
+
|
52
|
+
for i in range(mask.shape[0]-1,center[0],-1):
|
53
|
+
for j in range(num_middle_points):
|
54
|
+
if bottom[j].size==0:
|
55
|
+
if mask[i,center_bottom[j][1]]==255:
|
56
|
+
bottom[j] = np.asarray([center_bottom[j][1],i])
|
57
|
+
bottom[j] = bottom[j].reshape(1,2)
|
58
|
+
|
59
|
+
for i in range(mask.shape[1]-1,center[1],-1):
|
60
|
+
for j in range(num_middle_points):
|
61
|
+
if right[j].size==0:
|
62
|
+
if mask[center_right[j][0],i]==255:
|
63
|
+
right[j] = np.asarray([i,center_right[j][0]])
|
64
|
+
right[j] = right[j].reshape(1,2)
|
65
|
+
|
66
|
+
for i in range(0,center[1]):
|
67
|
+
for j in range(num_middle_points):
|
68
|
+
if left[j].size==0:
|
69
|
+
if mask[center_left[j][0],i]==255:
|
70
|
+
left[j] = np.asarray([i,center_left[j][0]])
|
71
|
+
left[j] = left[j].reshape(1,2)
|
72
|
+
|
73
|
+
return np.asarray(top+bottom+left+right)
|
74
|
+
|
75
|
+
def DP_algorithmv1(contours):
|
76
|
+
biggest = np.array([])
|
77
|
+
max_area = 0
|
78
|
+
step = 0.001
|
79
|
+
count = 0
|
80
|
+
# while biggest.size==0:
|
81
|
+
while True:
|
82
|
+
for i in contours:
|
83
|
+
# print(i.shape)
|
84
|
+
area = cv2.contourArea(i)
|
85
|
+
# print(area,cv2.arcLength(i, True))
|
86
|
+
if area > cv2.arcLength(i, True)*10:
|
87
|
+
peri = cv2.arcLength(i, True)
|
88
|
+
approx = cv2.approxPolyDP(i, (0.01+step*count) * peri, True)
|
89
|
+
if area > max_area and len(approx) == 4:
|
90
|
+
max_area = area
|
91
|
+
biggest_contours = i
|
92
|
+
biggest = approx
|
93
|
+
break
|
94
|
+
if abs(max_area - cv2.contourArea(biggest))/max_area > 0.3:
|
95
|
+
biggest = np.array([])
|
96
|
+
count += 1
|
97
|
+
if count > 200:
|
98
|
+
break
|
99
|
+
temp = biggest[0]
|
100
|
+
return biggest,max_area, biggest_contours
|
101
|
+
|
102
|
+
def DP_algorithm(contours):
|
103
|
+
biggest = np.array([])
|
104
|
+
max_area = 0
|
105
|
+
step = 0.001
|
106
|
+
count = 0
|
107
|
+
|
108
|
+
### largest contours
|
109
|
+
for i in contours:
|
110
|
+
area = cv2.contourArea(i)
|
111
|
+
if area > max_area:
|
112
|
+
max_area = area
|
113
|
+
biggest_contours = i
|
114
|
+
peri = cv2.arcLength(biggest_contours, True)
|
115
|
+
|
116
|
+
### find four corners
|
117
|
+
while True:
|
118
|
+
approx = cv2.approxPolyDP(biggest_contours, (0.01+step*count) * peri, True)
|
119
|
+
if len(approx) == 4:
|
120
|
+
biggest = approx
|
121
|
+
break
|
122
|
+
# if abs(max_area - cv2.contourArea(biggest))/max_area > 0.2:
|
123
|
+
# if abs(max_area - cv2.contourArea(biggest))/max_area > 0.4:
|
124
|
+
# biggest = np.array([])
|
125
|
+
count += 1
|
126
|
+
if count > 200:
|
127
|
+
break
|
128
|
+
return biggest,max_area, biggest_contours
|
129
|
+
|
130
|
+
def drawRectangle(img,biggest,color,thickness):
|
131
|
+
cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
|
132
|
+
cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
|
133
|
+
cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
|
134
|
+
cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
|
135
|
+
return img
|
136
|
+
|
137
|
+
def minAreaRect(contours,img):
|
138
|
+
# biggest = np.array([])
|
139
|
+
max_area = 0
|
140
|
+
for i in contours:
|
141
|
+
area = cv2.contourArea(i)
|
142
|
+
if area > max_area:
|
143
|
+
peri = cv2.arcLength(i, True)
|
144
|
+
rect = cv2.minAreaRect(i)
|
145
|
+
points = cv2.boxPoints(rect)
|
146
|
+
max_area = area
|
147
|
+
return points
|
148
|
+
|
149
|
+
def cropRectangle(img,biggest):
|
150
|
+
# print(biggest)
|
151
|
+
w = np.abs(biggest[0][0][0] - biggest[1][0][0])
|
152
|
+
h = np.abs(biggest[0][0][1] - biggest[2][0][1])
|
153
|
+
new_img = np.zeros((w,h,img.shape[-1]),dtype=np.uint8)
|
154
|
+
new_img = img[biggest[0][0][1]:biggest[0][0][1]+h,biggest[0][0][0]:biggest[0][0][0]+w]
|
155
|
+
return new_img
|
156
|
+
|
157
|
+
def cvimg2torch(img,min=0,max=1):
|
158
|
+
'''
|
159
|
+
input:
|
160
|
+
im -> ndarray uint8 HxWxC
|
161
|
+
return
|
162
|
+
tensor -> torch.tensor BxCxHxW
|
163
|
+
'''
|
164
|
+
if len(img.shape)==2:
|
165
|
+
img = np.expand_dims(img,axis=-1)
|
166
|
+
img = img.astype(float) / 255.0
|
167
|
+
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
168
|
+
img = np.expand_dims(img, 0)
|
169
|
+
img = torch.from_numpy(img).float()
|
170
|
+
return img
|
171
|
+
|
172
|
+
def torch2cvimg(tensor,min=0,max=1):
|
173
|
+
'''
|
174
|
+
input:
|
175
|
+
tensor -> torch.tensor BxCxHxW C can be 1,3
|
176
|
+
return
|
177
|
+
im -> ndarray uint8 HxWxC
|
178
|
+
'''
|
179
|
+
im_list = []
|
180
|
+
for i in range(tensor.shape[0]):
|
181
|
+
im = tensor.detach().cpu().data.numpy()[i]
|
182
|
+
im = im.transpose(1,2,0)
|
183
|
+
im = np.clip(im,min,max)
|
184
|
+
im = ((im-min)/(max-min)*255).astype(np.uint8)
|
185
|
+
im_list.append(im)
|
186
|
+
return im_list
|
187
|
+
|
188
|
+
|
189
|
+
|
190
|
+
class TPSGridGen(nn.Module):
|
191
|
+
def __init__(self, target_height, target_width, target_control_points):
|
192
|
+
'''
|
193
|
+
target_control_points -> torch.tensor num_pointx2 -1~1
|
194
|
+
source_control_points -> torch.tensor batch_size x num_point x 2 -1~1
|
195
|
+
return:
|
196
|
+
grid -> batch_size x hw x 2 -1~1
|
197
|
+
'''
|
198
|
+
super(TPSGridGen, self).__init__()
|
199
|
+
assert target_control_points.ndimension() == 2
|
200
|
+
assert target_control_points.size(1) == 2
|
201
|
+
N = target_control_points.size(0)
|
202
|
+
self.num_points = N
|
203
|
+
target_control_points = target_control_points.float()
|
204
|
+
|
205
|
+
# create padded kernel matrix
|
206
|
+
forward_kernel = torch.zeros(N + 3, N + 3)
|
207
|
+
target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
|
208
|
+
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
209
|
+
forward_kernel[:N, -3].fill_(1)
|
210
|
+
forward_kernel[-3, :N].fill_(1)
|
211
|
+
forward_kernel[:N, -2:].copy_(target_control_points)
|
212
|
+
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
213
|
+
# compute inverse matrix
|
214
|
+
inverse_kernel = torch.inverse(forward_kernel)
|
215
|
+
|
216
|
+
# create target cordinate matrix
|
217
|
+
HW = target_height * target_width
|
218
|
+
target_coordinate = list(itertools.product(range(target_height), range(target_width)))
|
219
|
+
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
220
|
+
Y, X = target_coordinate.split(1, dim = 1)
|
221
|
+
Y = Y * 2 / (target_height - 1) - 1
|
222
|
+
X = X * 2 / (target_width - 1) - 1
|
223
|
+
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
|
224
|
+
target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate.to(target_control_points.device), target_control_points)
|
225
|
+
target_coordinate_repr = torch.cat([
|
226
|
+
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
227
|
+
], dim = 1)
|
228
|
+
|
229
|
+
# register precomputed matrices
|
230
|
+
self.register_buffer('inverse_kernel', inverse_kernel)
|
231
|
+
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
232
|
+
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
233
|
+
|
234
|
+
def forward(self, source_control_points):
|
235
|
+
assert source_control_points.ndimension() == 3
|
236
|
+
assert source_control_points.size(1) == self.num_points
|
237
|
+
assert source_control_points.size(2) == 2
|
238
|
+
batch_size = source_control_points.size(0)
|
239
|
+
|
240
|
+
Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
|
241
|
+
mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
|
242
|
+
source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
|
243
|
+
return source_coordinate
|
244
|
+
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
245
|
+
def compute_partial_repr(self, input_points, control_points):
|
246
|
+
N = input_points.size(0)
|
247
|
+
M = control_points.size(0)
|
248
|
+
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
249
|
+
# original implementation, very slow
|
250
|
+
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
251
|
+
pairwise_diff_square = pairwise_diff * pairwise_diff
|
252
|
+
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
|
253
|
+
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
254
|
+
# fix numerical error for 0 * log(0), substitute all nan with 0
|
255
|
+
mask = repr_matrix != repr_matrix
|
256
|
+
repr_matrix.masked_fill_(mask, 0)
|
257
|
+
return repr_matrix
|
258
|
+
|
259
|
+
|
260
|
+
|
261
|
+
|
262
|
+
|
263
|
+
### deside wheather further process
|
264
|
+
# point_area = cv2.contourArea(np.concatenate((biggest_angle[0].reshape(1,1,2),middle[0:3],biggest_angle[1].reshape(1,1,2),middle[9:12],biggest_angle[3].reshape(1,1,2),middle[3:6][::-1],biggest_angle[2].reshape(1,1,2),middle[6:9][::-1]),axis=0))
|
265
|
+
#### 最小外接矩形
|
266
|
+
# rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
|
267
|
+
# box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
|
268
|
+
# box = np.int0(box)
|
269
|
+
# box = box.reshape((4,1,2))
|
270
|
+
# minrect_area = cv2.contourArea(box)
|
271
|
+
# print(abs(minrect_area-point_area)/point_area)
|
272
|
+
#### 四个角点 IOU
|
273
|
+
# biggest_box = np.concatenate((biggest_angle[0,:,:].reshape(1,1,2),biggest_angle[2,:,:].reshape(1,1,2),biggest_angle[3,:,:].reshape(1,1,2),biggest_angle[1,:,:].reshape(1,1,2)),axis=0)
|
274
|
+
# biggest_mask = np.zeros_like(mask)
|
275
|
+
# # corner_area = cv2.contourArea(biggest_box)
|
276
|
+
# cv2.drawContours(biggest_mask,[biggest_box], -1, color=255, thickness=-1)
|
277
|
+
|
278
|
+
# smooth = 1e-5
|
279
|
+
# biggest_mask_ = biggest_mask > 50
|
280
|
+
# mask_ = mask > 50
|
281
|
+
# intersection = (biggest_mask_ & mask_).sum()
|
282
|
+
# union = (biggest_mask_ | mask_).sum()
|
283
|
+
# iou = (intersection + smooth) / (union + smooth)
|
284
|
+
# if iou > 0.975:
|
285
|
+
# skip = True
|
286
|
+
# else:
|
287
|
+
# skip = False
|
288
|
+
# print(iou)
|
289
|
+
# cv2.imshow('mask',cv2.resize(mask,(512,512)))
|
290
|
+
# cv2.imshow('biggest_mask',cv2.resize(biggest_mask,(512,512)))
|
291
|
+
# cv2.waitKey(0)
|