sparrow-parse 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sparrow_parse/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = '0.3.2'
1
+ __version__ = '0.3.3'
@@ -3,9 +3,8 @@ from sentence_transformers import SentenceTransformer, util
3
3
  from bs4 import BeautifulSoup
4
4
  import json
5
5
  from rich.progress import Progress, SpinnerColumn, TextColumn
6
- from .extractor_helper import merge_html_table_headers
7
- from .extractor_helper import clean_html_table_header_names
8
- import re
6
+ from sparrow_parse.helpers.html_extractor_helper import merge_html_table_headers
7
+ from sparrow_parse.helpers.html_extractor_helper import clean_html_table_header_names
9
8
 
10
9
 
11
10
  class HTMLExtractor(object):
@@ -221,8 +220,8 @@ class HTMLExtractor(object):
221
220
 
222
221
 
223
222
  if __name__ == "__main__":
224
- # to run for debugging, navigate to sparrow_parse and run the following command:
225
- # python -m extractor.html_extractor
223
+ # to run for debugging, navigate above sparrow_parse and run the following command:
224
+ # python -m sparrow_parse.extractors.html_extractor
226
225
 
227
226
  # with open('data/invoice_1_table.txt', 'r') as file:
228
227
  # file_content = file.read()
@@ -233,7 +232,7 @@ if __name__ == "__main__":
233
232
 
234
233
  extractor = HTMLExtractor()
235
234
 
236
- # answer, targets_unprocessed = extractor.read_data(
235
+ # answer, targets_unprocessed = extractors.read_data(
237
236
  # # ['description', 'qty', 'net_price', 'net_worth', 'vat', 'gross_worth'],
238
237
  # ['transaction_date', 'value_date', 'description', 'cheque', 'withdrawal', 'deposit', 'balance',
239
238
  # 'deposits', 'account_number', 'od_limit', 'currency_balance', 'sgd_balance', 'maturity_date'],
@@ -0,0 +1,46 @@
1
+ from sparrow_parse.vllm.inference_factory import InferenceFactory
2
+ from rich import print
3
+ import os
4
+
5
+
6
+ class VLLMExtractor(object):
7
+ def __init__(self):
8
+ pass
9
+
10
+ def run_inference(self, model_inference_instance, input_data, generic_query=False, debug=False):
11
+ if generic_query:
12
+ input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
13
+
14
+ if debug:
15
+ print("Input Data:", input_data)
16
+
17
+ result = model_inference_instance.inference(input_data)
18
+
19
+ return result
20
+
21
+ if __name__ == "__main__":
22
+ extractor = VLLMExtractor()
23
+
24
+ # export HF_TOKEN="hf_"
25
+ config = {
26
+ "method": "huggingface", # Could be 'huggingface' or 'local_gpu'
27
+ "hf_space": "katanaml/sparrow-qwen2-vl-7b",
28
+ "hf_token": os.getenv('HF_TOKEN'),
29
+ # Additional fields for local GPU inference
30
+ # "device": "cuda", "model_path": "model.pth"
31
+ }
32
+
33
+ # Use the factory to get the correct instance
34
+ factory = InferenceFactory(config)
35
+ model_inference_instance = factory.get_inference_instance()
36
+
37
+ input_data = [
38
+ {
39
+ "image": "/Users/andrejb/Documents/work/epik/bankstatement/bonds_table.png",
40
+ "text_input": "retrieve financial instruments data. return response in JSON format"
41
+ }
42
+ ]
43
+
44
+ # Now you can run inference without knowing which implementation is used
45
+ result = extractor.run_inference(model_inference_instance, input_data, generic_query=False, debug=True)
46
+ print("Inference Result:", result)
File without changes
File without changes
@@ -0,0 +1,618 @@
1
+ from rich.progress import Progress, SpinnerColumn, TextColumn
2
+ from rich import print
3
+ from transformers import AutoModelForObjectDetection
4
+ from transformers import TableTransformerForObjectDetection
5
+ import torch
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from PIL import ImageDraw
9
+ import os
10
+ import numpy as np
11
+ import easyocr
12
+
13
+
14
+ class TableDetector(object):
15
+ def __init__(self):
16
+ self.reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory
17
+
18
+ class MaxResize(object):
19
+ def __init__(self, max_size=800):
20
+ self.max_size = max_size
21
+
22
+ def __call__(self, image):
23
+ width, height = image.size
24
+ current_max_size = max(width, height)
25
+ scale = self.max_size / current_max_size
26
+ resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
27
+
28
+ return resized_image
29
+
30
+ def detect_table(self, file_path, options, local=True, debug=False):
31
+ model, device = self.invoke_pipeline_step(
32
+ lambda: self.load_table_detection_model(),
33
+ "Loading table detection model...",
34
+ local
35
+ )
36
+
37
+ outputs, image = self.invoke_pipeline_step(
38
+ lambda: self.prepare_image(file_path, model, device),
39
+ "Preparing image for table detection...",
40
+ local
41
+ )
42
+
43
+ objects = self.invoke_pipeline_step(
44
+ lambda: self.identify_tables(model, outputs, image),
45
+ "Identifying tables in the image...",
46
+ local
47
+ )
48
+
49
+ cropped_table = self.invoke_pipeline_step(
50
+ lambda: self.crop_table(file_path, image, objects),
51
+ "Cropping tables from the image...",
52
+ local
53
+ )
54
+
55
+ structure_model = self.invoke_pipeline_step(
56
+ lambda: self.load_table_structure_model(device),
57
+ "Loading table structure recognition model...",
58
+ local
59
+ )
60
+
61
+ structure_outputs = self.invoke_pipeline_step(
62
+ lambda: self.get_table_structure(cropped_table, structure_model, device),
63
+ "Getting table structure from cropped table...",
64
+ local
65
+ )
66
+
67
+ table_data = self.invoke_pipeline_step(
68
+ lambda: self.get_structure_cells(structure_model, cropped_table, structure_outputs),
69
+ "Getting structure cells from cropped table...",
70
+ local
71
+ )
72
+
73
+ self.invoke_pipeline_step(
74
+ lambda: self.process_table_structure(table_data, cropped_table, file_path),
75
+ "Processing structure cells...",
76
+ local
77
+ )
78
+
79
+
80
+ def load_table_detection_model(self):
81
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
82
+
83
+ device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ model.to(device)
85
+
86
+ return model, device
87
+
88
+ def load_table_structure_model(self, device):
89
+ structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
90
+ structure_model.to(device)
91
+
92
+ return structure_model
93
+
94
+ def prepare_image(self, file_path, model, device):
95
+ image = Image.open(file_path).convert("RGB")
96
+
97
+ detection_transform = transforms.Compose([
98
+ self.MaxResize(800),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
101
+ ])
102
+
103
+ pixel_values = detection_transform(image).unsqueeze(0)
104
+ pixel_values = pixel_values.to(device)
105
+
106
+ with torch.no_grad():
107
+ outputs = model(pixel_values)
108
+
109
+ return outputs, image
110
+
111
+ def identify_tables(self, model, outputs, image):
112
+ id2label = model.config.id2label
113
+ id2label[len(model.config.id2label)] = "no object"
114
+
115
+ objects = self.outputs_to_objects(outputs, image.size, id2label)
116
+ return objects
117
+
118
+ def crop_table(self, file_path, image, objects):
119
+ tokens = []
120
+ detection_class_thresholds = {
121
+ "table": 0.5,
122
+ "table rotated": 0.5,
123
+ "no object": 10
124
+ }
125
+ crop_padding = 10
126
+
127
+ tables_crops = self.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
128
+
129
+ cropped_table = None
130
+
131
+ if len(tables_crops) == 0:
132
+ print("No tables detected.")
133
+ return
134
+ elif len(tables_crops) > 1:
135
+ for i, table_crop in enumerate(tables_crops):
136
+ cropped_table = table_crop['image'].convert("RGB")
137
+ file_name_table = self.append_filename(file_path, f"table_{i}")
138
+ cropped_table.save(file_name_table)
139
+ break
140
+ else:
141
+ cropped_table = tables_crops[0]['image'].convert("RGB")
142
+
143
+ file_name_table = self.append_filename(file_path, "table")
144
+ cropped_table.save(file_name_table)
145
+
146
+ return cropped_table
147
+
148
+ # for output bounding box post-processing
149
+ def box_cxcywh_to_xyxy(self, x):
150
+ x_c, y_c, w, h = x.unbind(-1)
151
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
152
+ return torch.stack(b, dim=1)
153
+
154
+ def rescale_bboxes(self, out_bbox, size):
155
+ img_w, img_h = size
156
+ b = self.box_cxcywh_to_xyxy(out_bbox)
157
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
158
+ return b
159
+
160
+ def outputs_to_objects(self, outputs, img_size, id2label):
161
+ m = outputs.logits.softmax(-1).max(-1)
162
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
163
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
164
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
165
+ pred_bboxes = [elem.tolist() for elem in self.rescale_bboxes(pred_bboxes, img_size)]
166
+
167
+ objects = []
168
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
169
+ class_label = id2label[int(label)]
170
+ if not class_label == 'no object':
171
+ objects.append({'label': class_label, 'score': float(score),
172
+ 'bbox': [float(elem) for elem in bbox]})
173
+
174
+ return objects
175
+
176
+ def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
177
+ """
178
+ Process the bounding boxes produced by the table detection model into
179
+ cropped table images and cropped tokens.
180
+ """
181
+
182
+ table_crops = []
183
+ for obj in objects:
184
+ if obj['score'] < class_thresholds[obj['label']]:
185
+ continue
186
+
187
+ cropped_table = {}
188
+
189
+ bbox = obj['bbox']
190
+ bbox = [bbox[0] - padding, bbox[1] - padding, bbox[2] + padding, bbox[3] + padding]
191
+
192
+ cropped_img = img.crop(bbox)
193
+
194
+ table_tokens = [token for token in tokens if self.iob(token['bbox'], bbox) >= 0.5]
195
+ for token in table_tokens:
196
+ token['bbox'] = [token['bbox'][0] - bbox[0],
197
+ token['bbox'][1] - bbox[1],
198
+ token['bbox'][2] - bbox[0],
199
+ token['bbox'][3] - bbox[1]]
200
+
201
+ # If table is predicted to be rotated, rotate cropped image and tokens/words:
202
+ if obj['label'] == 'table rotated':
203
+ cropped_img = cropped_img.rotate(270, expand=True)
204
+ for token in table_tokens:
205
+ bbox = token['bbox']
206
+ bbox = [cropped_img.size[0] - bbox[3] - 1,
207
+ bbox[0],
208
+ cropped_img.size[0] - bbox[1] - 1,
209
+ bbox[2]]
210
+ token['bbox'] = bbox
211
+
212
+ cropped_table['image'] = cropped_img
213
+ cropped_table['tokens'] = table_tokens
214
+
215
+ table_crops.append(cropped_table)
216
+
217
+ return table_crops
218
+
219
+ def get_table_structure(self, cropped_table, structure_model, device):
220
+ structure_transform = transforms.Compose([
221
+ self.MaxResize(1000),
222
+ transforms.ToTensor(),
223
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
224
+ ])
225
+
226
+ pixel_values = structure_transform(cropped_table).unsqueeze(0)
227
+ pixel_values = pixel_values.to(device)
228
+
229
+ with torch.no_grad():
230
+ outputs = structure_model(pixel_values)
231
+
232
+ return outputs
233
+
234
+ def get_structure_cells(self, structure_model, cropped_table, outputs):
235
+ structure_id2label = structure_model.config.id2label
236
+ structure_id2label[len(structure_id2label)] = "no object"
237
+
238
+ cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
239
+
240
+ return cells
241
+
242
+ def process_table_structure(self, table_data, cropped_table, file_path):
243
+ cropped_table_raw_visualized = cropped_table.copy()
244
+ draw_raw = ImageDraw.Draw(cropped_table_raw_visualized)
245
+ cropped_table_header_visualized = cropped_table.copy()
246
+ draw_header = ImageDraw.Draw(cropped_table_header_visualized)
247
+ cropped_table_visualized = cropped_table.copy()
248
+ draw = ImageDraw.Draw(cropped_table_visualized)
249
+
250
+ table_data = [cell for cell in table_data if cell['label'] != 'table spanning cell']
251
+ table_data = [cell for cell in table_data if cell['label'] != 'table']
252
+ table_data = [cell for cell in table_data if cell['score'] >= 0.8]
253
+
254
+ table_data = self.merge_overlapping_columns(cropped_table, table_data)
255
+ table_data = self.adjust_overlapping_rows(cropped_table, table_data)
256
+
257
+ table_data_filtered = [item for item in table_data if item['label'] == 'table row']
258
+ # table_data_filtered = table_data
259
+ for cell in table_data_filtered:
260
+ draw_raw.rectangle(cell["bbox"], outline="red")
261
+ file_name_table_grid_raw = self.append_filename(file_path, "table_raw")
262
+ cropped_table_raw_visualized.save(file_name_table_grid_raw)
263
+ print("Table raw data:")
264
+ print(table_data_filtered)
265
+
266
+ # table, table column header, table row, table column
267
+ table_data_header = [cell for cell in table_data if cell['label'] == 'table column header'
268
+ or cell['label'] == 'table' or cell['label'] == 'table column']
269
+ print("Table header data:")
270
+ print(table_data_header)
271
+
272
+ table_data_rows = [cell for cell in table_data if cell['label'] == 'table column'
273
+ or cell['label'] == 'table row']
274
+ table_data_rows = self.remove_overlapping_table_header_rows(table_data_header, table_data_rows)
275
+ print("Table row data:")
276
+ print(table_data_rows)
277
+
278
+ header_cells = self.get_header_cell_coordinates(table_data_header)
279
+ if header_cells is not None:
280
+ print("Header cell coordinates:")
281
+ print(header_cells)
282
+
283
+ header_data = self.do_ocr_with_coordinates(header_cells, cropped_table)
284
+ print("Header data:")
285
+ print(header_data)
286
+
287
+ for cell_data in header_cells['row0']:
288
+ draw_header.rectangle(cell_data["cell"], outline="red")
289
+
290
+ file_name_table_grid_header = self.append_filename(file_path, "table_grid_header")
291
+ cropped_table_header_visualized.save(file_name_table_grid_header)
292
+
293
+ table_cells = self.get_table_cell_coordinates(table_data_rows)
294
+ if table_cells is not None:
295
+ print("Table cell coordinates:")
296
+ print(table_cells)
297
+
298
+ table_data = self.do_ocr_with_coordinates(table_cells, cropped_table)
299
+ print("Table data:")
300
+ print(table_data)
301
+
302
+ for row_key, row_cells in table_cells.items():
303
+ for cell_data in row_cells:
304
+ draw.rectangle(cell_data["cell"], outline="red")
305
+
306
+ file_name_table_grid = self.append_filename(file_path, "table_grid_cells")
307
+ cropped_table_visualized.save(file_name_table_grid)
308
+
309
+ def get_header_cell_coordinates(self, table_data):
310
+ header_column = None
311
+ columns = []
312
+
313
+ # Separate header and columns
314
+ for item in table_data:
315
+ if item['label'] == 'table column header':
316
+ header_column = item['bbox']
317
+ elif item['label'] == 'table column':
318
+ columns.append(item['bbox'])
319
+
320
+ if not header_column:
321
+ return None
322
+
323
+ header_top = header_column[1]
324
+ header_bottom = header_column[3]
325
+
326
+ cells = []
327
+
328
+ # Calculate cell coordinates based on header and column intersections
329
+ for column in columns:
330
+ cell_left = column[0]
331
+ cell_right = column[2]
332
+ cell_top = header_top
333
+ cell_bottom = header_bottom
334
+
335
+ cells.append({
336
+ 'cell': (cell_left, cell_top, cell_right, cell_bottom)
337
+ })
338
+
339
+ # Sort cells by the left coordinate (cell_left) to order them from left to right
340
+ cells.sort(key=lambda x: x['cell'][0])
341
+
342
+ header_row = {"row0": cells}
343
+
344
+ return header_row
345
+
346
+ def get_table_cell_coordinates(self, table_data):
347
+ rows = []
348
+ columns = []
349
+
350
+ # Separate rows and columns
351
+ for item in table_data:
352
+ if item['label'] == 'table row':
353
+ rows.append(item['bbox'])
354
+ elif item['label'] == 'table column':
355
+ columns.append(item['bbox'])
356
+
357
+ if not rows or not columns:
358
+ return None
359
+
360
+ # Sort rows by the top coordinate to ensure they are processed from top to bottom
361
+ rows.sort(key=lambda x: x[1])
362
+
363
+ row_cells = {}
364
+
365
+ # Calculate cell coordinates based on row and column intersections
366
+ for row_idx, row in enumerate(rows):
367
+ row_top = row[1]
368
+ row_bottom = row[3]
369
+ cells = []
370
+ for column in columns:
371
+ cell_left = column[0]
372
+ cell_right = column[2]
373
+ cell_top = row_top
374
+ cell_bottom = row_bottom
375
+
376
+ cells.append({
377
+ 'cell': (cell_left, cell_top, cell_right, cell_bottom)
378
+ })
379
+
380
+ # Sort cells within the row by the left coordinate to ensure they are ordered from left to right
381
+ cells.sort(key=lambda x: x['cell'][0])
382
+ row_cells[f'row{row_idx}'] = cells
383
+
384
+ return row_cells
385
+
386
+ def do_ocr_with_coordinates(self, cell_coordinates, cropped_table):
387
+ data = {}
388
+ max_num_columns = 0
389
+
390
+ # Iterate over each row in cell_coordinates
391
+ for row_key in cell_coordinates:
392
+ row_text = []
393
+ for cell in cell_coordinates[row_key]:
394
+ # Crop cell out of image
395
+ cell_image = cropped_table.crop(cell['cell'])
396
+ cell_image_np = np.array(cell_image)
397
+
398
+ # Apply OCR
399
+ result = self.reader.readtext(cell_image_np)
400
+ if result:
401
+ text = " ".join([x[1] for x in result])
402
+ row_text.append(text)
403
+ else:
404
+ row_text.append("") # If no text is detected, append an empty string
405
+
406
+ if len(row_text) > max_num_columns:
407
+ max_num_columns = len(row_text)
408
+
409
+ data[row_key] = row_text
410
+
411
+ print("Max number of columns:", max_num_columns)
412
+
413
+ # Pad rows which don't have max_num_columns elements
414
+ for row_key, row_data in data.items():
415
+ if len(row_data) < max_num_columns:
416
+ row_data += [""] * (max_num_columns - len(row_data))
417
+ data[row_key] = row_data
418
+
419
+ return data
420
+
421
+ def append_filename(self, file_path, word):
422
+ directory, filename = os.path.split(file_path)
423
+ name, ext = os.path.splitext(filename)
424
+ new_filename = f"{name}_{word}{ext}"
425
+ return os.path.join(directory, new_filename)
426
+
427
+ def iob(boxA, boxB):
428
+ # Determine the coordinates of the intersection rectangle
429
+ xA = max(boxA[0], boxB[0])
430
+ yA = max(boxA[1], boxB[1])
431
+ xB = min(boxA[2], boxB[2])
432
+ yB = min(boxA[3], boxB[3])
433
+
434
+ # Compute the area of intersection rectangle
435
+ interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
436
+
437
+ # Compute the area of both the prediction and ground-truth rectangles
438
+ boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
439
+ boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
440
+
441
+ # Compute the intersection over box (IoB)
442
+ iob = interArea / float(boxAArea)
443
+
444
+ return iob
445
+
446
+ def remove_overlapping_table_header_rows(self, header_data, row_data, tolerance=1.0):
447
+ # Function to calculate the Intersection over Union (IoU) of two bounding boxes
448
+ def calculate_iou(bbox1, bbox2):
449
+ x1_min, y1_min, x1_max, y1_max = bbox1
450
+ x2_min, y2_min, x2_max, y2_max = bbox2
451
+
452
+ # Determine the coordinates of the intersection rectangle
453
+ inter_min_x = max(x1_min, x2_min)
454
+ inter_min_y = max(y1_min, y2_min)
455
+ inter_max_x = min(x1_max, x2_max)
456
+ inter_max_y = min(y1_max, y2_max)
457
+
458
+ # Compute the area of intersection
459
+ inter_area = max(0, inter_max_x - inter_min_x) * max(0, inter_max_y - inter_min_y)
460
+
461
+ # Compute the area of both bounding boxes
462
+ bbox1_area = (x1_max - x1_min) * (y1_max - y1_min)
463
+ bbox2_area = (x2_max - x2_min) * (y2_max - y2_min)
464
+
465
+ # Compute the Intersection over Union (IoU)
466
+ iou = inter_area / float(bbox1_area + bbox2_area - inter_area)
467
+ return iou
468
+
469
+ # Extract the bounding box of the table column header
470
+ header_bbox = None
471
+ for item in header_data:
472
+ if item['label'] == 'table column header':
473
+ header_bbox = item['bbox']
474
+ break
475
+
476
+ if header_bbox is None:
477
+ print("No 'table column header' found in header data.")
478
+ return row_data
479
+
480
+ # Initialize a counter for removed rows
481
+ removed_count = 0
482
+
483
+ # Iterate over the table row data and remove rows with overlapping bbox
484
+ updated_row_data = []
485
+ for row in row_data:
486
+ if row['label'] == 'table row':
487
+ row_bbox = row['bbox']
488
+ # Check for overlap (IoU > 0) or very similar bounding box
489
+ iou = calculate_iou(header_bbox, row_bbox)
490
+ if iou > 0 or np.allclose(row_bbox, header_bbox, atol=tolerance):
491
+ removed_count += 1 # Increment the removed counter
492
+ continue # Skip this row as it overlaps or matches the header bbox
493
+
494
+ # Add row to the updated list if it doesn't overlap
495
+ updated_row_data.append(row)
496
+
497
+ # Print the number of removed rows
498
+ print(f"Number of removed rows: {removed_count}")
499
+
500
+ return updated_row_data
501
+
502
+ def filter_table_columns(self, data):
503
+ return [item for item in data if item['label'] == 'table column']
504
+
505
+ def filter_table_rows(self, data):
506
+ return [item for item in data if item['label'] == 'table row']
507
+
508
+ def extract_text_boundaries(self, image, box):
509
+ """
510
+ Extract the start and end coordinates of the text within a bounding box,
511
+ and translate them back to the original image coordinates.
512
+
513
+ Args:
514
+ - image: The image in which the box is located.
515
+ - box: The bounding box (x_min, y_min, x_max, y_max).
516
+ - reader: The EasyOCR reader object.
517
+
518
+ Returns:
519
+ - text_start: The x-coordinate of the start of the text in the original image.
520
+ - text_end: The x-coordinate of the end of the text in the original image.
521
+ """
522
+ x_min, y_min, x_max, y_max = box
523
+ cropped_image = image.crop((x_min, y_min, x_max, y_max))
524
+ result = self.reader.readtext(np.array(cropped_image))
525
+
526
+ if result:
527
+ text_coordinates = result[0][0] # Extract the coordinates of the text within the cropped image
528
+
529
+ # Translate the coordinates back to the original image coordinates
530
+ text_start = min(point[0] + x_min for point in text_coordinates)
531
+ text_end = max(point[0] + x_min for point in text_coordinates)
532
+
533
+ return text_start, text_end
534
+
535
+ return None, None
536
+
537
+ def merge_overlapping_columns(self, image, data, proximity_threshold=20):
538
+ """
539
+ Merge only those bounding boxes where the text is split directly by the box line,
540
+ while keeping other labels intact.
541
+
542
+ Args:
543
+ - image: The image in which the boxes are located.
544
+ - data: List of dictionary items with bounding boxes and labels.
545
+ - reader: The EasyOCR reader object.
546
+ - proximity_threshold: The maximum distance between text boundaries to consider merging.
547
+
548
+ Returns:
549
+ - Updated list of dictionary items with merged bounding boxes and other entries preserved.
550
+ """
551
+ table_columns = self.filter_table_columns(data)
552
+ other_entries = [item for item in data if item['label'] != 'table column']
553
+ merged_boxes = []
554
+ table_columns = sorted(table_columns, key=lambda x: x['bbox'][0]) # Sort by x_min
555
+
556
+ while table_columns:
557
+ box_data = table_columns.pop(0)
558
+ x_min, y_min, x_max, y_max = box_data['bbox']
559
+
560
+ to_merge = []
561
+ for i, other_box_data in enumerate(table_columns):
562
+ ox_min, oy_min, ox_max, oy_max = other_box_data['bbox']
563
+
564
+ # Only consider merging if the boxes are adjacent horizontally
565
+ if x_min < ox_max and x_max > ox_min:
566
+ # Extract text boundaries from both boxes
567
+ text_start_1, text_end_1 = self.extract_text_boundaries(image, box_data['bbox'])
568
+ text_start_2, text_end_2 = self.extract_text_boundaries(image, other_box_data['bbox'])
569
+
570
+ # Check if the text from one box ends very close to where the text in the next box starts
571
+ if text_end_1 is not None and text_start_2 is not None and text_start_2 - text_end_1 <= proximity_threshold:
572
+ x_max = max(x_max, ox_max)
573
+ y_max = max(y_max, oy_max)
574
+ y_min = min(y_min, oy_min)
575
+ to_merge.append(i)
576
+
577
+ # Merge the boxes
578
+ for index in sorted(to_merge, reverse=True):
579
+ table_columns.pop(index)
580
+
581
+ merged_boxes.append({
582
+ 'label': box_data['label'],
583
+ 'score': box_data['score'],
584
+ 'bbox': [x_min, y_min, x_max, y_max]
585
+ })
586
+
587
+ # Combine the merged boxes with other entries
588
+ final_output = merged_boxes + other_entries
589
+
590
+ # Sort final output by the y-coordinate to maintain the original order
591
+ final_output = sorted(final_output, key=lambda x: x['bbox'][1])
592
+
593
+ return final_output
594
+
595
+ def adjust_overlapping_rows(self, image, data, proximity_threshold=10):
596
+ return data
597
+
598
+ def invoke_pipeline_step(self, task_call, task_description, local):
599
+ if local:
600
+ with Progress(
601
+ SpinnerColumn(),
602
+ TextColumn("[progress.description]{task.description}"),
603
+ transient=False,
604
+ ) as progress:
605
+ progress.add_task(description=task_description, total=None)
606
+ ret = task_call()
607
+ else:
608
+ print(task_description)
609
+ ret = task_call()
610
+
611
+ return ret
612
+
613
+
614
+ if __name__ == "__main__":
615
+ table_detector = TableDetector()
616
+
617
+ table_detector.detect_table("/Users/andrejb/Documents/work/epik/bankstatement/OCBC_1_1.jpg", None, local=True, debug=False)
618
+ # table_detector.detect_table("/Users/andrejb/infra/shared/katana-git/sparrow/sparrow-ml/llm/data/invoice_1.jpg", None, local=True, debug=False)
File without changes
@@ -0,0 +1,36 @@
1
+ from gradio_client import Client, handle_file
2
+ from sparrow_parse.vllm.inference_base import ModelInference
3
+ import json
4
+
5
+
6
+ class HuggingFaceInference(ModelInference):
7
+ def __init__(self, hf_space, hf_token):
8
+ self.hf_space = hf_space
9
+ self.hf_token = hf_token
10
+
11
+
12
+ def process_response(self, output_text):
13
+ json_string = output_text
14
+
15
+ json_string = json_string.strip("[]'")
16
+ json_string = json_string.replace("```json\n", "").replace("\n```", "")
17
+ json_string = json_string.replace("'", "")
18
+
19
+ try:
20
+ formatted_json = json.loads(json_string)
21
+ return json.dumps(formatted_json, indent=2)
22
+ except json.JSONDecodeError as e:
23
+ print("Failed to parse JSON:", e)
24
+ return output_text
25
+
26
+
27
+ def inference(self, input_data):
28
+ client = Client(self.hf_space, hf_token=self.hf_token)
29
+
30
+ result = client.predict(
31
+ image=handle_file(input_data[0]["image"]),
32
+ text_input=input_data[0]["text_input"],
33
+ api_name="/run_inference"
34
+ )
35
+
36
+ return self.process_response(result)
@@ -0,0 +1,7 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class ModelInference(ABC):
4
+ @abstractmethod
5
+ def inference(self, input_data):
6
+ """This method should be implemented by subclasses."""
7
+ pass
@@ -0,0 +1,22 @@
1
+ from sparrow_parse.vllm.huggingface_inference import HuggingFaceInference
2
+ from sparrow_parse.vllm.local_gpu_inference import LocalGPUInference
3
+
4
+
5
+ class InferenceFactory:
6
+ def __init__(self, config):
7
+ self.config = config
8
+
9
+ def get_inference_instance(self):
10
+ if self.config["method"] == "huggingface":
11
+ return HuggingFaceInference(hf_space=self.config["hf_space"], hf_token=self.config["hf_token"])
12
+ elif self.config["method"] == "local_gpu":
13
+ model = self._load_local_model() # Replace with actual model loading logic
14
+ return LocalGPUInference(model=model, device=self.config.get("device", "cuda"))
15
+ else:
16
+ raise ValueError(f"Unknown method: {self.config['method']}")
17
+
18
+ def _load_local_model(self):
19
+ # Example: Load a PyTorch model (replace with actual loading code)
20
+ # model = torch.load('model.pth')
21
+ # return model
22
+ raise NotImplementedError("Model loading logic not implemented")
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from sparrow_parse.vllm.inference_base import ModelInference
3
+
4
+
5
+ class LocalGPUInference(ModelInference):
6
+ def __init__(self, model, device='cuda'):
7
+ self.model = model
8
+ self.device = device
9
+ self.model.to(self.device)
10
+
11
+ def inference(self, input_data):
12
+ self.model.eval() # Set the model to evaluation mode
13
+ with torch.no_grad(): # No need to calculate gradients
14
+ input_tensor = torch.tensor(input_data).to(self.device)
15
+ output = self.model(input_tensor)
16
+ return output.cpu().numpy() # Convert the output back to NumPy if necessary
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sparrow-parse
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Sparrow Parse is a Python package for parsing and extracting information from documents.
5
5
  Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
6
6
  Author: Andrej Baranovskij
@@ -25,6 +25,8 @@ Requires-Dist: transformers ==4.41.2
25
25
  Requires-Dist: sentence-transformers ==3.0.1
26
26
  Requires-Dist: numpy ==1.26.4
27
27
  Requires-Dist: pypdf ==4.3.0
28
+ Requires-Dist: easyocr ==1.7.1
29
+ Requires-Dist: gradio-client
28
30
 
29
31
  # Sparrow Parse
30
32
 
@@ -0,0 +1,23 @@
1
+ sparrow_parse/__init__.py,sha256=JDRpXqOC0txw4_CqkfpSl89CczeXGgyjX4XSSLChyQg,21
2
+ sparrow_parse/__main__.py,sha256=Xs1bpJV0n08KWOoQE34FBYn6EBXZA9HIYJKrE4ZdG78,153
3
+ sparrow_parse/temp.py,sha256=gy4_mtNW_KfXn9br_suu6jHx7JKYLKs9pIOBynh_JWY,1134
4
+ sparrow_parse/extractors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ sparrow_parse/extractors/html_extractor.py,sha256=qe9Oz7J-GiIE8G1kIDMOeh96xe6P59Gyh5SjgV3v2c8,9977
6
+ sparrow_parse/extractors/vllm_extractor.py,sha256=Qwmf-SW4z_UstiiynX5TkyovlkokVhLuzcbUVZ16TXM,1540
7
+ sparrow_parse/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ sparrow_parse/helpers/html_extractor_helper.py,sha256=n9M9NyZfesiCCj3ET9WoyqRcWIFJ4k-jyQlUAarKIhE,13658
9
+ sparrow_parse/helpers/pdf_optimizer.py,sha256=KI_EweGt9Y_rDH1uCpYD5wKCW3rdjSFFhoVtiPBxX8k,3013
10
+ sparrow_parse/processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ sparrow_parse/processors/markdown_processor.py,sha256=dC2WUdA-v2psh7oytruftxYkXdQi72FoEYxF30ROuO0,4506
12
+ sparrow_parse/processors/table_structure_processor.py,sha256=bG_6jx66n_KNdY_O6hrZD1D4DHX5Qy__RYcKHmrSGnc,23894
13
+ sparrow_parse/processors/unstructured_processor.py,sha256=oonkB5ALaV1pVs0a-xr8yAf-kirIabmtugHMnnEILqo,6770
14
+ sparrow_parse/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ sparrow_parse/vllm/huggingface_inference.py,sha256=Q2Ju65LDzbO-8RWW7cXzrR-pbZ1zKuPVODlKOTWKg_E,1114
16
+ sparrow_parse/vllm/inference_base.py,sha256=W0N2khehGdF1XHzZACG3I1UZaydHMk6BZgWNvaJD4Ck,197
17
+ sparrow_parse/vllm/inference_factory.py,sha256=r04e95uPWG5l8Q23yeDqKmvFxLyF991aA2m0hfBTNn8,993
18
+ sparrow_parse/vllm/local_gpu_inference.py,sha256=I_uWYiFAQhRrykOKbVz69NzftDxuemDKtAye4kWhtnU,617
19
+ sparrow_parse-0.3.3.dist-info/METADATA,sha256=qFl4MsoV6lF_OqgtcfBqDRpTHX8MUJh0jeGgNr77o8w,6482
20
+ sparrow_parse-0.3.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
21
+ sparrow_parse-0.3.3.dist-info/entry_points.txt,sha256=8CrvTVTTcz1YuZ8aRCYNOH15ZOAaYLlcbYX3t28HwJY,54
22
+ sparrow_parse-0.3.3.dist-info/top_level.txt,sha256=n6b-WtT91zKLyCPZTP7wvne8v_yvIahcsz-4sX8I0rY,14
23
+ sparrow_parse-0.3.3.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- sparrow_parse/__init__.py,sha256=64UBVh2KX7E-WVG4ZyY1dUiW9jGXZloWZk1N9nEUC2k,21
2
- sparrow_parse/__main__.py,sha256=Xs1bpJV0n08KWOoQE34FBYn6EBXZA9HIYJKrE4ZdG78,153
3
- sparrow_parse/temp.py,sha256=gy4_mtNW_KfXn9br_suu6jHx7JKYLKs9pIOBynh_JWY,1134
4
- sparrow_parse/extractor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- sparrow_parse/extractor/extractor_helper.py,sha256=n9M9NyZfesiCCj3ET9WoyqRcWIFJ4k-jyQlUAarKIhE,13658
6
- sparrow_parse/extractor/html_extractor.py,sha256=Y9c17epY6esn1lNGhOVpzgRuolFJUUZAfZ3G9fKcArU,9916
7
- sparrow_parse/extractor/markdown_processor.py,sha256=dC2WUdA-v2psh7oytruftxYkXdQi72FoEYxF30ROuO0,4506
8
- sparrow_parse/extractor/pdf_optimizer.py,sha256=KI_EweGt9Y_rDH1uCpYD5wKCW3rdjSFFhoVtiPBxX8k,3013
9
- sparrow_parse/extractor/unstructured_processor.py,sha256=oonkB5ALaV1pVs0a-xr8yAf-kirIabmtugHMnnEILqo,6770
10
- sparrow_parse-0.3.2.dist-info/METADATA,sha256=BA_M_vHGpbJuXvivXHJLCIejtdGHFatOrUVJve1USXY,6422
11
- sparrow_parse-0.3.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
12
- sparrow_parse-0.3.2.dist-info/entry_points.txt,sha256=8CrvTVTTcz1YuZ8aRCYNOH15ZOAaYLlcbYX3t28HwJY,54
13
- sparrow_parse-0.3.2.dist-info/top_level.txt,sha256=n6b-WtT91zKLyCPZTP7wvne8v_yvIahcsz-4sX8I0rY,14
14
- sparrow_parse-0.3.2.dist-info/RECORD,,
File without changes
File without changes