sparrow-parse 1.0.5__tar.gz → 1.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/PKG-INFO +4 -4
  2. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/setup.py +1 -1
  3. sparrow-parse-1.0.7/sparrow_parse/__init__.py +1 -0
  4. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/extractors/vllm_extractor.py +19 -18
  5. sparrow-parse-1.0.7/sparrow_parse/text_extraction.py +216 -0
  6. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/vllm/huggingface_inference.py +1 -1
  7. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/vllm/inference_base.py +1 -1
  8. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/vllm/local_gpu_inference.py +1 -1
  9. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/vllm/mlx_inference.py +89 -9
  10. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse.egg-info/PKG-INFO +4 -4
  11. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse.egg-info/requires.txt +3 -3
  12. sparrow-parse-1.0.5/sparrow_parse/__init__.py +0 -1
  13. sparrow-parse-1.0.5/sparrow_parse/text_extraction.py +0 -35
  14. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/README.md +0 -0
  15. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/setup.cfg +0 -0
  16. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/__main__.py +0 -0
  17. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/extractors/__init__.py +0 -0
  18. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/helpers/__init__.py +0 -0
  19. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/helpers/image_optimizer.py +0 -0
  20. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/helpers/pdf_optimizer.py +0 -0
  21. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/processors/__init__.py +0 -0
  22. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/processors/table_structure_processor.py +0 -0
  23. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/vllm/__init__.py +0 -0
  24. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse/vllm/inference_factory.py +0 -0
  25. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse.egg-info/SOURCES.txt +0 -0
  26. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse.egg-info/dependency_links.txt +0 -0
  27. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse.egg-info/entry_points.txt +0 -0
  28. {sparrow-parse-1.0.5 → sparrow-parse-1.0.7}/sparrow_parse.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sparrow-parse
3
- Version: 1.0.5
3
+ Version: 1.0.7
4
4
  Summary: Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.
5
5
  Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
6
6
  Author: Andrej Baranovskij
@@ -20,11 +20,11 @@ Requires-Dist: torchvision>=0.22.0
20
20
  Requires-Dist: torch>=2.7.0
21
21
  Requires-Dist: sentence-transformers>=4.1.0
22
22
  Requires-Dist: numpy>=2.2.5
23
- Requires-Dist: pypdf>=5.4.0
23
+ Requires-Dist: pypdf>=5.5.0
24
24
  Requires-Dist: gradio_client>=1.7.2
25
25
  Requires-Dist: pdf2image>=1.17.0
26
- Requires-Dist: mlx>=0.25.1; sys_platform == "darwin" and platform_machine == "arm64"
27
- Requires-Dist: mlx-vlm==0.1.25; sys_platform == "darwin" and platform_machine == "arm64"
26
+ Requires-Dist: mlx>=0.25.2; sys_platform == "darwin" and platform_machine == "arm64"
27
+ Requires-Dist: mlx-vlm==0.1.26; sys_platform == "darwin" and platform_machine == "arm64"
28
28
 
29
29
  # Sparrow Parse
30
30
 
@@ -8,7 +8,7 @@ with open("requirements.txt", "r", encoding="utf-8") as fh:
8
8
 
9
9
  setup(
10
10
  name="sparrow-parse",
11
- version="1.0.5",
11
+ version="1.0.7",
12
12
  author="Andrej Baranovskij",
13
13
  author_email="andrejus.baranovskis@gmail.com",
14
14
  description="Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.",
@@ -0,0 +1 @@
1
+ __version__ = '1.0.7'
@@ -14,13 +14,14 @@ class VLLMExtractor(object):
14
14
  pass
15
15
 
16
16
  def run_inference(self, model_inference_instance, input_data, tables_only=False,
17
- generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None):
17
+ generic_query=False, crop_size=None, apply_annotation=False, debug_dir=None, debug=False, mode=None):
18
18
  """
19
19
  Main entry point for processing input data using a model inference instance.
20
20
  Handles generic queries, PDFs, and table extraction.
21
21
  """
22
22
  if generic_query:
23
23
  input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
24
+ apply_annotation=False
24
25
 
25
26
  if debug:
26
27
  print("Input data:", input_data)
@@ -37,12 +38,12 @@ class VLLMExtractor(object):
37
38
  # Document data extraction inference (file_path exists and is not None)
38
39
  file_path = input_data[0]["file_path"]
39
40
  if self.is_pdf(file_path):
40
- return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode)
41
+ return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir, mode)
41
42
  else:
42
- return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir)
43
+ return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir)
43
44
 
44
45
 
45
- def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode):
46
+ def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir, mode):
46
47
  """
47
48
  Handles processing and inference for PDF files, including page splitting and optional table extraction.
48
49
  """
@@ -50,21 +51,21 @@ class VLLMExtractor(object):
50
51
  num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
51
52
  debug_dir, convert_to_images=True)
52
53
 
53
- results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir)
54
+ results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir)
54
55
 
55
56
  # Clean up temporary directory
56
57
  shutil.rmtree(temp_dir, ignore_errors=True)
57
58
  return results, num_pages
58
59
 
59
60
 
60
- def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir):
61
+ def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir):
61
62
  """
62
63
  Handles processing and inference for non-PDF files, with optional table extraction.
63
64
  """
64
65
  file_path = input_data[0]["file_path"]
65
66
 
66
67
  if tables_only:
67
- return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1
68
+ return self._extract_tables(model_inference_instance, file_path, input_data, apply_annotation, debug, debug_dir), 1
68
69
  else:
69
70
  temp_dir = tempfile.mkdtemp()
70
71
 
@@ -77,13 +78,13 @@ class VLLMExtractor(object):
77
78
 
78
79
  file_path = input_data[0]["file_path"]
79
80
  input_data[0]["file_path"] = [file_path]
80
- results = model_inference_instance.inference(input_data)
81
+ results = model_inference_instance.inference(input_data, apply_annotation)
81
82
 
82
83
  shutil.rmtree(temp_dir, ignore_errors=True)
83
84
 
84
85
  return results, 1
85
86
 
86
- def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir):
87
+ def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, apply_annotation, debug, debug_dir):
87
88
  """
88
89
  Processes individual pages (PDF split) and handles table extraction or inference.
89
90
 
@@ -93,6 +94,7 @@ class VLLMExtractor(object):
93
94
  input_data: Input data for inference.
94
95
  tables_only: Whether to only process tables.
95
96
  crop_size: Size for cropping image borders.
97
+ apply_annotation: Flag to apply annotations to the output.
96
98
  debug: Debug flag for logging.
97
99
  debug_dir: Directory for saving debug information.
98
100
 
@@ -106,9 +108,7 @@ class VLLMExtractor(object):
106
108
  print(f"Processing {len(output_files)} pages for table extraction.")
107
109
  # Process each page individually for table extraction
108
110
  for i, file_path in enumerate(output_files):
109
- tables_result = self._extract_tables(
110
- model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
111
- )
111
+ tables_result = self._extract_tables( model_inference_instance, file_path, input_data, apply_annotation, debug, debug_dir, page_index=i)
112
112
  # Since _extract_tables returns a list with one JSON string, unpack it
113
113
  results_array.extend(tables_result) # Unpack the single JSON string
114
114
  else:
@@ -141,7 +141,7 @@ class VLLMExtractor(object):
141
141
  input_data[0]["file_path"] = output_files
142
142
 
143
143
  # Process all files at once
144
- results = model_inference_instance.inference(input_data)
144
+ results = model_inference_instance.inference(input_data, apply_annotation)
145
145
  results_array.extend(results)
146
146
 
147
147
  # Clean up temporary directory
@@ -150,7 +150,7 @@ class VLLMExtractor(object):
150
150
  return results_array
151
151
 
152
152
 
153
- def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
153
+ def _extract_tables(self, model_inference_instance, file_path, input_data, apply_annotation, debug, debug_dir, page_index=None):
154
154
  """
155
155
  Detects and processes tables from an input file.
156
156
  """
@@ -175,7 +175,7 @@ class VLLMExtractor(object):
175
175
  table.save(output_filename, "JPEG")
176
176
 
177
177
  input_data[0]["file_path"] = [output_filename]
178
- result = self._run_model_inference(model_inference_instance, input_data)
178
+ result = self._run_model_inference(model_inference_instance, input_data, apply_annotation)
179
179
  results_array.append(result)
180
180
 
181
181
  shutil.rmtree(temp_dir, ignore_errors=True)
@@ -191,11 +191,11 @@ class VLLMExtractor(object):
191
191
 
192
192
 
193
193
  @staticmethod
194
- def _run_model_inference(model_inference_instance, input_data):
194
+ def _run_model_inference(model_inference_instance, input_data, apply_annotation):
195
195
  """
196
196
  Runs model inference and handles JSON decoding.
197
197
  """
198
- result = model_inference_instance.inference(input_data)[0]
198
+ result = model_inference_instance.inference(input_data, apply_annotation)[0]
199
199
  try:
200
200
  return json.loads(result) if isinstance(result, str) else result
201
201
  except json.JSONDecodeError:
@@ -230,7 +230,7 @@ if __name__ == "__main__":
230
230
  # input_data = [
231
231
  # {
232
232
  # "file_path": "sparrow_parse/images/bonds_table.png",
233
- # "text_input": "retrieve all data. return response in JSON format"
233
+ # "text_input": "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"
234
234
  # }
235
235
  # ]
236
236
  #
@@ -245,6 +245,7 @@ if __name__ == "__main__":
245
245
  # results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
246
246
  # generic_query=False,
247
247
  # crop_size=0,
248
+ # apply_annotation=False,
248
249
  # debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
249
250
  # debug=True,
250
251
  # mode=None)
@@ -0,0 +1,216 @@
1
+ from mlx_vlm import load, apply_chat_template, generate
2
+ from mlx_vlm.utils import load_image
3
+ from PIL import ImageDraw, ImageFont
4
+ import json
5
+
6
+
7
+ # Load model and processor
8
+ vl_model, vl_processor = load("mlx-community/Mistral-Small-3.1-24B-Instruct-2503-8bit")
9
+ # vl_model, vl_processor = load("mlx-community/Qwen2.5-VL-72B-Instruct-4bit")
10
+ vl_config = vl_model.config
11
+
12
+ image = load_image("images/bonds_table.png")
13
+
14
+ # Qwen
15
+ # messages = [
16
+ # {"role": "system", "content": "You are an expert at extracting text from images. Format your response in JSON."},
17
+ # {"role": "user", "content": "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"}
18
+ # ]
19
+ # Qwen with bbox
20
+ # messages = [
21
+ # {"role": "system", "content": "You are an expert at extracting text from images. For each item in the table, provide separate bounding boxes for each field. All coordinates should be in pixels relative to the original image. Format your response in JSON."},
22
+ # {"role": "user", "content": "retrieve [{\"instrument_name\":{\"value\":\"str\", \"bbox\":[\"float\", \"float\", \"float\", \"float\"], \"confidence\":\"float\"}, \"valuation\":{\"value\":\"int\", \"bbox\":[\"float\", \"float\", \"float\", \"float\"], \"confidence\":\"float\"}}]. return response in JSON format"}
23
+ # ]
24
+ # Qwen with bbox, get all data
25
+ # messages = [
26
+ # {"role": "system", "content": "You are an expert at extracting text from images. For each item in the table, provide separate bounding boxes for each field. All coordinates should be in pixels relative to the original image. Format your response in JSON."},
27
+ # {"role": "user", "content": "retrieve all data. return response in JSON format. For each identified field or data element, include: 1) a descriptive field name as the object key, 2) a nested object with 'value' containing the extracted content, 'bbox' array with [x_min, y_min, x_max, y_max] coordinates in pixels, and 'confidence' score between 0-1. Example structure: [{\"field_name\":{\"value\":\"extracted value\", \"bbox\":[100, 200, 300, 250], \"confidence\":0.95}}]"}
28
+ # ]
29
+
30
+ # Mistral
31
+ # message = "retrieve all data. return response in JSON format"
32
+ message = "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"
33
+
34
+ # Qwen
35
+ # prompt = apply_chat_template(vl_processor, vl_config, messages)
36
+ # Mistral
37
+ prompt = apply_chat_template(vl_processor, vl_config, message)
38
+
39
+ # Generate text
40
+ vl_output, _ = generate(
41
+ vl_model,
42
+ vl_processor,
43
+ prompt,
44
+ image,
45
+ max_tokens=4000,
46
+ temperature=0,
47
+ verbose=False
48
+ )
49
+
50
+ print(vl_output)
51
+
52
+
53
+ # Comment out below code if non Qwen model is used
54
+
55
+ # # Convert to a format we can draw on
56
+ # img_draw = image.copy()
57
+ # draw = ImageDraw.Draw(img_draw)
58
+ #
59
+ # # Parse the JSON result
60
+ # results = json.loads(vl_output.strip('```json\n').strip('```'))
61
+ #
62
+ # # Predefined solid colors that are highly visible
63
+ # solid_colors = [
64
+ # (180, 30, 40), # Dark red
65
+ # (0, 100, 140), # Dark blue
66
+ # (30, 120, 40), # Dark green
67
+ # (140, 60, 160), # Purple
68
+ # (200, 100, 0), # Orange
69
+ # (100, 80, 0), # Brown
70
+ # (0, 100, 100), # Teal
71
+ # (120, 40, 100) # Magenta
72
+ # ]
73
+ #
74
+ # # Determine unique field keys across all items to assign consistent colors
75
+ # unique_fields = set()
76
+ # for item in results:
77
+ # unique_fields.update(item.keys())
78
+ #
79
+ # # Map each unique field to a color
80
+ # field_color_map = {}
81
+ # for i, field in enumerate(sorted(unique_fields)):
82
+ # field_color_map[field] = solid_colors[i % len(solid_colors)]
83
+ #
84
+ # # Load font with larger size
85
+ # font_size = 20
86
+ # try:
87
+ # font = ImageFont.truetype("arial.ttf", font_size)
88
+ # except IOError:
89
+ # try:
90
+ # font = ImageFont.truetype("DejaVuSans.ttf", font_size)
91
+ # except IOError:
92
+ # try:
93
+ # font = ImageFont.truetype("Helvetica.ttf", font_size)
94
+ # except IOError:
95
+ # font = ImageFont.load_default()
96
+ #
97
+ #
98
+ # # Helper function to measure text width
99
+ # def get_text_dimensions(text, font):
100
+ # try:
101
+ # # Method for newer Pillow versions
102
+ # left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
103
+ # return right - left, bottom - top
104
+ # except AttributeError:
105
+ # try:
106
+ # # Alternative method
107
+ # left, top, right, bottom = font.getbbox(text)
108
+ # return right - left, bottom - top
109
+ # except AttributeError:
110
+ # # Fallback approximation
111
+ # return len(text) * (font_size // 2), font_size + 2
112
+ #
113
+ #
114
+ # # Draw bounding boxes for each item
115
+ # for item in results:
116
+ # # Process each field
117
+ # for field_name, field_data in item.items():
118
+ # # Check if this field has the expected structure
119
+ # if isinstance(field_data, dict) and "bbox" in field_data and "value" in field_data:
120
+ # bbox = field_data["bbox"]
121
+ # value = field_data["value"]
122
+ # confidence = field_data.get("confidence", "N/A")
123
+ #
124
+ # # Check if coordinates need to be scaled (normalized 0-1 values)
125
+ # if all(isinstance(coord, (int, float)) for coord in bbox):
126
+ # if max(bbox) <= 1.0: # Normalized coordinates
127
+ # width, height = image.size
128
+ # bbox = [
129
+ # bbox[0] * width,
130
+ # bbox[1] * height,
131
+ # bbox[2] * width,
132
+ # bbox[3] * height
133
+ # ]
134
+ #
135
+ # # Get color from the mapping we created
136
+ # color = field_color_map[field_name]
137
+ #
138
+ # # Make sure bbox coordinates are integers
139
+ # bbox = [int(coord) for coord in bbox]
140
+ #
141
+ # # Calculate the bbox width
142
+ # bbox_width = bbox[2] - bbox[0]
143
+ #
144
+ # # Draw rectangle with appropriate thickness
145
+ # border_thickness = 3
146
+ # draw.rectangle(
147
+ # [(bbox[0], bbox[1]), (bbox[2], bbox[3])],
148
+ # outline=color,
149
+ # width=border_thickness
150
+ # )
151
+ #
152
+ # # Format the value and confidence
153
+ # value_str = str(value)
154
+ # confidence_str = f" [{confidence:.2f}]" if isinstance(confidence, (int, float)) else ""
155
+ # prefix = f"{field_name}: "
156
+ #
157
+ # # First, try with full text without truncation
158
+ # full_label = prefix + value_str + confidence_str
159
+ # full_width, text_height = get_text_dimensions(full_label, font)
160
+ #
161
+ # # Compare with a reasonable maximum display width
162
+ # min_display_width = 300 # Reasonable minimum width to display text
163
+ # max_display_width = max(bbox_width * 1.5, min_display_width)
164
+ #
165
+ # # Only truncate if the full text exceeds our maximum display width
166
+ # if full_width > max_display_width:
167
+ # # Calculate the space available for the value
168
+ # prefix_width, _ = get_text_dimensions(prefix, font)
169
+ # confidence_width, _ = get_text_dimensions(confidence_str, font)
170
+ # available_value_width = max_display_width - prefix_width - confidence_width
171
+ #
172
+ # # Truncate the value to fit
173
+ # truncated_value = value_str
174
+ # for i in range(len(value_str) - 1, 3, -1):
175
+ # truncated_value = value_str[:i] + "..."
176
+ # temp_width, _ = get_text_dimensions(truncated_value, font)
177
+ # if temp_width <= available_value_width:
178
+ # break
179
+ #
180
+ # label = prefix + truncated_value + confidence_str
181
+ # text_width, _ = get_text_dimensions(label, font)
182
+ # else:
183
+ # # No truncation needed
184
+ # label = full_label
185
+ # text_width = full_width
186
+ #
187
+ # # Position for text (above the bounding box)
188
+ # padding = 6
189
+ # text_position = (bbox[0], bbox[1] - text_height - (padding * 2))
190
+ #
191
+ # # Ensure text doesn't go off the top of the image
192
+ # if text_position[1] < padding:
193
+ # # If too close to top, position below the box instead
194
+ # text_position = (bbox[0], bbox[3] + padding)
195
+ #
196
+ # # Add a background rectangle with better contrast
197
+ # draw.rectangle(
198
+ # [(text_position[0] - padding, text_position[1] - padding),
199
+ # (text_position[0] + text_width + padding, text_position[1] + text_height + padding)],
200
+ # fill=(255, 255, 255, 240),
201
+ # outline=color,
202
+ # width=2
203
+ # )
204
+ #
205
+ # # Draw the text
206
+ # draw.text(
207
+ # text_position,
208
+ # label,
209
+ # fill=color,
210
+ # font=font
211
+ # )
212
+ #
213
+ # # Save the annotated image
214
+ # output_path = "images/bonds_table_annotated.png"
215
+ # img_draw.save(output_path)
216
+ # print(f"Annotated image saved to {output_path}")
@@ -26,7 +26,7 @@ class HuggingFaceInference(ModelInference):
26
26
  return output_text
27
27
 
28
28
 
29
- def inference(self, input_data, mode=None):
29
+ def inference(self, input_data, apply_annotation=False, mode=None):
30
30
  if mode == "static":
31
31
  simple_json = self.get_simple_json()
32
32
  return [simple_json]
@@ -4,7 +4,7 @@ import json
4
4
 
5
5
  class ModelInference(ABC):
6
6
  @abstractmethod
7
- def inference(self, input_data, mode=None):
7
+ def inference(self, input_data, apply_annotation=False, mode=None):
8
8
  """This method should be implemented by subclasses."""
9
9
  pass
10
10
 
@@ -8,7 +8,7 @@ class LocalGPUInference(ModelInference):
8
8
  self.device = device
9
9
  self.model.to(self.device)
10
10
 
11
- def inference(self, input_data, mode=None):
11
+ def inference(self, input_data, apply_annotation=False, mode=None):
12
12
  self.model.eval() # Set the model to evaluation mode
13
13
  with torch.no_grad(): # No need to calculate gradients
14
14
  input_tensor = torch.tensor(input_data).to(self.device)
@@ -3,7 +3,7 @@ from mlx_vlm.prompt_utils import apply_chat_template
3
3
  from mlx_vlm.utils import load_image
4
4
  from sparrow_parse.vllm.inference_base import ModelInference
5
5
  import os
6
- import json
6
+ import json, re
7
7
  from rich import print
8
8
 
9
9
 
@@ -98,11 +98,12 @@ class MLXInference(ModelInference):
98
98
  return image, width, height
99
99
 
100
100
 
101
- def inference(self, input_data, mode=None):
101
+ def inference(self, input_data, apply_annotation=False, mode=None):
102
102
  """
103
103
  Perform inference on input data using the specified model.
104
104
 
105
105
  :param input_data: A list of dictionaries containing image file paths and text inputs.
106
+ :param apply_annotation: Optional flag to apply annotations to the output.
106
107
  :param mode: Optional mode for inference ("static" for simple JSON output).
107
108
  :return: List of processed model responses.
108
109
  """
@@ -125,7 +126,7 @@ class MLXInference(ModelInference):
125
126
  else:
126
127
  # Image-based inference
127
128
  file_paths = self._extract_file_paths(input_data)
128
- results = self._process_images(model, processor, config, file_paths, input_data)
129
+ results = self._process_images(model, processor, config, file_paths, input_data, apply_annotation)
129
130
 
130
131
  return results
131
132
 
@@ -151,7 +152,7 @@ class MLXInference(ModelInference):
151
152
  print("Inference completed successfully")
152
153
  return response
153
154
 
154
- def _process_images(self, model, processor, config, file_paths, input_data):
155
+ def _process_images(self, model, processor, config, file_paths, input_data, apply_annotation):
155
156
  """
156
157
  Process images and generate responses for each.
157
158
 
@@ -160,6 +161,7 @@ class MLXInference(ModelInference):
160
161
  :param config: Model configuration
161
162
  :param file_paths: List of image file paths
162
163
  :param input_data: Original input data
164
+ :param apply_annotation: Flag to apply annotations
163
165
  :return: List of processed responses
164
166
  """
165
167
  results = []
@@ -167,11 +169,11 @@ class MLXInference(ModelInference):
167
169
  image, width, height = self.load_image_data(file_path)
168
170
 
169
171
  # Prepare messages based on model type
170
- messages = self._prepare_messages(input_data, file_path)
172
+ messages = self._prepare_messages(input_data, apply_annotation)
171
173
 
172
174
  # Generate and process response
173
175
  prompt = apply_chat_template(processor, config, messages)
174
- response = generate(
176
+ response, _ = generate(
175
177
  model,
176
178
  processor,
177
179
  prompt,
@@ -186,21 +188,99 @@ class MLXInference(ModelInference):
186
188
 
187
189
  return results
188
190
 
189
- def _prepare_messages(self, input_data, file_path):
191
+ def transform_query_with_bbox(self, text_input):
192
+ """
193
+ Transform JSON schema in text_input to include value, bbox, and confidence.
194
+ Works with formats like: "retrieve field1, field2. return response in JSON format,
195
+ by strictly following this JSON schema: [{...}]."
196
+
197
+ Args:
198
+ text_input (str): The input text containing a JSON schema
199
+
200
+ Returns:
201
+ str: Text with transformed JSON including value, bbox, and confidence
202
+ """
203
+
204
+ schema_pattern = r'JSON schema:\s*(\[.*?\]|\{.*?\})'
205
+ schema_match = re.search(schema_pattern, text_input, re.DOTALL)
206
+
207
+ if not schema_match:
208
+ return text_input # Return original if pattern not found
209
+
210
+ # Extract the schema part and its position
211
+ schema_str = schema_match.group(1).strip()
212
+ schema_start = schema_match.start(1)
213
+ schema_end = schema_match.end(1)
214
+
215
+ # Parse and transform the JSON
216
+ try:
217
+ # Handle single quotes if needed
218
+ schema_str = schema_str.replace("'", '"')
219
+
220
+ json_obj = json.loads(schema_str)
221
+ transformed_json = self.transform_query_structure(json_obj)
222
+ transformed_json_str = json.dumps(transformed_json)
223
+
224
+ # Rebuild the text by replacing just the schema portion
225
+ result = text_input[:schema_start] + transformed_json_str + text_input[schema_end:]
226
+
227
+ return result
228
+ except json.JSONDecodeError as e:
229
+ print(f"Error parsing JSON schema: {e}")
230
+ return text_input # Return original if parsing fails
231
+
232
+
233
+ def transform_query_structure(self, json_obj):
234
+ """
235
+ Transform each field in the JSON structure to include value, bbox, and confidence.
236
+ Handles both array and object formats recursively.
237
+ """
238
+ if isinstance(json_obj, list):
239
+ # Handle array format
240
+ return [self.transform_query_structure(item) for item in json_obj]
241
+ elif isinstance(json_obj, dict):
242
+ # Handle object format
243
+ result = {}
244
+ for key, value in json_obj.items():
245
+ if isinstance(value, (dict, list)):
246
+ # Recursively transform nested objects or arrays
247
+ result[key] = self.transform_query_structure(value)
248
+ else:
249
+ # Transform simple value to object with value, bbox, and confidence
250
+ result[key] = {
251
+ "value": value,
252
+ "bbox": ["float", "float", "float", "float"],
253
+ "confidence": "float"
254
+ }
255
+ return result
256
+ else:
257
+ # For primitive values, no transformation needed
258
+ return json_obj
259
+
260
+
261
+ def _prepare_messages(self, input_data, apply_annotation):
190
262
  """
191
263
  Prepare the appropriate messages based on the model type.
192
264
 
193
265
  :param input_data: Original input data
194
- :param file_path: Current file path being processed
266
+ :param apply_annotation: Flag to apply annotations
195
267
  :return: Properly formatted messages
196
268
  """
197
269
  if "mistral" in self.model_name.lower():
198
270
  return input_data[0]["text_input"]
199
- else:
271
+ elif "qwen" in self.model_name.lower():
272
+ if apply_annotation:
273
+ system_prompt = {"role": "system", "content": "You are an expert at extracting text from images. "
274
+ "For each item in the table, provide separate bounding boxes for each field. "
275
+ "All coordinates should be in pixels relative to the original image. Format your response in JSON."}
276
+ user_prompt = {"role": "user", "content": self.transform_query_with_bbox(input_data[0]["text_input"])}
277
+ return [system_prompt, user_prompt]
200
278
  return [
201
279
  {"role": "system", "content": "You are an expert at extracting structured text from image documents."},
202
280
  {"role": "user", "content": input_data[0]["text_input"]},
203
281
  ]
282
+ else:
283
+ raise ValueError("Unsupported model type. Please use either Mistral or Qwen.")
204
284
 
205
285
  @staticmethod
206
286
  def _extract_file_paths(input_data):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sparrow-parse
3
- Version: 1.0.5
3
+ Version: 1.0.7
4
4
  Summary: Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.
5
5
  Home-page: https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse
6
6
  Author: Andrej Baranovskij
@@ -20,11 +20,11 @@ Requires-Dist: torchvision>=0.22.0
20
20
  Requires-Dist: torch>=2.7.0
21
21
  Requires-Dist: sentence-transformers>=4.1.0
22
22
  Requires-Dist: numpy>=2.2.5
23
- Requires-Dist: pypdf>=5.4.0
23
+ Requires-Dist: pypdf>=5.5.0
24
24
  Requires-Dist: gradio_client>=1.7.2
25
25
  Requires-Dist: pdf2image>=1.17.0
26
- Requires-Dist: mlx>=0.25.1; sys_platform == "darwin" and platform_machine == "arm64"
27
- Requires-Dist: mlx-vlm==0.1.25; sys_platform == "darwin" and platform_machine == "arm64"
26
+ Requires-Dist: mlx>=0.25.2; sys_platform == "darwin" and platform_machine == "arm64"
27
+ Requires-Dist: mlx-vlm==0.1.26; sys_platform == "darwin" and platform_machine == "arm64"
28
28
 
29
29
  # Sparrow Parse
30
30
 
@@ -4,10 +4,10 @@ torchvision>=0.22.0
4
4
  torch>=2.7.0
5
5
  sentence-transformers>=4.1.0
6
6
  numpy>=2.2.5
7
- pypdf>=5.4.0
7
+ pypdf>=5.5.0
8
8
  gradio_client>=1.7.2
9
9
  pdf2image>=1.17.0
10
10
 
11
11
  [:sys_platform == "darwin" and platform_machine == "arm64"]
12
- mlx>=0.25.1
13
- mlx-vlm==0.1.25
12
+ mlx>=0.25.2
13
+ mlx-vlm==0.1.26
@@ -1 +0,0 @@
1
- __version__ = '1.0.5'
@@ -1,35 +0,0 @@
1
- from mlx_vlm import load, apply_chat_template, generate
2
- from mlx_vlm.utils import load_image
3
-
4
-
5
- # Load model and processor
6
- # vl_model, vl_processor = load("mlx-community/Mistral-Small-3.1-24B-Instruct-2503-8bit")
7
- vl_model, vl_processor = load("mlx-community/Qwen2.5-VL-7B-Instruct-8bit")
8
- vl_config = vl_model.config
9
-
10
- image = load_image("images/bonds_table.png")
11
-
12
- messages = [
13
- {"role": "system", "content": "You are an expert at extracting text from images. Format your response in json."},
14
- {"role": "user", "content": "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"}
15
- ]
16
-
17
- # message = "retrieve all data. return response in JSON format"
18
- # message = "retrieve [{\"instrument_name\":\"str\", \"valuation\":\"int\"}]. return response in JSON format"
19
-
20
- # Apply chat template
21
- prompt = apply_chat_template(vl_processor, vl_config, messages)
22
- # prompt = apply_chat_template(vl_processor, vl_config, message)
23
-
24
- # Generate text
25
- vl_output = generate(
26
- vl_model,
27
- vl_processor,
28
- prompt,
29
- image,
30
- max_tokens=1000,
31
- temperature=0,
32
- verbose=False
33
- )
34
-
35
- print(vl_output)
File without changes
File without changes