sparrow-parse 1.0.4a0__tar.gz → 1.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/PKG-INFO +4 -2
  2. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/setup.py +1 -1
  3. sparrow-parse-1.0.6/sparrow_parse/__init__.py +1 -0
  4. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/extractors/vllm_extractor.py +19 -18
  5. sparrow-parse-1.0.6/sparrow_parse/text_extraction.py +216 -0
  6. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/vllm/huggingface_inference.py +1 -1
  7. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/vllm/inference_base.py +1 -1
  8. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/vllm/inference_factory.py +2 -3
  9. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/vllm/local_gpu_inference.py +1 -1
  10. sparrow-parse-1.0.6/sparrow_parse/vllm/mlx_inference.py +302 -0
  11. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse.egg-info/PKG-INFO +4 -2
  12. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse.egg-info/requires.txt +5 -1
  13. sparrow-parse-1.0.4a0/sparrow_parse/__init__.py +0 -1
  14. sparrow-parse-1.0.4a0/sparrow_parse/text_extraction.py +0 -35
  15. sparrow-parse-1.0.4a0/sparrow_parse/vllm/mlx_inference.py +0 -217
  16. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/README.md +0 -0
  17. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/setup.cfg +0 -0
  18. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/__main__.py +0 -0
  19. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/extractors/__init__.py +0 -0
  20. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/helpers/__init__.py +0 -0
  21. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/helpers/image_optimizer.py +0 -0
  22. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/helpers/pdf_optimizer.py +0 -0
  23. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/processors/__init__.py +0 -0
  24. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/processors/table_structure_processor.py +0 -0
  25. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse/vllm/__init__.py +0 -0
  26. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse.egg-info/SOURCES.txt +0 -0
  27. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse.egg-info/dependency_links.txt +0 -0
  28. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse.egg-info/entry_points.txt +0 -0
  29. {sparrow-parse-1.0.4a0 → sparrow-parse-1.0.6}/sparrow_parse.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sparrow-parse
3
- Version: 1.0.4a0
3
+ Version: 1.0.6
4
4
  Summary: Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.
5
5
  Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
6
6
  Author: Andrej Baranovskij
@@ -20,9 +20,11 @@ Requires-Dist: torchvision>=0.22.0
20
20
  Requires-Dist: torch>=2.7.0
21
21
  Requires-Dist: sentence-transformers>=4.1.0
22
22
  Requires-Dist: numpy>=2.2.5
23
- Requires-Dist: pypdf>=5.4.0
23
+ Requires-Dist: pypdf>=5.5.0
24
24
  Requires-Dist: gradio_client>=1.7.2
25
25
  Requires-Dist: pdf2image>=1.17.0
26
+ Requires-Dist: mlx>=0.25.2; sys_platform == "darwin" and platform_machine == "arm64"
27
+ Requires-Dist: mlx-vlm==0.1.26; sys_platform == "darwin" and platform_machine == "arm64"
26
28
 
27
29
  # Sparrow Parse
28
30
 
@@ -8,7 +8,7 @@ with open("requirements.txt", "r", encoding="utf-8") as fh:
8
8
 
9
9
  setup(
10
10
  name="sparrow-parse",
11
- version="1.0.4a",
11
+ version="1.0.6",
12
12
  author="Andrej Baranovskij",
13
13
  author_email="andrejus.baranovskis@gmail.com",
14
14
  description="Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.",
@@ -0,0 +1 @@
1
+ __version__ = '1.0.6'
@@ -14,13 +14,14 @@ class VLLMExtractor(object):
14
14
  pass
15
15
 
16
16
  def run_inference(self, model_inference_instance, input_data, tables_only=False,
17
- generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None):
17
+ generic_query=False, crop_size=None, apply_annotation=False, debug_dir=None, debug=False, mode=None):
18
18
  """
19
19
  Main entry point for processing input data using a model inference instance.
20
20
  Handles generic queries, PDFs, and table extraction.
21
21
  """
22
22
  if generic_query:
23
23
  input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
24
+ apply_annotation=False
24
25
 
25
26
  if debug:
26
27
  print("Input data:", input_data)
@@ -37,12 +38,12 @@ class VLLMExtractor(object):
37
38
  # Document data extraction inference (file_path exists and is not None)
38
39
  file_path = input_data[0]["file_path"]
39
40
  if self.is_pdf(file_path):
40
- return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode)
41
+ return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir, mode)
41
42
  else:
42
- return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir)
43
+ return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir)
43
44
 
44
45
 
45
- def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode):
46
+ def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir, mode):
46
47
  """
47
48
  Handles processing and inference for PDF files, including page splitting and optional table extraction.
48
49
  """
@@ -50,21 +51,21 @@ class VLLMExtractor(object):
50
51
  num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
51
52
  debug_dir, convert_to_images=True)
52
53
 
53
- results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir)
54
+ results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir)
54
55
 
55
56
  # Clean up temporary directory
56
57
  shutil.rmtree(temp_dir, ignore_errors=True)
57
58
  return results, num_pages
58
59
 
59
60
 
60
- def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir):
61
+ def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir):
61
62
  """
62
63
  Handles processing and inference for non-PDF files, with optional table extraction.
63
64
  """
64
65
  file_path = input_data[0]["file_path"]
65
66
 
66
67
  if tables_only:
67
- return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1
68
+ return self._extract_tables(model_inference_instance, file_path, input_data, apply_annotation, debug, debug_dir), 1
68
69
  else:
69
70
  temp_dir = tempfile.mkdtemp()
70
71
 
@@ -77,13 +78,13 @@ class VLLMExtractor(object):
77
78
 
78
79
  file_path = input_data[0]["file_path"]
79
80
  input_data[0]["file_path"] = [file_path]
80
- results = model_inference_instance.inference(input_data)
81
+ results = model_inference_instance.inference(input_data, apply_annotation)
81
82
 
82
83
  shutil.rmtree(temp_dir, ignore_errors=True)
83
84
 
84
85
  return results, 1
85
86
 
86
- def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir):
87
+ def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir):
87
88
  """
88
89
  Processes individual pages (PDF split) and handles table extraction or inference.
89
90
 
@@ -93,6 +94,7 @@ class VLLMExtractor(object):
93
94
  input_data: Input data for inference.
94
95
  tables_only: Whether to only process tables.
95
96
  crop_size: Size for cropping image borders.
97
+ apply_annotation: Flag to apply annotations to the output.
96
98
  debug: Debug flag for logging.
97
99
  debug_dir: Directory for saving debug information.
98
100
 
@@ -106,9 +108,7 @@ class VLLMExtractor(object):
106
108
  print(f"Processing {len(output_files)} pages for table extraction.")
107
109
  # Process each page individually for table extraction
108
110
  for i, file_path in enumerate(output_files):
109
- tables_result = self._extract_tables(
110
- model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
111
- )
111
+ tables_result = self._extract_tables( model_inference_instance, file_path, input_data, apply_annotation, debug, debug_dir, page_index=i)
112
112
  # Since _extract_tables returns a list with one JSON string, unpack it
113
113
  results_array.extend(tables_result) # Unpack the single JSON string
114
114
  else:
@@ -141,7 +141,7 @@ class VLLMExtractor(object):
141
141
  input_data[0]["file_path"] = output_files
142
142
 
143
143
  # Process all files at once
144
- results = model_inference_instance.inference(input_data)
144
+ results = model_inference_instance.inference(input_data, apply_annotation)
145
145
  results_array.extend(results)
146
146
 
147
147
  # Clean up temporary directory
@@ -150,7 +150,7 @@ class VLLMExtractor(object):
150
150
  return results_array
151
151
 
152
152
 
153
- def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
153
+ def _extract_tables(self, model_inference_instance, file_path, input_data, apply_annotation, debug, debug_dir, page_index=None):
154
154
  """
155
155
  Detects and processes tables from an input file.
156
156
  """
@@ -175,7 +175,7 @@ class VLLMExtractor(object):
175
175
  table.save(output_filename, "JPEG")
176
176
 
177
177
  input_data[0]["file_path"] = [output_filename]
178
- result = self._run_model_inference(model_inference_instance, input_data)
178
+ result = self._run_model_inference(model_inference_instance, input_data, apply_annotation)
179
179
  results_array.append(result)
180
180
 
181
181
  shutil.rmtree(temp_dir, ignore_errors=True)
@@ -191,11 +191,11 @@ class VLLMExtractor(object):
191
191
 
192
192
 
193
193
  @staticmethod
194
- def _run_model_inference(model_inference_instance, input_data):
194
+ def _run_model_inference(model_inference_instance, input_data, apply_annotation):
195
195
  """
196
196
  Runs model inference and handles JSON decoding.
197
197
  """
198
- result = model_inference_instance.inference(input_data)[0]
198
+ result = model_inference_instance.inference(input_data, apply_annotation)[0]
199
199
  try:
200
200
  return json.loads(result) if isinstance(result, str) else result
201
201
  except json.JSONDecodeError:
@@ -230,7 +230,7 @@ if __name__ == "__main__":
230
230
  # input_data = [
231
231
  # {
232
232
  # "file_path": "sparrow_parse/images/bonds_table.png",
233
- # "text_input": "retrieve all data. return response in JSON format"
233
+ # "text_input": "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"
234
234
  # }
235
235
  # ]
236
236
  #
@@ -245,6 +245,7 @@ if __name__ == "__main__":
245
245
  # results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
246
246
  # generic_query=False,
247
247
  # crop_size=0,
248
+ # apply_annotation=False,
248
249
  # debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
249
250
  # debug=True,
250
251
  # mode=None)
@@ -0,0 +1,216 @@
1
+ from mlx_vlm import load, apply_chat_template, generate
2
+ from mlx_vlm.utils import load_image
3
+ from PIL import ImageDraw, ImageFont
4
+ import json
5
+
6
+
7
+ # Load model and processor
8
+ vl_model, vl_processor = load("mlx-community/Mistral-Small-3.1-24B-Instruct-2503-8bit")
9
+ # vl_model, vl_processor = load("mlx-community/Qwen2.5-VL-72B-Instruct-4bit")
10
+ vl_config = vl_model.config
11
+
12
+ image = load_image("images/bonds_table.png")
13
+
14
+ # Qwen
15
+ # messages = [
16
+ # {"role": "system", "content": "You are an expert at extracting text from images. Format your response in JSON."},
17
+ # {"role": "user", "content": "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"}
18
+ # ]
19
+ # Qwen with bbox
20
+ # messages = [
21
+ # {"role": "system", "content": "You are an expert at extracting text from images. For each item in the table, provide separate bounding boxes for each field. All coordinates should be in pixels relative to the original image. Format your response in JSON."},
22
+ # {"role": "user", "content": "retrieve [{\"instrument_name\":{\"value\":\"str\", \"bbox\":[\"float\", \"float\", \"float\", \"float\"], \"confidence\":\"float\"}, \"valuation\":{\"value\":\"int\", \"bbox\":[\"float\", \"float\", \"float\", \"float\"], \"confidence\":\"float\"}}]. return response in JSON format"}
23
+ # ]
24
+ # Qwen with bbox, get all data
25
+ # messages = [
26
+ # {"role": "system", "content": "You are an expert at extracting text from images. For each item in the table, provide separate bounding boxes for each field. All coordinates should be in pixels relative to the original image. Format your response in JSON."},
27
+ # {"role": "user", "content": "retrieve all data. return response in JSON format. For each identified field or data element, include: 1) a descriptive field name as the object key, 2) a nested object with 'value' containing the extracted content, 'bbox' array with [x_min, y_min, x_max, y_max] coordinates in pixels, and 'confidence' score between 0-1. Example structure: [{\"field_name\":{\"value\":\"extracted value\", \"bbox\":[100, 200, 300, 250], \"confidence\":0.95}}]"}
28
+ # ]
29
+
30
+ # Mistral
31
+ # message = "retrieve all data. return response in JSON format"
32
+ message = "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"
33
+
34
+ # Qwen
35
+ # prompt = apply_chat_template(vl_processor, vl_config, messages)
36
+ # Mistral
37
+ prompt = apply_chat_template(vl_processor, vl_config, message)
38
+
39
+ # Generate text
40
+ vl_output, _ = generate(
41
+ vl_model,
42
+ vl_processor,
43
+ prompt,
44
+ image,
45
+ max_tokens=4000,
46
+ temperature=0,
47
+ verbose=False
48
+ )
49
+
50
+ print(vl_output)
51
+
52
+
53
+ # Comment out below code if non Qwen model is used
54
+
55
+ # # Convert to a format we can draw on
56
+ # img_draw = image.copy()
57
+ # draw = ImageDraw.Draw(img_draw)
58
+ #
59
+ # # Parse the JSON result
60
+ # results = json.loads(vl_output.strip('```json\n').strip('```'))
61
+ #
62
+ # # Predefined solid colors that are highly visible
63
+ # solid_colors = [
64
+ # (180, 30, 40), # Dark red
65
+ # (0, 100, 140), # Dark blue
66
+ # (30, 120, 40), # Dark green
67
+ # (140, 60, 160), # Purple
68
+ # (200, 100, 0), # Orange
69
+ # (100, 80, 0), # Brown
70
+ # (0, 100, 100), # Teal
71
+ # (120, 40, 100) # Magenta
72
+ # ]
73
+ #
74
+ # # Determine unique field keys across all items to assign consistent colors
75
+ # unique_fields = set()
76
+ # for item in results:
77
+ # unique_fields.update(item.keys())
78
+ #
79
+ # # Map each unique field to a color
80
+ # field_color_map = {}
81
+ # for i, field in enumerate(sorted(unique_fields)):
82
+ # field_color_map[field] = solid_colors[i % len(solid_colors)]
83
+ #
84
+ # # Load font with larger size
85
+ # font_size = 20
86
+ # try:
87
+ # font = ImageFont.truetype("arial.ttf", font_size)
88
+ # except IOError:
89
+ # try:
90
+ # font = ImageFont.truetype("DejaVuSans.ttf", font_size)
91
+ # except IOError:
92
+ # try:
93
+ # font = ImageFont.truetype("Helvetica.ttf", font_size)
94
+ # except IOError:
95
+ # font = ImageFont.load_default()
96
+ #
97
+ #
98
+ # # Helper function to measure text width
99
+ # def get_text_dimensions(text, font):
100
+ # try:
101
+ # # Method for newer Pillow versions
102
+ # left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
103
+ # return right - left, bottom - top
104
+ # except AttributeError:
105
+ # try:
106
+ # # Alternative method
107
+ # left, top, right, bottom = font.getbbox(text)
108
+ # return right - left, bottom - top
109
+ # except AttributeError:
110
+ # # Fallback approximation
111
+ # return len(text) * (font_size // 2), font_size + 2
112
+ #
113
+ #
114
+ # # Draw bounding boxes for each item
115
+ # for item in results:
116
+ # # Process each field
117
+ # for field_name, field_data in item.items():
118
+ # # Check if this field has the expected structure
119
+ # if isinstance(field_data, dict) and "bbox" in field_data and "value" in field_data:
120
+ # bbox = field_data["bbox"]
121
+ # value = field_data["value"]
122
+ # confidence = field_data.get("confidence", "N/A")
123
+ #
124
+ # # Check if coordinates need to be scaled (normalized 0-1 values)
125
+ # if all(isinstance(coord, (int, float)) for coord in bbox):
126
+ # if max(bbox) <= 1.0: # Normalized coordinates
127
+ # width, height = image.size
128
+ # bbox = [
129
+ # bbox[0] * width,
130
+ # bbox[1] * height,
131
+ # bbox[2] * width,
132
+ # bbox[3] * height
133
+ # ]
134
+ #
135
+ # # Get color from the mapping we created
136
+ # color = field_color_map[field_name]
137
+ #
138
+ # # Make sure bbox coordinates are integers
139
+ # bbox = [int(coord) for coord in bbox]
140
+ #
141
+ # # Calculate the bbox width
142
+ # bbox_width = bbox[2] - bbox[0]
143
+ #
144
+ # # Draw rectangle with appropriate thickness
145
+ # border_thickness = 3
146
+ # draw.rectangle(
147
+ # [(bbox[0], bbox[1]), (bbox[2], bbox[3])],
148
+ # outline=color,
149
+ # width=border_thickness
150
+ # )
151
+ #
152
+ # # Format the value and confidence
153
+ # value_str = str(value)
154
+ # confidence_str = f" [{confidence:.2f}]" if isinstance(confidence, (int, float)) else ""
155
+ # prefix = f"{field_name}: "
156
+ #
157
+ # # First, try with full text without truncation
158
+ # full_label = prefix + value_str + confidence_str
159
+ # full_width, text_height = get_text_dimensions(full_label, font)
160
+ #
161
+ # # Compare with a reasonable maximum display width
162
+ # min_display_width = 300 # Reasonable minimum width to display text
163
+ # max_display_width = max(bbox_width * 1.5, min_display_width)
164
+ #
165
+ # # Only truncate if the full text exceeds our maximum display width
166
+ # if full_width > max_display_width:
167
+ # # Calculate the space available for the value
168
+ # prefix_width, _ = get_text_dimensions(prefix, font)
169
+ # confidence_width, _ = get_text_dimensions(confidence_str, font)
170
+ # available_value_width = max_display_width - prefix_width - confidence_width
171
+ #
172
+ # # Truncate the value to fit
173
+ # truncated_value = value_str
174
+ # for i in range(len(value_str) - 1, 3, -1):
175
+ # truncated_value = value_str[:i] + "..."
176
+ # temp_width, _ = get_text_dimensions(truncated_value, font)
177
+ # if temp_width <= available_value_width:
178
+ # break
179
+ #
180
+ # label = prefix + truncated_value + confidence_str
181
+ # text_width, _ = get_text_dimensions(label, font)
182
+ # else:
183
+ # # No truncation needed
184
+ # label = full_label
185
+ # text_width = full_width
186
+ #
187
+ # # Position for text (above the bounding box)
188
+ # padding = 6
189
+ # text_position = (bbox[0], bbox[1] - text_height - (padding * 2))
190
+ #
191
+ # # Ensure text doesn't go off the top of the image
192
+ # if text_position[1] < padding:
193
+ # # If too close to top, position below the box instead
194
+ # text_position = (bbox[0], bbox[3] + padding)
195
+ #
196
+ # # Add a background rectangle with better contrast
197
+ # draw.rectangle(
198
+ # [(text_position[0] - padding, text_position[1] - padding),
199
+ # (text_position[0] + text_width + padding, text_position[1] + text_height + padding)],
200
+ # fill=(255, 255, 255, 240),
201
+ # outline=color,
202
+ # width=2
203
+ # )
204
+ #
205
+ # # Draw the text
206
+ # draw.text(
207
+ # text_position,
208
+ # label,
209
+ # fill=color,
210
+ # font=font
211
+ # )
212
+ #
213
+ # # Save the annotated image
214
+ # output_path = "images/bonds_table_annotated.png"
215
+ # img_draw.save(output_path)
216
+ # print(f"Annotated image saved to {output_path}")
@@ -26,7 +26,7 @@ class HuggingFaceInference(ModelInference):
26
26
  return output_text
27
27
 
28
28
 
29
- def inference(self, input_data, mode=None):
29
+ def inference(self, input_data, apply_annotation=False, mode=None):
30
30
  if mode == "static":
31
31
  simple_json = self.get_simple_json()
32
32
  return [simple_json]
@@ -4,7 +4,7 @@ import json
4
4
 
5
5
  class ModelInference(ABC):
6
6
  @abstractmethod
7
- def inference(self, input_data, mode=None):
7
+ def inference(self, input_data, apply_annotation=False, mode=None):
8
8
  """This method should be implemented by subclasses."""
9
9
  pass
10
10
 
@@ -1,6 +1,6 @@
1
1
  from sparrow_parse.vllm.huggingface_inference import HuggingFaceInference
2
2
  from sparrow_parse.vllm.local_gpu_inference import LocalGPUInference
3
- # from sparrow_parse.vllm.mlx_inference import MLXInference
3
+ from sparrow_parse.vllm.mlx_inference import MLXInference
4
4
 
5
5
 
6
6
  class InferenceFactory:
@@ -14,8 +14,7 @@ class InferenceFactory:
14
14
  model = self._load_local_model() # Replace with actual model loading logic
15
15
  return LocalGPUInference(model=model, device=self.config.get("device", "cuda"))
16
16
  elif self.config["method"] == "mlx":
17
- # return MLXInference(model_name=self.config["model_name"])
18
- return None
17
+ return MLXInference(model_name=self.config["model_name"])
19
18
  else:
20
19
  raise ValueError(f"Unknown method: {self.config['method']}")
21
20
 
@@ -8,7 +8,7 @@ class LocalGPUInference(ModelInference):
8
8
  self.device = device
9
9
  self.model.to(self.device)
10
10
 
11
- def inference(self, input_data, mode=None):
11
+ def inference(self, input_data, apply_annotation=False, mode=None):
12
12
  self.model.eval() # Set the model to evaluation mode
13
13
  with torch.no_grad(): # No need to calculate gradients
14
14
  input_tensor = torch.tensor(input_data).to(self.device)
@@ -0,0 +1,302 @@
1
+ from mlx_vlm import load, generate
2
+ from mlx_vlm.prompt_utils import apply_chat_template
3
+ from mlx_vlm.utils import load_image
4
+ from sparrow_parse.vllm.inference_base import ModelInference
5
+ import os
6
+ import json, re
7
+ from rich import print
8
+
9
+
10
+ class MLXInference(ModelInference):
11
+ """
12
+ A class for performing inference using the MLX model.
13
+ Handles image preprocessing, response formatting, and model interaction.
14
+ """
15
+
16
+ def __init__(self, model_name):
17
+ """
18
+ Initialize the inference class with the given model name.
19
+
20
+ :param model_name: Name of the model to load.
21
+ """
22
+ self.model_name = model_name
23
+ print(f"MLXInference initialized for model: {model_name}")
24
+
25
+
26
+ @staticmethod
27
+ def _load_model_and_processor(model_name):
28
+ """
29
+ Load the model and processor for inference.
30
+
31
+ :param model_name: Name of the model to load.
32
+ :return: Tuple containing the loaded model and processor.
33
+ """
34
+ model, processor = load(model_name)
35
+ print(f"Loaded model: {model_name}")
36
+ return model, processor
37
+
38
+
39
+ def process_response(self, output_text):
40
+ """
41
+ Process and clean the model's raw output to format as JSON.
42
+ """
43
+ try:
44
+ # Check if we have markdown code block markers
45
+ if "```" in output_text:
46
+ # Handle markdown-formatted output
47
+ json_start = output_text.find("```json")
48
+ if json_start != -1:
49
+ # Extract content between ```json and ```
50
+ content = output_text[json_start + 7:]
51
+ json_end = content.rfind("```")
52
+ if json_end != -1:
53
+ content = content[:json_end].strip()
54
+ formatted_json = json.loads(content)
55
+ return json.dumps(formatted_json, indent=2)
56
+
57
+ # Handle raw JSON (no markdown formatting)
58
+ # First try to find JSON array or object patterns
59
+ for pattern in [r'\[\s*\{.*\}\s*\]', r'\{.*\}']:
60
+ import re
61
+ matches = re.search(pattern, output_text, re.DOTALL)
62
+ if matches:
63
+ potential_json = matches.group(0)
64
+ try:
65
+ formatted_json = json.loads(potential_json)
66
+ return json.dumps(formatted_json, indent=2)
67
+ except:
68
+ pass
69
+
70
+ # Last resort: try to parse the whole text as JSON
71
+ formatted_json = json.loads(output_text.strip())
72
+ return json.dumps(formatted_json, indent=2)
73
+
74
+ except Exception as e:
75
+ print(f"Failed to parse JSON: {e}")
76
+ return output_text
77
+
78
+
79
+ def load_image_data(self, image_filepath, max_width=1250, max_height=1750):
80
+ """
81
+ Load and resize image while maintaining its aspect ratio.
82
+
83
+ :param image_filepath: Path to the image file.
84
+ :param max_width: Maximum allowed width of the image.
85
+ :param max_height: Maximum allowed height of the image.
86
+ :return: Tuple containing the image object and its new dimensions.
87
+ """
88
+ image = load_image(image_filepath) # Assuming load_image is defined elsewhere
89
+ width, height = image.size
90
+
91
+ # Calculate new dimensions while maintaining the aspect ratio
92
+ if width > max_width or height > max_height:
93
+ aspect_ratio = width / height
94
+ new_width = min(max_width, int(max_height * aspect_ratio))
95
+ new_height = min(max_height, int(max_width / aspect_ratio))
96
+ return image, new_width, new_height
97
+
98
+ return image, width, height
99
+
100
+
101
+ def inference(self, input_data, apply_annotation=False, mode=None):
102
+ """
103
+ Perform inference on input data using the specified model.
104
+
105
+ :param input_data: A list of dictionaries containing image file paths and text inputs.
106
+ :param apply_annotation: Optional flag to apply annotations to the output.
107
+ :param mode: Optional mode for inference ("static" for simple JSON output).
108
+ :return: List of processed model responses.
109
+ """
110
+ # Handle static mode
111
+ if mode == "static":
112
+ return [self.get_simple_json()]
113
+
114
+ # Load the model and processor
115
+ model, processor = self._load_model_and_processor(self.model_name)
116
+ config = model.config
117
+
118
+ # Determine if we're doing text-only or image-based inference
119
+ is_text_only = input_data[0].get("file_path") is None
120
+
121
+ if is_text_only:
122
+ # Text-only inference
123
+ messages = input_data[0]["text_input"]
124
+ response = self._generate_text_response(model, processor, config, messages)
125
+ results = [response]
126
+ else:
127
+ # Image-based inference
128
+ file_paths = self._extract_file_paths(input_data)
129
+ results = self._process_images(model, processor, config, file_paths, input_data, apply_annotation)
130
+
131
+ return results
132
+
133
+ def _generate_text_response(self, model, processor, config, messages):
134
+ """
135
+ Generate a text response for text-only inputs.
136
+
137
+ :param model: The loaded model
138
+ :param processor: The loaded processor
139
+ :param config: Model configuration
140
+ :param messages: Input messages
141
+ :return: Generated response
142
+ """
143
+ prompt = apply_chat_template(processor, config, messages)
144
+ response = generate(
145
+ model,
146
+ processor,
147
+ prompt,
148
+ max_tokens=4000,
149
+ temperature=0.0,
150
+ verbose=False
151
+ )
152
+ print("Inference completed successfully")
153
+ return response
154
+
155
+ def _process_images(self, model, processor, config, file_paths, input_data, apply_annotation):
156
+ """
157
+ Process images and generate responses for each.
158
+
159
+ :param model: The loaded model
160
+ :param processor: The loaded processor
161
+ :param config: Model configuration
162
+ :param file_paths: List of image file paths
163
+ :param input_data: Original input data
164
+ :param apply_annotation: Flag to apply annotations
165
+ :return: List of processed responses
166
+ """
167
+ results = []
168
+ for file_path in file_paths:
169
+ image, width, height = self.load_image_data(file_path)
170
+
171
+ # Prepare messages based on model type
172
+ messages = self._prepare_messages(input_data, apply_annotation)
173
+
174
+ # Generate and process response
175
+ prompt = apply_chat_template(processor, config, messages)
176
+ response, _ = generate(
177
+ model,
178
+ processor,
179
+ prompt,
180
+ image,
181
+ resize_shape=(width, height),
182
+ max_tokens=4000,
183
+ temperature=0.0,
184
+ verbose=False
185
+ )
186
+ results.append(self.process_response(response))
187
+ print(f"Inference completed successfully for: {file_path}")
188
+
189
+ return results
190
+
191
+
192
+ def transform_query_with_bbox(self, text_input):
193
+ """
194
+ Transform JSON schema in text_input to include value, bbox, and confidence.
195
+ Works with both array and object JSON structures.
196
+
197
+ Args:
198
+ text_input (str): The input text containing a JSON schema
199
+
200
+ Returns:
201
+ str: Text with transformed JSON including value, bbox, and confidence
202
+ """
203
+ # Split text into parts - find the JSON portion between "retrieve" and "return response"
204
+ retrieve_pattern = r'retrieve\s+'
205
+ return_pattern = r'\.\s+return\s+response'
206
+
207
+ retrieve_match = re.search(retrieve_pattern, text_input)
208
+ return_match = re.search(return_pattern, text_input)
209
+
210
+ if not retrieve_match or not return_match:
211
+ return text_input # Return original if pattern not found
212
+
213
+ json_start = retrieve_match.end()
214
+ json_end = return_match.start()
215
+
216
+ prefix = text_input[:json_start]
217
+ json_str = text_input[json_start:json_end].strip()
218
+ suffix = text_input[json_end:]
219
+
220
+ # Parse and transform the JSON
221
+ try:
222
+ # Handle single quotes if needed
223
+ json_str = json_str.replace("'", '"')
224
+
225
+ json_obj = json.loads(json_str)
226
+ transformed_json = self.transform_query_structure(json_obj)
227
+ transformed_json_str = json.dumps(transformed_json)
228
+
229
+ # Rebuild the text
230
+ result = prefix + transformed_json_str + suffix
231
+
232
+ return result
233
+ except json.JSONDecodeError as e:
234
+ print(f"Error parsing JSON: {e}")
235
+ return text_input # Return original if parsing fails
236
+
237
+
238
+ def transform_query_structure(self, json_obj):
239
+ """
240
+ Transform each field in the JSON structure to include value, bbox, and confidence.
241
+ Handles both array and object formats recursively.
242
+ """
243
+ if isinstance(json_obj, list):
244
+ # Handle array format
245
+ return [self.transform_query_structure(item) for item in json_obj]
246
+ elif isinstance(json_obj, dict):
247
+ # Handle object format
248
+ result = {}
249
+ for key, value in json_obj.items():
250
+ if isinstance(value, (dict, list)):
251
+ # Recursively transform nested objects or arrays
252
+ result[key] = self.transform_query_structure(value)
253
+ else:
254
+ # Transform simple value to object with value, bbox, and confidence
255
+ result[key] = {
256
+ "value": value,
257
+ "bbox": ["float", "float", "float", "float"],
258
+ "confidence": "float"
259
+ }
260
+ return result
261
+ else:
262
+ # For primitive values, no transformation needed
263
+ return json_obj
264
+
265
+
266
+ def _prepare_messages(self, input_data, apply_annotation):
267
+ """
268
+ Prepare the appropriate messages based on the model type.
269
+
270
+ :param input_data: Original input data
271
+ :param apply_annotation: Flag to apply annotations
272
+ :return: Properly formatted messages
273
+ """
274
+ if "mistral" in self.model_name.lower():
275
+ return input_data[0]["text_input"]
276
+ elif "qwen" in self.model_name.lower():
277
+ if apply_annotation:
278
+ system_prompt = {"role": "system", "content": "You are an expert at extracting text from images. "
279
+ "For each item in the table, provide separate bounding boxes for each field. "
280
+ "All coordinates should be in pixels relative to the original image. Format your response in JSON."}
281
+ user_prompt = {"role": "user", "content": self.transform_query_with_bbox(input_data[0]["text_input"])}
282
+ return [system_prompt, user_prompt]
283
+ return [
284
+ {"role": "system", "content": "You are an expert at extracting structured text from image documents."},
285
+ {"role": "user", "content": input_data[0]["text_input"]},
286
+ ]
287
+ else:
288
+ raise ValueError("Unsupported model type. Please use either Mistral or Qwen.")
289
+
290
+ @staticmethod
291
+ def _extract_file_paths(input_data):
292
+ """
293
+ Extract and resolve absolute file paths from input data.
294
+
295
+ :param input_data: List of dictionaries containing image file paths.
296
+ :return: List of absolute file paths.
297
+ """
298
+ return [
299
+ os.path.abspath(file_path)
300
+ for data in input_data
301
+ for file_path in data.get("file_path", [])
302
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sparrow-parse
3
- Version: 1.0.4a0
3
+ Version: 1.0.6
4
4
  Summary: Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.
5
5
  Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
6
6
  Author: Andrej Baranovskij
@@ -20,9 +20,11 @@ Requires-Dist: torchvision>=0.22.0
20
20
  Requires-Dist: torch>=2.7.0
21
21
  Requires-Dist: sentence-transformers>=4.1.0
22
22
  Requires-Dist: numpy>=2.2.5
23
- Requires-Dist: pypdf>=5.4.0
23
+ Requires-Dist: pypdf>=5.5.0
24
24
  Requires-Dist: gradio_client>=1.7.2
25
25
  Requires-Dist: pdf2image>=1.17.0
26
+ Requires-Dist: mlx>=0.25.2; sys_platform == "darwin" and platform_machine == "arm64"
27
+ Requires-Dist: mlx-vlm==0.1.26; sys_platform == "darwin" and platform_machine == "arm64"
26
28
 
27
29
  # Sparrow Parse
28
30
 
@@ -4,6 +4,10 @@ torchvision>=0.22.0
4
4
  torch>=2.7.0
5
5
  sentence-transformers>=4.1.0
6
6
  numpy>=2.2.5
7
- pypdf>=5.4.0
7
+ pypdf>=5.5.0
8
8
  gradio_client>=1.7.2
9
9
  pdf2image>=1.17.0
10
+
11
+ [:sys_platform == "darwin" and platform_machine == "arm64"]
12
+ mlx>=0.25.2
13
+ mlx-vlm==0.1.26
@@ -1 +0,0 @@
1
- __version__ = '1.0.4a'
@@ -1,35 +0,0 @@
1
- from mlx_vlm import load, apply_chat_template, generate
2
- from mlx_vlm.utils import load_image
3
-
4
-
5
- # Load model and processor
6
- # vl_model, vl_processor = load("mlx-community/Mistral-Small-3.1-24B-Instruct-2503-8bit")
7
- vl_model, vl_processor = load("mlx-community/Qwen2.5-VL-7B-Instruct-8bit")
8
- vl_config = vl_model.config
9
-
10
- image = load_image("images/bonds_table.png")
11
-
12
- messages = [
13
- {"role": "system", "content": "You are an expert at extracting text from images. Format your response in json."},
14
- {"role": "user", "content": "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"}
15
- ]
16
-
17
- # message = "retrieve all data. return response in JSON format"
18
- # message = "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"
19
-
20
- # Apply chat template
21
- prompt = apply_chat_template(vl_processor, vl_config, messages)
22
- # prompt = apply_chat_template(vl_processor, vl_config, message)
23
-
24
- # Generate text
25
- vl_output = generate(
26
- vl_model,
27
- vl_processor,
28
- prompt,
29
- image,
30
- max_tokens=1000,
31
- temperature=0,
32
- verbose=False
33
- )
34
-
35
- print(vl_output)
@@ -1,217 +0,0 @@
1
- # from mlx_vlm import load, generate
2
- # from mlx_vlm.prompt_utils import apply_chat_template
3
- # from mlx_vlm.utils import load_image
4
- # from sparrow_parse.vllm.inference_base import ModelInference
5
- # import os
6
- # import json
7
- # from rich import print
8
- #
9
- #
10
- # class MLXInference(ModelInference):
11
- # """
12
- # A class for performing inference using the MLX model.
13
- # Handles image preprocessing, response formatting, and model interaction.
14
- # """
15
- #
16
- # def __init__(self, model_name):
17
- # """
18
- # Initialize the inference class with the given model name.
19
- #
20
- # :param model_name: Name of the model to load.
21
- # """
22
- # self.model_name = model_name
23
- # print(f"MLXInference initialized for model: {model_name}")
24
- #
25
- #
26
- # @staticmethod
27
- # def _load_model_and_processor(model_name):
28
- # """
29
- # Load the model and processor for inference.
30
- #
31
- # :param model_name: Name of the model to load.
32
- # :return: Tuple containing the loaded model and processor.
33
- # """
34
- # model, processor = load(model_name)
35
- # print(f"Loaded model: {model_name}")
36
- # return model, processor
37
- #
38
- #
39
- # def process_response(self, output_text):
40
- # """
41
- # Process and clean the model's raw output to format as JSON.
42
- # """
43
- # try:
44
- # # Check if we have markdown code block markers
45
- # if "```" in output_text:
46
- # # Handle markdown-formatted output
47
- # json_start = output_text.find("```json")
48
- # if json_start != -1:
49
- # # Extract content between ```json and ```
50
- # content = output_text[json_start + 7:]
51
- # json_end = content.rfind("```")
52
- # if json_end != -1:
53
- # content = content[:json_end].strip()
54
- # formatted_json = json.loads(content)
55
- # return json.dumps(formatted_json, indent=2)
56
- #
57
- # # Handle raw JSON (no markdown formatting)
58
- # # First try to find JSON array or object patterns
59
- # for pattern in [r'\[\s*\{.*\}\s*\]', r'\{.*\}']:
60
- # import re
61
- # matches = re.search(pattern, output_text, re.DOTALL)
62
- # if matches:
63
- # potential_json = matches.group(0)
64
- # try:
65
- # formatted_json = json.loads(potential_json)
66
- # return json.dumps(formatted_json, indent=2)
67
- # except:
68
- # pass
69
- #
70
- # # Last resort: try to parse the whole text as JSON
71
- # formatted_json = json.loads(output_text.strip())
72
- # return json.dumps(formatted_json, indent=2)
73
- #
74
- # except Exception as e:
75
- # print(f"Failed to parse JSON: {e}")
76
- # return output_text
77
- #
78
- #
79
- # def load_image_data(self, image_filepath, max_width=1250, max_height=1750):
80
- # """
81
- # Load and resize image while maintaining its aspect ratio.
82
- #
83
- # :param image_filepath: Path to the image file.
84
- # :param max_width: Maximum allowed width of the image.
85
- # :param max_height: Maximum allowed height of the image.
86
- # :return: Tuple containing the image object and its new dimensions.
87
- # """
88
- # image = load_image(image_filepath) # Assuming load_image is defined elsewhere
89
- # width, height = image.size
90
- #
91
- # # Calculate new dimensions while maintaining the aspect ratio
92
- # if width > max_width or height > max_height:
93
- # aspect_ratio = width / height
94
- # new_width = min(max_width, int(max_height * aspect_ratio))
95
- # new_height = min(max_height, int(max_width / aspect_ratio))
96
- # return image, new_width, new_height
97
- #
98
- # return image, width, height
99
- #
100
- #
101
- # def inference(self, input_data, mode=None):
102
- # """
103
- # Perform inference on input data using the specified model.
104
- #
105
- # :param input_data: A list of dictionaries containing image file paths and text inputs.
106
- # :param mode: Optional mode for inference ("static" for simple JSON output).
107
- # :return: List of processed model responses.
108
- # """
109
- # # Handle static mode
110
- # if mode == "static":
111
- # return [self.get_simple_json()]
112
- #
113
- # # Load the model and processor
114
- # model, processor = self._load_model_and_processor(self.model_name)
115
- # config = model.config
116
- #
117
- # # Determine if we're doing text-only or image-based inference
118
- # is_text_only = input_data[0].get("file_path") is None
119
- #
120
- # if is_text_only:
121
- # # Text-only inference
122
- # messages = input_data[0]["text_input"]
123
- # response = self._generate_text_response(model, processor, config, messages)
124
- # results = [response]
125
- # else:
126
- # # Image-based inference
127
- # file_paths = self._extract_file_paths(input_data)
128
- # results = self._process_images(model, processor, config, file_paths, input_data)
129
- #
130
- # return results
131
- #
132
- # def _generate_text_response(self, model, processor, config, messages):
133
- # """
134
- # Generate a text response for text-only inputs.
135
- #
136
- # :param model: The loaded model
137
- # :param processor: The loaded processor
138
- # :param config: Model configuration
139
- # :param messages: Input messages
140
- # :return: Generated response
141
- # """
142
- # prompt = apply_chat_template(processor, config, messages)
143
- # response = generate(
144
- # model,
145
- # processor,
146
- # prompt,
147
- # max_tokens=4000,
148
- # temperature=0.0,
149
- # verbose=False
150
- # )
151
- # print("Inference completed successfully")
152
- # return response
153
- #
154
- # def _process_images(self, model, processor, config, file_paths, input_data):
155
- # """
156
- # Process images and generate responses for each.
157
- #
158
- # :param model: The loaded model
159
- # :param processor: The loaded processor
160
- # :param config: Model configuration
161
- # :param file_paths: List of image file paths
162
- # :param input_data: Original input data
163
- # :return: List of processed responses
164
- # """
165
- # results = []
166
- # for file_path in file_paths:
167
- # image, width, height = self.load_image_data(file_path)
168
- #
169
- # # Prepare messages based on model type
170
- # messages = self._prepare_messages(input_data, file_path)
171
- #
172
- # # Generate and process response
173
- # prompt = apply_chat_template(processor, config, messages)
174
- # response = generate(
175
- # model,
176
- # processor,
177
- # prompt,
178
- # image,
179
- # resize_shape=(width, height),
180
- # max_tokens=4000,
181
- # temperature=0.0,
182
- # verbose=False
183
- # )
184
- # results.append(self.process_response(response))
185
- # print(f"Inference completed successfully for: {file_path}")
186
- #
187
- # return results
188
- #
189
- # def _prepare_messages(self, input_data, file_path):
190
- # """
191
- # Prepare the appropriate messages based on the model type.
192
- #
193
- # :param input_data: Original input data
194
- # :param file_path: Current file path being processed
195
- # :return: Properly formatted messages
196
- # """
197
- # if "mistral" in self.model_name.lower():
198
- # return input_data[0]["text_input"]
199
- # else:
200
- # return [
201
- # {"role": "system", "content": "You are an expert at extracting structured text from image documents."},
202
- # {"role": "user", "content": input_data[0]["text_input"]},
203
- # ]
204
- #
205
- # @staticmethod
206
- # def _extract_file_paths(input_data):
207
- # """
208
- # Extract and resolve absolute file paths from input data.
209
- #
210
- # :param input_data: List of dictionaries containing image file paths.
211
- # :return: List of absolute file paths.
212
- # """
213
- # return [
214
- # os.path.abspath(file_path)
215
- # for data in input_data
216
- # for file_path in data.get("file_path", [])
217
- # ]
File without changes
File without changes