vlm4ocr 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +3 -1
- vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt +1 -0
- vlm4ocr/cli.py +276 -287
- vlm4ocr/data_types.py +109 -0
- vlm4ocr/ocr_engines.py +363 -195
- vlm4ocr/utils.py +386 -39
- vlm4ocr/vlm_engines.py +316 -190
- {vlm4ocr-0.1.0.dist-info → vlm4ocr-0.3.0.dist-info}/METADATA +5 -1
- vlm4ocr-0.3.0.dist-info/RECORD +17 -0
- vlm4ocr-0.1.0.dist-info/RECORD +0 -15
- {vlm4ocr-0.1.0.dist-info → vlm4ocr-0.3.0.dist-info}/WHEEL +0 -0
- {vlm4ocr-0.1.0.dist-info → vlm4ocr-0.3.0.dist-info}/entry_points.txt +0 -0
vlm4ocr/cli.py
CHANGED
|
@@ -1,378 +1,367 @@
|
|
|
1
|
-
# vlm4ocr/cli.py
|
|
2
|
-
|
|
3
1
|
import argparse
|
|
4
2
|
import os
|
|
5
3
|
import sys
|
|
6
4
|
import logging
|
|
5
|
+
import asyncio
|
|
6
|
+
import time
|
|
7
7
|
|
|
8
8
|
# Attempt to import from the local package structure
|
|
9
|
-
# This allows running the script directly for development,
|
|
10
|
-
# assuming the script is in vlm4ocr/vlm4ocr/cli.py and the package root is vlm4ocr/vlm4ocr
|
|
11
9
|
try:
|
|
12
|
-
from .ocr_engines import OCREngine
|
|
13
|
-
from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine
|
|
10
|
+
from .ocr_engines import OCREngine
|
|
11
|
+
from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
12
|
+
from .data_types import OCRResult
|
|
14
13
|
except ImportError:
|
|
15
|
-
# Fallback for when the package is installed
|
|
14
|
+
# Fallback for when the package is installed
|
|
16
15
|
from vlm4ocr.ocr_engines import OCREngine
|
|
17
|
-
from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine
|
|
16
|
+
from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
17
|
+
from vlm4ocr.data_types import OCRResult
|
|
18
|
+
|
|
19
|
+
import tqdm.asyncio
|
|
18
20
|
|
|
19
|
-
#
|
|
20
|
-
logging.basicConfig(
|
|
21
|
-
|
|
21
|
+
# --- Global logger setup (console) ---
|
|
22
|
+
logging.basicConfig(
|
|
23
|
+
level=logging.INFO,
|
|
24
|
+
format='%(asctime)s - %(levelname)s: %(message)s',
|
|
25
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
|
26
|
+
)
|
|
27
|
+
logger = logging.getLogger("vlm4ocr_cli")
|
|
22
28
|
|
|
23
|
-
# Define supported extensions here, ideally this should be sourced from ocr_engines.py
|
|
24
29
|
SUPPORTED_IMAGE_EXTS_CLI = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
25
|
-
OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'txt'}
|
|
30
|
+
OUTPUT_EXTENSIONS = {'markdown': '.md', 'HTML':'.html', 'text':'.txt'}
|
|
26
31
|
|
|
27
|
-
def
|
|
32
|
+
def get_output_path_for_ocr_result(input_file_path, specified_output_path_arg, output_mode, num_total_inputs, base_output_dir_if_no_specific_path):
|
|
28
33
|
"""
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
Determines the full output path for a given OCR result file.
|
|
35
|
+
Output filename format: <original_basename>_ocr.<new_extension>
|
|
36
|
+
Example: input "abc.pdf", output_mode "markdown" -> "abc.pdf_ocr.md"
|
|
31
37
|
"""
|
|
38
|
+
original_basename = os.path.basename(input_file_path)
|
|
39
|
+
output_filename_core = f"{original_basename}_ocr"
|
|
40
|
+
|
|
41
|
+
output_filename_ext = OUTPUT_EXTENSIONS.get(output_mode, '.txt')
|
|
42
|
+
final_output_filename = f"{output_filename_core}{output_filename_ext}"
|
|
43
|
+
|
|
44
|
+
if specified_output_path_arg: # If --output_path is used
|
|
45
|
+
# Scenario 1: Multiple input files, --output_path is expected to be a directory.
|
|
46
|
+
if num_total_inputs > 1 and os.path.isdir(specified_output_path_arg):
|
|
47
|
+
return os.path.join(specified_output_path_arg, final_output_filename)
|
|
48
|
+
# Scenario 2: Single input file.
|
|
49
|
+
# --output_path could be a full file path OR a directory.
|
|
50
|
+
elif num_total_inputs == 1:
|
|
51
|
+
if os.path.isdir(specified_output_path_arg): # If --output_path is a directory for the single file
|
|
52
|
+
return os.path.join(specified_output_path_arg, final_output_filename)
|
|
53
|
+
else: # If --output_path is a specific file name for the single file
|
|
54
|
+
return specified_output_path_arg
|
|
55
|
+
# Scenario 3: Multiple input files, but --output_path is NOT a directory (error, handled before this fn)
|
|
56
|
+
# or other edge cases, fall back to base_output_dir_if_no_specific_path
|
|
57
|
+
else:
|
|
58
|
+
return os.path.join(base_output_dir_if_no_specific_path, final_output_filename)
|
|
59
|
+
else: # No --output_path, save to the determined base output directory
|
|
60
|
+
return os.path.join(base_output_dir_if_no_specific_path, final_output_filename)
|
|
61
|
+
|
|
62
|
+
def setup_file_logger(log_dir, timestamp_str, debug_mode):
|
|
63
|
+
"""Sets up a file handler for logging."""
|
|
64
|
+
log_file_name = f"vlm4ocr_{timestamp_str}.log"
|
|
65
|
+
log_file_path = os.path.join(log_dir, log_file_name)
|
|
66
|
+
|
|
67
|
+
file_handler = logging.FileHandler(log_file_path, mode='a')
|
|
68
|
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
|
69
|
+
file_handler.setFormatter(formatter)
|
|
70
|
+
|
|
71
|
+
log_level = logging.DEBUG if debug_mode else logging.INFO
|
|
72
|
+
file_handler.setLevel(log_level)
|
|
73
|
+
|
|
74
|
+
logger.addHandler(file_handler)
|
|
75
|
+
logger.info(f"Logging to file: {log_file_path}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def main():
|
|
32
79
|
parser = argparse.ArgumentParser(
|
|
33
|
-
description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models.",
|
|
34
|
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
80
|
+
description="VLM4OCR: Perform OCR on images, PDFs, or TIFF files using Vision Language Models. Processing is concurrent by default.",
|
|
81
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
35
82
|
)
|
|
36
83
|
|
|
37
|
-
# --- Input/Output Arguments ---
|
|
38
84
|
io_group = parser.add_argument_group("Input/Output Options")
|
|
39
|
-
io_group.add_argument(
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
"If a directory is provided, all supported files within will be processed."
|
|
44
|
-
)
|
|
45
|
-
io_group.add_argument(
|
|
46
|
-
"--output_mode",
|
|
47
|
-
choices=["markdown", "HTML", "text"],
|
|
48
|
-
default="markdown",
|
|
49
|
-
help="Desired output format for the OCR results."
|
|
50
|
-
)
|
|
51
|
-
io_group.add_argument(
|
|
52
|
-
"--output_file",
|
|
53
|
-
help="Optional: Path to a file to save the output. "
|
|
54
|
-
"If input_path is a directory, this should be a directory where results will be saved "
|
|
55
|
-
"(one file per input, with original name and new extension). "
|
|
56
|
-
"If not provided, output is written to files in the current working directory "
|
|
57
|
-
"(e.g., 'input_name_ocr.output_mode')."
|
|
58
|
-
)
|
|
85
|
+
io_group.add_argument("--input_path", required=True, help="Path to a single input file or a directory of files.")
|
|
86
|
+
io_group.add_argument("--output_mode", choices=["markdown", "HTML", "text"], default="markdown", help="Output format.")
|
|
87
|
+
io_group.add_argument("--output_path", help="Optional: Path to save OCR results. If input_path is a directory of multiple files, this should be an output directory. If input is a single file, this can be a full file path or a directory. If not provided, results are saved to the current working directory (or a sub-directory for logs if --log is used).")
|
|
88
|
+
io_group.add_argument("--skip_existing", action="store_true", help="Skip processing files that already have OCR results in the output directory.")
|
|
59
89
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
"
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
"
|
|
70
|
-
required=True,
|
|
71
|
-
help="The specific model identifier for the chosen VLM engine. "
|
|
72
|
-
"E.g., 'gpt-4o' for OpenAI, 'deployment-name' for Azure, "
|
|
73
|
-
"'Qwen/Qwen2.5-VL-7B-Instruct' for OpenAI-compatible, "
|
|
74
|
-
"or 'llava:latest' for Ollama."
|
|
90
|
+
image_processing_group = parser.add_argument_group("Image Processing Parameters")
|
|
91
|
+
image_processing_group.add_argument(
|
|
92
|
+
"--rotate_correction",
|
|
93
|
+
action="store_true",
|
|
94
|
+
help="Enable automatic rotation correction for input images. This requires Tesseract OCR to be installed and configured correctly.")
|
|
95
|
+
image_processing_group.add_argument(
|
|
96
|
+
"--max_dimension_pixels",
|
|
97
|
+
type=int,
|
|
98
|
+
default=4000,
|
|
99
|
+
help="Maximum dimension (width or height) in pixels for input images. Images larger than this will be resized to fit within this limit while maintaining aspect ratio."
|
|
75
100
|
)
|
|
76
101
|
|
|
77
|
-
|
|
102
|
+
vlm_engine_group = parser.add_argument_group("VLM Engine Options")
|
|
103
|
+
vlm_engine_group.add_argument("--vlm_engine", choices=["openai", "azure_openai", "ollama", "openai_compatible"], required=True, help="VLM engine.")
|
|
104
|
+
vlm_engine_group.add_argument("--model", required=True, help="Model identifier for the VLM engine.")
|
|
105
|
+
vlm_engine_group.add_argument("--max_new_tokens", type=int, default=4096, help="Max new tokens for VLM.")
|
|
106
|
+
vlm_engine_group.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature.")
|
|
107
|
+
|
|
78
108
|
openai_group = parser.add_argument_group("OpenAI & OpenAI-Compatible Options")
|
|
79
|
-
openai_group.add_argument(
|
|
80
|
-
|
|
81
|
-
default=os.environ.get("OPENAI_API_KEY"),
|
|
82
|
-
help="API key for OpenAI or OpenAI-compatible service. "
|
|
83
|
-
"Can also be set via OPENAI_API_KEY environment variable."
|
|
84
|
-
)
|
|
85
|
-
openai_group.add_argument(
|
|
86
|
-
"--base_url",
|
|
87
|
-
help="Base URL for OpenAI-compatible services (e.g., vLLM endpoint like 'http://localhost:8000/v1'). "
|
|
88
|
-
"Not used for official OpenAI API."
|
|
89
|
-
)
|
|
109
|
+
openai_group.add_argument("--api_key", default=os.environ.get("OPENAI_API_KEY"), help="API key.")
|
|
110
|
+
openai_group.add_argument("--base_url", help="Base URL for OpenAI-compatible services.")
|
|
90
111
|
|
|
91
|
-
# --- Azure OpenAI Engine Arguments ---
|
|
92
112
|
azure_group = parser.add_argument_group("Azure OpenAI Options")
|
|
93
|
-
azure_group.add_argument(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
help="API key for Azure OpenAI service. "
|
|
97
|
-
"Can also be set via AZURE_OPENAI_API_KEY environment variable."
|
|
98
|
-
)
|
|
99
|
-
azure_group.add_argument(
|
|
100
|
-
"--azure_endpoint",
|
|
101
|
-
default=os.environ.get("AZURE_OPENAI_ENDPOINT"),
|
|
102
|
-
help="Endpoint URL for Azure OpenAI service. "
|
|
103
|
-
"Can also be set via AZURE_OPENAI_ENDPOINT environment variable."
|
|
104
|
-
)
|
|
105
|
-
azure_group.add_argument(
|
|
106
|
-
"--azure_api_version",
|
|
107
|
-
default=os.environ.get("AZURE_OPENAI_API_VERSION"),
|
|
108
|
-
help="API version for Azure OpenAI service (e.g., '2024-02-01'). "
|
|
109
|
-
"Can also be set via AZURE_OPENAI_API_VERSION environment variable."
|
|
110
|
-
)
|
|
113
|
+
azure_group.add_argument("--azure_api_key", default=os.environ.get("AZURE_OPENAI_API_KEY"), help="Azure API key.")
|
|
114
|
+
azure_group.add_argument("--azure_endpoint", default=os.environ.get("AZURE_OPENAI_ENDPOINT"), help="Azure endpoint URL.")
|
|
115
|
+
azure_group.add_argument("--azure_api_version", default=os.environ.get("AZURE_OPENAI_API_VERSION"), help="Azure API version.")
|
|
111
116
|
|
|
112
|
-
# --- Ollama Engine Arguments ---
|
|
113
117
|
ollama_group = parser.add_argument_group("Ollama Options")
|
|
114
|
-
ollama_group.add_argument(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
help="Host URL for the Ollama server."
|
|
118
|
-
)
|
|
119
|
-
ollama_group.add_argument(
|
|
120
|
-
"--ollama_num_ctx",
|
|
121
|
-
type=int,
|
|
122
|
-
default=4096,
|
|
123
|
-
help="Context length for Ollama models."
|
|
124
|
-
)
|
|
125
|
-
ollama_group.add_argument(
|
|
126
|
-
"--ollama_keep_alive",
|
|
127
|
-
type=int,
|
|
128
|
-
default=300, # Default from OllamaVLMEngine
|
|
129
|
-
help="Seconds to keep the Ollama model loaded after the last call."
|
|
130
|
-
)
|
|
131
|
-
|
|
118
|
+
ollama_group.add_argument("--ollama_host", default="http://localhost:11434", help="Ollama host URL.")
|
|
119
|
+
ollama_group.add_argument("--ollama_num_ctx", type=int, default=4096, help="Context length for Ollama.")
|
|
120
|
+
ollama_group.add_argument("--ollama_keep_alive", type=int, default=300, help="Ollama keep_alive seconds.")
|
|
132
121
|
|
|
133
|
-
# --- OCR Engine Parameters ---
|
|
134
122
|
ocr_params_group = parser.add_argument_group("OCR Engine Parameters")
|
|
135
|
-
ocr_params_group.add_argument(
|
|
136
|
-
"--user_prompt",
|
|
137
|
-
help="Optional: Custom user prompt to provide context about the image/PDF/TIFF."
|
|
138
|
-
)
|
|
139
|
-
# REMOVED --system_prompt argument
|
|
140
|
-
ocr_params_group.add_argument(
|
|
141
|
-
"--max_new_tokens",
|
|
142
|
-
type=int,
|
|
143
|
-
default=4096, # Default from OCREngine
|
|
144
|
-
help="Maximum number of new tokens the VLM can generate."
|
|
145
|
-
)
|
|
146
|
-
ocr_params_group.add_argument(
|
|
147
|
-
"--temperature",
|
|
148
|
-
type=float,
|
|
149
|
-
default=0.0, # Default from OCREngine
|
|
150
|
-
help="Temperature for token sampling (0.0 for deterministic output)."
|
|
151
|
-
)
|
|
123
|
+
ocr_params_group.add_argument("--user_prompt", help="Custom user prompt.")
|
|
152
124
|
|
|
153
|
-
# --- Processing Options ---
|
|
154
125
|
processing_group = parser.add_argument_group("Processing Options")
|
|
155
|
-
processing_group.add_argument(
|
|
156
|
-
"--concurrent",
|
|
157
|
-
action="store_true",
|
|
158
|
-
help="Enable concurrent processing for multiple files or PDF/TIFF pages."
|
|
159
|
-
)
|
|
160
126
|
processing_group.add_argument(
|
|
161
127
|
"--concurrent_batch_size",
|
|
162
128
|
type=int,
|
|
163
|
-
default=
|
|
164
|
-
help="
|
|
129
|
+
default=4,
|
|
130
|
+
help="Number of images/pages to process concurrently. Set to 1 for sequential processing of VLM calls."
|
|
165
131
|
)
|
|
166
132
|
processing_group.add_argument(
|
|
167
|
-
"--
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
processing_group.add_argument(
|
|
172
|
-
"--debug",
|
|
173
|
-
action="store_true",
|
|
174
|
-
help="Enable debug level logging for more detailed information."
|
|
133
|
+
"--max_file_load",
|
|
134
|
+
type=int,
|
|
135
|
+
default=-1,
|
|
136
|
+
help="Number of input files to pre-load. Set to -1 for automatic config: 2 * concurrent_batch_size."
|
|
175
137
|
)
|
|
138
|
+
# --verbose flag was removed by user in previous version provided
|
|
139
|
+
processing_group.add_argument("--log", action="store_true", help="Enable writing logs to a timestamped file in the output directory.")
|
|
140
|
+
processing_group.add_argument("--debug", action="store_true", help="Enable debug level logging for console (and file if --log is active).")
|
|
176
141
|
|
|
177
142
|
args = parser.parse_args()
|
|
143
|
+
|
|
144
|
+
current_timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
178
145
|
|
|
146
|
+
# --- Configure Logger Level based on args ---
|
|
179
147
|
if args.debug:
|
|
180
|
-
logging.getLogger().setLevel(logging.DEBUG)
|
|
181
148
|
logger.setLevel(logging.DEBUG)
|
|
182
|
-
logger
|
|
183
|
-
logger.
|
|
184
|
-
|
|
185
|
-
logger.
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
149
|
+
# Set root logger to DEBUG only if our specific logger is DEBUG, to avoid overly verbose library logs unless intended.
|
|
150
|
+
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
151
|
+
logging.getLogger().setLevel(logging.DEBUG)
|
|
152
|
+
logger.debug("Debug mode enabled for console.")
|
|
153
|
+
else:
|
|
154
|
+
logger.setLevel(logging.INFO) # Default for our CLI's own messages
|
|
155
|
+
logging.getLogger().setLevel(logging.WARNING) # Keep external libraries quieter by default
|
|
156
|
+
|
|
157
|
+
if args.concurrent_batch_size < 1:
|
|
158
|
+
parser.error("--concurrent_batch_size must be 1 or greater.")
|
|
159
|
+
|
|
160
|
+
# --- Determine Effective Output Directory (for logs and default OCR outputs) ---
|
|
161
|
+
effective_output_dir = os.getcwd() # Default if no --output_path
|
|
162
|
+
|
|
163
|
+
# Preliminary check to see if multiple files will be processed
|
|
164
|
+
_is_multi_file_scenario = False
|
|
165
|
+
if os.path.isdir(args.input_path):
|
|
166
|
+
_temp_files_list = [f for f in os.listdir(args.input_path) if os.path.isfile(os.path.join(args.input_path, f)) and os.path.splitext(f)[1].lower() in SUPPORTED_IMAGE_EXTS_CLI]
|
|
167
|
+
if len(_temp_files_list) > 1:
|
|
168
|
+
_is_multi_file_scenario = True
|
|
169
|
+
|
|
170
|
+
if args.output_path:
|
|
171
|
+
if _is_multi_file_scenario: # Input is a dir with multiple files
|
|
172
|
+
if os.path.exists(args.output_path) and not os.path.isdir(args.output_path):
|
|
173
|
+
logger.critical(f"Output path '{args.output_path}' must be a directory when processing multiple files. It currently points to a file.")
|
|
174
|
+
sys.exit(1)
|
|
175
|
+
effective_output_dir = args.output_path # --output_path is the directory for outputs and logs
|
|
176
|
+
else: # Single input file scenario
|
|
177
|
+
# If args.output_path is a directory, use it.
|
|
178
|
+
# If args.output_path is a file path, use its directory for logs.
|
|
179
|
+
if os.path.isdir(args.output_path):
|
|
180
|
+
effective_output_dir = args.output_path
|
|
181
|
+
else: # Assumed to be a file path
|
|
182
|
+
dir_name = os.path.dirname(args.output_path)
|
|
183
|
+
if dir_name: # If output_path includes a directory
|
|
184
|
+
effective_output_dir = dir_name
|
|
185
|
+
else: # output_path is just a filename, logs go to CWD
|
|
186
|
+
effective_output_dir = os.getcwd()
|
|
187
|
+
|
|
188
|
+
if not os.path.exists(effective_output_dir):
|
|
189
|
+
logger.info(f"Creating output directory: {effective_output_dir}")
|
|
190
|
+
os.makedirs(effective_output_dir, exist_ok=True)
|
|
191
|
+
|
|
192
|
+
# --- Setup File Logger (if --log is specified) ---
|
|
193
|
+
if args.log:
|
|
194
|
+
setup_file_logger(effective_output_dir, current_timestamp_str, args.debug)
|
|
195
|
+
|
|
196
|
+
logger.debug(f"Parsed arguments: {args}")
|
|
197
|
+
|
|
194
198
|
# --- Initialize VLM Engine ---
|
|
195
199
|
vlm_engine_instance = None
|
|
196
200
|
try:
|
|
197
201
|
logger.info(f"Initializing VLM engine: {args.vlm_engine} with model: {args.model}")
|
|
202
|
+
config = BasicVLMConfig(
|
|
203
|
+
max_new_tokens=args.max_new_tokens,
|
|
204
|
+
temperature=args.temperature
|
|
205
|
+
)
|
|
198
206
|
if args.vlm_engine == "openai":
|
|
199
|
-
if not args.api_key:
|
|
200
|
-
|
|
201
|
-
vlm_engine_instance = OpenAIVLMEngine(
|
|
202
|
-
model=args.model,
|
|
203
|
-
api_key=args.api_key
|
|
204
|
-
# reasoning_model removed
|
|
205
|
-
)
|
|
207
|
+
if not args.api_key: parser.error("--api_key (or OPENAI_API_KEY) is required for OpenAI.")
|
|
208
|
+
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
|
|
206
209
|
elif args.vlm_engine == "openai_compatible":
|
|
207
|
-
if not args.
|
|
208
|
-
|
|
209
|
-
if not args.base_url:
|
|
210
|
-
parser.error("--base_url is required for openai_compatible engine.")
|
|
211
|
-
vlm_engine_instance = OpenAIVLMEngine(
|
|
212
|
-
model=args.model,
|
|
213
|
-
api_key=args.api_key,
|
|
214
|
-
base_url=args.base_url
|
|
215
|
-
# reasoning_model removed
|
|
216
|
-
)
|
|
210
|
+
if not args.base_url: parser.error("--base_url is required for openai_compatible.")
|
|
211
|
+
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
|
|
217
212
|
elif args.vlm_engine == "azure_openai":
|
|
218
|
-
if not args.azure_api_key:
|
|
219
|
-
|
|
220
|
-
if not args.
|
|
221
|
-
|
|
222
|
-
if not args.azure_api_version:
|
|
223
|
-
parser.error("--azure_api_version (or AZURE_OPENAI_API_VERSION env var) is required for Azure OpenAI engine.")
|
|
224
|
-
vlm_engine_instance = AzureOpenAIVLMEngine(
|
|
225
|
-
model=args.model,
|
|
226
|
-
api_key=args.azure_api_key,
|
|
227
|
-
azure_endpoint=args.azure_endpoint,
|
|
228
|
-
api_version=args.azure_api_version
|
|
229
|
-
# reasoning_model removed
|
|
230
|
-
)
|
|
213
|
+
if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
|
|
214
|
+
if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
|
|
215
|
+
if not args.azure_api_version: parser.error("--azure_api_version (or AZURE_OPENAI_API_VERSION) is required.")
|
|
216
|
+
vlm_engine_instance = AzureOpenAIVLMEngine(model=args.model, api_key=args.azure_api_key, azure_endpoint=args.azure_endpoint, api_version=args.azure_api_version, config=config)
|
|
231
217
|
elif args.vlm_engine == "ollama":
|
|
232
|
-
vlm_engine_instance = OllamaVLMEngine(
|
|
233
|
-
model_name=args.model, # OllamaVLMEngine expects model_name
|
|
234
|
-
host=args.ollama_host,
|
|
235
|
-
num_ctx=args.ollama_num_ctx,
|
|
236
|
-
keep_alive=args.ollama_keep_alive
|
|
237
|
-
)
|
|
238
|
-
else:
|
|
239
|
-
# This case should be caught by argparse choices, but as a safeguard:
|
|
240
|
-
logger.error(f"Invalid VLM engine specified: {args.vlm_engine}")
|
|
241
|
-
sys.exit(1)
|
|
218
|
+
vlm_engine_instance = OllamaVLMEngine(model_name=args.model, host=args.ollama_host, num_ctx=args.ollama_num_ctx, keep_alive=args.ollama_keep_alive, config=config)
|
|
242
219
|
logger.info("VLM engine initialized successfully.")
|
|
243
|
-
|
|
244
220
|
except ImportError as e:
|
|
245
|
-
logger.error(f"Failed to import
|
|
246
|
-
"Please ensure the necessary dependencies (e.g., 'openai', 'ollama') are installed.")
|
|
221
|
+
logger.error(f"Failed to import library for {args.vlm_engine}: {e}. Install dependencies.")
|
|
247
222
|
sys.exit(1)
|
|
248
223
|
except Exception as e:
|
|
249
224
|
logger.error(f"Error initializing VLM engine '{args.vlm_engine}': {e}")
|
|
250
|
-
if args.debug:
|
|
251
|
-
logger.exception("Traceback for VLM engine initialization error:")
|
|
225
|
+
if args.debug: logger.exception("Traceback:")
|
|
252
226
|
sys.exit(1)
|
|
253
227
|
|
|
254
228
|
# --- Initialize OCR Engine ---
|
|
255
229
|
try:
|
|
256
230
|
logger.info(f"Initializing OCR engine with output mode: {args.output_mode}")
|
|
257
|
-
ocr_engine_instance = OCREngine(
|
|
258
|
-
vlm_engine=vlm_engine_instance,
|
|
259
|
-
output_mode=args.output_mode,
|
|
260
|
-
# system_prompt removed, OCREngine will use its default
|
|
261
|
-
user_prompt=args.user_prompt
|
|
262
|
-
)
|
|
231
|
+
ocr_engine_instance = OCREngine(vlm_engine=vlm_engine_instance, output_mode=args.output_mode, user_prompt=args.user_prompt)
|
|
263
232
|
logger.info("OCR engine initialized successfully.")
|
|
264
233
|
except Exception as e:
|
|
265
234
|
logger.error(f"Error initializing OCR engine: {e}")
|
|
266
|
-
if args.debug:
|
|
267
|
-
logger.exception("Traceback for OCR engine initialization error:")
|
|
235
|
+
if args.debug: logger.exception("Traceback:")
|
|
268
236
|
sys.exit(1)
|
|
269
237
|
|
|
270
|
-
# --- Prepare input file paths ---
|
|
238
|
+
# --- Prepare input file paths (actual list) ---
|
|
271
239
|
input_files_to_process = []
|
|
272
240
|
if os.path.isdir(args.input_path):
|
|
273
|
-
logger.info(f"Input
|
|
241
|
+
logger.info(f"Input is directory: {args.input_path}. Scanning for files...")
|
|
274
242
|
for item in os.listdir(args.input_path):
|
|
275
243
|
item_path = os.path.join(args.input_path, item)
|
|
276
|
-
if os.path.isfile(item_path):
|
|
277
|
-
|
|
278
|
-
if file_ext in SUPPORTED_IMAGE_EXTS_CLI:
|
|
279
|
-
input_files_to_process.append(item_path)
|
|
244
|
+
if os.path.isfile(item_path) and os.path.splitext(item)[1].lower() in SUPPORTED_IMAGE_EXTS_CLI:
|
|
245
|
+
input_files_to_process.append(item_path)
|
|
280
246
|
if not input_files_to_process:
|
|
281
|
-
logger.error(f"No supported files
|
|
247
|
+
logger.error(f"No supported files found in directory: {args.input_path}")
|
|
282
248
|
sys.exit(1)
|
|
283
|
-
logger.info(f"Found {len(input_files_to_process)}
|
|
249
|
+
logger.info(f"Found {len(input_files_to_process)} files to process.")
|
|
284
250
|
elif os.path.isfile(args.input_path):
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
logger.error(f"Input file '{args.input_path}' is not a supported file type. Supported: {SUPPORTED_IMAGE_EXTS_CLI}")
|
|
251
|
+
if os.path.splitext(args.input_path)[1].lower() not in SUPPORTED_IMAGE_EXTS_CLI:
|
|
252
|
+
logger.error(f"Input file '{args.input_path}' is not supported. Supported: {SUPPORTED_IMAGE_EXTS_CLI}")
|
|
288
253
|
sys.exit(1)
|
|
289
254
|
input_files_to_process = [args.input_path]
|
|
290
255
|
logger.info(f"Processing single input file: {args.input_path}")
|
|
291
256
|
else:
|
|
292
|
-
logger.error(f"Input path
|
|
257
|
+
logger.error(f"Input path not valid: {args.input_path}")
|
|
293
258
|
sys.exit(1)
|
|
259
|
+
|
|
260
|
+
# --- Skip existing files if --skip_existing is used ---
|
|
261
|
+
if args.skip_existing:
|
|
262
|
+
logger.info("Checking for existing OCR results in output path to skip...")
|
|
263
|
+
# Check each input file against the expected output file
|
|
264
|
+
existing_files = os.listdir(effective_output_dir)
|
|
265
|
+
filtered_input_files_to_process = []
|
|
266
|
+
for input_file in input_files_to_process:
|
|
267
|
+
expected_output_name = get_output_path_for_ocr_result(input_file, args.output_path, args.output_mode, len(input_files_to_process), effective_output_dir)
|
|
268
|
+
if os.path.basename(expected_output_name) not in existing_files:
|
|
269
|
+
filtered_input_files_to_process.append(input_file)
|
|
294
270
|
|
|
271
|
+
original_num_files = len(input_files_to_process)
|
|
272
|
+
after_filter_num_files = len(filtered_input_files_to_process)
|
|
273
|
+
input_files_to_process = filtered_input_files_to_process
|
|
274
|
+
logger.info(f"Dropped {original_num_files - after_filter_num_files} existing files. Number of input files to process after filtering: {len(input_files_to_process)}")
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
logger.info("All input files will be processed (`--skip_existing=False`).")
|
|
278
|
+
# This re-evaluation is useful if the initial _is_multi_file_scenario was just for log dir
|
|
279
|
+
num_actual_files = len(input_files_to_process)
|
|
295
280
|
|
|
296
281
|
# --- Run OCR ---
|
|
297
282
|
try:
|
|
298
|
-
logger.info("
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
283
|
+
logger.info(f"Processing with concurrent_batch_size: {args.concurrent_batch_size}.")
|
|
284
|
+
|
|
285
|
+
async def process_and_write_concurrently():
|
|
286
|
+
ocr_task_generator = ocr_engine_instance.concurrent_ocr(
|
|
287
|
+
file_paths=input_files_to_process,
|
|
288
|
+
rotate_correction=args.rotate_correction,
|
|
289
|
+
max_dimension_pixels=args.max_dimension_pixels,
|
|
290
|
+
concurrent_batch_size=args.concurrent_batch_size,
|
|
291
|
+
max_file_load=args.max_file_load if args.max_file_load > 0 else None
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Progress bar always attempted if tqdm is available and files exist,
|
|
295
|
+
# console verbosity controlled by logger level.
|
|
296
|
+
show_progress_bar = (num_actual_files > 0)
|
|
297
|
+
|
|
298
|
+
iterator_wrapper = tqdm.asyncio.tqdm(
|
|
299
|
+
ocr_task_generator,
|
|
300
|
+
total=num_actual_files,
|
|
301
|
+
desc="Processing files",
|
|
302
|
+
unit="file",
|
|
303
|
+
disable=not show_progress_bar # disable if no files, or can remove this disable if tqdm handles total=0
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
async for result_object in iterator_wrapper:
|
|
307
|
+
if not isinstance(result_object, OCRResult):
|
|
308
|
+
logger.warning(f"Received unexpected data type: {type(result_object)}")
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
input_file_path_from_result = result_object.input_dir
|
|
312
|
+
# For get_output_path_for_ocr_result, effective_output_dir is the base if args.output_path isn't specific enough
|
|
313
|
+
current_ocr_output_file_path = get_output_path_for_ocr_result(
|
|
314
|
+
input_file_path_from_result, args.output_path, args.output_mode,
|
|
315
|
+
num_actual_files, effective_output_dir
|
|
316
|
+
)
|
|
319
317
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
# Ensure its directory exists.
|
|
324
|
-
output_target_dir = os.path.dirname(args.output_file)
|
|
325
|
-
if output_target_dir and not os.path.exists(output_target_dir):
|
|
326
|
-
logger.info(f"Creating output directory: {output_target_dir}")
|
|
327
|
-
os.makedirs(output_target_dir, exist_ok=True)
|
|
328
|
-
else: # Should not happen if logic above is correct
|
|
329
|
-
output_target_dir = os.getcwd()
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
for i, input_file_path in enumerate(input_files_to_process):
|
|
333
|
-
if os.path.isdir(args.input_path) and len(input_files_to_process) > 1:
|
|
334
|
-
# Multiple inputs, save into the directory specified by args.output_file
|
|
335
|
-
base_name = os.path.basename(input_file_path)
|
|
336
|
-
name_part, _ = os.path.splitext(base_name)
|
|
337
|
-
output_filename = f"{name_part}_ocr{OUTPUT_EXTENSIONS[args.output_mode]}"
|
|
338
|
-
full_output_path = os.path.join(args.output_file, output_filename)
|
|
318
|
+
if result_object.status == "error":
|
|
319
|
+
error_message = result_object.get_page(0) if len(result_object) > 0 else 'Unknown error during OCR'
|
|
320
|
+
logger.error(f"OCR failed for {result_object.filename}: {error_message}")
|
|
339
321
|
else:
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
322
|
+
try:
|
|
323
|
+
content_to_write = result_object.to_string()
|
|
324
|
+
with open(current_ocr_output_file_path, "w", encoding="utf-8") as f:
|
|
325
|
+
f.write(content_to_write)
|
|
326
|
+
# Log less verbosely to console if progress bar is active
|
|
327
|
+
if not show_progress_bar or logger.getEffectiveLevel() <= logging.DEBUG:
|
|
328
|
+
logger.info(f"OCR result for '{input_file_path_from_result}' saved to: {current_ocr_output_file_path}")
|
|
329
|
+
except Exception as e:
|
|
330
|
+
logger.error(f"Error writing output for '{input_file_path_from_result}' to '{current_ocr_output_file_path}': {e}")
|
|
331
|
+
|
|
332
|
+
if hasattr(iterator_wrapper, 'close') and isinstance(iterator_wrapper, tqdm.asyncio.tqdm):
|
|
333
|
+
if iterator_wrapper.n < iterator_wrapper.total:
|
|
334
|
+
iterator_wrapper.n = iterator_wrapper.total
|
|
335
|
+
iterator_wrapper.refresh()
|
|
336
|
+
iterator_wrapper.close()
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
asyncio.run(process_and_write_concurrently())
|
|
340
|
+
except RuntimeError as e:
|
|
341
|
+
if "asyncio.run() cannot be called from a running event loop" in str(e):
|
|
342
|
+
logger.warning("asyncio.run() error. Attempting to use existing loop.")
|
|
343
|
+
loop = asyncio.get_event_loop_policy().get_event_loop()
|
|
344
|
+
if loop.is_running():
|
|
345
|
+
logger.critical("Cannot execute in current asyncio context. If in Jupyter, try 'import nest_asyncio; nest_asyncio.apply()'.")
|
|
346
|
+
sys.exit(1)
|
|
347
|
+
else:
|
|
348
|
+
loop.run_until_complete(process_and_write_concurrently())
|
|
349
|
+
else: raise e
|
|
350
|
+
|
|
351
|
+
logger.info("All processing finished.")
|
|
364
352
|
|
|
365
353
|
except FileNotFoundError as e:
|
|
366
|
-
logger.error(f"File not found
|
|
354
|
+
logger.error(f"File not found: {e}")
|
|
355
|
+
if args.debug: logger.exception("Traceback:")
|
|
367
356
|
sys.exit(1)
|
|
368
|
-
except ValueError as e:
|
|
369
|
-
logger.error(f"Input
|
|
357
|
+
except ValueError as e:
|
|
358
|
+
logger.error(f"Input/Value Error: {e}")
|
|
359
|
+
if args.debug: logger.exception("Traceback:")
|
|
370
360
|
sys.exit(1)
|
|
371
361
|
except Exception as e:
|
|
372
|
-
logger.error(f"
|
|
373
|
-
if args.debug:
|
|
374
|
-
logger.exception("Traceback for OCR processing error:")
|
|
362
|
+
logger.error(f"Unexpected error during main processing: {e}")
|
|
363
|
+
if args.debug: logger.exception("Traceback:")
|
|
375
364
|
sys.exit(1)
|
|
376
365
|
|
|
377
366
|
if __name__ == "__main__":
|
|
378
|
-
main()
|
|
367
|
+
main()
|