xfmr-zem 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/cli.py +32 -3
- xfmr_zem/client.py +59 -8
- xfmr_zem/server.py +21 -4
- xfmr_zem/servers/data_juicer/server.py +1 -1
- xfmr_zem/servers/instruction_gen/server.py +1 -1
- xfmr_zem/servers/io/server.py +1 -1
- xfmr_zem/servers/llm/parameters.yml +10 -0
- xfmr_zem/servers/nemo_curator/server.py +1 -1
- xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
- xfmr_zem/servers/ocr/engines.py +242 -0
- xfmr_zem/servers/ocr/install_models.py +63 -0
- xfmr_zem/servers/ocr/parameters.yml +4 -0
- xfmr_zem/servers/ocr/server.py +102 -0
- xfmr_zem/servers/profiler/parameters.yml +4 -0
- xfmr_zem/servers/sinks/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/server.py +62 -0
- xfmr_zem/zenml_wrapper.py +20 -7
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/METADATA +20 -1
- xfmr_zem-0.2.6.dist-info/RECORD +58 -0
- xfmr_zem-0.2.4.dist-info/RECORD +0 -23
- /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/WHEEL +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/entry_points.txt +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Document Processing Pipeline
|
|
3
|
+
|
|
4
|
+
Pipeline class orchestrates tất cả các phases để xử lý document.
|
|
5
|
+
Cho phép dễ dàng thử nghiệm bằng cách swap các phase implementations.
|
|
6
|
+
|
|
7
|
+
Usage Example:
|
|
8
|
+
from deepdoc_vietocr.pipeline import DocumentPipeline
|
|
9
|
+
from deepdoc_vietocr.implementations import (
|
|
10
|
+
DocLayoutYOLOAnalyzer,
|
|
11
|
+
PaddleOCRTextDetector,
|
|
12
|
+
VietOCRRecognizer,
|
|
13
|
+
VietnameseTextPostProcessor,
|
|
14
|
+
SmartMarkdownReconstruction
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Create pipeline với custom phases
|
|
18
|
+
pipeline = DocumentPipeline(
|
|
19
|
+
layout_analyzer=DocLayoutYOLOAnalyzer(),
|
|
20
|
+
text_detector=PaddleOCRTextDetector(),
|
|
21
|
+
text_recognizer=VietOCRRecognizer(),
|
|
22
|
+
post_processor=VietnameseTextPostProcessor(),
|
|
23
|
+
reconstructor=SmartMarkdownReconstruction()
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Process image
|
|
27
|
+
from PIL import Image
|
|
28
|
+
img = Image.open("document.jpg")
|
|
29
|
+
markdown = pipeline.process(img)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import os
|
|
33
|
+
import time
|
|
34
|
+
import logging
|
|
35
|
+
from typing import Optional, List, Tuple, Any
|
|
36
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
37
|
+
|
|
38
|
+
import numpy as np
|
|
39
|
+
from PIL import Image, ImageDraw
|
|
40
|
+
|
|
41
|
+
from .phases import (
|
|
42
|
+
LayoutAnalysisPhase,
|
|
43
|
+
TextDetectionPhase,
|
|
44
|
+
TextRecognitionPhase,
|
|
45
|
+
PostProcessingPhase,
|
|
46
|
+
DocumentReconstructionPhase,
|
|
47
|
+
NoOpPostProcessing,
|
|
48
|
+
SimpleMarkdownReconstruction,
|
|
49
|
+
)
|
|
50
|
+
from . import LayoutRecognizer, TableStructureRecognizer
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DocumentPipeline:
|
|
54
|
+
"""
|
|
55
|
+
Main pipeline orchestrator cho document processing.
|
|
56
|
+
|
|
57
|
+
Coordinates các phases để xử lý document từ image sang markdown.
|
|
58
|
+
Mỗi phase có thể được swap độc lập để thử nghiệm implementations khác.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
layout_analyzer: LayoutAnalysisPhase,
|
|
64
|
+
text_detector: TextDetectionPhase,
|
|
65
|
+
text_recognizer: TextRecognitionPhase,
|
|
66
|
+
post_processor: Optional[PostProcessingPhase] = None,
|
|
67
|
+
reconstructor: Optional[DocumentReconstructionPhase] = None,
|
|
68
|
+
threshold: float = 0.5,
|
|
69
|
+
max_workers: int = 4,
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Initialize pipeline với các phase implementations.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
layout_analyzer: Phase phân tích layout
|
|
76
|
+
text_detector: Phase detect text boxes
|
|
77
|
+
text_recognizer: Phase nhận dạng text
|
|
78
|
+
post_processor: Phase xử lý hậu kỳ (optional, default: NoOp)
|
|
79
|
+
reconstructor: Phase ghép nối document (optional, default: Simple)
|
|
80
|
+
threshold: Detection threshold
|
|
81
|
+
max_workers: Number of parallel workers for table processing
|
|
82
|
+
"""
|
|
83
|
+
self.layout_analyzer = layout_analyzer
|
|
84
|
+
self.text_detector = text_detector
|
|
85
|
+
self.text_recognizer = text_recognizer
|
|
86
|
+
self.post_processor = post_processor or NoOpPostProcessing()
|
|
87
|
+
self.reconstructor = reconstructor or SimpleMarkdownReconstruction()
|
|
88
|
+
|
|
89
|
+
self.threshold = threshold
|
|
90
|
+
self.max_workers = max_workers
|
|
91
|
+
|
|
92
|
+
# Initialize table recognizer (specialized component)
|
|
93
|
+
self.table_recognizer = TableStructureRecognizer()
|
|
94
|
+
|
|
95
|
+
logging.info("✓ DocumentPipeline initialized with custom phases")
|
|
96
|
+
|
|
97
|
+
def process(
|
|
98
|
+
self,
|
|
99
|
+
image: Image.Image,
|
|
100
|
+
img_name: str = "page",
|
|
101
|
+
figure_save_dir: str = "figures",
|
|
102
|
+
image_link_prefix: str = "figures",
|
|
103
|
+
debug_path: Optional[str] = None,
|
|
104
|
+
) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Process một document image thành markdown.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
image: PIL Image cần xử lý
|
|
110
|
+
img_name: Base name cho figures
|
|
111
|
+
figure_save_dir: Directory lưu figures
|
|
112
|
+
image_link_prefix: Prefix cho image links trong markdown
|
|
113
|
+
debug_path: Path để lưu debug visualization (optional)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Markdown string
|
|
117
|
+
"""
|
|
118
|
+
start_time = time.time()
|
|
119
|
+
|
|
120
|
+
# Ensure figure directory exists
|
|
121
|
+
os.makedirs(figure_save_dir, exist_ok=True)
|
|
122
|
+
|
|
123
|
+
# Prepare visualization if debug requested
|
|
124
|
+
vis_draw = None
|
|
125
|
+
vis_img = None
|
|
126
|
+
if debug_path:
|
|
127
|
+
vis_img = image.copy()
|
|
128
|
+
vis_draw = ImageDraw.Draw(vis_img)
|
|
129
|
+
|
|
130
|
+
# ====================================================================
|
|
131
|
+
# PHASE 1: Layout Analysis
|
|
132
|
+
# ====================================================================
|
|
133
|
+
logging.info("Phase 1: Running Layout Analysis...")
|
|
134
|
+
layouts = self.layout_analyzer.analyze(image, threshold=self.threshold)
|
|
135
|
+
logging.info(f"✓ Detected {len(layouts)} layout regions")
|
|
136
|
+
|
|
137
|
+
# Collection for final output
|
|
138
|
+
region_and_pos = []
|
|
139
|
+
|
|
140
|
+
# Create mask for leftover detection
|
|
141
|
+
mask = Image.new("1", image.size, 0)
|
|
142
|
+
draw = ImageDraw.Draw(mask)
|
|
143
|
+
|
|
144
|
+
# Collect regions by type for batch processing
|
|
145
|
+
text_regions_batch = [] # [(bbox, label, y_pos, region_index), ...]
|
|
146
|
+
table_regions_batch = [] # [(region, y_pos, region_index), ...]
|
|
147
|
+
|
|
148
|
+
# ====================================================================
|
|
149
|
+
# Process Each Layout Region
|
|
150
|
+
# ====================================================================
|
|
151
|
+
for i, region in enumerate(layouts):
|
|
152
|
+
bbox = region["bbox"]
|
|
153
|
+
label = region["type"]
|
|
154
|
+
score = region["score"]
|
|
155
|
+
y_pos = bbox[1]
|
|
156
|
+
|
|
157
|
+
# Draw debug visualization
|
|
158
|
+
if vis_draw:
|
|
159
|
+
color = "red" if label in ["table", "figure", "equation"] else "blue"
|
|
160
|
+
vis_draw.rectangle(bbox, outline=color, width=3)
|
|
161
|
+
vis_draw.text((bbox[0], bbox[1]), f"{label} ({score:.2f})", fill=color)
|
|
162
|
+
|
|
163
|
+
# Mark region as processed
|
|
164
|
+
draw.rectangle(bbox, fill=1)
|
|
165
|
+
|
|
166
|
+
# Route to appropriate handler
|
|
167
|
+
if label == "table":
|
|
168
|
+
table_regions_batch.append((region, y_pos, i))
|
|
169
|
+
|
|
170
|
+
elif label == "figure":
|
|
171
|
+
# Save figure image
|
|
172
|
+
fig_filename = f"{img_name}_fig_{i}.jpg"
|
|
173
|
+
fig_path = os.path.join(figure_save_dir, fig_filename)
|
|
174
|
+
md_link_path = f"{image_link_prefix}/{fig_filename}" if image_link_prefix else fig_filename
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
crop_img = image.crop(bbox)
|
|
178
|
+
crop_img.save(fig_path)
|
|
179
|
+
region_and_pos.append((y_pos, f"", bbox))
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logging.error(f"Failed to save figure {fig_path}: {e}")
|
|
182
|
+
|
|
183
|
+
# Also run OCR on figure region (might contain misclassified text)
|
|
184
|
+
text_regions_batch.append((bbox, "figure_text", y_pos + 1, i))
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
# Text-based regions: title, text, header, footer, captions, equation
|
|
188
|
+
text_regions_batch.append((bbox, label, y_pos, i))
|
|
189
|
+
|
|
190
|
+
# ====================================================================
|
|
191
|
+
# PHASE 2+3: Batch Text Processing (Detection + Recognition)
|
|
192
|
+
# ====================================================================
|
|
193
|
+
if text_regions_batch:
|
|
194
|
+
logging.info(f"Phase 2+3: Processing {len(text_regions_batch)} text regions...")
|
|
195
|
+
try:
|
|
196
|
+
text_results = self._process_text_regions_batch(
|
|
197
|
+
image, text_regions_batch
|
|
198
|
+
)
|
|
199
|
+
region_and_pos.extend(text_results)
|
|
200
|
+
logging.info(f"✓ Processed {len(text_results)} text regions")
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logging.error(f"Error in batch text processing: {e}", exc_info=True)
|
|
203
|
+
|
|
204
|
+
# ====================================================================
|
|
205
|
+
# Table Processing (Parallel)
|
|
206
|
+
# ====================================================================
|
|
207
|
+
if table_regions_batch:
|
|
208
|
+
logging.info(f"Processing {len(table_regions_batch)} table regions...")
|
|
209
|
+
try:
|
|
210
|
+
table_results = self._process_tables_parallel(image, table_regions_batch)
|
|
211
|
+
for y_pos, markdown in table_results:
|
|
212
|
+
if markdown and markdown.strip():
|
|
213
|
+
region_and_pos.append((y_pos, markdown, None))
|
|
214
|
+
logging.info(f"✓ Processed {len(table_results)} tables")
|
|
215
|
+
except Exception as e:
|
|
216
|
+
logging.error(f"Error in table processing: {e}", exc_info=True)
|
|
217
|
+
|
|
218
|
+
# ====================================================================
|
|
219
|
+
# Leftover OCR (Undetected Areas)
|
|
220
|
+
# ====================================================================
|
|
221
|
+
inv_mask = mask.point(lambda p: 1 - p)
|
|
222
|
+
if inv_mask.getbbox():
|
|
223
|
+
leftover_text = self._process_leftover_regions(
|
|
224
|
+
image, inv_mask, layouts, vis_draw
|
|
225
|
+
)
|
|
226
|
+
if leftover_text:
|
|
227
|
+
lx0, ly0, lx1, ly1 = inv_mask.getbbox()
|
|
228
|
+
region_and_pos.append((ly0, leftover_text, None))
|
|
229
|
+
|
|
230
|
+
# ====================================================================
|
|
231
|
+
# Save Debug Visualization
|
|
232
|
+
# ====================================================================
|
|
233
|
+
if debug_path and vis_img:
|
|
234
|
+
try:
|
|
235
|
+
vis_img.save(debug_path)
|
|
236
|
+
logging.info(f"✓ Saved debug visualization to: {debug_path}")
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logging.error(f"Failed to save debug image: {e}")
|
|
239
|
+
|
|
240
|
+
# ====================================================================
|
|
241
|
+
# PHASE 5: Document Reconstruction
|
|
242
|
+
# ====================================================================
|
|
243
|
+
logging.info("Phase 5: Reconstructing document...")
|
|
244
|
+
markdown = self.reconstructor.reconstruct(region_and_pos, output_format="markdown")
|
|
245
|
+
|
|
246
|
+
elapsed = time.time() - start_time
|
|
247
|
+
logging.info(f"✓ Pipeline completed in {elapsed:.2f} seconds")
|
|
248
|
+
|
|
249
|
+
return markdown
|
|
250
|
+
|
|
251
|
+
def _process_text_regions_batch(
|
|
252
|
+
self,
|
|
253
|
+
image: Image.Image,
|
|
254
|
+
text_regions_batch: List[Tuple[List[int], str, int, int]]
|
|
255
|
+
) -> List[Tuple[int, str, List[int]]]:
|
|
256
|
+
"""
|
|
257
|
+
Process batch of text regions với text detection + recognition.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
image: PIL Image
|
|
261
|
+
text_regions_batch: List of (bbox, label, y_pos, index) tuples
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
List of (y_pos, formatted_text, bbox) tuples
|
|
265
|
+
"""
|
|
266
|
+
results = []
|
|
267
|
+
|
|
268
|
+
# Crop all regions
|
|
269
|
+
cropped_images = [np.array(image.crop(bbox)) for bbox, _, _, _ in text_regions_batch]
|
|
270
|
+
|
|
271
|
+
# Batch OCR: Detect + Recognize
|
|
272
|
+
try:
|
|
273
|
+
batch_ocr_results = self._batch_ocr(cropped_images)
|
|
274
|
+
|
|
275
|
+
# PHASE 4: Post-process each result
|
|
276
|
+
for (bbox, label, y_pos, idx), ocr_results in zip(text_regions_batch, batch_ocr_results):
|
|
277
|
+
# Extract text from OCR results
|
|
278
|
+
texts = []
|
|
279
|
+
for _, (text, confidence) in ocr_results:
|
|
280
|
+
if text:
|
|
281
|
+
# Apply post-processing
|
|
282
|
+
processed_text = self.post_processor.process(
|
|
283
|
+
text, confidence, metadata={"label": label, "bbox": bbox}
|
|
284
|
+
)
|
|
285
|
+
texts.append(processed_text)
|
|
286
|
+
|
|
287
|
+
text_content = "\n".join(texts)
|
|
288
|
+
if not text_content.strip():
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
# Apply formatting based on label
|
|
292
|
+
formatted_text = self._apply_formatting(text_content, label)
|
|
293
|
+
results.append((y_pos, formatted_text, bbox))
|
|
294
|
+
|
|
295
|
+
except Exception as e:
|
|
296
|
+
logging.error(f"Batch OCR processing failed: {e}", exc_info=True)
|
|
297
|
+
|
|
298
|
+
return results
|
|
299
|
+
|
|
300
|
+
def _batch_ocr(self, image_list: List[np.ndarray]) -> List[List[Tuple[Any, Tuple[str, float]]]]:
|
|
301
|
+
"""
|
|
302
|
+
Batch OCR processing: Detect + Recognize tất cả images.
|
|
303
|
+
|
|
304
|
+
Strategy:
|
|
305
|
+
1. Detect text boxes in all images
|
|
306
|
+
2. Collect ALL text boxes from all images
|
|
307
|
+
3. Batch recognize ALL boxes at once
|
|
308
|
+
4. Map results back to images
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
List of OCR results, one per image
|
|
312
|
+
"""
|
|
313
|
+
if not image_list:
|
|
314
|
+
return []
|
|
315
|
+
|
|
316
|
+
all_boxes_info = [] # [(image_idx, box, img_crop), ...]
|
|
317
|
+
results = [[] for _ in range(len(image_list))]
|
|
318
|
+
|
|
319
|
+
# Phase 2: Detect boxes in all images
|
|
320
|
+
for img_idx, img_array in enumerate(image_list):
|
|
321
|
+
dt_boxes, _ = self.text_detector.detect(img_array)
|
|
322
|
+
|
|
323
|
+
if dt_boxes is None:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
# Sort boxes for reading order
|
|
327
|
+
dt_boxes = self._sorted_boxes(dt_boxes)
|
|
328
|
+
|
|
329
|
+
# Crop all detected boxes
|
|
330
|
+
for box in dt_boxes:
|
|
331
|
+
img_crop = self._get_rotate_crop_image(img_array, box)
|
|
332
|
+
all_boxes_info.append((img_idx, box, img_crop))
|
|
333
|
+
|
|
334
|
+
# Phase 3: Batch recognize ALL boxes at once
|
|
335
|
+
if all_boxes_info:
|
|
336
|
+
all_crops = [crop for _, _, crop in all_boxes_info]
|
|
337
|
+
rec_results, _ = self.text_recognizer.recognize(all_crops)
|
|
338
|
+
|
|
339
|
+
# Map results back to original images
|
|
340
|
+
for (img_idx, box, _), (text, score) in zip(all_boxes_info, rec_results):
|
|
341
|
+
# Filter low confidence results
|
|
342
|
+
if score >= 0.5: # TODO: Make this configurable
|
|
343
|
+
results[img_idx].append((box.tolist() if hasattr(box, 'tolist') else box, (text, score)))
|
|
344
|
+
|
|
345
|
+
return results
|
|
346
|
+
|
|
347
|
+
def _process_tables_parallel(
|
|
348
|
+
self,
|
|
349
|
+
image: Image.Image,
|
|
350
|
+
table_regions_batch: List[Tuple[dict, int, int]]
|
|
351
|
+
) -> List[Tuple[int, str]]:
|
|
352
|
+
"""Process multiple tables in parallel."""
|
|
353
|
+
def process_single_table(args):
|
|
354
|
+
region, y_pos, idx = args
|
|
355
|
+
try:
|
|
356
|
+
markdown = self._extract_table_markdown(image, region)
|
|
357
|
+
return (y_pos, markdown)
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logging.error(f"Error processing table {idx}: {e}")
|
|
360
|
+
return (y_pos, "")
|
|
361
|
+
|
|
362
|
+
# Single table: no parallelization needed
|
|
363
|
+
if len(table_regions_batch) == 1:
|
|
364
|
+
return [process_single_table(table_regions_batch[0])]
|
|
365
|
+
|
|
366
|
+
# Parallel processing
|
|
367
|
+
results = []
|
|
368
|
+
with ThreadPoolExecutor(max_workers=min(self.max_workers, len(table_regions_batch))) as executor:
|
|
369
|
+
future_to_table = {
|
|
370
|
+
executor.submit(process_single_table, table_data): table_data
|
|
371
|
+
for table_data in table_regions_batch
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
for future in as_completed(future_to_table):
|
|
375
|
+
try:
|
|
376
|
+
result = future.result()
|
|
377
|
+
results.append(result)
|
|
378
|
+
except Exception as e:
|
|
379
|
+
table_data = future_to_table[future]
|
|
380
|
+
logging.error(f"Table processing failed for region {table_data[2]}: {e}")
|
|
381
|
+
|
|
382
|
+
return results
|
|
383
|
+
|
|
384
|
+
def _extract_table_markdown(self, image: Image.Image, table_region: dict) -> str:
|
|
385
|
+
"""Extract table as markdown (uses specialized table recognizer)."""
|
|
386
|
+
import re
|
|
387
|
+
|
|
388
|
+
bbox = table_region["bbox"]
|
|
389
|
+
table_img = image.crop(bbox)
|
|
390
|
+
|
|
391
|
+
tb_cpns = self.table_recognizer([table_img])[0]
|
|
392
|
+
|
|
393
|
+
# Run OCR on table image
|
|
394
|
+
table_ocr_results = self._batch_ocr([np.array(table_img)])
|
|
395
|
+
boxes = table_ocr_results[0] if table_ocr_results else []
|
|
396
|
+
|
|
397
|
+
# Sort and clean up boxes
|
|
398
|
+
boxes = LayoutRecognizer.sort_Y_firstly(
|
|
399
|
+
[{
|
|
400
|
+
"x0": b[0][0], "x1": b[1][0],
|
|
401
|
+
"top": b[0][1], "text": t[0],
|
|
402
|
+
"bottom": b[-1][1],
|
|
403
|
+
"layout_type": "table",
|
|
404
|
+
"page_number": 0
|
|
405
|
+
} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
|
|
406
|
+
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3 if boxes else 10
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if not boxes:
|
|
410
|
+
return ""
|
|
411
|
+
|
|
412
|
+
def gather(kwd, fzy=10, ption=0.6):
|
|
413
|
+
eles = LayoutRecognizer.sort_Y_firstly(
|
|
414
|
+
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
|
|
415
|
+
eles = LayoutRecognizer.layouts_cleanup(boxes, eles, 5, ption)
|
|
416
|
+
return LayoutRecognizer.sort_Y_firstly(eles, 0)
|
|
417
|
+
|
|
418
|
+
headers = gather(r".*header$")
|
|
419
|
+
rows = gather(r".* (row|header)")
|
|
420
|
+
spans = gather(r".*spanning")
|
|
421
|
+
clmns = sorted([r for r in tb_cpns if re.match(r"table column$", r["label"])], key=lambda x: x["x0"])
|
|
422
|
+
clmns = LayoutRecognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
|
423
|
+
|
|
424
|
+
for b in boxes:
|
|
425
|
+
self._map_cell_to_structure(b, rows, headers, clmns, spans)
|
|
426
|
+
|
|
427
|
+
return TableStructureRecognizer.construct_table(boxes, markdown=True)
|
|
428
|
+
|
|
429
|
+
def _process_leftover_regions(
|
|
430
|
+
self,
|
|
431
|
+
image: Image.Image,
|
|
432
|
+
inv_mask: Image.Image,
|
|
433
|
+
layouts: List[dict],
|
|
434
|
+
vis_draw: Optional[ImageDraw.ImageDraw]
|
|
435
|
+
) -> Optional[str]:
|
|
436
|
+
"""Process leftover areas not detected by layout analysis."""
|
|
437
|
+
lx0, ly0, lx1, ly1 = inv_mask.getbbox()
|
|
438
|
+
leftover_area = (lx1 - lx0) * (ly1 - ly0)
|
|
439
|
+
total_area = image.size[0] * image.size[1]
|
|
440
|
+
leftover_ratio = leftover_area / total_area if total_area > 0 else 0
|
|
441
|
+
|
|
442
|
+
# Skip if too small
|
|
443
|
+
if leftover_ratio < 0.03:
|
|
444
|
+
logging.info(f"Skipping leftover OCR: only {leftover_ratio:.1%} of image area")
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
# Create leftover image
|
|
448
|
+
white_bg = Image.new("RGB", image.size, (255, 255, 255))
|
|
449
|
+
leftover_img = Image.composite(image, white_bg, inv_mask)
|
|
450
|
+
leftover_crop = leftover_img.crop((lx0, ly0, lx1, ly1))
|
|
451
|
+
|
|
452
|
+
# Check if mostly white
|
|
453
|
+
leftover_array = np.array(leftover_crop)
|
|
454
|
+
white_pixels = np.sum(np.all(leftover_array > 240, axis=-1))
|
|
455
|
+
total_pixels = leftover_array.shape[0] * leftover_array.shape[1]
|
|
456
|
+
white_ratio = white_pixels / total_pixels if total_pixels > 0 else 1.0
|
|
457
|
+
|
|
458
|
+
force_ocr = len(layouts) == 0
|
|
459
|
+
if white_ratio > 0.99 and not force_ocr:
|
|
460
|
+
logging.info(f"Skipping leftover OCR: {white_ratio:.1%} white pixels")
|
|
461
|
+
return None
|
|
462
|
+
|
|
463
|
+
# Draw debug
|
|
464
|
+
if vis_draw:
|
|
465
|
+
vis_draw.rectangle((lx0, ly0, lx1, ly1), outline="green", width=2)
|
|
466
|
+
vis_draw.text((lx0, ly0), "leftover", fill="green")
|
|
467
|
+
|
|
468
|
+
# Run OCR on leftover
|
|
469
|
+
try:
|
|
470
|
+
ocr_results = self._batch_ocr([np.array(leftover_crop)])[0]
|
|
471
|
+
texts = [t[0] for _, t in ocr_results if t and t[0]]
|
|
472
|
+
leftover_text = "\n".join(texts)
|
|
473
|
+
return leftover_text if leftover_text.strip() else None
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logging.error(f"Error processing leftovers: {e}")
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
# Helper methods
|
|
479
|
+
def _apply_formatting(self, text: str, label: str) -> str:
|
|
480
|
+
"""Apply markdown formatting based on region label."""
|
|
481
|
+
if label == "title":
|
|
482
|
+
return f"# {text}"
|
|
483
|
+
elif label in ["header", "footer"]:
|
|
484
|
+
return f"_{text}_"
|
|
485
|
+
elif label in ["figure caption", "table caption"]:
|
|
486
|
+
return f"*{text}*"
|
|
487
|
+
elif label == "equation":
|
|
488
|
+
return f"$$ {text} $$"
|
|
489
|
+
return text
|
|
490
|
+
|
|
491
|
+
def _map_cell_to_structure(self, b, rows, headers, clmns, spans):
|
|
492
|
+
"""Map cell to table structure (helper for table extraction)."""
|
|
493
|
+
ii = LayoutRecognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
|
494
|
+
if ii is not None:
|
|
495
|
+
b["R"] = ii
|
|
496
|
+
b["R_top"] = rows[ii]["top"]
|
|
497
|
+
b["R_bott"] = rows[ii]["bottom"]
|
|
498
|
+
|
|
499
|
+
ii = LayoutRecognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
|
|
500
|
+
if ii is not None:
|
|
501
|
+
b["H_top"] = headers[ii]["top"]
|
|
502
|
+
b["H_bott"] = headers[ii]["bottom"]
|
|
503
|
+
b["H_left"] = headers[ii]["x0"]
|
|
504
|
+
b["H_right"] = headers[ii]["x1"]
|
|
505
|
+
b["H"] = ii
|
|
506
|
+
|
|
507
|
+
ii = LayoutRecognizer.find_horizontally_tightest_fit(b, clmns)
|
|
508
|
+
if ii is not None:
|
|
509
|
+
b["C"] = ii
|
|
510
|
+
b["C_left"] = clmns[ii]["x0"]
|
|
511
|
+
b["C_right"] = clmns[ii]["x1"]
|
|
512
|
+
|
|
513
|
+
ii = LayoutRecognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
|
|
514
|
+
if ii is not None:
|
|
515
|
+
b["H_top"] = spans[ii]["top"]
|
|
516
|
+
b["H_bott"] = spans[ii]["bottom"]
|
|
517
|
+
b["H_left"] = spans[ii]["x0"]
|
|
518
|
+
b["H_right"] = spans[ii]["x1"]
|
|
519
|
+
b["SP"] = ii
|
|
520
|
+
|
|
521
|
+
def _sorted_boxes(self, dt_boxes):
|
|
522
|
+
"""Sort text boxes by position (từ OCR class)."""
|
|
523
|
+
num_boxes = len(dt_boxes)
|
|
524
|
+
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
|
525
|
+
_boxes = list(sorted_boxes)
|
|
526
|
+
|
|
527
|
+
for i in range(num_boxes - 1):
|
|
528
|
+
for j in range(i, -1, -1):
|
|
529
|
+
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
|
530
|
+
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
|
531
|
+
tmp = _boxes[j]
|
|
532
|
+
_boxes[j] = _boxes[j + 1]
|
|
533
|
+
_boxes[j + 1] = tmp
|
|
534
|
+
else:
|
|
535
|
+
break
|
|
536
|
+
return _boxes
|
|
537
|
+
|
|
538
|
+
def _get_rotate_crop_image(self, img, points):
|
|
539
|
+
"""Rotate and crop image based on points (từ OCR class)."""
|
|
540
|
+
import cv2
|
|
541
|
+
img_crop_width = int(
|
|
542
|
+
max(
|
|
543
|
+
np.linalg.norm(points[0] - points[1]),
|
|
544
|
+
np.linalg.norm(points[2] - points[3])))
|
|
545
|
+
img_crop_height = int(
|
|
546
|
+
max(
|
|
547
|
+
np.linalg.norm(points[0] - points[3]),
|
|
548
|
+
np.linalg.norm(points[1] - points[2])))
|
|
549
|
+
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
|
550
|
+
[img_crop_width, img_crop_height],
|
|
551
|
+
[0, img_crop_height]])
|
|
552
|
+
M = cv2.getPerspectiveTransform(points.astype(np.float32), pts_std)
|
|
553
|
+
dst_img = cv2.warpPerspective(
|
|
554
|
+
img,
|
|
555
|
+
M, (img_crop_width, img_crop_height),
|
|
556
|
+
borderMode=cv2.BORDER_REPLICATE,
|
|
557
|
+
flags=cv2.INTER_CUBIC)
|
|
558
|
+
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
|
559
|
+
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
|
560
|
+
dst_img = np.rot90(dst_img)
|
|
561
|
+
return dst_img
|