sparrow-parse 0.3.12__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sparrow_parse/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.3.12'
1
+ __version__ = '0.4.1'
@@ -1,55 +1,144 @@
1
+ import json
2
+
1
3
  from sparrow_parse.vllm.inference_factory import InferenceFactory
2
4
  from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer
5
+ from sparrow_parse.processors.table_structure_processor import TableDetector
3
6
  from rich import print
4
7
  import os
8
+ import tempfile
5
9
  import shutil
10
+ from typing import Any, Dict, List, Union
6
11
 
7
12
 
8
13
  class VLLMExtractor(object):
9
14
  def __init__(self):
10
15
  pass
11
16
 
12
- def run_inference(self, model_inference_instance, input_data,
17
+ def run_inference(self, model_inference_instance, input_data, tables_only=False,
13
18
  generic_query=False, debug_dir=None, debug=False, mode=None):
14
- # Modify input for generic queries
19
+ """
20
+ Main entry point for processing input data using a model inference instance.
21
+ Handles generic queries, PDFs, and table extraction.
22
+ """
15
23
  if generic_query:
16
24
  input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
17
25
 
18
26
  if debug:
19
- print("Input Data:", input_data)
27
+ print("Input data:", input_data)
20
28
 
21
- # Check if the input file is a PDF
22
29
  file_path = input_data[0]["file_path"]
23
30
  if self.is_pdf(file_path):
24
- return self._process_pdf(model_inference_instance, input_data, debug_dir, mode)
31
+ return self._process_pdf(model_inference_instance, input_data, tables_only, debug, debug_dir, mode)
25
32
 
26
- # Default processing for non-PDF files
27
- input_data[0]["file_path"] = [file_path]
28
- results_array = model_inference_instance.inference(input_data)
29
- return results_array, 1
33
+ return self._process_non_pdf(model_inference_instance, input_data, tables_only, debug, debug_dir)
30
34
 
31
35
 
32
- def _process_pdf(self, model_inference_instance, input_data, debug_dir, mode):
33
- """Handles processing and inference for PDF files."""
36
+ def _process_pdf(self, model_inference_instance, input_data, tables_only, debug, debug_dir, mode):
37
+ """
38
+ Handles processing and inference for PDF files, including page splitting and optional table extraction.
39
+ """
34
40
  pdf_optimizer = PDFOptimizer()
35
41
  num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
36
- debug_dir,
37
- True)
38
- # Update file paths for PDF pages
39
- input_data[0]["file_path"] = output_files
42
+ debug_dir, convert_to_images=True)
40
43
 
41
- # Run inference on PDF pages
42
- results_array = model_inference_instance.inference(input_data, mode)
44
+ results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, debug, debug_dir)
43
45
 
44
46
  # Clean up temporary directory
45
47
  shutil.rmtree(temp_dir, ignore_errors=True)
46
- return results_array, num_pages
48
+ return results, num_pages
49
+
50
+
51
+ def _process_non_pdf(self, model_inference_instance, input_data, tables_only, debug, debug_dir):
52
+ """
53
+ Handles processing and inference for non-PDF files, with optional table extraction.
54
+ """
55
+ file_path = input_data[0]["file_path"]
56
+ if tables_only:
57
+ return [self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir)], 1
58
+ else:
59
+ input_data[0]["file_path"] = [file_path]
60
+ results = model_inference_instance.inference(input_data)
61
+ return results, 1
62
+
63
+ def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, debug, debug_dir):
64
+ """
65
+ Processes individual pages (PDF split) and handles table extraction or inference.
66
+
67
+ Args:
68
+ model_inference_instance: The model inference object.
69
+ output_files: List of file paths for the split PDF pages.
70
+ input_data: Input data for inference.
71
+ tables_only: Whether to only process tables.
72
+ debug: Debug flag for logging.
73
+ debug_dir: Directory for saving debug information.
74
+
75
+ Returns:
76
+ List of results from the processing or inference.
77
+ """
78
+ results_array = []
79
+
80
+ if tables_only:
81
+ if debug:
82
+ print(f"Processing {len(output_files)} pages for table extraction.")
83
+ # Process each page individually for table extraction
84
+ for i, file_path in enumerate(output_files):
85
+ tables_result = self._extract_tables(
86
+ model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
87
+ )
88
+ results_array.append(tables_result)
89
+ else:
90
+ if debug:
91
+ print(f"Processing {len(output_files)} pages for inference at once.")
92
+ # Pass all output files to the inference method for processing at once
93
+ input_data[0]["file_path"] = output_files
94
+ results = model_inference_instance.inference(input_data)
95
+ results_array.extend(results)
96
+
97
+ return results_array
98
+
99
+
100
+ def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
101
+ """
102
+ Detects and processes tables from an input file.
103
+ """
104
+ table_detector = TableDetector()
105
+ cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug)
106
+ results_array = []
107
+ temp_dir = tempfile.mkdtemp()
108
+
109
+ for i, table in enumerate(cropped_tables):
110
+ table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}"
111
+ print(f"Processing {table_index} for document {file_path}")
112
+
113
+ output_filename = os.path.join(temp_dir, f"{table_index}.jpg")
114
+ table.save(output_filename, "JPEG")
115
+
116
+ input_data[0]["file_path"] = [output_filename]
117
+ result = self._run_model_inference(model_inference_instance, input_data)
118
+ results_array.append(result)
119
+
120
+ shutil.rmtree(temp_dir, ignore_errors=True)
121
+ return json.dumps(results_array, indent=4)
122
+
123
+
124
+ @staticmethod
125
+ def _run_model_inference(model_inference_instance, input_data):
126
+ """
127
+ Runs model inference and handles JSON decoding.
128
+ """
129
+ result = model_inference_instance.inference(input_data)[0]
130
+ try:
131
+ return json.loads(result) if isinstance(result, str) else result
132
+ except json.JSONDecodeError:
133
+ return {"message": "Invalid JSON format in LLM output", "valid": "false"}
134
+
47
135
 
48
136
  @staticmethod
49
137
  def is_pdf(file_path):
50
138
  """Checks if a file is a PDF based on its extension."""
51
139
  return file_path.lower().endswith('.pdf')
52
140
 
141
+
53
142
  if __name__ == "__main__":
54
143
  # run locally: python -m sparrow_parse.extractors.vllm_extractor
55
144
 
@@ -71,16 +160,17 @@ if __name__ == "__main__":
71
160
  #
72
161
  # input_data = [
73
162
  # {
74
- # "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/oracle_10k_2014_q1_small.pdf",
75
- # "text_input": "retrieve table, description, latest_amount, previous_amount. return response in JSON format, by strictly following this JSON schema: {\"table\": [{\"description\": \"str\", \"latest_amount\": 0, \"previous_amount\": 0}]}"
163
+ # "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/invoice_1.jpg",
164
+ # "text_input": "retrieve document data. return response in JSON format"
76
165
  # }
77
166
  # ]
78
167
  #
79
168
  # # Now you can run inference without knowing which implementation is used
80
- # results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, generic_query=False,
81
- # debug_dir=None,
82
- # debug=True,
83
- # mode=None)
169
+ # results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
170
+ # generic_query=False,
171
+ # debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
172
+ # debug=True,
173
+ # mode=None)
84
174
  #
85
175
  # for i, result in enumerate(results_array):
86
176
  # print(f"Result for page {i + 1}:", result)
@@ -1,19 +1,18 @@
1
1
  from rich.progress import Progress, SpinnerColumn, TextColumn
2
2
  from rich import print
3
3
  from transformers import AutoModelForObjectDetection
4
- from transformers import TableTransformerForObjectDetection
5
4
  import torch
6
5
  from PIL import Image
7
6
  from torchvision import transforms
8
- from PIL import ImageDraw
9
7
  import os
10
- import numpy as np
11
- import easyocr
12
8
 
13
9
 
14
10
  class TableDetector(object):
11
+ _model = None # Static variable to hold the table detection model
12
+ _device = None # Static variable to hold the device information
13
+
15
14
  def __init__(self):
16
- self.reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory
15
+ pass
17
16
 
18
17
  class MaxResize(object):
19
18
  def __init__(self, max_size=800):
@@ -27,12 +26,27 @@ class TableDetector(object):
27
26
 
28
27
  return resized_image
29
28
 
30
- def detect_table(self, file_path, options, local=True, debug=False):
31
- model, device = self.invoke_pipeline_step(
32
- lambda: self.load_table_detection_model(),
33
- "Loading table detection model...",
34
- local
35
- )
29
+ @classmethod
30
+ def _initialize_model(cls, invoke_pipeline_step, local):
31
+ """
32
+ Static method to initialize the table detection model if not already initialized.
33
+ """
34
+ if cls._model is None:
35
+ # Use invoke_pipeline_step to load the model
36
+ cls._model, cls._device = invoke_pipeline_step(
37
+ lambda: cls.load_table_detection_model(),
38
+ "Loading table detection model...",
39
+ local
40
+ )
41
+ print("Table detection model initialized.")
42
+
43
+
44
+ def detect_tables(self, file_path, local=True, debug_dir=None, debug=False):
45
+ # Ensure the model is initialized using invoke_pipeline_step
46
+ self._initialize_model(self.invoke_pipeline_step, local)
47
+
48
+ # Use the static model and device
49
+ model, device = self._model, self._device
36
50
 
37
51
  outputs, image = self.invoke_pipeline_step(
38
52
  lambda: self.prepare_image(file_path, model, device),
@@ -46,38 +60,17 @@ class TableDetector(object):
46
60
  local
47
61
  )
48
62
 
49
- cropped_table = self.invoke_pipeline_step(
50
- lambda: self.crop_table(file_path, image, objects),
63
+ cropped_tables = self.invoke_pipeline_step(
64
+ lambda: self.crop_tables(file_path, image, objects, debug, debug_dir),
51
65
  "Cropping tables from the image...",
52
66
  local
53
67
  )
54
68
 
55
- structure_model = self.invoke_pipeline_step(
56
- lambda: self.load_table_structure_model(device),
57
- "Loading table structure recognition model...",
58
- local
59
- )
60
-
61
- structure_outputs = self.invoke_pipeline_step(
62
- lambda: self.get_table_structure(cropped_table, structure_model, device),
63
- "Getting table structure from cropped table...",
64
- local
65
- )
66
-
67
- table_data = self.invoke_pipeline_step(
68
- lambda: self.get_structure_cells(structure_model, cropped_table, structure_outputs),
69
- "Getting structure cells from cropped table...",
70
- local
71
- )
72
-
73
- self.invoke_pipeline_step(
74
- lambda: self.process_table_structure(table_data, cropped_table, file_path),
75
- "Processing structure cells...",
76
- local
77
- )
69
+ return cropped_tables
78
70
 
79
71
 
80
- def load_table_detection_model(self):
72
+ @staticmethod
73
+ def load_table_detection_model():
81
74
  model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
82
75
 
83
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -85,11 +78,6 @@ class TableDetector(object):
85
78
 
86
79
  return model, device
87
80
 
88
- def load_table_structure_model(self, device):
89
- structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
90
- structure_model.to(device)
91
-
92
- return structure_model
93
81
 
94
82
  def prepare_image(self, file_path, model, device):
95
83
  image = Image.open(file_path).convert("RGB")
@@ -115,38 +103,52 @@ class TableDetector(object):
115
103
  objects = self.outputs_to_objects(outputs, image.size, id2label)
116
104
  return objects
117
105
 
118
- def crop_table(self, file_path, image, objects):
106
+
107
+ def crop_tables(self, file_path, image, objects, debug, debug_dir):
119
108
  tokens = []
120
109
  detection_class_thresholds = {
121
110
  "table": 0.5,
122
111
  "table rotated": 0.5,
123
112
  "no object": 10
124
113
  }
125
- crop_padding = 10
114
+ crop_padding = 30
126
115
 
127
116
  tables_crops = self.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
128
117
 
129
- cropped_table = None
118
+ cropped_tables = []
130
119
 
131
120
  if len(tables_crops) == 0:
132
- print("No tables detected.")
133
- return
121
+ if debug:
122
+ print("No tables detected in: ", file_path)
123
+
124
+ return None
134
125
  elif len(tables_crops) > 1:
135
126
  for i, table_crop in enumerate(tables_crops):
127
+ if debug:
128
+ print("Table detected in:", file_path, "-", i + 1)
129
+
136
130
  cropped_table = table_crop['image'].convert("RGB")
137
- file_name_table = self.append_filename(file_path, f"table_{i}")
138
- cropped_table.save(file_name_table)
139
- break
131
+ cropped_tables.append(cropped_table)
132
+
133
+ if debug_dir:
134
+ file_name_table = self.append_filename(file_path, debug_dir, f"cropped_{i + 1}")
135
+ cropped_table.save(file_name_table)
140
136
  else:
137
+ if debug:
138
+ print("Table detected in: ", file_path)
139
+
141
140
  cropped_table = tables_crops[0]['image'].convert("RGB")
141
+ cropped_tables.append(cropped_table)
142
142
 
143
- file_name_table = self.append_filename(file_path, "table")
144
- cropped_table.save(file_name_table)
143
+ if debug_dir:
144
+ file_name_table = self.append_filename(file_path, debug_dir, "cropped")
145
+ cropped_table.save(file_name_table)
145
146
 
146
- return cropped_table
147
+ return cropped_tables
147
148
 
148
149
  # for output bounding box post-processing
149
- def box_cxcywh_to_xyxy(self, x):
150
+ @staticmethod
151
+ def box_cxcywh_to_xyxy(x):
150
152
  x_c, y_c, w, h = x.unbind(-1)
151
153
  b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
152
154
  return torch.stack(b, dim=1)
@@ -216,214 +218,15 @@ class TableDetector(object):
216
218
 
217
219
  return table_crops
218
220
 
219
- def get_table_structure(self, cropped_table, structure_model, device):
220
- structure_transform = transforms.Compose([
221
- self.MaxResize(1000),
222
- transforms.ToTensor(),
223
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
224
- ])
225
-
226
- pixel_values = structure_transform(cropped_table).unsqueeze(0)
227
- pixel_values = pixel_values.to(device)
228
-
229
- with torch.no_grad():
230
- outputs = structure_model(pixel_values)
231
-
232
- return outputs
233
-
234
- def get_structure_cells(self, structure_model, cropped_table, outputs):
235
- structure_id2label = structure_model.config.id2label
236
- structure_id2label[len(structure_id2label)] = "no object"
237
-
238
- cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
239
-
240
- return cells
241
-
242
- def process_table_structure(self, table_data, cropped_table, file_path):
243
- cropped_table_raw_visualized = cropped_table.copy()
244
- draw_raw = ImageDraw.Draw(cropped_table_raw_visualized)
245
- cropped_table_header_visualized = cropped_table.copy()
246
- draw_header = ImageDraw.Draw(cropped_table_header_visualized)
247
- cropped_table_visualized = cropped_table.copy()
248
- draw = ImageDraw.Draw(cropped_table_visualized)
249
-
250
- table_data = [cell for cell in table_data if cell['label'] != 'table spanning cell']
251
- table_data = [cell for cell in table_data if cell['label'] != 'table']
252
- table_data = [cell for cell in table_data if cell['score'] >= 0.8]
253
-
254
- table_data = self.merge_overlapping_columns(cropped_table, table_data)
255
- table_data = self.adjust_overlapping_rows(cropped_table, table_data)
256
-
257
- table_data_filtered = [item for item in table_data if item['label'] == 'table row']
258
- # table_data_filtered = table_data
259
- for cell in table_data_filtered:
260
- draw_raw.rectangle(cell["bbox"], outline="red")
261
- file_name_table_grid_raw = self.append_filename(file_path, "table_raw")
262
- cropped_table_raw_visualized.save(file_name_table_grid_raw)
263
- print("Table raw data:")
264
- print(table_data_filtered)
265
-
266
- # table, table column header, table row, table column
267
- table_data_header = [cell for cell in table_data if cell['label'] == 'table column header'
268
- or cell['label'] == 'table' or cell['label'] == 'table column']
269
- print("Table header data:")
270
- print(table_data_header)
271
-
272
- table_data_rows = [cell for cell in table_data if cell['label'] == 'table column'
273
- or cell['label'] == 'table row']
274
- table_data_rows = self.remove_overlapping_table_header_rows(table_data_header, table_data_rows)
275
- print("Table row data:")
276
- print(table_data_rows)
277
-
278
- header_cells = self.get_header_cell_coordinates(table_data_header)
279
- if header_cells is not None:
280
- print("Header cell coordinates:")
281
- print(header_cells)
282
-
283
- header_data = self.do_ocr_with_coordinates(header_cells, cropped_table)
284
- print("Header data:")
285
- print(header_data)
286
-
287
- for cell_data in header_cells['row0']:
288
- draw_header.rectangle(cell_data["cell"], outline="red")
289
-
290
- file_name_table_grid_header = self.append_filename(file_path, "table_grid_header")
291
- cropped_table_header_visualized.save(file_name_table_grid_header)
292
-
293
- table_cells = self.get_table_cell_coordinates(table_data_rows)
294
- if table_cells is not None:
295
- print("Table cell coordinates:")
296
- print(table_cells)
297
-
298
- table_data = self.do_ocr_with_coordinates(table_cells, cropped_table)
299
- print("Table data:")
300
- print(table_data)
301
-
302
- for row_key, row_cells in table_cells.items():
303
- for cell_data in row_cells:
304
- draw.rectangle(cell_data["cell"], outline="red")
305
-
306
- file_name_table_grid = self.append_filename(file_path, "table_grid_cells")
307
- cropped_table_visualized.save(file_name_table_grid)
308
-
309
- def get_header_cell_coordinates(self, table_data):
310
- header_column = None
311
- columns = []
312
-
313
- # Separate header and columns
314
- for item in table_data:
315
- if item['label'] == 'table column header':
316
- header_column = item['bbox']
317
- elif item['label'] == 'table column':
318
- columns.append(item['bbox'])
319
-
320
- if not header_column:
321
- return None
322
221
 
323
- header_top = header_column[1]
324
- header_bottom = header_column[3]
325
-
326
- cells = []
327
-
328
- # Calculate cell coordinates based on header and column intersections
329
- for column in columns:
330
- cell_left = column[0]
331
- cell_right = column[2]
332
- cell_top = header_top
333
- cell_bottom = header_bottom
334
-
335
- cells.append({
336
- 'cell': (cell_left, cell_top, cell_right, cell_bottom)
337
- })
338
-
339
- # Sort cells by the left coordinate (cell_left) to order them from left to right
340
- cells.sort(key=lambda x: x['cell'][0])
341
-
342
- header_row = {"row0": cells}
343
-
344
- return header_row
345
-
346
- def get_table_cell_coordinates(self, table_data):
347
- rows = []
348
- columns = []
349
-
350
- # Separate rows and columns
351
- for item in table_data:
352
- if item['label'] == 'table row':
353
- rows.append(item['bbox'])
354
- elif item['label'] == 'table column':
355
- columns.append(item['bbox'])
356
-
357
- if not rows or not columns:
358
- return None
359
-
360
- # Sort rows by the top coordinate to ensure they are processed from top to bottom
361
- rows.sort(key=lambda x: x[1])
362
-
363
- row_cells = {}
364
-
365
- # Calculate cell coordinates based on row and column intersections
366
- for row_idx, row in enumerate(rows):
367
- row_top = row[1]
368
- row_bottom = row[3]
369
- cells = []
370
- for column in columns:
371
- cell_left = column[0]
372
- cell_right = column[2]
373
- cell_top = row_top
374
- cell_bottom = row_bottom
375
-
376
- cells.append({
377
- 'cell': (cell_left, cell_top, cell_right, cell_bottom)
378
- })
379
-
380
- # Sort cells within the row by the left coordinate to ensure they are ordered from left to right
381
- cells.sort(key=lambda x: x['cell'][0])
382
- row_cells[f'row{row_idx}'] = cells
383
-
384
- return row_cells
385
-
386
- def do_ocr_with_coordinates(self, cell_coordinates, cropped_table):
387
- data = {}
388
- max_num_columns = 0
389
-
390
- # Iterate over each row in cell_coordinates
391
- for row_key in cell_coordinates:
392
- row_text = []
393
- for cell in cell_coordinates[row_key]:
394
- # Crop cell out of image
395
- cell_image = cropped_table.crop(cell['cell'])
396
- cell_image_np = np.array(cell_image)
397
-
398
- # Apply OCR
399
- result = self.reader.readtext(cell_image_np)
400
- if result:
401
- text = " ".join([x[1] for x in result])
402
- row_text.append(text)
403
- else:
404
- row_text.append("") # If no text is detected, append an empty string
405
-
406
- if len(row_text) > max_num_columns:
407
- max_num_columns = len(row_text)
408
-
409
- data[row_key] = row_text
410
-
411
- print("Max number of columns:", max_num_columns)
412
-
413
- # Pad rows which don't have max_num_columns elements
414
- for row_key, row_data in data.items():
415
- if len(row_data) < max_num_columns:
416
- row_data += [""] * (max_num_columns - len(row_data))
417
- data[row_key] = row_data
418
-
419
- return data
420
-
421
- def append_filename(self, file_path, word):
222
+ @staticmethod
223
+ def append_filename(file_path, debug_dir, word):
422
224
  directory, filename = os.path.split(file_path)
423
225
  name, ext = os.path.splitext(filename)
424
226
  new_filename = f"{name}_{word}{ext}"
425
- return os.path.join(directory, new_filename)
227
+ return os.path.join(debug_dir, new_filename)
426
228
 
229
+ @staticmethod
427
230
  def iob(boxA, boxB):
428
231
  # Determine the coordinates of the intersection rectangle
429
232
  xA = max(boxA[0], boxB[0])
@@ -443,159 +246,9 @@ class TableDetector(object):
443
246
 
444
247
  return iob
445
248
 
446
- def remove_overlapping_table_header_rows(self, header_data, row_data, tolerance=1.0):
447
- # Function to calculate the Intersection over Union (IoU) of two bounding boxes
448
- def calculate_iou(bbox1, bbox2):
449
- x1_min, y1_min, x1_max, y1_max = bbox1
450
- x2_min, y2_min, x2_max, y2_max = bbox2
451
-
452
- # Determine the coordinates of the intersection rectangle
453
- inter_min_x = max(x1_min, x2_min)
454
- inter_min_y = max(y1_min, y2_min)
455
- inter_max_x = min(x1_max, x2_max)
456
- inter_max_y = min(y1_max, y2_max)
457
-
458
- # Compute the area of intersection
459
- inter_area = max(0, inter_max_x - inter_min_x) * max(0, inter_max_y - inter_min_y)
460
-
461
- # Compute the area of both bounding boxes
462
- bbox1_area = (x1_max - x1_min) * (y1_max - y1_min)
463
- bbox2_area = (x2_max - x2_min) * (y2_max - y2_min)
464
-
465
- # Compute the Intersection over Union (IoU)
466
- iou = inter_area / float(bbox1_area + bbox2_area - inter_area)
467
- return iou
468
-
469
- # Extract the bounding box of the table column header
470
- header_bbox = None
471
- for item in header_data:
472
- if item['label'] == 'table column header':
473
- header_bbox = item['bbox']
474
- break
475
-
476
- if header_bbox is None:
477
- print("No 'table column header' found in header data.")
478
- return row_data
479
-
480
- # Initialize a counter for removed rows
481
- removed_count = 0
482
-
483
- # Iterate over the table row data and remove rows with overlapping bbox
484
- updated_row_data = []
485
- for row in row_data:
486
- if row['label'] == 'table row':
487
- row_bbox = row['bbox']
488
- # Check for overlap (IoU > 0) or very similar bounding box
489
- iou = calculate_iou(header_bbox, row_bbox)
490
- if iou > 0 or np.allclose(row_bbox, header_bbox, atol=tolerance):
491
- removed_count += 1 # Increment the removed counter
492
- continue # Skip this row as it overlaps or matches the header bbox
493
-
494
- # Add row to the updated list if it doesn't overlap
495
- updated_row_data.append(row)
496
-
497
- # Print the number of removed rows
498
- print(f"Number of removed rows: {removed_count}")
499
-
500
- return updated_row_data
501
-
502
- def filter_table_columns(self, data):
503
- return [item for item in data if item['label'] == 'table column']
504
-
505
- def filter_table_rows(self, data):
506
- return [item for item in data if item['label'] == 'table row']
507
-
508
- def extract_text_boundaries(self, image, box):
509
- """
510
- Extract the start and end coordinates of the text within a bounding box,
511
- and translate them back to the original image coordinates.
512
-
513
- Args:
514
- - image: The image in which the box is located.
515
- - box: The bounding box (x_min, y_min, x_max, y_max).
516
- - reader: The EasyOCR reader object.
517
249
 
518
- Returns:
519
- - text_start: The x-coordinate of the start of the text in the original image.
520
- - text_end: The x-coordinate of the end of the text in the original image.
521
- """
522
- x_min, y_min, x_max, y_max = box
523
- cropped_image = image.crop((x_min, y_min, x_max, y_max))
524
- result = self.reader.readtext(np.array(cropped_image))
525
-
526
- if result:
527
- text_coordinates = result[0][0] # Extract the coordinates of the text within the cropped image
528
-
529
- # Translate the coordinates back to the original image coordinates
530
- text_start = min(point[0] + x_min for point in text_coordinates)
531
- text_end = max(point[0] + x_min for point in text_coordinates)
532
-
533
- return text_start, text_end
534
-
535
- return None, None
536
-
537
- def merge_overlapping_columns(self, image, data, proximity_threshold=20):
538
- """
539
- Merge only those bounding boxes where the text is split directly by the box line,
540
- while keeping other labels intact.
541
-
542
- Args:
543
- - image: The image in which the boxes are located.
544
- - data: List of dictionary items with bounding boxes and labels.
545
- - reader: The EasyOCR reader object.
546
- - proximity_threshold: The maximum distance between text boundaries to consider merging.
547
-
548
- Returns:
549
- - Updated list of dictionary items with merged bounding boxes and other entries preserved.
550
- """
551
- table_columns = self.filter_table_columns(data)
552
- other_entries = [item for item in data if item['label'] != 'table column']
553
- merged_boxes = []
554
- table_columns = sorted(table_columns, key=lambda x: x['bbox'][0]) # Sort by x_min
555
-
556
- while table_columns:
557
- box_data = table_columns.pop(0)
558
- x_min, y_min, x_max, y_max = box_data['bbox']
559
-
560
- to_merge = []
561
- for i, other_box_data in enumerate(table_columns):
562
- ox_min, oy_min, ox_max, oy_max = other_box_data['bbox']
563
-
564
- # Only consider merging if the boxes are adjacent horizontally
565
- if x_min < ox_max and x_max > ox_min:
566
- # Extract text boundaries from both boxes
567
- text_start_1, text_end_1 = self.extract_text_boundaries(image, box_data['bbox'])
568
- text_start_2, text_end_2 = self.extract_text_boundaries(image, other_box_data['bbox'])
569
-
570
- # Check if the text from one box ends very close to where the text in the next box starts
571
- if text_end_1 is not None and text_start_2 is not None and text_start_2 - text_end_1 <= proximity_threshold:
572
- x_max = max(x_max, ox_max)
573
- y_max = max(y_max, oy_max)
574
- y_min = min(y_min, oy_min)
575
- to_merge.append(i)
576
-
577
- # Merge the boxes
578
- for index in sorted(to_merge, reverse=True):
579
- table_columns.pop(index)
580
-
581
- merged_boxes.append({
582
- 'label': box_data['label'],
583
- 'score': box_data['score'],
584
- 'bbox': [x_min, y_min, x_max, y_max]
585
- })
586
-
587
- # Combine the merged boxes with other entries
588
- final_output = merged_boxes + other_entries
589
-
590
- # Sort final output by the y-coordinate to maintain the original order
591
- final_output = sorted(final_output, key=lambda x: x['bbox'][1])
592
-
593
- return final_output
594
-
595
- def adjust_overlapping_rows(self, image, data, proximity_threshold=10):
596
- return data
597
-
598
- def invoke_pipeline_step(self, task_call, task_description, local):
250
+ @staticmethod
251
+ def invoke_pipeline_step(task_call, task_description, local):
599
252
  if local:
600
253
  with Progress(
601
254
  SpinnerColumn(),
@@ -614,5 +267,9 @@ class TableDetector(object):
614
267
  if __name__ == "__main__":
615
268
  table_detector = TableDetector()
616
269
 
617
- table_detector.detect_table("/Users/andrejb/Documents/work/epik/bankstatement/OCBC_1_1.jpg", None, local=True, debug=False)
618
- # table_detector.detect_table("/Users/andrejb/infra/shared/katana-git/sparrow/sparrow-ml/llm/data/invoice_1.jpg", None, local=True, debug=False)
270
+ # file_path = "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png"
271
+ # cropped_tables = table_detector.detect_tables(file_path, local=True, debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/", debug=True)
272
+
273
+ # for i, cropped_table in enumerate(cropped_tables):
274
+ # file_name_table = table_detector.append_filename(file_path, "cropped_" + str(i))
275
+ # cropped_table.save(file_name_table)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sparrow-parse
3
- Version: 0.3.12
3
+ Version: 0.4.1
4
4
  Summary: Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.
5
5
  Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
6
6
  Author: Andrej Baranovskij
@@ -21,7 +21,6 @@ Requires-Dist: transformers==4.46.3
21
21
  Requires-Dist: sentence-transformers==3.3.1
22
22
  Requires-Dist: numpy==2.1.3
23
23
  Requires-Dist: pypdf==4.3.0
24
- Requires-Dist: easyocr==1.7.1
25
24
  Requires-Dist: gradio-client
26
25
  Requires-Dist: pdf2image
27
26
  Requires-Dist: mlx-vlm==0.1.4; sys_platform == "darwin" and platform_machine == "arm64"
@@ -67,7 +66,8 @@ input_data = [
67
66
  ]
68
67
 
69
68
  # Now you can run inference without knowing which implementation is used
70
- results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, generic_query=False,
69
+ results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
70
+ generic_query=False,
71
71
  debug_dir=None,
72
72
  debug=True,
73
73
  mode=None)
@@ -77,6 +77,8 @@ for i, result in enumerate(results_array):
77
77
  print(f"Number of pages: {num_pages}")
78
78
  ```
79
79
 
80
+ Use `tables_only=True` if you want to extract only tables.
81
+
80
82
  Use `mode="static"` if you want to simulate LLM call, without executing LLM backend.
81
83
 
82
84
  Method `run_inference` will return results and number of pages processed.
@@ -1,19 +1,19 @@
1
- sparrow_parse/__init__.py,sha256=_ZihGlMI7X-0KVVWL_iXdUXrQS9gun7MiVTXSd1Xs4o,22
1
+ sparrow_parse/__init__.py,sha256=8yPI9dbwQUYqhMtA3RfAi5yJOhZBnz-g8966ssrYXiU,21
2
2
  sparrow_parse/__main__.py,sha256=Xs1bpJV0n08KWOoQE34FBYn6EBXZA9HIYJKrE4ZdG78,153
3
3
  sparrow_parse/extractors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- sparrow_parse/extractors/vllm_extractor.py,sha256=E2-Zpfssu2MVkmVjHpjBh9WK3rJ7ywHhMjJX3xZB_H8,3648
4
+ sparrow_parse/extractors/vllm_extractor.py,sha256=QIg7AMCfw81YHQN6CutF2ipV_DZ3txSGduPIcvQRmiA,7439
5
5
  sparrow_parse/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  sparrow_parse/helpers/pdf_optimizer.py,sha256=GIqQYWtixFeZGCRFXL0lQfQByapCDuQzzRHAkzcPwLE,3302
7
7
  sparrow_parse/processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- sparrow_parse/processors/table_structure_processor.py,sha256=bG_6jx66n_KNdY_O6hrZD1D4DHX5Qy__RYcKHmrSGnc,23894
8
+ sparrow_parse/processors/table_structure_processor.py,sha256=PQHHFdQUuTin3Mm2USuUga2n4fGWMLwiBJYq4CVD67o,9775
9
9
  sparrow_parse/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  sparrow_parse/vllm/huggingface_inference.py,sha256=EJnG6PesGKMc_0qGPN8ufE6pSnhAgFu0XjCbaLCNVyM,1980
11
11
  sparrow_parse/vllm/inference_base.py,sha256=4mwGoAY63MB4cHZpV0czTkJWEzimmiTzqqzKmLNzgjw,820
12
12
  sparrow_parse/vllm/inference_factory.py,sha256=FTM65O-dW2WZchHOrNN7_Q3-FlVoAc65iSptuuUuClM,1166
13
13
  sparrow_parse/vllm/local_gpu_inference.py,sha256=aHoJTejb5xrXjWDIGu5RBQWEyRCOBCB04sMvO2Wyvg8,628
14
14
  sparrow_parse/vllm/mlx_inference.py,sha256=xR40qwjIR0HvrN8x58oOq6F4r1hEANRB-9kcokUQHHU,4748
15
- sparrow_parse-0.3.12.dist-info/METADATA,sha256=tgTHQSWokNEc0TOry7EpunttLgYXYd6GEx6bzedb25s,6351
16
- sparrow_parse-0.3.12.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
17
- sparrow_parse-0.3.12.dist-info/entry_points.txt,sha256=8CrvTVTTcz1YuZ8aRCYNOH15ZOAaYLlcbYX3t28HwJY,54
18
- sparrow_parse-0.3.12.dist-info/top_level.txt,sha256=n6b-WtT91zKLyCPZTP7wvne8v_yvIahcsz-4sX8I0rY,14
19
- sparrow_parse-0.3.12.dist-info/RECORD,,
15
+ sparrow_parse-0.4.1.dist-info/METADATA,sha256=4rmJ1CURKtyTs-ZH1eyHn_VptHosJZwhQFB5Fssr5e0,6432
16
+ sparrow_parse-0.4.1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
17
+ sparrow_parse-0.4.1.dist-info/entry_points.txt,sha256=8CrvTVTTcz1YuZ8aRCYNOH15ZOAaYLlcbYX3t28HwJY,54
18
+ sparrow_parse-0.4.1.dist-info/top_level.txt,sha256=n6b-WtT91zKLyCPZTP7wvne8v_yvIahcsz-4sX8I0rY,14
19
+ sparrow_parse-0.4.1.dist-info/RECORD,,