vlm4ocr 0.3.0__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/PKG-INFO +1 -1
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/pyproject.toml +1 -1
- vlm4ocr-0.3.1/vlm4ocr/__init__.py +15 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/cli.py +4 -13
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/data_types.py +17 -5
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/ocr_engines.py +30 -13
- vlm4ocr-0.3.1/vlm4ocr/vlm_engines.py +1163 -0
- vlm4ocr-0.3.0/vlm4ocr/__init__.py +0 -11
- vlm4ocr-0.3.0/vlm4ocr/vlm_engines.py +0 -570
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/README.md +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_markdown_user_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt +0 -0
- {vlm4ocr-0.3.0 → vlm4ocr-0.3.1}/vlm4ocr/utils.py +0 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .ocr_engines import OCREngine
|
|
2
|
+
from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"BasicVLMConfig",
|
|
6
|
+
"ReasoningVLMConfig",
|
|
7
|
+
"OpenAIReasoningVLMConfig",
|
|
8
|
+
"OCREngine",
|
|
9
|
+
"OllamaVLMEngine",
|
|
10
|
+
"OpenAICompatibleVLMEngine",
|
|
11
|
+
"VLLMVLMEngine",
|
|
12
|
+
"OpenRouterVLMEngine",
|
|
13
|
+
"OpenAIVLMEngine",
|
|
14
|
+
"AzureOpenAIVLMEngine"
|
|
15
|
+
]
|
|
@@ -4,18 +4,9 @@ import sys
|
|
|
4
4
|
import logging
|
|
5
5
|
import asyncio
|
|
6
6
|
import time
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from .ocr_engines import OCREngine
|
|
11
|
-
from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
12
|
-
from .data_types import OCRResult
|
|
13
|
-
except ImportError:
|
|
14
|
-
# Fallback for when the package is installed
|
|
15
|
-
from vlm4ocr.ocr_engines import OCREngine
|
|
16
|
-
from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
17
|
-
from vlm4ocr.data_types import OCRResult
|
|
18
|
-
|
|
7
|
+
from .ocr_engines import OCREngine
|
|
8
|
+
from .vlm_engines import OpenAICompatibleVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
9
|
+
from .data_types import OCRResult
|
|
19
10
|
import tqdm.asyncio
|
|
20
11
|
|
|
21
12
|
# --- Global logger setup (console) ---
|
|
@@ -208,7 +199,7 @@ def main():
|
|
|
208
199
|
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
|
|
209
200
|
elif args.vlm_engine == "openai_compatible":
|
|
210
201
|
if not args.base_url: parser.error("--base_url is required for openai_compatible.")
|
|
211
|
-
vlm_engine_instance =
|
|
202
|
+
vlm_engine_instance = OpenAICompatibleVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
|
|
212
203
|
elif args.vlm_engine == "azure_openai":
|
|
213
204
|
if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
|
|
214
205
|
if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import List, Literal
|
|
2
|
+
from typing import List, Dict, Literal
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from vlm4ocr.utils import get_default_page_delimiter
|
|
5
5
|
|
|
@@ -24,6 +24,7 @@ class OCRResult:
|
|
|
24
24
|
pages: List[dict] = field(default_factory=list)
|
|
25
25
|
filename: str = field(init=False)
|
|
26
26
|
status: str = field(init=False, default="processing")
|
|
27
|
+
messages_log: List[List[Dict[str,str]]] = field(default_factory=list)
|
|
27
28
|
|
|
28
29
|
def __post_init__(self):
|
|
29
30
|
"""
|
|
@@ -67,10 +68,6 @@ class OCRResult:
|
|
|
67
68
|
}
|
|
68
69
|
self.pages.append(page)
|
|
69
70
|
|
|
70
|
-
|
|
71
|
-
def __len__(self):
|
|
72
|
-
return len(self.pages)
|
|
73
|
-
|
|
74
71
|
def get_page(self, idx):
|
|
75
72
|
if not isinstance(idx, int):
|
|
76
73
|
raise ValueError("Index must be an integer")
|
|
@@ -78,6 +75,21 @@ class OCRResult:
|
|
|
78
75
|
raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
|
|
79
76
|
|
|
80
77
|
return self.pages[idx]
|
|
78
|
+
|
|
79
|
+
def clear_messages_log(self):
|
|
80
|
+
self.messages_log = []
|
|
81
|
+
|
|
82
|
+
def add_messages_to_log(self, messages: List[Dict[str,str]]):
|
|
83
|
+
if not isinstance(messages, list):
|
|
84
|
+
raise ValueError("messages must be a list of dict")
|
|
85
|
+
|
|
86
|
+
self.messages_log.extend(messages)
|
|
87
|
+
|
|
88
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
89
|
+
return self.messages_log.copy()
|
|
90
|
+
|
|
91
|
+
def __len__(self):
|
|
92
|
+
return len(self.pages)
|
|
81
93
|
|
|
82
94
|
def __iter__(self):
|
|
83
95
|
return iter(self.pages)
|
|
@@ -6,7 +6,7 @@ from colorama import Fore, Style
|
|
|
6
6
|
import json
|
|
7
7
|
from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
|
|
8
8
|
from vlm4ocr.data_types import OCRResult
|
|
9
|
-
from vlm4ocr.vlm_engines import VLMEngine
|
|
9
|
+
from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
|
|
10
10
|
|
|
11
11
|
SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
12
12
|
|
|
@@ -126,7 +126,8 @@ class OCREngine:
|
|
|
126
126
|
stream=True
|
|
127
127
|
)
|
|
128
128
|
for chunk in response_stream:
|
|
129
|
-
|
|
129
|
+
if chunk["type"] == "response":
|
|
130
|
+
yield {"type": "ocr_chunk", "data": chunk["data"]}
|
|
130
131
|
|
|
131
132
|
if i < len(images) - 1:
|
|
132
133
|
yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
|
|
@@ -157,7 +158,8 @@ class OCREngine:
|
|
|
157
158
|
stream=True
|
|
158
159
|
)
|
|
159
160
|
for chunk in response_stream:
|
|
160
|
-
|
|
161
|
+
if chunk["type"] == "response":
|
|
162
|
+
yield {"type": "ocr_chunk", "data": chunk["data"]}
|
|
161
163
|
|
|
162
164
|
|
|
163
165
|
def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
|
|
@@ -271,24 +273,32 @@ class OCREngine:
|
|
|
271
273
|
|
|
272
274
|
try:
|
|
273
275
|
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
276
|
+
# Define a messages logger to capture messages
|
|
277
|
+
messages_logger = MessagesLogger()
|
|
278
|
+
# Generate response
|
|
274
279
|
response = self.vlm_engine.chat(
|
|
275
280
|
messages,
|
|
276
281
|
verbose=verbose,
|
|
277
|
-
stream=False
|
|
282
|
+
stream=False,
|
|
283
|
+
messages_logger=messages_logger
|
|
278
284
|
)
|
|
285
|
+
ocr_text = response["response"]
|
|
279
286
|
# Clean the response if output mode is markdown
|
|
280
287
|
if self.output_mode == "markdown":
|
|
281
|
-
|
|
288
|
+
ocr_text = clean_markdown(ocr_text)
|
|
282
289
|
|
|
283
290
|
# Parse the response if output mode is JSON
|
|
284
|
-
|
|
285
|
-
json_list = extract_json(
|
|
291
|
+
elif self.output_mode == "JSON":
|
|
292
|
+
json_list = extract_json(ocr_text)
|
|
286
293
|
# Serialize the JSON list to a string
|
|
287
|
-
|
|
294
|
+
ocr_text = json.dumps(json_list, indent=4)
|
|
288
295
|
|
|
289
296
|
# Add the page to the OCR result
|
|
290
|
-
ocr_result.add_page(text=
|
|
297
|
+
ocr_result.add_page(text=ocr_text,
|
|
291
298
|
image_processing_status=image_processing_status)
|
|
299
|
+
|
|
300
|
+
# Add messages log to the OCR result
|
|
301
|
+
ocr_result.add_messages_to_log(messages_logger.get_messages_log())
|
|
292
302
|
|
|
293
303
|
except Exception as page_e:
|
|
294
304
|
ocr_result.status = "error"
|
|
@@ -387,6 +397,7 @@ class OCREngine:
|
|
|
387
397
|
filename = os.path.basename(file_path)
|
|
388
398
|
file_ext = os.path.splitext(file_path)[1].lower()
|
|
389
399
|
result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
|
|
400
|
+
messages_logger = MessagesLogger()
|
|
390
401
|
# check file extension
|
|
391
402
|
if file_ext not in SUPPORTED_IMAGE_EXTS:
|
|
392
403
|
result.status = "error"
|
|
@@ -416,7 +427,8 @@ class OCREngine:
|
|
|
416
427
|
data_loader=data_loader,
|
|
417
428
|
page_index=page_index,
|
|
418
429
|
rotate_correction=rotate_correction,
|
|
419
|
-
max_dimension_pixels=max_dimension_pixels
|
|
430
|
+
max_dimension_pixels=max_dimension_pixels,
|
|
431
|
+
messages_logger=messages_logger
|
|
420
432
|
)
|
|
421
433
|
page_processing_tasks.append(task)
|
|
422
434
|
|
|
@@ -428,14 +440,17 @@ class OCREngine:
|
|
|
428
440
|
except Exception as e:
|
|
429
441
|
result.status = "error"
|
|
430
442
|
result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
|
|
443
|
+
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
431
444
|
return result
|
|
432
445
|
|
|
433
446
|
# Set status to success if no errors occurred
|
|
434
447
|
result.status = "success"
|
|
448
|
+
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
435
449
|
return result
|
|
436
450
|
|
|
437
451
|
async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
|
|
438
|
-
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
452
|
+
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
453
|
+
messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
|
|
439
454
|
"""
|
|
440
455
|
This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
|
|
441
456
|
|
|
@@ -476,15 +491,17 @@ class OCREngine:
|
|
|
476
491
|
}
|
|
477
492
|
|
|
478
493
|
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
479
|
-
|
|
494
|
+
response = await self.vlm_engine.chat_async(
|
|
480
495
|
messages,
|
|
496
|
+
messages_logger=messages_logger
|
|
481
497
|
)
|
|
498
|
+
ocr_text = response["response"]
|
|
482
499
|
# Clean the OCR text if output mode is markdown
|
|
483
500
|
if self.output_mode == "markdown":
|
|
484
501
|
ocr_text = clean_markdown(ocr_text)
|
|
485
502
|
|
|
486
503
|
# Parse the response if output mode is JSON
|
|
487
|
-
|
|
504
|
+
elif self.output_mode == "JSON":
|
|
488
505
|
json_list = extract_json(ocr_text)
|
|
489
506
|
# Serialize the JSON list to a string
|
|
490
507
|
ocr_text = json.dumps(json_list, indent=4)
|