vlm4ocr 0.2.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vlm4ocr
3
- Version: 0.2.0
3
+ Version: 0.3.1
4
4
  Summary: Python package and Web App for OCR with vision language models.
5
5
  License: MIT
6
6
  Author: Enshuo (David) Hsu
@@ -10,6 +10,8 @@ Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Provides-Extra: tesseract
13
+ Requires-Dist: colorama (>=0.4.4)
14
+ Requires-Dist: json-repair (>=0.30.0)
13
15
  Requires-Dist: pdf2image (>=1.16.0)
14
16
  Requires-Dist: pillow (>=10.0.0)
15
17
  Requires-Dist: pytesseract (>=0.3.13) ; extra == "tesseract"
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "vlm4ocr"
3
- version = "0.2.0"
3
+ version = "0.3.1"
4
4
  description = "Python package and Web App for OCR with vision language models."
5
5
  authors = ["Enshuo (David) Hsu"]
6
6
  license = "MIT"
@@ -15,7 +15,9 @@ exclude = [
15
15
  [tool.poetry.dependencies]
16
16
  python = "^3.11"
17
17
  pdf2image = ">=1.16.0"
18
+ colorama = ">=0.4.4"
18
19
  pillow = ">=10.0.0"
20
+ json-repair = ">=0.30.0"
19
21
  pytesseract = { version = ">=0.3.13", optional = true }
20
22
 
21
23
  [tool.poetry.scripts]
@@ -0,0 +1,15 @@
1
+ from .ocr_engines import OCREngine
2
+ from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
3
+
4
+ __all__ = [
5
+ "BasicVLMConfig",
6
+ "ReasoningVLMConfig",
7
+ "OpenAIReasoningVLMConfig",
8
+ "OCREngine",
9
+ "OllamaVLMEngine",
10
+ "OpenAICompatibleVLMEngine",
11
+ "VLLMVLMEngine",
12
+ "OpenRouterVLMEngine",
13
+ "OpenAIVLMEngine",
14
+ "AzureOpenAIVLMEngine"
15
+ ]
@@ -0,0 +1 @@
1
+ You are a helpful assistant that can convert scanned documents into JSON format. Your output is accurate and well-formatted, starting with ```json and ending with ```. You will only output the JSON text without any additional explanations or comments. The JSON should include all text, tables, and lists with appropriate keys and values. You will ignore images, icons, or anything that can not be converted into text.
@@ -4,18 +4,9 @@ import sys
4
4
  import logging
5
5
  import asyncio
6
6
  import time
7
-
8
- # Attempt to import from the local package structure
9
- try:
10
- from .ocr_engines import OCREngine
11
- from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
12
- from .data_types import OCRResult
13
- except ImportError:
14
- # Fallback for when the package is installed
15
- from vlm4ocr.ocr_engines import OCREngine
16
- from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
17
- from vlm4ocr.data_types import OCRResult
18
-
7
+ from .ocr_engines import OCREngine
8
+ from .vlm_engines import OpenAICompatibleVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
9
+ from .data_types import OCRResult
19
10
  import tqdm.asyncio
20
11
 
21
12
  # --- Global logger setup (console) ---
@@ -208,7 +199,7 @@ def main():
208
199
  vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
209
200
  elif args.vlm_engine == "openai_compatible":
210
201
  if not args.base_url: parser.error("--base_url is required for openai_compatible.")
211
- vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
202
+ vlm_engine_instance = OpenAICompatibleVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
212
203
  elif args.vlm_engine == "azure_openai":
213
204
  if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
214
205
  if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
@@ -1,9 +1,9 @@
1
1
  import os
2
- from typing import List, Literal
2
+ from typing import List, Dict, Literal
3
3
  from dataclasses import dataclass, field
4
4
  from vlm4ocr.utils import get_default_page_delimiter
5
5
 
6
- OutputMode = Literal["markdown", "HTML", "text"]
6
+ OutputMode = Literal["markdown", "HTML", "text", "JSON"]
7
7
 
8
8
  @dataclass
9
9
  class OCRResult:
@@ -24,6 +24,7 @@ class OCRResult:
24
24
  pages: List[dict] = field(default_factory=list)
25
25
  filename: str = field(init=False)
26
26
  status: str = field(init=False, default="processing")
27
+ messages_log: List[List[Dict[str,str]]] = field(default_factory=list)
27
28
 
28
29
  def __post_init__(self):
29
30
  """
@@ -33,8 +34,8 @@ class OCRResult:
33
34
  self.filename = os.path.basename(self.input_dir)
34
35
 
35
36
  # output_mode validation
36
- if self.output_mode not in ["markdown", "HTML", "text"]:
37
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
37
+ if self.output_mode not in ["markdown", "HTML", "text", "JSON"]:
38
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'")
38
39
 
39
40
  # pages validation
40
41
  if not isinstance(self.pages, list):
@@ -67,10 +68,6 @@ class OCRResult:
67
68
  }
68
69
  self.pages.append(page)
69
70
 
70
-
71
- def __len__(self):
72
- return len(self.pages)
73
-
74
71
  def get_page(self, idx):
75
72
  if not isinstance(idx, int):
76
73
  raise ValueError("Index must be an integer")
@@ -78,6 +75,21 @@ class OCRResult:
78
75
  raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
79
76
 
80
77
  return self.pages[idx]
78
+
79
+ def clear_messages_log(self):
80
+ self.messages_log = []
81
+
82
+ def add_messages_to_log(self, messages: List[Dict[str,str]]):
83
+ if not isinstance(messages, list):
84
+ raise ValueError("messages must be a list of dict")
85
+
86
+ self.messages_log.extend(messages)
87
+
88
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
89
+ return self.messages_log.copy()
90
+
91
+ def __len__(self):
92
+ return len(self.pages)
81
93
 
82
94
  def __iter__(self):
83
95
  return iter(self.pages)
@@ -3,10 +3,10 @@ from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
3
3
  import importlib
4
4
  import asyncio
5
5
  from colorama import Fore, Style
6
- from PIL import Image
7
- from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, get_default_page_delimiter
6
+ import json
7
+ from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
8
8
  from vlm4ocr.data_types import OCRResult
9
- from vlm4ocr.vlm_engines import VLMEngine
9
+ from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
10
10
 
11
11
  SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
12
12
 
@@ -21,7 +21,7 @@ class OCREngine:
21
21
  inference_engine : InferenceEngine
22
22
  The inference engine to use for OCR.
23
23
  output_mode : str, Optional
24
- The output format. Must be 'markdown', 'HTML', or 'text'.
24
+ The output format. Must be 'markdown', 'HTML', 'text', or 'JSON'.
25
25
  system_prompt : str, Optional
26
26
  Custom system prompt. We recommend use a default system prompt by leaving this blank.
27
27
  user_prompt : str, Optional
@@ -33,8 +33,8 @@ class OCREngine:
33
33
  self.vlm_engine = vlm_engine
34
34
 
35
35
  # Check output mode
36
- if output_mode not in ["markdown", "HTML", "text"]:
37
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
36
+ if output_mode not in ["markdown", "HTML", "text", "JSON"]:
37
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'.")
38
38
  self.output_mode = output_mode
39
39
 
40
40
  # System prompt
@@ -49,6 +49,9 @@ class OCREngine:
49
49
  if isinstance(user_prompt, str) and user_prompt:
50
50
  self.user_prompt = user_prompt
51
51
  else:
52
+ if self.output_mode == "JSON":
53
+ raise ValueError("user_prompt must be provided when output_mode is 'JSON' to define the JSON structure.")
54
+
52
55
  prompt_template_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_user_prompt.txt')
53
56
  with prompt_template_path.open('r', encoding='utf-8') as f:
54
57
  self.user_prompt = f.read()
@@ -123,7 +126,8 @@ class OCREngine:
123
126
  stream=True
124
127
  )
125
128
  for chunk in response_stream:
126
- yield {"type": "ocr_chunk", "data": chunk}
129
+ if chunk["type"] == "response":
130
+ yield {"type": "ocr_chunk", "data": chunk["data"]}
127
131
 
128
132
  if i < len(images) - 1:
129
133
  yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
@@ -154,7 +158,8 @@ class OCREngine:
154
158
  stream=True
155
159
  )
156
160
  for chunk in response_stream:
157
- yield {"type": "ocr_chunk", "data": chunk}
161
+ if chunk["type"] == "response":
162
+ yield {"type": "ocr_chunk", "data": chunk["data"]}
158
163
 
159
164
 
160
165
  def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
@@ -268,18 +273,32 @@ class OCREngine:
268
273
 
269
274
  try:
270
275
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
276
+ # Define a messages logger to capture messages
277
+ messages_logger = MessagesLogger()
278
+ # Generate response
271
279
  response = self.vlm_engine.chat(
272
280
  messages,
273
281
  verbose=verbose,
274
- stream=False
282
+ stream=False,
283
+ messages_logger=messages_logger
275
284
  )
285
+ ocr_text = response["response"]
276
286
  # Clean the response if output mode is markdown
277
287
  if self.output_mode == "markdown":
278
- response = clean_markdown(response)
288
+ ocr_text = clean_markdown(ocr_text)
289
+
290
+ # Parse the response if output mode is JSON
291
+ elif self.output_mode == "JSON":
292
+ json_list = extract_json(ocr_text)
293
+ # Serialize the JSON list to a string
294
+ ocr_text = json.dumps(json_list, indent=4)
279
295
 
280
296
  # Add the page to the OCR result
281
- ocr_result.add_page(text=response,
297
+ ocr_result.add_page(text=ocr_text,
282
298
  image_processing_status=image_processing_status)
299
+
300
+ # Add messages log to the OCR result
301
+ ocr_result.add_messages_to_log(messages_logger.get_messages_log())
283
302
 
284
303
  except Exception as page_e:
285
304
  ocr_result.status = "error"
@@ -378,6 +397,7 @@ class OCREngine:
378
397
  filename = os.path.basename(file_path)
379
398
  file_ext = os.path.splitext(file_path)[1].lower()
380
399
  result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
400
+ messages_logger = MessagesLogger()
381
401
  # check file extension
382
402
  if file_ext not in SUPPORTED_IMAGE_EXTS:
383
403
  result.status = "error"
@@ -407,7 +427,8 @@ class OCREngine:
407
427
  data_loader=data_loader,
408
428
  page_index=page_index,
409
429
  rotate_correction=rotate_correction,
410
- max_dimension_pixels=max_dimension_pixels
430
+ max_dimension_pixels=max_dimension_pixels,
431
+ messages_logger=messages_logger
411
432
  )
412
433
  page_processing_tasks.append(task)
413
434
 
@@ -419,14 +440,17 @@ class OCREngine:
419
440
  except Exception as e:
420
441
  result.status = "error"
421
442
  result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
443
+ result.add_messages_to_log(messages_logger.get_messages_log())
422
444
  return result
423
445
 
424
446
  # Set status to success if no errors occurred
425
447
  result.status = "success"
448
+ result.add_messages_to_log(messages_logger.get_messages_log())
426
449
  return result
427
450
 
428
451
  async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
429
- page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Tuple[str, Dict[str, str]]:
452
+ page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
453
+ messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
430
454
  """
431
455
  This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
432
456
 
@@ -467,9 +491,19 @@ class OCREngine:
467
491
  }
468
492
 
469
493
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
470
- ocr_text = await self.vlm_engine.chat_async(
494
+ response = await self.vlm_engine.chat_async(
471
495
  messages,
496
+ messages_logger=messages_logger
472
497
  )
498
+ ocr_text = response["response"]
499
+ # Clean the OCR text if output mode is markdown
473
500
  if self.output_mode == "markdown":
474
501
  ocr_text = clean_markdown(ocr_text)
502
+
503
+ # Parse the response if output mode is JSON
504
+ elif self.output_mode == "JSON":
505
+ json_list = extract_json(ocr_text)
506
+ # Serialize the JSON list to a string
507
+ ocr_text = json.dumps(json_list, indent=4)
508
+
475
509
  return ocr_text, image_processing_status
@@ -2,11 +2,14 @@ import abc
2
2
  import os
3
3
  import io
4
4
  import base64
5
- from typing import Union, List, Tuple
5
+ from typing import Dict, List, Tuple
6
+ import json
7
+ import json_repair
6
8
  import importlib.util
7
9
  from pdf2image import convert_from_path, pdfinfo_from_path
8
10
  from PIL import Image
9
11
  import asyncio
12
+ import warnings
10
13
 
11
14
 
12
15
  class DataLoader(abc.ABC):
@@ -229,6 +232,55 @@ def clean_markdown(text:str) -> str:
229
232
  cleaned_text = text.replace("```markdown", "").replace("```", "")
230
233
  return cleaned_text
231
234
 
235
+ def _find_dict_strings( text: str) -> List[str]:
236
+ """
237
+ Extracts balanced JSON-like dictionaries from a string, even if nested.
238
+
239
+ Parameters:
240
+ -----------
241
+ text : str
242
+ the input text containing JSON-like structures.
243
+
244
+ Returns : List[str]
245
+ A list of valid JSON-like strings representing dictionaries.
246
+ """
247
+ open_brace = 0
248
+ start = -1
249
+ json_objects = []
250
+
251
+ for i, char in enumerate(text):
252
+ if char == '{':
253
+ if open_brace == 0:
254
+ # start of a new JSON object
255
+ start = i
256
+ open_brace += 1
257
+ elif char == '}':
258
+ open_brace -= 1
259
+ if open_brace == 0 and start != -1:
260
+ json_objects.append(text[start:i + 1])
261
+ start = -1
262
+
263
+ return json_objects
264
+
265
+ def extract_json(gen_text:str) -> List[Dict[str, str]]:
266
+ """
267
+ This method inputs a generated text and output a JSON of information tuples
268
+ """
269
+ out = []
270
+ dict_str_list = _find_dict_strings(gen_text)
271
+ for dict_str in dict_str_list:
272
+ try:
273
+ dict_obj = json.loads(dict_str)
274
+ out.append(dict_obj)
275
+ except json.JSONDecodeError:
276
+ dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
277
+ if dict_obj:
278
+ warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
279
+ out.append(dict_obj)
280
+ else:
281
+ warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
282
+ return out
283
+
232
284
  def get_default_page_delimiter(output_mode:str) -> str:
233
285
  """
234
286
  Returns the default page delimiter based on the environment variable.
@@ -243,8 +295,8 @@ def get_default_page_delimiter(output_mode:str) -> str:
243
295
  str
244
296
  The default page delimiter.
245
297
  """
246
- if output_mode not in ["markdown", "HTML", "text"]:
247
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
298
+ if output_mode not in ["markdown", "HTML", "text", "JSON"]:
299
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'")
248
300
 
249
301
  if output_mode == "markdown":
250
302
  return "\n\n---\n\n"
@@ -252,6 +304,8 @@ def get_default_page_delimiter(output_mode:str) -> str:
252
304
  return "<br><br>"
253
305
  elif output_mode == "text":
254
306
  return "\n\n---\n\n"
307
+ elif output_mode == "JSON":
308
+ return "\n\n---\n\n"
255
309
 
256
310
 
257
311
  class ImageProcessor: