doctra 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. doctra/__init__.py +4 -0
  2. doctra/cli/main.py +168 -0
  3. doctra/engines/image_restoration/__init__.py +10 -0
  4. doctra/engines/image_restoration/docres_engine.py +566 -0
  5. doctra/engines/vlm/service.py +0 -12
  6. doctra/parsers/enhanced_pdf_parser.py +370 -0
  7. doctra/parsers/structured_pdf_parser.py +11 -60
  8. doctra/parsers/table_chart_extractor.py +8 -44
  9. doctra/third_party/docres/data/MBD/MBD.py +110 -0
  10. doctra/third_party/docres/data/MBD/MBD_utils.py +291 -0
  11. doctra/third_party/docres/data/MBD/infer.py +151 -0
  12. doctra/third_party/docres/data/MBD/model/deep_lab_model/aspp.py +95 -0
  13. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/__init__.py +13 -0
  14. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/drn.py +402 -0
  15. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/mobilenet.py +151 -0
  16. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/resnet.py +170 -0
  17. doctra/third_party/docres/data/MBD/model/deep_lab_model/backbone/xception.py +288 -0
  18. doctra/third_party/docres/data/MBD/model/deep_lab_model/decoder.py +59 -0
  19. doctra/third_party/docres/data/MBD/model/deep_lab_model/deeplab.py +81 -0
  20. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/__init__.py +12 -0
  21. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/batchnorm.py +282 -0
  22. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/comm.py +129 -0
  23. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/replicate.py +88 -0
  24. doctra/third_party/docres/data/MBD/model/deep_lab_model/sync_batchnorm/unittest.py +29 -0
  25. doctra/third_party/docres/data/preprocess/crop_merge_image.py +142 -0
  26. doctra/third_party/docres/inference.py +370 -0
  27. doctra/third_party/docres/models/restormer_arch.py +308 -0
  28. doctra/third_party/docres/utils.py +464 -0
  29. doctra/ui/app.py +5 -32
  30. doctra/utils/progress.py +13 -98
  31. doctra/utils/structured_utils.py +45 -49
  32. doctra/version.py +1 -1
  33. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/METADATA +1 -1
  34. doctra-0.4.0.dist-info/RECORD +67 -0
  35. doctra-0.3.2.dist-info/RECORD +0 -44
  36. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/WHEEL +0 -0
  37. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/licenses/LICENSE +0 -0
  38. {doctra-0.3.2.dist-info → doctra-0.4.0.dist-info}/top_level.txt +0 -0
@@ -61,22 +61,17 @@ class ChartTablePDFParser:
61
61
  ):
62
62
  """
63
63
  Initialize the ChartTablePDFParser with extraction configuration.
64
-
65
- Sets up the layout detection engine and optionally the VLM service
66
- for structured data extraction.
67
64
 
68
- :param extract_charts: Whether to extract charts from the document
69
- :param extract_tables: Whether to extract tables from the document
70
- :param use_vlm: Whether to use VLM for structured data extraction
71
- :param vlm_provider: VLM provider to use ("gemini", "openai", "anthropic", or "openrouter")
65
+ :param extract_charts: Whether to extract charts from the document (default: True)
66
+ :param extract_tables: Whether to extract tables from the document (default: True)
67
+ :param use_vlm: Whether to use VLM for structured data extraction (default: False)
68
+ :param vlm_provider: VLM provider to use ("gemini", "openai", "anthropic", or "openrouter", default: "gemini")
72
69
  :param vlm_model: Model name to use (defaults to provider-specific defaults)
73
- :param vlm_api_key: API key for VLM provider
74
- :param layout_model_name: Layout detection model name
75
- :param dpi: DPI for PDF rendering
76
- :param min_score: Minimum confidence score for layout detection
77
- :raises ValueError: If neither extract_charts nor extract_tables is True
70
+ :param vlm_api_key: API key for VLM provider (required if use_vlm is True)
71
+ :param layout_model_name: Layout detection model name (default: "PP-DocLayout_plus-L")
72
+ :param dpi: DPI for PDF rendering (default: 200)
73
+ :param min_score: Minimum confidence score for layout detection (default: 0.0)
78
74
  """
79
- # Validation
80
75
  if not extract_charts and not extract_tables:
81
76
  raise ValueError("At least one of extract_charts or extract_tables must be True")
82
77
 
@@ -98,21 +93,15 @@ class ChartTablePDFParser:
98
93
  def parse(self, pdf_path: str, output_base_dir: str = "outputs") -> None:
99
94
  """
100
95
  Parse a PDF document and extract charts and/or tables.
101
-
102
- Processes the PDF through layout detection, extracts the specified
103
- element types, saves cropped images, and optionally converts them
104
- to structured data using VLM.
105
96
 
106
97
  :param pdf_path: Path to the input PDF file
107
98
  :param output_base_dir: Base directory for output files (default: "outputs")
108
99
  :return: None
109
100
  """
110
- # Create output directory structure: outputs/<filename>/structured_parsing/
111
101
  pdf_name = Path(pdf_path).stem
112
102
  out_dir = os.path.join(output_base_dir, pdf_name, "structured_parsing")
113
103
  os.makedirs(out_dir, exist_ok=True)
114
104
 
115
- # Create subdirectories based on what we're extracting
116
105
  charts_dir = None
117
106
  tables_dir = None
118
107
 
@@ -129,24 +118,20 @@ class ChartTablePDFParser:
129
118
  )
130
119
  pil_pages = [im for (im, _, _) in render_pdf_to_images(pdf_path, dpi=self.dpi)]
131
120
 
132
- # Determine which labels to extract
133
121
  target_labels = []
134
122
  if self.extract_charts:
135
123
  target_labels.append("chart")
136
124
  if self.extract_tables:
137
125
  target_labels.append("table")
138
126
 
139
- # Count items for progress bars
140
127
  chart_count = sum(sum(1 for b in p.boxes if b.label == "chart") for p in pages) if self.extract_charts else 0
141
128
  table_count = sum(sum(1 for b in p.boxes if b.label == "table") for p in pages) if self.extract_tables else 0
142
129
 
143
- # Prepare output content
144
130
  if self.use_vlm:
145
131
  md_lines: List[str] = ["# Extracted Charts and Tables\n"]
146
132
  structured_items: List[Dict[str, Any]] = []
147
133
  vlm_items: List[Dict[str, Any]] = []
148
134
 
149
- # Progress bar descriptions
150
135
  charts_desc = "Charts (VLM → table)" if self.use_vlm else "Charts (cropped)"
151
136
  tables_desc = "Tables (VLM → table)" if self.use_vlm else "Tables (cropped)"
152
137
 
@@ -154,11 +139,9 @@ class ChartTablePDFParser:
154
139
  table_counter = 1
155
140
 
156
141
  with ExitStack() as stack:
157
- # Enhanced environment detection
158
142
  is_notebook = "ipykernel" in sys.modules or "jupyter" in sys.modules
159
143
  is_terminal = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
160
144
 
161
- # Use appropriate progress bars based on environment
162
145
  if is_notebook:
163
146
  charts_bar = stack.enter_context(
164
147
  create_notebook_friendly_bar(total=chart_count, desc=charts_desc)) if chart_count else None
@@ -174,23 +157,19 @@ class ChartTablePDFParser:
174
157
  page_num = p.page_index
175
158
  page_img: Image.Image = pil_pages[page_num - 1]
176
159
 
177
- # Only process selected item types
178
160
  target_items = [box for box in p.boxes if box.label in target_labels]
179
161
 
180
162
  if target_items and self.use_vlm:
181
163
  md_lines.append(f"\n## Page {page_num}\n")
182
164
 
183
165
  for box in sorted(target_items, key=reading_order_key):
184
- # Handle charts
185
166
  if box.label == "chart" and self.extract_charts:
186
167
  chart_filename = f"chart_{chart_counter:03d}.png"
187
168
  chart_path = os.path.join(charts_dir, chart_filename)
188
169
 
189
- # Save image
190
170
  cropped_img = page_img.crop((box.x1, box.y1, box.x2, box.y2))
191
171
  cropped_img.save(chart_path)
192
172
 
193
- # Handle VLM processing if enabled
194
173
  if self.use_vlm and self.vlm:
195
174
  rel_path = os.path.join("charts", chart_filename)
196
175
  wrote_table = False
@@ -227,16 +206,13 @@ class ChartTablePDFParser:
227
206
  if charts_bar:
228
207
  charts_bar.update(1)
229
208
 
230
- # Handle tables
231
209
  elif box.label == "table" and self.extract_tables:
232
210
  table_filename = f"table_{table_counter:03d}.png"
233
211
  table_path = os.path.join(tables_dir, table_filename)
234
212
 
235
- # Save image
236
213
  cropped_img = page_img.crop((box.x1, box.y1, box.x2, box.y2))
237
214
  cropped_img.save(table_path)
238
215
 
239
- # Handle VLM processing if enabled
240
216
  if self.use_vlm and self.vlm:
241
217
  rel_path = os.path.join("tables", table_filename)
242
218
  wrote_table = False
@@ -273,19 +249,11 @@ class ChartTablePDFParser:
273
249
  if tables_bar:
274
250
  tables_bar.update(1)
275
251
 
276
- # Write outputs only if VLM is used
277
- md_path = None
278
252
  excel_path = None
279
253
 
280
254
  if self.use_vlm:
281
- # Write markdown file
282
- md_path = os.path.join(out_dir, "charts.md")
283
- with open(md_path, 'w', encoding='utf-8') as f:
284
- f.write('\n'.join(md_lines))
285
255
 
286
- # Write Excel file if we have structured data
287
256
  if structured_items:
288
- # Determine Excel filename based on extraction target
289
257
  if self.extract_charts and self.extract_tables:
290
258
  excel_filename = "parsed_tables_charts.xlsx"
291
259
  elif self.extract_charts:
@@ -299,23 +267,19 @@ class ChartTablePDFParser:
299
267
  excel_path = os.path.join(out_dir, excel_filename)
300
268
  write_structured_excel(excel_path, structured_items)
301
269
 
302
- # Also create HTML version
303
270
  html_filename = excel_filename.replace('.xlsx', '.html')
304
271
  html_path = os.path.join(out_dir, html_filename)
305
272
  write_structured_html(html_path, structured_items)
306
273
 
307
- # Write VLM items mapping for UI linkage
308
274
  if 'vlm_items' in locals() and vlm_items:
309
275
  with open(os.path.join(out_dir, "vlm_items.json"), 'w', encoding='utf-8') as jf:
310
276
  json.dump(vlm_items, jf, ensure_ascii=False, indent=2)
311
277
 
312
- # Print results
313
278
  extraction_types = []
314
279
  if self.extract_charts:
315
280
  extraction_types.append("charts")
316
281
  if self.extract_tables:
317
282
  extraction_types.append("tables")
318
283
 
319
- # Print completion message with output directory
320
284
  print(f"✅ Parsing completed successfully!")
321
285
  print(f"📁 Output directory: {out_dir}")
@@ -0,0 +1,110 @@
1
+ import cv2
2
+ import numpy as np
3
+ import MBD_utils
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def mask_base_dewarper(image,mask):
9
+ '''
10
+ input:
11
+ image -> ndarray HxWx3 uint8
12
+ mask -> ndarray HxW uint8
13
+ return
14
+ dewarped -> ndarray HxWx3 uint8
15
+ grid (optional) -> ndarray HxWx2 -1~1
16
+ '''
17
+
18
+ ## get contours
19
+ # _, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
20
+ contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
21
+
22
+ ## get biggest contours and four corners based on Douglas-Peucker algorithm
23
+ four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
24
+ four_corners = MBD_utils.reorder(four_corners)
25
+
26
+ ## reserve biggest contours and remove other noisy contours
27
+ new_mask = np.zeros_like(mask)
28
+ new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
29
+
30
+ ## obtain middle points
31
+ # ratios = [0.25,0.5,0.75] # ratios = [0.125,0.25,0.375,0.5,0.625,0.75,0.875]
32
+ ratios = [0.25,0.5,0.75]
33
+ # ratios = [0.0625,0.125,0.1875,0.25,0.3125,0.375,0.4475,0.5,0.5625,0.625,0.06875,0.75,0.8125,0.875,0.9375]
34
+ middle = MBD_utils.findMiddle(corners=four_corners,mask=new_mask,points=ratios)
35
+
36
+ ## all points
37
+ source_points = np.concatenate((four_corners,middle),axis=0) ## all_point = four_corners(topleft,topright,bottom)+top+bottom+left+right
38
+
39
+ ## target points
40
+ h,w = image.shape[:2]
41
+ padding = 0
42
+ target_points = [[padding, padding],[w-padding, padding], [padding, h-padding],[w-padding, h-padding]]
43
+ for ratio in ratios:
44
+ target_points.append([int((w-2*padding)*ratio)+padding,padding])
45
+ for ratio in ratios:
46
+ target_points.append([int((w-2*padding)*ratio)+padding,h-padding])
47
+ for ratio in ratios:
48
+ target_points.append([padding,int((h-2*padding)*ratio)+padding])
49
+ for ratio in ratios:
50
+ target_points.append([w-padding,int((h-2*padding)*ratio)+padding])
51
+
52
+ ## dewarp base on cv2
53
+ # pts1 = np.float32(source_points)
54
+ # pts2 = np.float32(target_points)
55
+ # tps = cv2.createThinPlateSplineShapeTransformer()
56
+ # matches = []
57
+ # N = pts1.shape[0]
58
+ # for i in range(0,N):
59
+ # matches.append(cv2.DMatch(i,i,0))
60
+ # pts1 = pts1.reshape(1,-1,2)
61
+ # pts2 = pts2.reshape(1,-1,2)
62
+ # tps.estimateTransformation(pts2,pts1,matches)
63
+ # dewarped = tps.warpImage(image)
64
+
65
+ ## dewarp base on generated grid
66
+ source_points = source_points.reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
67
+ source_points = torch.from_numpy(source_points).float().cuda()
68
+ source_points = source_points.unsqueeze(0)
69
+ source_points = (source_points-0.5)*2
70
+ target_points = np.asarray(target_points).reshape(-1,2)/np.array([image.shape[:2][::-1]]).reshape(1,2)
71
+ target_points = torch.from_numpy(target_points).float()
72
+ target_points = (target_points-0.5)*2
73
+
74
+ model = MBD_utils.TPSGridGen(target_height=256,target_width=256,target_control_points=target_points)
75
+ model = model.cuda()
76
+ grid = model(source_points).view(-1,256,256,2).permute(0,3,1,2)
77
+ grid = F.interpolate(grid,(h,w),mode='bilinear').permute(0,2,3,1)
78
+ dewarped = MBD_utils.torch2cvimg(F.grid_sample(MBD_utils.cvimg2torch(image).cuda(),grid))[0]
79
+ return dewarped,grid[0].cpu().numpy()
80
+
81
+ def mask_base_cropper(image,mask):
82
+ '''
83
+ input:
84
+ image -> ndarray HxWx3 uint8
85
+ mask -> ndarray HxW uint8
86
+ return
87
+ dewarped -> ndarray HxWx3 uint8
88
+ grid (optional) -> ndarray HxWx2 -1~1
89
+ '''
90
+
91
+ ## get contours
92
+ _, contours, hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE) ## cv2.__version__ == 3.x
93
+ # contours,hierarchy = cv2.findContours(mask,cv2.RETR_EXTERNAL,method=cv2.CHAIN_APPROX_SIMPLE) ## cv2.__version__ == 4.x
94
+
95
+ ## get biggest contours and four corners based on Douglas-Peucker algorithm
96
+ four_corners, maxArea, contour= MBD_utils.DP_algorithm(contours)
97
+ four_corners = MBD_utils.reorder(four_corners)
98
+
99
+ ## reserve biggest contours and remove other noisy contours
100
+ new_mask = np.zeros_like(mask)
101
+ new_mask = cv2.drawContours(new_mask,[contour],-1,255,cv2.FILLED)
102
+
103
+ ## 最小外接矩形
104
+ rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
105
+ box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
106
+ box = np.int0(box)
107
+ box = box.reshape((4,1,2))
108
+
109
+
110
+
@@ -0,0 +1,291 @@
1
+ import cv2
2
+ import numpy as np
3
+ import copy
4
+ import torch
5
+ import torch
6
+ import itertools
7
+ import torch.nn as nn
8
+ from torch.autograd import Function, Variable
9
+
10
+ def reorder(myPoints):
11
+ myPoints = myPoints.reshape((4, 2))
12
+ myPointsNew = np.zeros((4, 1, 2), dtype=np.int32)
13
+ add = myPoints.sum(1)
14
+ myPointsNew[0] = myPoints[np.argmin(add)]
15
+ myPointsNew[3] =myPoints[np.argmax(add)]
16
+ diff = np.diff(myPoints, axis=1)
17
+ myPointsNew[1] =myPoints[np.argmin(diff)]
18
+ myPointsNew[2] = myPoints[np.argmax(diff)]
19
+ return myPointsNew
20
+
21
+
22
+ def findMiddle(corners,mask,points=[0.25,0.5,0.75]):
23
+ num_middle_points = len(points)
24
+ top = [np.array([])]*num_middle_points
25
+ bottom = [np.array([])]*num_middle_points
26
+ left = [np.array([])]*num_middle_points
27
+ right = [np.array([])]*num_middle_points
28
+
29
+ center_top = []
30
+ center_bottom = []
31
+ center_left = []
32
+ center_right = []
33
+
34
+ center = (int((corners[0][0][1]+corners[3][0][1])/2),int((corners[0][0][0]+corners[3][0][0])/2))
35
+ for ratio in points:
36
+
37
+ center_top.append( (center[0],int(corners[0][0][0]*(1-ratio)+corners[1][0][0]*ratio)) )
38
+
39
+ center_bottom.append( (center[0],int(corners[2][0][0]*(1-ratio)+corners[3][0][0]*ratio)) )
40
+
41
+ center_left.append( (int(corners[0][0][1]*(1-ratio)+corners[2][0][1]*ratio),center[1]) )
42
+
43
+ center_right.append( (int(corners[1][0][1]*(1-ratio)+corners[3][0][1]*ratio),center[1]) )
44
+
45
+ for i in range(0,center[0],1):
46
+ for j in range(num_middle_points):
47
+ if top[j].size==0:
48
+ if mask[i,center_top[j][1]]==255:
49
+ top[j] = np.asarray([center_top[j][1],i])
50
+ top[j] = top[j].reshape(1,2)
51
+
52
+ for i in range(mask.shape[0]-1,center[0],-1):
53
+ for j in range(num_middle_points):
54
+ if bottom[j].size==0:
55
+ if mask[i,center_bottom[j][1]]==255:
56
+ bottom[j] = np.asarray([center_bottom[j][1],i])
57
+ bottom[j] = bottom[j].reshape(1,2)
58
+
59
+ for i in range(mask.shape[1]-1,center[1],-1):
60
+ for j in range(num_middle_points):
61
+ if right[j].size==0:
62
+ if mask[center_right[j][0],i]==255:
63
+ right[j] = np.asarray([i,center_right[j][0]])
64
+ right[j] = right[j].reshape(1,2)
65
+
66
+ for i in range(0,center[1]):
67
+ for j in range(num_middle_points):
68
+ if left[j].size==0:
69
+ if mask[center_left[j][0],i]==255:
70
+ left[j] = np.asarray([i,center_left[j][0]])
71
+ left[j] = left[j].reshape(1,2)
72
+
73
+ return np.asarray(top+bottom+left+right)
74
+
75
+ def DP_algorithmv1(contours):
76
+ biggest = np.array([])
77
+ max_area = 0
78
+ step = 0.001
79
+ count = 0
80
+ # while biggest.size==0:
81
+ while True:
82
+ for i in contours:
83
+ # print(i.shape)
84
+ area = cv2.contourArea(i)
85
+ # print(area,cv2.arcLength(i, True))
86
+ if area > cv2.arcLength(i, True)*10:
87
+ peri = cv2.arcLength(i, True)
88
+ approx = cv2.approxPolyDP(i, (0.01+step*count) * peri, True)
89
+ if area > max_area and len(approx) == 4:
90
+ max_area = area
91
+ biggest_contours = i
92
+ biggest = approx
93
+ break
94
+ if abs(max_area - cv2.contourArea(biggest))/max_area > 0.3:
95
+ biggest = np.array([])
96
+ count += 1
97
+ if count > 200:
98
+ break
99
+ temp = biggest[0]
100
+ return biggest,max_area, biggest_contours
101
+
102
+ def DP_algorithm(contours):
103
+ biggest = np.array([])
104
+ max_area = 0
105
+ step = 0.001
106
+ count = 0
107
+
108
+ ### largest contours
109
+ for i in contours:
110
+ area = cv2.contourArea(i)
111
+ if area > max_area:
112
+ max_area = area
113
+ biggest_contours = i
114
+ peri = cv2.arcLength(biggest_contours, True)
115
+
116
+ ### find four corners
117
+ while True:
118
+ approx = cv2.approxPolyDP(biggest_contours, (0.01+step*count) * peri, True)
119
+ if len(approx) == 4:
120
+ biggest = approx
121
+ break
122
+ # if abs(max_area - cv2.contourArea(biggest))/max_area > 0.2:
123
+ # if abs(max_area - cv2.contourArea(biggest))/max_area > 0.4:
124
+ # biggest = np.array([])
125
+ count += 1
126
+ if count > 200:
127
+ break
128
+ return biggest,max_area, biggest_contours
129
+
130
+ def drawRectangle(img,biggest,color,thickness):
131
+ cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
132
+ cv2.line(img, (biggest[0][0][0], biggest[0][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
133
+ cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[2][0][0], biggest[2][0][1]), color, thickness)
134
+ cv2.line(img, (biggest[3][0][0], biggest[3][0][1]), (biggest[1][0][0], biggest[1][0][1]), color, thickness)
135
+ return img
136
+
137
+ def minAreaRect(contours,img):
138
+ # biggest = np.array([])
139
+ max_area = 0
140
+ for i in contours:
141
+ area = cv2.contourArea(i)
142
+ if area > max_area:
143
+ peri = cv2.arcLength(i, True)
144
+ rect = cv2.minAreaRect(i)
145
+ points = cv2.boxPoints(rect)
146
+ max_area = area
147
+ return points
148
+
149
+ def cropRectangle(img,biggest):
150
+ # print(biggest)
151
+ w = np.abs(biggest[0][0][0] - biggest[1][0][0])
152
+ h = np.abs(biggest[0][0][1] - biggest[2][0][1])
153
+ new_img = np.zeros((w,h,img.shape[-1]),dtype=np.uint8)
154
+ new_img = img[biggest[0][0][1]:biggest[0][0][1]+h,biggest[0][0][0]:biggest[0][0][0]+w]
155
+ return new_img
156
+
157
+ def cvimg2torch(img,min=0,max=1):
158
+ '''
159
+ input:
160
+ im -> ndarray uint8 HxWxC
161
+ return
162
+ tensor -> torch.tensor BxCxHxW
163
+ '''
164
+ if len(img.shape)==2:
165
+ img = np.expand_dims(img,axis=-1)
166
+ img = img.astype(float) / 255.0
167
+ img = img.transpose(2, 0, 1) # NHWC -> NCHW
168
+ img = np.expand_dims(img, 0)
169
+ img = torch.from_numpy(img).float()
170
+ return img
171
+
172
+ def torch2cvimg(tensor,min=0,max=1):
173
+ '''
174
+ input:
175
+ tensor -> torch.tensor BxCxHxW C can be 1,3
176
+ return
177
+ im -> ndarray uint8 HxWxC
178
+ '''
179
+ im_list = []
180
+ for i in range(tensor.shape[0]):
181
+ im = tensor.detach().cpu().data.numpy()[i]
182
+ im = im.transpose(1,2,0)
183
+ im = np.clip(im,min,max)
184
+ im = ((im-min)/(max-min)*255).astype(np.uint8)
185
+ im_list.append(im)
186
+ return im_list
187
+
188
+
189
+
190
+ class TPSGridGen(nn.Module):
191
+ def __init__(self, target_height, target_width, target_control_points):
192
+ '''
193
+ target_control_points -> torch.tensor num_pointx2 -1~1
194
+ source_control_points -> torch.tensor batch_size x num_point x 2 -1~1
195
+ return:
196
+ grid -> batch_size x hw x 2 -1~1
197
+ '''
198
+ super(TPSGridGen, self).__init__()
199
+ assert target_control_points.ndimension() == 2
200
+ assert target_control_points.size(1) == 2
201
+ N = target_control_points.size(0)
202
+ self.num_points = N
203
+ target_control_points = target_control_points.float()
204
+
205
+ # create padded kernel matrix
206
+ forward_kernel = torch.zeros(N + 3, N + 3)
207
+ target_control_partial_repr = self.compute_partial_repr(target_control_points, target_control_points)
208
+ forward_kernel[:N, :N].copy_(target_control_partial_repr)
209
+ forward_kernel[:N, -3].fill_(1)
210
+ forward_kernel[-3, :N].fill_(1)
211
+ forward_kernel[:N, -2:].copy_(target_control_points)
212
+ forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
213
+ # compute inverse matrix
214
+ inverse_kernel = torch.inverse(forward_kernel)
215
+
216
+ # create target cordinate matrix
217
+ HW = target_height * target_width
218
+ target_coordinate = list(itertools.product(range(target_height), range(target_width)))
219
+ target_coordinate = torch.Tensor(target_coordinate) # HW x 2
220
+ Y, X = target_coordinate.split(1, dim = 1)
221
+ Y = Y * 2 / (target_height - 1) - 1
222
+ X = X * 2 / (target_width - 1) - 1
223
+ target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
224
+ target_coordinate_partial_repr = self.compute_partial_repr(target_coordinate.to(target_control_points.device), target_control_points)
225
+ target_coordinate_repr = torch.cat([
226
+ target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
227
+ ], dim = 1)
228
+
229
+ # register precomputed matrices
230
+ self.register_buffer('inverse_kernel', inverse_kernel)
231
+ self.register_buffer('padding_matrix', torch.zeros(3, 2))
232
+ self.register_buffer('target_coordinate_repr', target_coordinate_repr)
233
+
234
+ def forward(self, source_control_points):
235
+ assert source_control_points.ndimension() == 3
236
+ assert source_control_points.size(1) == self.num_points
237
+ assert source_control_points.size(2) == 2
238
+ batch_size = source_control_points.size(0)
239
+
240
+ Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
241
+ mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
242
+ source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
243
+ return source_coordinate
244
+ # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
245
+ def compute_partial_repr(self, input_points, control_points):
246
+ N = input_points.size(0)
247
+ M = control_points.size(0)
248
+ pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
249
+ # original implementation, very slow
250
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
251
+ pairwise_diff_square = pairwise_diff * pairwise_diff
252
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
253
+ repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
254
+ # fix numerical error for 0 * log(0), substitute all nan with 0
255
+ mask = repr_matrix != repr_matrix
256
+ repr_matrix.masked_fill_(mask, 0)
257
+ return repr_matrix
258
+
259
+
260
+
261
+
262
+
263
+ ### deside wheather further process
264
+ # point_area = cv2.contourArea(np.concatenate((biggest_angle[0].reshape(1,1,2),middle[0:3],biggest_angle[1].reshape(1,1,2),middle[9:12],biggest_angle[3].reshape(1,1,2),middle[3:6][::-1],biggest_angle[2].reshape(1,1,2),middle[6:9][::-1]),axis=0))
265
+ #### 最小外接矩形
266
+ # rect = cv2.minAreaRect(contour) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度)
267
+ # box = cv2.boxPoints(rect) # cv2.boxPoints(rect) for OpenCV 3.x 获取最小外接矩形的4个顶点坐标
268
+ # box = np.int0(box)
269
+ # box = box.reshape((4,1,2))
270
+ # minrect_area = cv2.contourArea(box)
271
+ # print(abs(minrect_area-point_area)/point_area)
272
+ #### 四个角点 IOU
273
+ # biggest_box = np.concatenate((biggest_angle[0,:,:].reshape(1,1,2),biggest_angle[2,:,:].reshape(1,1,2),biggest_angle[3,:,:].reshape(1,1,2),biggest_angle[1,:,:].reshape(1,1,2)),axis=0)
274
+ # biggest_mask = np.zeros_like(mask)
275
+ # # corner_area = cv2.contourArea(biggest_box)
276
+ # cv2.drawContours(biggest_mask,[biggest_box], -1, color=255, thickness=-1)
277
+
278
+ # smooth = 1e-5
279
+ # biggest_mask_ = biggest_mask > 50
280
+ # mask_ = mask > 50
281
+ # intersection = (biggest_mask_ & mask_).sum()
282
+ # union = (biggest_mask_ | mask_).sum()
283
+ # iou = (intersection + smooth) / (union + smooth)
284
+ # if iou > 0.975:
285
+ # skip = True
286
+ # else:
287
+ # skip = False
288
+ # print(iou)
289
+ # cv2.imshow('mask',cv2.resize(mask,(512,512)))
290
+ # cv2.imshow('biggest_mask',cv2.resize(biggest_mask,(512,512)))
291
+ # cv2.waitKey(0)