vlm4ocr 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +3 -1
- vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt +1 -0
- vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt +1 -0
- vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt +1 -0
- vlm4ocr/cli.py +367 -0
- vlm4ocr/data_types.py +109 -0
- vlm4ocr/ocr_engines.py +359 -195
- vlm4ocr/utils.py +328 -18
- vlm4ocr/vlm_engines.py +317 -191
- {vlm4ocr-0.0.1.dist-info → vlm4ocr-0.2.0.dist-info}/METADATA +4 -2
- vlm4ocr-0.2.0.dist-info/RECORD +16 -0
- vlm4ocr-0.2.0.dist-info/entry_points.txt +3 -0
- vlm4ocr-0.0.1.dist-info/RECORD +0 -10
- /vlm4ocr/assets/default_prompt_templates/{ocr_user_prompt.txt → ocr_markdown_user_prompt.txt} +0 -0
- {vlm4ocr-0.0.1.dist-info → vlm4ocr-0.2.0.dist-info}/WHEEL +0 -0
vlm4ocr/__init__.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from .ocr_engines import OCREngine
|
|
2
|
-
from .vlm_engines import OllamaVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
2
|
+
from .vlm_engines import BasicVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
+
"BasicVLMConfig",
|
|
6
|
+
"OpenAIReasoningVLMConfig",
|
|
5
7
|
"OCREngine",
|
|
6
8
|
"OllamaVLMEngine",
|
|
7
9
|
"OpenAIVLMEngine",
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
You are a helpful assistant that can convert scanned documents into functional HTML. Your output is accurate and well-formatted, starting with <html> and ending with </html>. You will only output the HTML without any additional explanations or comments. The HTML should include all text, tables, and lists with appropriate tags (e.g., "table", "tbody", "tr", ""li) and stlyes (e.g., "font-family", "color", "font-size") that represents the text contents in the input. You will ignore images, icons, or anything that can not be converted into text.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Convert contents in this image into HTML.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Convert contents in this image into plain text.
|
vlm4ocr/cli.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import logging
|
|
5
|
+
import asyncio
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
# Attempt to import from the local package structure
|
|
9
|
+
try:
|
|
10
|
+
from .ocr_engines import OCREngine
|
|
11
|
+
from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
12
|
+
from .data_types import OCRResult
|
|
13
|
+
except ImportError:
|
|
14
|
+
# Fallback for when the package is installed
|
|
15
|
+
from vlm4ocr.ocr_engines import OCREngine
|
|
16
|
+
from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
17
|
+
from vlm4ocr.data_types import OCRResult
|
|
18
|
+
|
|
19
|
+
import tqdm.asyncio
|
|
20
|
+
|
|
21
|
+
# --- Global logger setup (console) ---
|
|
22
|
+
logging.basicConfig(
|
|
23
|
+
level=logging.INFO,
|
|
24
|
+
format='%(asctime)s - %(levelname)s: %(message)s',
|
|
25
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
|
26
|
+
)
|
|
27
|
+
logger = logging.getLogger("vlm4ocr_cli")
|
|
28
|
+
|
|
29
|
+
SUPPORTED_IMAGE_EXTS_CLI = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
30
|
+
OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'.txt'}
|
|
31
|
+
|
|
32
|
+
def get_output_path_for_ocr_result(input_file_path, specified_output_path_arg, output_mode, num_total_inputs, base_output_dir_if_no_specific_path):
|
|
33
|
+
"""
|
|
34
|
+
Determines the full output path for a given OCR result file.
|
|
35
|
+
Output filename format: <original_basename>_ocr.<new_extension>
|
|
36
|
+
Example: input "abc.pdf", output_mode "markdown" -> "abc.pdf_ocr.md"
|
|
37
|
+
"""
|
|
38
|
+
original_basename = os.path.basename(input_file_path)
|
|
39
|
+
output_filename_core = f"{original_basename}_ocr"
|
|
40
|
+
|
|
41
|
+
output_filename_ext = OUTPUT_EXTENSIONS.get(output_mode, '.txt')
|
|
42
|
+
final_output_filename = f"{output_filename_core}{output_filename_ext}"
|
|
43
|
+
|
|
44
|
+
if specified_output_path_arg: # If --output_path is used
|
|
45
|
+
# Scenario 1: Multiple input files, --output_path is expected to be a directory.
|
|
46
|
+
if num_total_inputs > 1 and os.path.isdir(specified_output_path_arg):
|
|
47
|
+
return os.path.join(specified_output_path_arg, final_output_filename)
|
|
48
|
+
# Scenario 2: Single input file.
|
|
49
|
+
# --output_path could be a full file path OR a directory.
|
|
50
|
+
elif num_total_inputs == 1:
|
|
51
|
+
if os.path.isdir(specified_output_path_arg): # If --output_path is a directory for the single file
|
|
52
|
+
return os.path.join(specified_output_path_arg, final_output_filename)
|
|
53
|
+
else: # If --output_path is a specific file name for the single file
|
|
54
|
+
return specified_output_path_arg
|
|
55
|
+
# Scenario 3: Multiple input files, but --output_path is NOT a directory (error, handled before this fn)
|
|
56
|
+
# or other edge cases, fall back to base_output_dir_if_no_specific_path
|
|
57
|
+
else:
|
|
58
|
+
return os.path.join(base_output_dir_if_no_specific_path, final_output_filename)
|
|
59
|
+
else: # No --output_path, save to the determined base output directory
|
|
60
|
+
return os.path.join(base_output_dir_if_no_specific_path, final_output_filename)
|
|
61
|
+
|
|
62
|
+
def setup_file_logger(log_dir, timestamp_str, debug_mode):
|
|
63
|
+
"""Sets up a file handler for logging."""
|
|
64
|
+
log_file_name = f"vlm4ocr_{timestamp_str}.log"
|
|
65
|
+
log_file_path = os.path.join(log_dir, log_file_name)
|
|
66
|
+
|
|
67
|
+
file_handler = logging.FileHandler(log_file_path, mode='a')
|
|
68
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
69
|
+
file_handler.setFormatter(formatter)
|
|
70
|
+
|
|
71
|
+
log_level = logging.DEBUG if debug_mode else logging.INFO
|
|
72
|
+
file_handler.setLevel(log_level)
|
|
73
|
+
|
|
74
|
+
logger.addHandler(file_handler)
|
|
75
|
+
logger.info(f"Logging to file: {log_file_path}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def main():
|
|
79
|
+
parser = argparse.ArgumentParser(
|
|
80
|
+
description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models. Processing is concurrent by default.",
|
|
81
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
io_group = parser.add_argument_group("Input/Output Options")
|
|
85
|
+
io_group.add_argument("--input_path", required=True, help="Path to a single input file or a directory of files.")
|
|
86
|
+
io_group.add_argument("--output_mode", choices=["markdown", "HTML", "text"], default="markdown", help="Output format.")
|
|
87
|
+
io_group.add_argument("--output_path", help="Optional: Path to save OCR results. If input_path is a directory of multiple files, this should be an output directory. If input is a single file, this can be a full file path or a directory. If not provided, results are saved to the current working directory (or a sub-directory for logs if --log is used).")
|
|
88
|
+
io_group.add_argument("--skip_existing", action="store_true", help="Skip processing files that already have OCR results in the output directory.")
|
|
89
|
+
|
|
90
|
+
image_processing_group = parser.add_argument_group("Image Processing Parameters")
|
|
91
|
+
image_processing_group.add_argument(
|
|
92
|
+
"--rotate_correction",
|
|
93
|
+
action="store_true",
|
|
94
|
+
help="Enable automatic rotation correction for input images. This requires Tesseract OCR to be installed and configured correctly.")
|
|
95
|
+
image_processing_group.add_argument(
|
|
96
|
+
"--max_dimension_pixels",
|
|
97
|
+
type=int,
|
|
98
|
+
default=4000,
|
|
99
|
+
help="Maximum dimension (width or height) in pixels for input images. Images larger than this will be resized to fit within this limit while maintaining aspect ratio."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
vlm_engine_group = parser.add_argument_group("VLM Engine Options")
|
|
103
|
+
vlm_engine_group.add_argument("--vlm_engine", choices=["openai", "azure_openai", "ollama", "openai_compatible"], required=True, help="VLM engine.")
|
|
104
|
+
vlm_engine_group.add_argument("--model", required=True, help="Model identifier for the VLM engine.")
|
|
105
|
+
vlm_engine_group.add_argument("--max_new_tokens", type=int, default=4096, help="Max new tokens for VLM.")
|
|
106
|
+
vlm_engine_group.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
|
|
107
|
+
|
|
108
|
+
openai_group = parser.add_argument_group("OpenAI & OpenAI-Compatible Options")
|
|
109
|
+
openai_group.add_argument("--api_key", default=os.environ.get("OPENAI_API_KEY"), help="API key.")
|
|
110
|
+
openai_group.add_argument("--base_url", help="Base URL for OpenAI-compatible services.")
|
|
111
|
+
|
|
112
|
+
azure_group = parser.add_argument_group("Azure OpenAI Options")
|
|
113
|
+
azure_group.add_argument("--azure_api_key", default=os.environ.get("AZURE_OPENAI_API_KEY"), help="Azure API key.")
|
|
114
|
+
azure_group.add_argument("--azure_endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"), help="Azure endpoint URL.")
|
|
115
|
+
azure_group.add_argument("--azure_api_version", default=os.environ.get("AZURE_OPENAI_API_VERSION"), help="Azure API version.")
|
|
116
|
+
|
|
117
|
+
ollama_group = parser.add_argument_group("Ollama Options")
|
|
118
|
+
ollama_group.add_argument("--ollama_host", default="http://localhost:11434", help="Ollama host URL.")
|
|
119
|
+
ollama_group.add_argument("--ollama_num_ctx", type=int, default=4096, help="Context length for Ollama.")
|
|
120
|
+
ollama_group.add_argument("--ollama_keep_alive", type=int, default=300, help="Ollama keep_alive seconds.")
|
|
121
|
+
|
|
122
|
+
ocr_params_group = parser.add_argument_group("OCR Engine Parameters")
|
|
123
|
+
ocr_params_group.add_argument("--user_prompt", help="Custom user prompt.")
|
|
124
|
+
|
|
125
|
+
processing_group = parser.add_argument_group("Processing Options")
|
|
126
|
+
processing_group.add_argument(
|
|
127
|
+
"--concurrent_batch_size",
|
|
128
|
+
type=int,
|
|
129
|
+
default=4,
|
|
130
|
+
help="Number of images/pages to process concurrently. Set to 1 for sequential processing of VLM calls."
|
|
131
|
+
)
|
|
132
|
+
processing_group.add_argument(
|
|
133
|
+
"--max_file_load",
|
|
134
|
+
type=int,
|
|
135
|
+
default=-1,
|
|
136
|
+
help="Number of input files to pre-load. Set to -1 for automatic config: 2 * concurrent_batch_size."
|
|
137
|
+
)
|
|
138
|
+
# --verbose flag was removed by user in previous version provided
|
|
139
|
+
processing_group.add_argument("--log", action="store_true", help="Enable writing logs to a timestamped file in the output directory.")
|
|
140
|
+
processing_group.add_argument("--debug", action="store_true", help="Enable debug level logging for console (and file if --log is active).")
|
|
141
|
+
|
|
142
|
+
args = parser.parse_args()
|
|
143
|
+
|
|
144
|
+
current_timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
145
|
+
|
|
146
|
+
# --- Configure Logger Level based on args ---
|
|
147
|
+
if args.debug:
|
|
148
|
+
logger.setLevel(logging.DEBUG)
|
|
149
|
+
# Set root logger to DEBUG only if our specific logger is DEBUG, to avoid overly verbose library logs unless intended.
|
|
150
|
+
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
151
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
152
|
+
logger.debug("Debug mode enabled for console.")
|
|
153
|
+
else:
|
|
154
|
+
logger.setLevel(logging.INFO) # Default for our CLI's own messages
|
|
155
|
+
logging.getLogger().setLevel(logging.WARNING) # Keep external libraries quieter by default
|
|
156
|
+
|
|
157
|
+
if args.concurrent_batch_size < 1:
|
|
158
|
+
parser.error("--concurrent_batch_size must be 1 or greater.")
|
|
159
|
+
|
|
160
|
+
# --- Determine Effective Output Directory (for logs and default OCR outputs) ---
|
|
161
|
+
effective_output_dir = os.getcwd() # Default if no --output_path
|
|
162
|
+
|
|
163
|
+
# Preliminary check to see if multiple files will be processed
|
|
164
|
+
_is_multi_file_scenario = False
|
|
165
|
+
if os.path.isdir(args.input_path):
|
|
166
|
+
_temp_files_list = [f for f in os.listdir(args.input_path) if os.path.isfile(os.path.join(args.input_path, f)) and os.path.splitext(f)[1].lower() in SUPPORTED_IMAGE_EXTS_CLI]
|
|
167
|
+
if len(_temp_files_list) > 1:
|
|
168
|
+
_is_multi_file_scenario = True
|
|
169
|
+
|
|
170
|
+
if args.output_path:
|
|
171
|
+
if _is_multi_file_scenario: # Input is a dir with multiple files
|
|
172
|
+
if os.path.exists(args.output_path) and not os.path.isdir(args.output_path):
|
|
173
|
+
logger.critical(f"Output path '{args.output_path}' must be a directory when processing multiple files. It currently points to a file.")
|
|
174
|
+
sys.exit(1)
|
|
175
|
+
effective_output_dir = args.output_path # --output_path is the directory for outputs and logs
|
|
176
|
+
else: # Single input file scenario
|
|
177
|
+
# If args.output_path is a directory, use it.
|
|
178
|
+
# If args.output_path is a file path, use its directory for logs.
|
|
179
|
+
if os.path.isdir(args.output_path):
|
|
180
|
+
effective_output_dir = args.output_path
|
|
181
|
+
else: # Assumed to be a file path
|
|
182
|
+
dir_name = os.path.dirname(args.output_path)
|
|
183
|
+
if dir_name: # If output_path includes a directory
|
|
184
|
+
effective_output_dir = dir_name
|
|
185
|
+
else: # output_path is just a filename, logs go to CWD
|
|
186
|
+
effective_output_dir = os.getcwd()
|
|
187
|
+
|
|
188
|
+
if not os.path.exists(effective_output_dir):
|
|
189
|
+
logger.info(f"Creating output directory: {effective_output_dir}")
|
|
190
|
+
os.makedirs(effective_output_dir, exist_ok=True)
|
|
191
|
+
|
|
192
|
+
# --- Setup File Logger (if --log is specified) ---
|
|
193
|
+
if args.log:
|
|
194
|
+
setup_file_logger(effective_output_dir, current_timestamp_str, args.debug)
|
|
195
|
+
|
|
196
|
+
logger.debug(f"Parsed arguments: {args}")
|
|
197
|
+
|
|
198
|
+
# --- Initialize VLM Engine ---
|
|
199
|
+
vlm_engine_instance = None
|
|
200
|
+
try:
|
|
201
|
+
logger.info(f"Initializing VLM engine: {args.vlm_engine} with model: {args.model}")
|
|
202
|
+
config = BasicVLMConfig(
|
|
203
|
+
max_new_tokens=args.max_new_tokens,
|
|
204
|
+
temperature=args.temperature
|
|
205
|
+
)
|
|
206
|
+
if args.vlm_engine == "openai":
|
|
207
|
+
if not args.api_key: parser.error("--api_key (or OPENAI_API_KEY) is required for OpenAI.")
|
|
208
|
+
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
|
|
209
|
+
elif args.vlm_engine == "openai_compatible":
|
|
210
|
+
if not args.base_url: parser.error("--base_url is required for openai_compatible.")
|
|
211
|
+
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
|
|
212
|
+
elif args.vlm_engine == "azure_openai":
|
|
213
|
+
if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
|
|
214
|
+
if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
|
|
215
|
+
if not args.azure_api_version: parser.error("--azure_api_version (or AZURE_OPENAI_API_VERSION) is required.")
|
|
216
|
+
vlm_engine_instance = AzureOpenAIVLMEngine(model=args.model, api_key=args.azure_api_key, azure_endpoint=args.azure_endpoint, api_version=args.azure_api_version, config=config)
|
|
217
|
+
elif args.vlm_engine == "ollama":
|
|
218
|
+
vlm_engine_instance = OllamaVLMEngine(model_name=args.model, host=args.ollama_host, num_ctx=args.ollama_num_ctx, keep_alive=args.ollama_keep_alive, config=config)
|
|
219
|
+
logger.info("VLM engine initialized successfully.")
|
|
220
|
+
except ImportError as e:
|
|
221
|
+
logger.error(f"Failed to import library for {args.vlm_engine}: {e}. Install dependencies.")
|
|
222
|
+
sys.exit(1)
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Error initializing VLM engine '{args.vlm_engine}': {e}")
|
|
225
|
+
if args.debug: logger.exception("Traceback:")
|
|
226
|
+
sys.exit(1)
|
|
227
|
+
|
|
228
|
+
# --- Initialize OCR Engine ---
|
|
229
|
+
try:
|
|
230
|
+
logger.info(f"Initializing OCR engine with output mode: {args.output_mode}")
|
|
231
|
+
ocr_engine_instance = OCREngine(vlm_engine=vlm_engine_instance, output_mode=args.output_mode, user_prompt=args.user_prompt)
|
|
232
|
+
logger.info("OCR engine initialized successfully.")
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logger.error(f"Error initializing OCR engine: {e}")
|
|
235
|
+
if args.debug: logger.exception("Traceback:")
|
|
236
|
+
sys.exit(1)
|
|
237
|
+
|
|
238
|
+
# --- Prepare input file paths (actual list) ---
|
|
239
|
+
input_files_to_process = []
|
|
240
|
+
if os.path.isdir(args.input_path):
|
|
241
|
+
logger.info(f"Input is directory: {args.input_path}. Scanning for files...")
|
|
242
|
+
for item in os.listdir(args.input_path):
|
|
243
|
+
item_path = os.path.join(args.input_path, item)
|
|
244
|
+
if os.path.isfile(item_path) and os.path.splitext(item)[1].lower() in SUPPORTED_IMAGE_EXTS_CLI:
|
|
245
|
+
input_files_to_process.append(item_path)
|
|
246
|
+
if not input_files_to_process:
|
|
247
|
+
logger.error(f"No supported files found in directory: {args.input_path}")
|
|
248
|
+
sys.exit(1)
|
|
249
|
+
logger.info(f"Found {len(input_files_to_process)} files to process.")
|
|
250
|
+
elif os.path.isfile(args.input_path):
|
|
251
|
+
if os.path.splitext(args.input_path)[1].lower() not in SUPPORTED_IMAGE_EXTS_CLI:
|
|
252
|
+
logger.error(f"Input file '{args.input_path}' is not supported. Supported: {SUPPORTED_IMAGE_EXTS_CLI}")
|
|
253
|
+
sys.exit(1)
|
|
254
|
+
input_files_to_process = [args.input_path]
|
|
255
|
+
logger.info(f"Processing single input file: {args.input_path}")
|
|
256
|
+
else:
|
|
257
|
+
logger.error(f"Input path not valid: {args.input_path}")
|
|
258
|
+
sys.exit(1)
|
|
259
|
+
|
|
260
|
+
# --- Skip existing files if --skip_existing is used ---
|
|
261
|
+
if args.skip_existing:
|
|
262
|
+
logger.info("Checking for existing OCR results in output path to skip...")
|
|
263
|
+
# Check each input file against the expected output file
|
|
264
|
+
existing_files = os.listdir(effective_output_dir)
|
|
265
|
+
filtered_input_files_to_process = []
|
|
266
|
+
for input_file in input_files_to_process:
|
|
267
|
+
expected_output_name = get_output_path_for_ocr_result(input_file, args.output_path, args.output_mode, len(input_files_to_process), effective_output_dir)
|
|
268
|
+
if os.path.basename(expected_output_name) not in existing_files:
|
|
269
|
+
filtered_input_files_to_process.append(input_file)
|
|
270
|
+
|
|
271
|
+
original_num_files = len(input_files_to_process)
|
|
272
|
+
after_filter_num_files = len(filtered_input_files_to_process)
|
|
273
|
+
input_files_to_process = filtered_input_files_to_process
|
|
274
|
+
logger.info(f"Dropped {original_num_files - after_filter_num_files} existing files. Number of input files to process after filtering: {len(input_files_to_process)}")
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
logger.info("All input files will be processed (`--skip_existing=False`).")
|
|
278
|
+
# This re-evaluation is useful if the initial _is_multi_file_scenario was just for log dir
|
|
279
|
+
num_actual_files = len(input_files_to_process)
|
|
280
|
+
|
|
281
|
+
# --- Run OCR ---
|
|
282
|
+
try:
|
|
283
|
+
logger.info(f"Processing with concurrent_batch_size: {args.concurrent_batch_size}.")
|
|
284
|
+
|
|
285
|
+
async def process_and_write_concurrently():
|
|
286
|
+
ocr_task_generator = ocr_engine_instance.concurrent_ocr(
|
|
287
|
+
file_paths=input_files_to_process,
|
|
288
|
+
rotate_correction=args.rotate_correction,
|
|
289
|
+
max_dimension_pixels=args.max_dimension_pixels,
|
|
290
|
+
concurrent_batch_size=args.concurrent_batch_size,
|
|
291
|
+
max_file_load=args.max_file_load if args.max_file_load > 0 else None
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Progress bar always attempted if tqdm is available and files exist,
|
|
295
|
+
# console verbosity controlled by logger level.
|
|
296
|
+
show_progress_bar = (num_actual_files > 0)
|
|
297
|
+
|
|
298
|
+
iterator_wrapper = tqdm.asyncio.tqdm(
|
|
299
|
+
ocr_task_generator,
|
|
300
|
+
total=num_actual_files,
|
|
301
|
+
desc="Processing files",
|
|
302
|
+
unit="file",
|
|
303
|
+
disable=not show_progress_bar # disable if no files, or can remove this disable if tqdm handles total=0
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
async for result_object in iterator_wrapper:
|
|
307
|
+
if not isinstance(result_object, OCRResult):
|
|
308
|
+
logger.warning(f"Received unexpected data type: {type(result_object)}")
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
input_file_path_from_result = result_object.input_dir
|
|
312
|
+
# For get_output_path_for_ocr_result, effective_output_dir is the base if args.output_path isn't specific enough
|
|
313
|
+
current_ocr_output_file_path = get_output_path_for_ocr_result(
|
|
314
|
+
input_file_path_from_result, args.output_path, args.output_mode,
|
|
315
|
+
num_actual_files, effective_output_dir
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if result_object.status == "error":
|
|
319
|
+
error_message = result_object.get_page(0) if len(result_object) > 0 else 'Unknown error during OCR'
|
|
320
|
+
logger.error(f"OCR failed for {result_object.filename}: {error_message}")
|
|
321
|
+
else:
|
|
322
|
+
try:
|
|
323
|
+
content_to_write = result_object.to_string()
|
|
324
|
+
with open(current_ocr_output_file_path, "w", encoding="utf-8") as f:
|
|
325
|
+
f.write(content_to_write)
|
|
326
|
+
# Log less verbosely to console if progress bar is active
|
|
327
|
+
if not show_progress_bar or logger.getEffectiveLevel() <= logging.DEBUG:
|
|
328
|
+
logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error(f"Error writing output for '{input_file_path_from_result}' to '{current_ocr_output_file_path}': {e}")
|
|
331
|
+
|
|
332
|
+
if hasattr(iterator_wrapper, 'close') and isinstance(iterator_wrapper, tqdm.asyncio.tqdm):
|
|
333
|
+
if iterator_wrapper.n < iterator_wrapper.total:
|
|
334
|
+
iterator_wrapper.n = iterator_wrapper.total
|
|
335
|
+
iterator_wrapper.refresh()
|
|
336
|
+
iterator_wrapper.close()
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
asyncio.run(process_and_write_concurrently())
|
|
340
|
+
except RuntimeError as e:
|
|
341
|
+
if "asyncio.run() cannot be called from a running event loop" in str(e):
|
|
342
|
+
logger.warning("asyncio.run() error. Attempting to use existing loop.")
|
|
343
|
+
loop = asyncio.get_event_loop_policy().get_event_loop()
|
|
344
|
+
if loop.is_running():
|
|
345
|
+
logger.critical("Cannot execute in current asyncio context. If in Jupyter, try 'import nest_asyncio; nest_asyncio.apply()'.")
|
|
346
|
+
sys.exit(1)
|
|
347
|
+
else:
|
|
348
|
+
loop.run_until_complete(process_and_write_concurrently())
|
|
349
|
+
else: raise e
|
|
350
|
+
|
|
351
|
+
logger.info("All processing finished.")
|
|
352
|
+
|
|
353
|
+
except FileNotFoundError as e:
|
|
354
|
+
logger.error(f"File not found: {e}")
|
|
355
|
+
if args.debug: logger.exception("Traceback:")
|
|
356
|
+
sys.exit(1)
|
|
357
|
+
except ValueError as e:
|
|
358
|
+
logger.error(f"Input/Value Error: {e}")
|
|
359
|
+
if args.debug: logger.exception("Traceback:")
|
|
360
|
+
sys.exit(1)
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.error(f"Unexpected error during main processing: {e}")
|
|
363
|
+
if args.debug: logger.exception("Traceback:")
|
|
364
|
+
sys.exit(1)
|
|
365
|
+
|
|
366
|
+
if __name__ == "__main__":
|
|
367
|
+
main()
|
vlm4ocr/data_types.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Literal
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from vlm4ocr.utils import get_default_page_delimiter
|
|
5
|
+
|
|
6
|
+
OutputMode = Literal["markdown", "HTML", "text"]
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class OCRResult:
|
|
10
|
+
"""
|
|
11
|
+
This class represents the result of an OCR process.
|
|
12
|
+
|
|
13
|
+
Parameters:
|
|
14
|
+
----------
|
|
15
|
+
input_dir : str
|
|
16
|
+
The directory where the input files (e.g., image, PDF, tiff) are located.
|
|
17
|
+
output_mode : str
|
|
18
|
+
The output format. Must be 'markdown', 'HTML', or 'text'.
|
|
19
|
+
pages : List[str]
|
|
20
|
+
A list of strings, each representing a page of the OCR result.
|
|
21
|
+
"""
|
|
22
|
+
input_dir: str
|
|
23
|
+
output_mode: OutputMode
|
|
24
|
+
pages: List[dict] = field(default_factory=list)
|
|
25
|
+
filename: str = field(init=False)
|
|
26
|
+
status: str = field(init=False, default="processing")
|
|
27
|
+
|
|
28
|
+
def __post_init__(self):
|
|
29
|
+
"""
|
|
30
|
+
Called after the dataclass-generated __init__ method.
|
|
31
|
+
Used for validation and initializing derived fields.
|
|
32
|
+
"""
|
|
33
|
+
self.filename = os.path.basename(self.input_dir)
|
|
34
|
+
|
|
35
|
+
# output_mode validation
|
|
36
|
+
if self.output_mode not in ["markdown", "HTML", "text"]:
|
|
37
|
+
raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
|
|
38
|
+
|
|
39
|
+
# pages validation
|
|
40
|
+
if not isinstance(self.pages, list):
|
|
41
|
+
raise ValueError("pages must be a list of dict")
|
|
42
|
+
for i, page_content in enumerate(self.pages):
|
|
43
|
+
if not isinstance(page_content, dict):
|
|
44
|
+
raise ValueError(f"Each page must be a dict. Page at index {i} is not a dict.")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def add_page(self, text:str, image_processing_status: dict):
|
|
48
|
+
"""
|
|
49
|
+
This method adds a new page to the OCRResult object.
|
|
50
|
+
|
|
51
|
+
Parameters:
|
|
52
|
+
----------
|
|
53
|
+
text : str
|
|
54
|
+
The OCR result text of the page.
|
|
55
|
+
image_processing_status : dict
|
|
56
|
+
A dictionary containing the image processing status for the page.
|
|
57
|
+
It can include keys like 'rotate_correction', 'max_dimension_pixels', etc.
|
|
58
|
+
"""
|
|
59
|
+
if not isinstance(text, str):
|
|
60
|
+
raise ValueError("text must be a string")
|
|
61
|
+
if not isinstance(image_processing_status, dict):
|
|
62
|
+
raise ValueError("image_processing_status must be a dict")
|
|
63
|
+
|
|
64
|
+
page = {
|
|
65
|
+
"text": text,
|
|
66
|
+
"image_processing_status": image_processing_status
|
|
67
|
+
}
|
|
68
|
+
self.pages.append(page)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def __len__(self):
|
|
72
|
+
return len(self.pages)
|
|
73
|
+
|
|
74
|
+
def get_page(self, idx):
|
|
75
|
+
if not isinstance(idx, int):
|
|
76
|
+
raise ValueError("Index must be an integer")
|
|
77
|
+
if idx < 0 or idx >= len(self.pages):
|
|
78
|
+
raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
|
|
79
|
+
|
|
80
|
+
return self.pages[idx]
|
|
81
|
+
|
|
82
|
+
def __iter__(self):
|
|
83
|
+
return iter(self.pages)
|
|
84
|
+
|
|
85
|
+
def __repr__(self):
|
|
86
|
+
return f"OCRResult(filename={self.filename}, output_mode={self.output_mode}, pages_count={len(self.pages)}, status={self.status})"
|
|
87
|
+
|
|
88
|
+
def to_string(self, page_delimiter:str="auto") -> str:
|
|
89
|
+
"""
|
|
90
|
+
Convert the OCRResult object to a string representation.
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
----------
|
|
94
|
+
page_delimiter : str, Optional
|
|
95
|
+
Only applies if separate_pages = True. The delimiter to use between PDF pages.
|
|
96
|
+
if 'auto', it will be set to the default page delimiter for the output mode:
|
|
97
|
+
'markdown' -> '\n\n---\n\n'
|
|
98
|
+
'HTML' -> '<br><br>'
|
|
99
|
+
'text' -> '\n\n---\n\n'
|
|
100
|
+
"""
|
|
101
|
+
if not isinstance(page_delimiter, str):
|
|
102
|
+
raise ValueError("page_delimiter must be a string")
|
|
103
|
+
|
|
104
|
+
if page_delimiter == "auto":
|
|
105
|
+
self.page_delimiter = get_default_page_delimiter(self.output_mode)
|
|
106
|
+
else:
|
|
107
|
+
self.page_delimiter = page_delimiter
|
|
108
|
+
|
|
109
|
+
return self.page_delimiter.join([page.get("text", "") for page in self.pages])
|