sparrow-parse 0.3.11__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sparrow_parse/__init__.py +1 -1
- sparrow_parse/extractors/vllm_extractor.py +168 -48
- sparrow_parse/processors/table_structure_processor.py +68 -411
- {sparrow_parse-0.3.11.dist-info → sparrow_parse-0.4.0.dist-info}/METADATA +6 -4
- {sparrow_parse-0.3.11.dist-info → sparrow_parse-0.4.0.dist-info}/RECORD +8 -8
- {sparrow_parse-0.3.11.dist-info → sparrow_parse-0.4.0.dist-info}/WHEEL +0 -0
- {sparrow_parse-0.3.11.dist-info → sparrow_parse-0.4.0.dist-info}/entry_points.txt +0 -0
- {sparrow_parse-0.3.11.dist-info → sparrow_parse-0.4.0.dist-info}/top_level.txt +0 -0
sparrow_parse/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.
|
1
|
+
__version__ = '0.4.0'
|
@@ -1,87 +1,207 @@
|
|
1
|
+
import json
|
2
|
+
|
1
3
|
from sparrow_parse.vllm.inference_factory import InferenceFactory
|
2
4
|
from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer
|
5
|
+
from sparrow_parse.processors.table_structure_processor import TableDetector
|
3
6
|
from rich import print
|
4
7
|
import os
|
8
|
+
import tempfile
|
5
9
|
import shutil
|
10
|
+
from typing import Any, Dict, List, Union
|
6
11
|
|
7
12
|
|
8
13
|
class VLLMExtractor(object):
|
9
14
|
def __init__(self):
|
10
15
|
pass
|
11
16
|
|
12
|
-
def run_inference(self, model_inference_instance, input_data,
|
17
|
+
def run_inference(self, model_inference_instance, input_data, tables_only=False,
|
13
18
|
generic_query=False, debug_dir=None, debug=False, mode=None):
|
14
|
-
|
19
|
+
"""
|
20
|
+
Main entry point for processing input data using a model inference instance.
|
21
|
+
Handles generic queries, PDFs, and table extraction.
|
22
|
+
"""
|
15
23
|
if generic_query:
|
16
24
|
input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
|
17
25
|
|
18
26
|
if debug:
|
19
|
-
print("Input
|
27
|
+
print("Input data:", input_data)
|
20
28
|
|
21
|
-
# Check if the input file is a PDF
|
22
29
|
file_path = input_data[0]["file_path"]
|
23
30
|
if self.is_pdf(file_path):
|
24
|
-
return self._process_pdf(model_inference_instance, input_data, debug_dir, mode)
|
31
|
+
return self._process_pdf(model_inference_instance, input_data, tables_only, debug, debug_dir, mode)
|
25
32
|
|
26
|
-
|
27
|
-
input_data[0]["file_path"] = [file_path]
|
28
|
-
results_array = model_inference_instance.inference(input_data)
|
29
|
-
return results_array, 1
|
33
|
+
return self._process_non_pdf(model_inference_instance, input_data, tables_only, debug, debug_dir)
|
30
34
|
|
31
35
|
|
32
|
-
def _process_pdf(self, model_inference_instance, input_data, debug_dir, mode):
|
33
|
-
"""
|
36
|
+
def _process_pdf(self, model_inference_instance, input_data, tables_only, debug, debug_dir, mode):
|
37
|
+
"""
|
38
|
+
Handles processing and inference for PDF files, including page splitting and optional table extraction.
|
39
|
+
"""
|
34
40
|
pdf_optimizer = PDFOptimizer()
|
35
41
|
num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
|
36
|
-
debug_dir,
|
37
|
-
True)
|
38
|
-
# Update file paths for PDF pages
|
39
|
-
input_data[0]["file_path"] = output_files
|
42
|
+
debug_dir, convert_to_images=True)
|
40
43
|
|
41
|
-
|
42
|
-
results_array = model_inference_instance.inference(input_data, mode)
|
44
|
+
results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, debug, debug_dir)
|
43
45
|
|
44
46
|
# Clean up temporary directory
|
45
47
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
46
|
-
return
|
48
|
+
return results, num_pages
|
49
|
+
|
50
|
+
|
51
|
+
def _process_non_pdf(self, model_inference_instance, input_data, tables_only, debug, debug_dir):
|
52
|
+
"""
|
53
|
+
Handles processing and inference for non-PDF files, with optional table extraction.
|
54
|
+
"""
|
55
|
+
file_path = input_data[0]["file_path"]
|
56
|
+
if tables_only:
|
57
|
+
return [self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir)], 1
|
58
|
+
else:
|
59
|
+
input_data[0]["file_path"] = [file_path]
|
60
|
+
results = model_inference_instance.inference(input_data)
|
61
|
+
return results, 1
|
62
|
+
|
63
|
+
def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, debug, debug_dir):
|
64
|
+
"""
|
65
|
+
Processes individual pages (PDF split) and handles table extraction or inference.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
model_inference_instance: The model inference object.
|
69
|
+
output_files: List of file paths for the split PDF pages.
|
70
|
+
input_data: Input data for inference.
|
71
|
+
tables_only: Whether to only process tables.
|
72
|
+
debug: Debug flag for logging.
|
73
|
+
debug_dir: Directory for saving debug information.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
List of results from the processing or inference.
|
77
|
+
"""
|
78
|
+
results_array = []
|
79
|
+
|
80
|
+
if tables_only:
|
81
|
+
if debug:
|
82
|
+
print(f"Processing {len(output_files)} pages for table extraction.")
|
83
|
+
# Process each page individually for table extraction
|
84
|
+
for i, file_path in enumerate(output_files):
|
85
|
+
tables_result = self._extract_tables(
|
86
|
+
model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
|
87
|
+
)
|
88
|
+
results_array.append(tables_result)
|
89
|
+
else:
|
90
|
+
if debug:
|
91
|
+
print(f"Processing {len(output_files)} pages for inference at once.")
|
92
|
+
# Pass all output files to the inference method for processing at once
|
93
|
+
input_data[0]["file_path"] = output_files
|
94
|
+
results = model_inference_instance.inference(input_data)
|
95
|
+
results_array.extend(results)
|
96
|
+
|
97
|
+
return results_array
|
98
|
+
|
99
|
+
|
100
|
+
def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
|
101
|
+
"""
|
102
|
+
Detects and processes tables from an input file.
|
103
|
+
"""
|
104
|
+
table_detector = TableDetector()
|
105
|
+
cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug)
|
106
|
+
results_array = []
|
107
|
+
temp_dir = tempfile.mkdtemp()
|
108
|
+
|
109
|
+
for i, table in enumerate(cropped_tables):
|
110
|
+
table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}"
|
111
|
+
print(f"Processing {table_index} for document {file_path}")
|
112
|
+
|
113
|
+
output_filename = os.path.join(temp_dir, f"{table_index}.jpg")
|
114
|
+
table.save(output_filename, "JPEG")
|
115
|
+
|
116
|
+
input_data[0]["file_path"] = [output_filename]
|
117
|
+
result = self._run_model_inference(model_inference_instance, input_data)
|
118
|
+
result = self.add_table_info_to_data(result, "table_nr", i + 1)
|
119
|
+
results_array.append(result)
|
120
|
+
|
121
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
122
|
+
return json.dumps(results_array, indent=4)
|
123
|
+
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def _run_model_inference(model_inference_instance, input_data):
|
127
|
+
"""
|
128
|
+
Runs model inference and handles JSON decoding.
|
129
|
+
"""
|
130
|
+
result = model_inference_instance.inference(input_data)[0]
|
131
|
+
try:
|
132
|
+
return json.loads(result) if isinstance(result, str) else result
|
133
|
+
except json.JSONDecodeError:
|
134
|
+
return {"message": "Invalid JSON format in LLM output", "valid": "false"}
|
135
|
+
|
47
136
|
|
48
137
|
@staticmethod
|
49
138
|
def is_pdf(file_path):
|
50
139
|
"""Checks if a file is a PDF based on its extension."""
|
51
140
|
return file_path.lower().endswith('.pdf')
|
52
141
|
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def add_table_info_to_data(data: Union[Dict, List], key: str, message: Any) -> Dict:
|
145
|
+
"""
|
146
|
+
Add a key-value pair to a dictionary or wrap a list in a dictionary.
|
147
|
+
If a 'table' key exists, add or update the key-value pair inside it.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
data (Union[Dict, List]): The input data (either a dictionary or list).
|
151
|
+
key (str): The key to add.
|
152
|
+
message (Any): The value to associate with the key.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
Dict: The modified data.
|
156
|
+
"""
|
157
|
+
if isinstance(data, dict):
|
158
|
+
if "table" in data and isinstance(data["table"], list):
|
159
|
+
# Add or update the key-value pair in the existing structure
|
160
|
+
data[key] = message
|
161
|
+
else:
|
162
|
+
# Wrap the dictionary inside a `table` key and include the additional key-value pair
|
163
|
+
data = {"table": [data], key: message}
|
164
|
+
elif isinstance(data, list):
|
165
|
+
# Wrap the list in a dictionary with the additional key-value pair
|
166
|
+
data = {"table": data, key: message}
|
167
|
+
else:
|
168
|
+
raise TypeError("Data must be a dictionary or a list.")
|
169
|
+
return data
|
170
|
+
|
171
|
+
|
53
172
|
if __name__ == "__main__":
|
54
173
|
# run locally: python -m sparrow_parse.extractors.vllm_extractor
|
55
174
|
|
56
175
|
extractor = VLLMExtractor()
|
57
176
|
|
58
|
-
# export HF_TOKEN="hf_"
|
59
|
-
config = {
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
}
|
67
|
-
|
68
|
-
# Use the factory to get the correct instance
|
69
|
-
factory = InferenceFactory(config)
|
70
|
-
model_inference_instance = factory.get_inference_instance()
|
71
|
-
|
72
|
-
input_data = [
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
]
|
78
|
-
|
79
|
-
# Now you can run inference without knowing which implementation is used
|
80
|
-
results_array, num_pages = extractor.run_inference(model_inference_instance, input_data,
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
print(f"
|
177
|
+
# # export HF_TOKEN="hf_"
|
178
|
+
# config = {
|
179
|
+
# "method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu'
|
180
|
+
# "model_name": "mlx-community/Qwen2-VL-72B-Instruct-4bit",
|
181
|
+
# # "hf_space": "katanaml/sparrow-qwen2-vl-7b",
|
182
|
+
# # "hf_token": os.getenv('HF_TOKEN'),
|
183
|
+
# # Additional fields for local GPU inference
|
184
|
+
# # "device": "cuda", "model_path": "model.pth"
|
185
|
+
# }
|
186
|
+
#
|
187
|
+
# # Use the factory to get the correct instance
|
188
|
+
# factory = InferenceFactory(config)
|
189
|
+
# model_inference_instance = factory.get_inference_instance()
|
190
|
+
#
|
191
|
+
# input_data = [
|
192
|
+
# {
|
193
|
+
# "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/invoice_1.jpg",
|
194
|
+
# "text_input": "retrieve document data. return response in JSON format"
|
195
|
+
# }
|
196
|
+
# ]
|
197
|
+
#
|
198
|
+
# # Now you can run inference without knowing which implementation is used
|
199
|
+
# results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
|
200
|
+
# generic_query=False,
|
201
|
+
# debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
|
202
|
+
# debug=True,
|
203
|
+
# mode=None)
|
204
|
+
#
|
205
|
+
# for i, result in enumerate(results_array):
|
206
|
+
# print(f"Result for page {i + 1}:", result)
|
207
|
+
# print(f"Number of pages: {num_pages}")
|
@@ -1,19 +1,18 @@
|
|
1
1
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
2
2
|
from rich import print
|
3
3
|
from transformers import AutoModelForObjectDetection
|
4
|
-
from transformers import TableTransformerForObjectDetection
|
5
4
|
import torch
|
6
5
|
from PIL import Image
|
7
6
|
from torchvision import transforms
|
8
|
-
from PIL import ImageDraw
|
9
7
|
import os
|
10
|
-
import numpy as np
|
11
|
-
import easyocr
|
12
8
|
|
13
9
|
|
14
10
|
class TableDetector(object):
|
11
|
+
_model = None # Static variable to hold the table detection model
|
12
|
+
_device = None # Static variable to hold the device information
|
13
|
+
|
15
14
|
def __init__(self):
|
16
|
-
|
15
|
+
pass
|
17
16
|
|
18
17
|
class MaxResize(object):
|
19
18
|
def __init__(self, max_size=800):
|
@@ -27,12 +26,27 @@ class TableDetector(object):
|
|
27
26
|
|
28
27
|
return resized_image
|
29
28
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
29
|
+
@classmethod
|
30
|
+
def _initialize_model(cls, invoke_pipeline_step, local):
|
31
|
+
"""
|
32
|
+
Static method to initialize the table detection model if not already initialized.
|
33
|
+
"""
|
34
|
+
if cls._model is None:
|
35
|
+
# Use invoke_pipeline_step to load the model
|
36
|
+
cls._model, cls._device = invoke_pipeline_step(
|
37
|
+
lambda: cls.load_table_detection_model(),
|
38
|
+
"Loading table detection model...",
|
39
|
+
local
|
40
|
+
)
|
41
|
+
print("Table detection model initialized.")
|
42
|
+
|
43
|
+
|
44
|
+
def detect_tables(self, file_path, local=True, debug_dir=None, debug=False):
|
45
|
+
# Ensure the model is initialized using invoke_pipeline_step
|
46
|
+
self._initialize_model(self.invoke_pipeline_step, local)
|
47
|
+
|
48
|
+
# Use the static model and device
|
49
|
+
model, device = self._model, self._device
|
36
50
|
|
37
51
|
outputs, image = self.invoke_pipeline_step(
|
38
52
|
lambda: self.prepare_image(file_path, model, device),
|
@@ -46,38 +60,17 @@ class TableDetector(object):
|
|
46
60
|
local
|
47
61
|
)
|
48
62
|
|
49
|
-
|
50
|
-
lambda: self.
|
63
|
+
cropped_tables = self.invoke_pipeline_step(
|
64
|
+
lambda: self.crop_tables(file_path, image, objects, debug, debug_dir),
|
51
65
|
"Cropping tables from the image...",
|
52
66
|
local
|
53
67
|
)
|
54
68
|
|
55
|
-
|
56
|
-
lambda: self.load_table_structure_model(device),
|
57
|
-
"Loading table structure recognition model...",
|
58
|
-
local
|
59
|
-
)
|
60
|
-
|
61
|
-
structure_outputs = self.invoke_pipeline_step(
|
62
|
-
lambda: self.get_table_structure(cropped_table, structure_model, device),
|
63
|
-
"Getting table structure from cropped table...",
|
64
|
-
local
|
65
|
-
)
|
66
|
-
|
67
|
-
table_data = self.invoke_pipeline_step(
|
68
|
-
lambda: self.get_structure_cells(structure_model, cropped_table, structure_outputs),
|
69
|
-
"Getting structure cells from cropped table...",
|
70
|
-
local
|
71
|
-
)
|
72
|
-
|
73
|
-
self.invoke_pipeline_step(
|
74
|
-
lambda: self.process_table_structure(table_data, cropped_table, file_path),
|
75
|
-
"Processing structure cells...",
|
76
|
-
local
|
77
|
-
)
|
69
|
+
return cropped_tables
|
78
70
|
|
79
71
|
|
80
|
-
|
72
|
+
@staticmethod
|
73
|
+
def load_table_detection_model():
|
81
74
|
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
|
82
75
|
|
83
76
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -85,11 +78,6 @@ class TableDetector(object):
|
|
85
78
|
|
86
79
|
return model, device
|
87
80
|
|
88
|
-
def load_table_structure_model(self, device):
|
89
|
-
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
|
90
|
-
structure_model.to(device)
|
91
|
-
|
92
|
-
return structure_model
|
93
81
|
|
94
82
|
def prepare_image(self, file_path, model, device):
|
95
83
|
image = Image.open(file_path).convert("RGB")
|
@@ -115,38 +103,52 @@ class TableDetector(object):
|
|
115
103
|
objects = self.outputs_to_objects(outputs, image.size, id2label)
|
116
104
|
return objects
|
117
105
|
|
118
|
-
|
106
|
+
|
107
|
+
def crop_tables(self, file_path, image, objects, debug, debug_dir):
|
119
108
|
tokens = []
|
120
109
|
detection_class_thresholds = {
|
121
110
|
"table": 0.5,
|
122
111
|
"table rotated": 0.5,
|
123
112
|
"no object": 10
|
124
113
|
}
|
125
|
-
crop_padding =
|
114
|
+
crop_padding = 30
|
126
115
|
|
127
116
|
tables_crops = self.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
|
128
117
|
|
129
|
-
|
118
|
+
cropped_tables = []
|
130
119
|
|
131
120
|
if len(tables_crops) == 0:
|
132
|
-
|
133
|
-
|
121
|
+
if debug:
|
122
|
+
print("No tables detected in: ", file_path)
|
123
|
+
|
124
|
+
return None
|
134
125
|
elif len(tables_crops) > 1:
|
135
126
|
for i, table_crop in enumerate(tables_crops):
|
127
|
+
if debug:
|
128
|
+
print("Table detected in:", file_path, "-", i + 1)
|
129
|
+
|
136
130
|
cropped_table = table_crop['image'].convert("RGB")
|
137
|
-
|
138
|
-
|
139
|
-
|
131
|
+
cropped_tables.append(cropped_table)
|
132
|
+
|
133
|
+
if debug_dir:
|
134
|
+
file_name_table = self.append_filename(file_path, debug_dir, f"cropped_{i + 1}")
|
135
|
+
cropped_table.save(file_name_table)
|
140
136
|
else:
|
137
|
+
if debug:
|
138
|
+
print("Table detected in: ", file_path)
|
139
|
+
|
141
140
|
cropped_table = tables_crops[0]['image'].convert("RGB")
|
141
|
+
cropped_tables.append(cropped_table)
|
142
142
|
|
143
|
-
|
144
|
-
|
143
|
+
if debug_dir:
|
144
|
+
file_name_table = self.append_filename(file_path, debug_dir, "cropped")
|
145
|
+
cropped_table.save(file_name_table)
|
145
146
|
|
146
|
-
return
|
147
|
+
return cropped_tables
|
147
148
|
|
148
149
|
# for output bounding box post-processing
|
149
|
-
|
150
|
+
@staticmethod
|
151
|
+
def box_cxcywh_to_xyxy(x):
|
150
152
|
x_c, y_c, w, h = x.unbind(-1)
|
151
153
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
152
154
|
return torch.stack(b, dim=1)
|
@@ -216,214 +218,15 @@ class TableDetector(object):
|
|
216
218
|
|
217
219
|
return table_crops
|
218
220
|
|
219
|
-
def get_table_structure(self, cropped_table, structure_model, device):
|
220
|
-
structure_transform = transforms.Compose([
|
221
|
-
self.MaxResize(1000),
|
222
|
-
transforms.ToTensor(),
|
223
|
-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
224
|
-
])
|
225
|
-
|
226
|
-
pixel_values = structure_transform(cropped_table).unsqueeze(0)
|
227
|
-
pixel_values = pixel_values.to(device)
|
228
|
-
|
229
|
-
with torch.no_grad():
|
230
|
-
outputs = structure_model(pixel_values)
|
231
|
-
|
232
|
-
return outputs
|
233
|
-
|
234
|
-
def get_structure_cells(self, structure_model, cropped_table, outputs):
|
235
|
-
structure_id2label = structure_model.config.id2label
|
236
|
-
structure_id2label[len(structure_id2label)] = "no object"
|
237
|
-
|
238
|
-
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
|
239
|
-
|
240
|
-
return cells
|
241
|
-
|
242
|
-
def process_table_structure(self, table_data, cropped_table, file_path):
|
243
|
-
cropped_table_raw_visualized = cropped_table.copy()
|
244
|
-
draw_raw = ImageDraw.Draw(cropped_table_raw_visualized)
|
245
|
-
cropped_table_header_visualized = cropped_table.copy()
|
246
|
-
draw_header = ImageDraw.Draw(cropped_table_header_visualized)
|
247
|
-
cropped_table_visualized = cropped_table.copy()
|
248
|
-
draw = ImageDraw.Draw(cropped_table_visualized)
|
249
|
-
|
250
|
-
table_data = [cell for cell in table_data if cell['label'] != 'table spanning cell']
|
251
|
-
table_data = [cell for cell in table_data if cell['label'] != 'table']
|
252
|
-
table_data = [cell for cell in table_data if cell['score'] >= 0.8]
|
253
|
-
|
254
|
-
table_data = self.merge_overlapping_columns(cropped_table, table_data)
|
255
|
-
table_data = self.adjust_overlapping_rows(cropped_table, table_data)
|
256
|
-
|
257
|
-
table_data_filtered = [item for item in table_data if item['label'] == 'table row']
|
258
|
-
# table_data_filtered = table_data
|
259
|
-
for cell in table_data_filtered:
|
260
|
-
draw_raw.rectangle(cell["bbox"], outline="red")
|
261
|
-
file_name_table_grid_raw = self.append_filename(file_path, "table_raw")
|
262
|
-
cropped_table_raw_visualized.save(file_name_table_grid_raw)
|
263
|
-
print("Table raw data:")
|
264
|
-
print(table_data_filtered)
|
265
|
-
|
266
|
-
# table, table column header, table row, table column
|
267
|
-
table_data_header = [cell for cell in table_data if cell['label'] == 'table column header'
|
268
|
-
or cell['label'] == 'table' or cell['label'] == 'table column']
|
269
|
-
print("Table header data:")
|
270
|
-
print(table_data_header)
|
271
|
-
|
272
|
-
table_data_rows = [cell for cell in table_data if cell['label'] == 'table column'
|
273
|
-
or cell['label'] == 'table row']
|
274
|
-
table_data_rows = self.remove_overlapping_table_header_rows(table_data_header, table_data_rows)
|
275
|
-
print("Table row data:")
|
276
|
-
print(table_data_rows)
|
277
|
-
|
278
|
-
header_cells = self.get_header_cell_coordinates(table_data_header)
|
279
|
-
if header_cells is not None:
|
280
|
-
print("Header cell coordinates:")
|
281
|
-
print(header_cells)
|
282
|
-
|
283
|
-
header_data = self.do_ocr_with_coordinates(header_cells, cropped_table)
|
284
|
-
print("Header data:")
|
285
|
-
print(header_data)
|
286
|
-
|
287
|
-
for cell_data in header_cells['row0']:
|
288
|
-
draw_header.rectangle(cell_data["cell"], outline="red")
|
289
|
-
|
290
|
-
file_name_table_grid_header = self.append_filename(file_path, "table_grid_header")
|
291
|
-
cropped_table_header_visualized.save(file_name_table_grid_header)
|
292
|
-
|
293
|
-
table_cells = self.get_table_cell_coordinates(table_data_rows)
|
294
|
-
if table_cells is not None:
|
295
|
-
print("Table cell coordinates:")
|
296
|
-
print(table_cells)
|
297
|
-
|
298
|
-
table_data = self.do_ocr_with_coordinates(table_cells, cropped_table)
|
299
|
-
print("Table data:")
|
300
|
-
print(table_data)
|
301
|
-
|
302
|
-
for row_key, row_cells in table_cells.items():
|
303
|
-
for cell_data in row_cells:
|
304
|
-
draw.rectangle(cell_data["cell"], outline="red")
|
305
|
-
|
306
|
-
file_name_table_grid = self.append_filename(file_path, "table_grid_cells")
|
307
|
-
cropped_table_visualized.save(file_name_table_grid)
|
308
|
-
|
309
|
-
def get_header_cell_coordinates(self, table_data):
|
310
|
-
header_column = None
|
311
|
-
columns = []
|
312
|
-
|
313
|
-
# Separate header and columns
|
314
|
-
for item in table_data:
|
315
|
-
if item['label'] == 'table column header':
|
316
|
-
header_column = item['bbox']
|
317
|
-
elif item['label'] == 'table column':
|
318
|
-
columns.append(item['bbox'])
|
319
|
-
|
320
|
-
if not header_column:
|
321
|
-
return None
|
322
221
|
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
cells = []
|
327
|
-
|
328
|
-
# Calculate cell coordinates based on header and column intersections
|
329
|
-
for column in columns:
|
330
|
-
cell_left = column[0]
|
331
|
-
cell_right = column[2]
|
332
|
-
cell_top = header_top
|
333
|
-
cell_bottom = header_bottom
|
334
|
-
|
335
|
-
cells.append({
|
336
|
-
'cell': (cell_left, cell_top, cell_right, cell_bottom)
|
337
|
-
})
|
338
|
-
|
339
|
-
# Sort cells by the left coordinate (cell_left) to order them from left to right
|
340
|
-
cells.sort(key=lambda x: x['cell'][0])
|
341
|
-
|
342
|
-
header_row = {"row0": cells}
|
343
|
-
|
344
|
-
return header_row
|
345
|
-
|
346
|
-
def get_table_cell_coordinates(self, table_data):
|
347
|
-
rows = []
|
348
|
-
columns = []
|
349
|
-
|
350
|
-
# Separate rows and columns
|
351
|
-
for item in table_data:
|
352
|
-
if item['label'] == 'table row':
|
353
|
-
rows.append(item['bbox'])
|
354
|
-
elif item['label'] == 'table column':
|
355
|
-
columns.append(item['bbox'])
|
356
|
-
|
357
|
-
if not rows or not columns:
|
358
|
-
return None
|
359
|
-
|
360
|
-
# Sort rows by the top coordinate to ensure they are processed from top to bottom
|
361
|
-
rows.sort(key=lambda x: x[1])
|
362
|
-
|
363
|
-
row_cells = {}
|
364
|
-
|
365
|
-
# Calculate cell coordinates based on row and column intersections
|
366
|
-
for row_idx, row in enumerate(rows):
|
367
|
-
row_top = row[1]
|
368
|
-
row_bottom = row[3]
|
369
|
-
cells = []
|
370
|
-
for column in columns:
|
371
|
-
cell_left = column[0]
|
372
|
-
cell_right = column[2]
|
373
|
-
cell_top = row_top
|
374
|
-
cell_bottom = row_bottom
|
375
|
-
|
376
|
-
cells.append({
|
377
|
-
'cell': (cell_left, cell_top, cell_right, cell_bottom)
|
378
|
-
})
|
379
|
-
|
380
|
-
# Sort cells within the row by the left coordinate to ensure they are ordered from left to right
|
381
|
-
cells.sort(key=lambda x: x['cell'][0])
|
382
|
-
row_cells[f'row{row_idx}'] = cells
|
383
|
-
|
384
|
-
return row_cells
|
385
|
-
|
386
|
-
def do_ocr_with_coordinates(self, cell_coordinates, cropped_table):
|
387
|
-
data = {}
|
388
|
-
max_num_columns = 0
|
389
|
-
|
390
|
-
# Iterate over each row in cell_coordinates
|
391
|
-
for row_key in cell_coordinates:
|
392
|
-
row_text = []
|
393
|
-
for cell in cell_coordinates[row_key]:
|
394
|
-
# Crop cell out of image
|
395
|
-
cell_image = cropped_table.crop(cell['cell'])
|
396
|
-
cell_image_np = np.array(cell_image)
|
397
|
-
|
398
|
-
# Apply OCR
|
399
|
-
result = self.reader.readtext(cell_image_np)
|
400
|
-
if result:
|
401
|
-
text = " ".join([x[1] for x in result])
|
402
|
-
row_text.append(text)
|
403
|
-
else:
|
404
|
-
row_text.append("") # If no text is detected, append an empty string
|
405
|
-
|
406
|
-
if len(row_text) > max_num_columns:
|
407
|
-
max_num_columns = len(row_text)
|
408
|
-
|
409
|
-
data[row_key] = row_text
|
410
|
-
|
411
|
-
print("Max number of columns:", max_num_columns)
|
412
|
-
|
413
|
-
# Pad rows which don't have max_num_columns elements
|
414
|
-
for row_key, row_data in data.items():
|
415
|
-
if len(row_data) < max_num_columns:
|
416
|
-
row_data += [""] * (max_num_columns - len(row_data))
|
417
|
-
data[row_key] = row_data
|
418
|
-
|
419
|
-
return data
|
420
|
-
|
421
|
-
def append_filename(self, file_path, word):
|
222
|
+
@staticmethod
|
223
|
+
def append_filename(file_path, debug_dir, word):
|
422
224
|
directory, filename = os.path.split(file_path)
|
423
225
|
name, ext = os.path.splitext(filename)
|
424
226
|
new_filename = f"{name}_{word}{ext}"
|
425
|
-
return os.path.join(
|
227
|
+
return os.path.join(debug_dir, new_filename)
|
426
228
|
|
229
|
+
@staticmethod
|
427
230
|
def iob(boxA, boxB):
|
428
231
|
# Determine the coordinates of the intersection rectangle
|
429
232
|
xA = max(boxA[0], boxB[0])
|
@@ -443,159 +246,9 @@ class TableDetector(object):
|
|
443
246
|
|
444
247
|
return iob
|
445
248
|
|
446
|
-
def remove_overlapping_table_header_rows(self, header_data, row_data, tolerance=1.0):
|
447
|
-
# Function to calculate the Intersection over Union (IoU) of two bounding boxes
|
448
|
-
def calculate_iou(bbox1, bbox2):
|
449
|
-
x1_min, y1_min, x1_max, y1_max = bbox1
|
450
|
-
x2_min, y2_min, x2_max, y2_max = bbox2
|
451
|
-
|
452
|
-
# Determine the coordinates of the intersection rectangle
|
453
|
-
inter_min_x = max(x1_min, x2_min)
|
454
|
-
inter_min_y = max(y1_min, y2_min)
|
455
|
-
inter_max_x = min(x1_max, x2_max)
|
456
|
-
inter_max_y = min(y1_max, y2_max)
|
457
|
-
|
458
|
-
# Compute the area of intersection
|
459
|
-
inter_area = max(0, inter_max_x - inter_min_x) * max(0, inter_max_y - inter_min_y)
|
460
|
-
|
461
|
-
# Compute the area of both bounding boxes
|
462
|
-
bbox1_area = (x1_max - x1_min) * (y1_max - y1_min)
|
463
|
-
bbox2_area = (x2_max - x2_min) * (y2_max - y2_min)
|
464
|
-
|
465
|
-
# Compute the Intersection over Union (IoU)
|
466
|
-
iou = inter_area / float(bbox1_area + bbox2_area - inter_area)
|
467
|
-
return iou
|
468
|
-
|
469
|
-
# Extract the bounding box of the table column header
|
470
|
-
header_bbox = None
|
471
|
-
for item in header_data:
|
472
|
-
if item['label'] == 'table column header':
|
473
|
-
header_bbox = item['bbox']
|
474
|
-
break
|
475
|
-
|
476
|
-
if header_bbox is None:
|
477
|
-
print("No 'table column header' found in header data.")
|
478
|
-
return row_data
|
479
|
-
|
480
|
-
# Initialize a counter for removed rows
|
481
|
-
removed_count = 0
|
482
|
-
|
483
|
-
# Iterate over the table row data and remove rows with overlapping bbox
|
484
|
-
updated_row_data = []
|
485
|
-
for row in row_data:
|
486
|
-
if row['label'] == 'table row':
|
487
|
-
row_bbox = row['bbox']
|
488
|
-
# Check for overlap (IoU > 0) or very similar bounding box
|
489
|
-
iou = calculate_iou(header_bbox, row_bbox)
|
490
|
-
if iou > 0 or np.allclose(row_bbox, header_bbox, atol=tolerance):
|
491
|
-
removed_count += 1 # Increment the removed counter
|
492
|
-
continue # Skip this row as it overlaps or matches the header bbox
|
493
|
-
|
494
|
-
# Add row to the updated list if it doesn't overlap
|
495
|
-
updated_row_data.append(row)
|
496
|
-
|
497
|
-
# Print the number of removed rows
|
498
|
-
print(f"Number of removed rows: {removed_count}")
|
499
|
-
|
500
|
-
return updated_row_data
|
501
|
-
|
502
|
-
def filter_table_columns(self, data):
|
503
|
-
return [item for item in data if item['label'] == 'table column']
|
504
|
-
|
505
|
-
def filter_table_rows(self, data):
|
506
|
-
return [item for item in data if item['label'] == 'table row']
|
507
|
-
|
508
|
-
def extract_text_boundaries(self, image, box):
|
509
|
-
"""
|
510
|
-
Extract the start and end coordinates of the text within a bounding box,
|
511
|
-
and translate them back to the original image coordinates.
|
512
|
-
|
513
|
-
Args:
|
514
|
-
- image: The image in which the box is located.
|
515
|
-
- box: The bounding box (x_min, y_min, x_max, y_max).
|
516
|
-
- reader: The EasyOCR reader object.
|
517
249
|
|
518
|
-
|
519
|
-
|
520
|
-
- text_end: The x-coordinate of the end of the text in the original image.
|
521
|
-
"""
|
522
|
-
x_min, y_min, x_max, y_max = box
|
523
|
-
cropped_image = image.crop((x_min, y_min, x_max, y_max))
|
524
|
-
result = self.reader.readtext(np.array(cropped_image))
|
525
|
-
|
526
|
-
if result:
|
527
|
-
text_coordinates = result[0][0] # Extract the coordinates of the text within the cropped image
|
528
|
-
|
529
|
-
# Translate the coordinates back to the original image coordinates
|
530
|
-
text_start = min(point[0] + x_min for point in text_coordinates)
|
531
|
-
text_end = max(point[0] + x_min for point in text_coordinates)
|
532
|
-
|
533
|
-
return text_start, text_end
|
534
|
-
|
535
|
-
return None, None
|
536
|
-
|
537
|
-
def merge_overlapping_columns(self, image, data, proximity_threshold=20):
|
538
|
-
"""
|
539
|
-
Merge only those bounding boxes where the text is split directly by the box line,
|
540
|
-
while keeping other labels intact.
|
541
|
-
|
542
|
-
Args:
|
543
|
-
- image: The image in which the boxes are located.
|
544
|
-
- data: List of dictionary items with bounding boxes and labels.
|
545
|
-
- reader: The EasyOCR reader object.
|
546
|
-
- proximity_threshold: The maximum distance between text boundaries to consider merging.
|
547
|
-
|
548
|
-
Returns:
|
549
|
-
- Updated list of dictionary items with merged bounding boxes and other entries preserved.
|
550
|
-
"""
|
551
|
-
table_columns = self.filter_table_columns(data)
|
552
|
-
other_entries = [item for item in data if item['label'] != 'table column']
|
553
|
-
merged_boxes = []
|
554
|
-
table_columns = sorted(table_columns, key=lambda x: x['bbox'][0]) # Sort by x_min
|
555
|
-
|
556
|
-
while table_columns:
|
557
|
-
box_data = table_columns.pop(0)
|
558
|
-
x_min, y_min, x_max, y_max = box_data['bbox']
|
559
|
-
|
560
|
-
to_merge = []
|
561
|
-
for i, other_box_data in enumerate(table_columns):
|
562
|
-
ox_min, oy_min, ox_max, oy_max = other_box_data['bbox']
|
563
|
-
|
564
|
-
# Only consider merging if the boxes are adjacent horizontally
|
565
|
-
if x_min < ox_max and x_max > ox_min:
|
566
|
-
# Extract text boundaries from both boxes
|
567
|
-
text_start_1, text_end_1 = self.extract_text_boundaries(image, box_data['bbox'])
|
568
|
-
text_start_2, text_end_2 = self.extract_text_boundaries(image, other_box_data['bbox'])
|
569
|
-
|
570
|
-
# Check if the text from one box ends very close to where the text in the next box starts
|
571
|
-
if text_end_1 is not None and text_start_2 is not None and text_start_2 - text_end_1 <= proximity_threshold:
|
572
|
-
x_max = max(x_max, ox_max)
|
573
|
-
y_max = max(y_max, oy_max)
|
574
|
-
y_min = min(y_min, oy_min)
|
575
|
-
to_merge.append(i)
|
576
|
-
|
577
|
-
# Merge the boxes
|
578
|
-
for index in sorted(to_merge, reverse=True):
|
579
|
-
table_columns.pop(index)
|
580
|
-
|
581
|
-
merged_boxes.append({
|
582
|
-
'label': box_data['label'],
|
583
|
-
'score': box_data['score'],
|
584
|
-
'bbox': [x_min, y_min, x_max, y_max]
|
585
|
-
})
|
586
|
-
|
587
|
-
# Combine the merged boxes with other entries
|
588
|
-
final_output = merged_boxes + other_entries
|
589
|
-
|
590
|
-
# Sort final output by the y-coordinate to maintain the original order
|
591
|
-
final_output = sorted(final_output, key=lambda x: x['bbox'][1])
|
592
|
-
|
593
|
-
return final_output
|
594
|
-
|
595
|
-
def adjust_overlapping_rows(self, image, data, proximity_threshold=10):
|
596
|
-
return data
|
597
|
-
|
598
|
-
def invoke_pipeline_step(self, task_call, task_description, local):
|
250
|
+
@staticmethod
|
251
|
+
def invoke_pipeline_step(task_call, task_description, local):
|
599
252
|
if local:
|
600
253
|
with Progress(
|
601
254
|
SpinnerColumn(),
|
@@ -614,5 +267,9 @@ class TableDetector(object):
|
|
614
267
|
if __name__ == "__main__":
|
615
268
|
table_detector = TableDetector()
|
616
269
|
|
617
|
-
|
618
|
-
# table_detector.
|
270
|
+
# file_path = "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png"
|
271
|
+
# cropped_tables = table_detector.detect_tables(file_path, local=True, debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/", debug=True)
|
272
|
+
|
273
|
+
# for i, cropped_table in enumerate(cropped_tables):
|
274
|
+
# file_name_table = table_detector.append_filename(file_path, "cropped_" + str(i))
|
275
|
+
# cropped_table.save(file_name_table)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sparrow-parse
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.
|
5
5
|
Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
|
6
6
|
Author: Andrej Baranovskij
|
@@ -21,10 +21,9 @@ Requires-Dist: transformers==4.46.3
|
|
21
21
|
Requires-Dist: sentence-transformers==3.3.1
|
22
22
|
Requires-Dist: numpy==2.1.3
|
23
23
|
Requires-Dist: pypdf==4.3.0
|
24
|
-
Requires-Dist: easyocr==1.7.1
|
25
24
|
Requires-Dist: gradio-client
|
26
25
|
Requires-Dist: pdf2image
|
27
|
-
Requires-Dist: mlx-vlm==0.1.
|
26
|
+
Requires-Dist: mlx-vlm==0.1.4; sys_platform == "darwin" and platform_machine == "arm64"
|
28
27
|
|
29
28
|
# Sparrow Parse
|
30
29
|
|
@@ -67,7 +66,8 @@ input_data = [
|
|
67
66
|
]
|
68
67
|
|
69
68
|
# Now you can run inference without knowing which implementation is used
|
70
|
-
results_array, num_pages = extractor.run_inference(model_inference_instance, input_data,
|
69
|
+
results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
|
70
|
+
generic_query=False,
|
71
71
|
debug_dir=None,
|
72
72
|
debug=True,
|
73
73
|
mode=None)
|
@@ -77,6 +77,8 @@ for i, result in enumerate(results_array):
|
|
77
77
|
print(f"Number of pages: {num_pages}")
|
78
78
|
```
|
79
79
|
|
80
|
+
Use `tables_only=True` if you want to extract only tables.
|
81
|
+
|
80
82
|
Use `mode="static"` if you want to simulate LLM call, without executing LLM backend.
|
81
83
|
|
82
84
|
Method `run_inference` will return results and number of pages processed.
|
@@ -1,19 +1,19 @@
|
|
1
|
-
sparrow_parse/__init__.py,sha256=
|
1
|
+
sparrow_parse/__init__.py,sha256=DObMj8zITWgJRRICOQXNFEgLDtZ9uQZUVwbNAU-P3oc,21
|
2
2
|
sparrow_parse/__main__.py,sha256=Xs1bpJV0n08KWOoQE34FBYn6EBXZA9HIYJKrE4ZdG78,153
|
3
3
|
sparrow_parse/extractors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
sparrow_parse/extractors/vllm_extractor.py,sha256=
|
4
|
+
sparrow_parse/extractors/vllm_extractor.py,sha256=SCqxdr8V_cm0COfs0TelTcBXapVcz2WffhESJ1fry0g,8716
|
5
5
|
sparrow_parse/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
sparrow_parse/helpers/pdf_optimizer.py,sha256=GIqQYWtixFeZGCRFXL0lQfQByapCDuQzzRHAkzcPwLE,3302
|
7
7
|
sparrow_parse/processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
sparrow_parse/processors/table_structure_processor.py,sha256=
|
8
|
+
sparrow_parse/processors/table_structure_processor.py,sha256=PQHHFdQUuTin3Mm2USuUga2n4fGWMLwiBJYq4CVD67o,9775
|
9
9
|
sparrow_parse/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
sparrow_parse/vllm/huggingface_inference.py,sha256=EJnG6PesGKMc_0qGPN8ufE6pSnhAgFu0XjCbaLCNVyM,1980
|
11
11
|
sparrow_parse/vllm/inference_base.py,sha256=4mwGoAY63MB4cHZpV0czTkJWEzimmiTzqqzKmLNzgjw,820
|
12
12
|
sparrow_parse/vllm/inference_factory.py,sha256=FTM65O-dW2WZchHOrNN7_Q3-FlVoAc65iSptuuUuClM,1166
|
13
13
|
sparrow_parse/vllm/local_gpu_inference.py,sha256=aHoJTejb5xrXjWDIGu5RBQWEyRCOBCB04sMvO2Wyvg8,628
|
14
14
|
sparrow_parse/vllm/mlx_inference.py,sha256=xR40qwjIR0HvrN8x58oOq6F4r1hEANRB-9kcokUQHHU,4748
|
15
|
-
sparrow_parse-0.
|
16
|
-
sparrow_parse-0.
|
17
|
-
sparrow_parse-0.
|
18
|
-
sparrow_parse-0.
|
19
|
-
sparrow_parse-0.
|
15
|
+
sparrow_parse-0.4.0.dist-info/METADATA,sha256=IQqfUUKnpA0ystjBmrrpSWw4b1hDYnLO4sqKdoNYEHk,6432
|
16
|
+
sparrow_parse-0.4.0.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
17
|
+
sparrow_parse-0.4.0.dist-info/entry_points.txt,sha256=8CrvTVTTcz1YuZ8aRCYNOH15ZOAaYLlcbYX3t28HwJY,54
|
18
|
+
sparrow_parse-0.4.0.dist-info/top_level.txt,sha256=n6b-WtT91zKLyCPZTP7wvne8v_yvIahcsz-4sX8I0rY,14
|
19
|
+
sparrow_parse-0.4.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|