vlm4ocr 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.1
2
+ Name: vlm4ocr
3
+ Version: 0.0.1
4
+ Summary: OCR with vision language models.
5
+ License: MIT
6
+ Author: Enshuo (David) Hsu
7
+ Requires-Python: >=3.11,<4.0
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: pdf2image (>=1.16.0)
13
+ Requires-Dist: pillow (>=10.0.0)
14
+ Description-Content-Type: text/markdown
15
+
16
+ This is the readme for vlm4ocr Python package.
@@ -0,0 +1 @@
1
+ This is the readme for vlm4ocr Python package.
@@ -0,0 +1,23 @@
1
+ [tool.poetry]
2
+ name = "vlm4ocr"
3
+ version = "0.0.1"
4
+ description = "OCR with vision language models."
5
+ authors = ["Enshuo (David) Hsu"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+
9
+ exclude = [
10
+ "tests/**",
11
+ "develop/**"
12
+ ]
13
+
14
+
15
+ [tool.poetry.dependencies]
16
+ python = "^3.11"
17
+ pdf2image = ">=1.16.0"
18
+ pillow = ">=10.0.0"
19
+
20
+
21
+ [build-system]
22
+ requires = ["poetry-core"]
23
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,9 @@
1
+ from .ocr_engines import OCREngine
2
+ from .vlm_engines import OllamaVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
3
+
4
+ __all__ = [
5
+ "OCREngine",
6
+ "OllamaVLMEngine",
7
+ "OpenAIVLMEngine",
8
+ "AzureOpenAIVLMEngine"
9
+ ]
@@ -0,0 +1 @@
1
+ You are a helpful assistant that can convert scanned documents into markdown text. Your output is accurate and well-formatted, starting with ```markdown and ending with ```. You will only output the markdown text without any additional explanations or comments. The markdown should include all text, tables, and lists with appropriate headers (e.g., "##"). You will ignore images, icons, or anything that can not be converted into text.
@@ -0,0 +1 @@
1
+ You are a helpful assistant that can convert scanned documents into plain text. Your output is accurate and well-formatted. You will only output the plain text without any additional explanations or comments. The plain text should include all text, tables, and lists with appropriate layout (e.g., \n, space, and tab). You will ignore images, icons, or anything that can not be converted into text.
@@ -0,0 +1 @@
1
+ Convert contents in this image into markdown.
@@ -0,0 +1,311 @@
1
+ import os
2
+ from typing import List, Dict, Union, Generator, Iterable
3
+ import importlib
4
+ import asyncio
5
+ from vlm4ocr.utils import get_images_from_pdf, get_image_from_file, clean_markdown
6
+ from vlm4ocr.vlm_engines import VLMEngine
7
+
8
+ SUPPORTED_IMAGE_EXTS = ['.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
9
+
10
+ class OCREngine:
11
+ def __init__(self, vlm_engine:VLMEngine, output_mode:str="markdown", system_prompt:str=None, user_prompt:str=None, page_delimiter:str="\n\n---\n\n"):
12
+ """
13
+ This class inputs a image or PDF file path and processes them using a VLM inference engine. Outputs plain text or markdown.
14
+
15
+ Parameters:
16
+ -----------
17
+ inference_engine : InferenceEngine
18
+ The inference engine to use for OCR.
19
+ output_mode : str, Optional
20
+ The output format. Can be 'markdown' or 'text'.
21
+ system_prompt : str, Optional
22
+ Custom system prompt. We recommend use a default system prompt by leaving this blank.
23
+ user_prompt : str, Optional
24
+ Custom user prompt. It is good to include some information regarding the document. If not specified, a default will be used.
25
+ page_delimiter : str, Optional
26
+ The delimiter to use between PDF pages.
27
+ """
28
+ # Check inference engine
29
+ if not isinstance(vlm_engine, VLMEngine):
30
+ raise TypeError("vlm_engine must be an instance of VLMEngine")
31
+ self.vlm_engine = vlm_engine
32
+
33
+ # Check output mode
34
+ if output_mode not in ["markdown", "text"]:
35
+ raise ValueError("output_mode must be 'markdown' or 'text'")
36
+ self.output_mode = output_mode
37
+
38
+ # System prompt
39
+ if isinstance(system_prompt, str) and system_prompt:
40
+ self.system_prompt = system_prompt
41
+ else:
42
+ file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_system_prompt.txt')
43
+ with open(file_path, 'r', encoding='utf-8') as f:
44
+ self.system_prompt = f.read()
45
+
46
+ # User prompt
47
+ if isinstance(user_prompt, str) and user_prompt:
48
+ self.user_prompt = user_prompt
49
+ else:
50
+ file_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath('ocr_user_prompt.txt')
51
+ with open(file_path, 'r', encoding='utf-8') as f:
52
+ self.user_prompt = f.read()
53
+
54
+ # Page delimiter
55
+ if isinstance(page_delimiter, str):
56
+ self.page_delimiter = page_delimiter
57
+ else:
58
+ raise ValueError("page_delimiter must be a string")
59
+
60
+
61
+ def stream_ocr(self, file_path: str, max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> Generator[str, None, None]:
62
+ """
63
+ This method inputs a file path (image or PDF) and stream OCR results in real-time. This is useful for frontend applications.
64
+
65
+ Parameters:
66
+ -----------
67
+ file_path : str
68
+ The path to the image or PDF file. Must be one of '.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
69
+ max_new_tokens : int, Optional
70
+ The maximum number of tokens to generate.
71
+ temperature : float, Optional
72
+ The temperature to use for sampling.
73
+
74
+ Returns:
75
+ --------
76
+ Generator[str, None, None]
77
+ A generator that yields the output.
78
+ """
79
+ # Check file path
80
+ if not isinstance(file_path, str):
81
+ raise TypeError("file_path must be a string")
82
+
83
+ # Check file extension
84
+ file_ext = os.path.splitext(file_path)[1].lower()
85
+ if file_ext not in SUPPORTED_IMAGE_EXTS and file_ext != '.pdf':
86
+ raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS + ['.pdf']}")
87
+
88
+ # PDF
89
+ if file_ext == '.pdf':
90
+ images = get_images_from_pdf(file_path)
91
+ if not images:
92
+ raise ValueError(f"No images extracted from PDF: {file_path}")
93
+ for i, image in enumerate(images):
94
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
95
+ response_stream = self.vlm_engine.chat(
96
+ messages,
97
+ max_new_tokens=max_new_tokens,
98
+ temperature=temperature,
99
+ stream=True,
100
+ **kwrs
101
+ )
102
+ for chunk in response_stream:
103
+ yield chunk
104
+
105
+ if i < len(images) - 1:
106
+ yield self.page_delimiter
107
+
108
+ # Image
109
+ else:
110
+ image = get_image_from_file(file_path)
111
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
112
+ response_stream = self.vlm_engine.chat(
113
+ messages,
114
+ max_new_tokens=max_new_tokens,
115
+ temperature=temperature,
116
+ stream=True,
117
+ **kwrs
118
+ )
119
+ for chunk in response_stream:
120
+ yield chunk
121
+
122
+
123
+ def run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096, temperature:float=0.0,
124
+ verbose:bool=False, concurrent:bool=False, concurrent_batch_size:int=32, **kwrs) -> Union[str, Generator[str, None, None]]:
125
+ """
126
+ This method takes a list of file paths (image or PDF) and perform OCR using the VLM inference engine.
127
+
128
+ Parameters:
129
+ -----------
130
+ file_paths : Union[str, Iterable[str]]
131
+ A file path or a list of file paths to process. Must be one of '.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
132
+ max_new_tokens : int, Optional
133
+ The maximum number of tokens to generate.
134
+ temperature : float, Optional
135
+ The temperature to use for sampling.
136
+ verbose : bool, Optional
137
+ If True, the function will print the output in terminal.
138
+ concurrent : bool, Optional
139
+ If True, the function will process the files concurrently.
140
+ concurrent_batch_size : int, Optional
141
+ The number of images/pages to process concurrently.
142
+ """
143
+ # if file_paths is a string, convert it to a list
144
+ if isinstance(file_paths, str):
145
+ file_paths = [file_paths]
146
+
147
+ if not isinstance(file_paths, Iterable):
148
+ raise TypeError("file_paths must be a string or an iterable of strings")
149
+
150
+ # check if all file paths are valid
151
+ for file_path in file_paths:
152
+ if not isinstance(file_path, str):
153
+ raise TypeError("file_paths must be a string or an iterable of strings")
154
+ file_ext = os.path.splitext(file_path)[1].lower()
155
+ if file_ext not in SUPPORTED_IMAGE_EXTS and file_ext != '.pdf':
156
+ raise ValueError(f"Unsupported file type: {file_ext}. Supported types are: {SUPPORTED_IMAGE_EXTS + ['.pdf']}")
157
+
158
+
159
+ # Concurrent processing
160
+ if concurrent:
161
+ # Check concurrent_batch_size
162
+ if concurrent_batch_size <= 0:
163
+ raise ValueError("concurrent_batch_size must be greater than 0")
164
+
165
+ if verbose:
166
+ Warning("verbose is not supported for concurrent processing.", UserWarning)
167
+
168
+ return asyncio.run(self._run_ocr_async(file_paths,
169
+ max_new_tokens=max_new_tokens,
170
+ temperature=temperature,
171
+ concurrent_batch_size=concurrent_batch_size,
172
+ **kwrs))
173
+
174
+ # Sync processing
175
+ return self._run_ocr(file_paths, max_new_tokens=max_new_tokens, temperature=temperature, verbose=verbose, **kwrs)
176
+
177
+
178
+ def _run_ocr(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
179
+ temperature:float=0.0, verbose:bool=False, **kwrs) -> Iterable[str]:
180
+ """
181
+ This method inputs a file path or a list of file paths (image or PDF) and performs OCR using the VLM inference engine.
182
+
183
+ Parameters:
184
+ -----------
185
+ file_paths : Union[str, Iterable[str]]
186
+ A file path or a list of file paths to process. Must be one of '.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp'
187
+ max_new_tokens : int, Optional
188
+ The maximum number of tokens to generate.
189
+ temperature : float, Optional
190
+ The temperature to use for sampling.
191
+ verbose : bool, Optional
192
+ If True, the function will print the output in terminal.
193
+
194
+ Returns:
195
+ --------
196
+ Iterable[str]
197
+ A list of strings containing the OCR results.
198
+ """
199
+ ocr_results = []
200
+ for file_path in file_paths:
201
+ file_ext = os.path.splitext(file_path)[1].lower()
202
+ # PDF
203
+ if file_ext == '.pdf':
204
+ images = get_images_from_pdf(file_path)
205
+ if not images:
206
+ raise ValueError(f"No images extracted from PDF: {file_path}")
207
+ pdf_results = []
208
+ for image in images:
209
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
210
+ response = self.vlm_engine.chat(
211
+ messages,
212
+ max_new_tokens=max_new_tokens,
213
+ temperature=temperature,
214
+ verbose=verbose,
215
+ stream=False,
216
+ **kwrs
217
+ )
218
+ pdf_results.append(response)
219
+
220
+ ocr_text = self.page_delimiter.join(pdf_results)
221
+ # Image
222
+ else:
223
+ image = get_image_from_file(file_path)
224
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
225
+ ocr_text = self.vlm_engine.chat(
226
+ messages,
227
+ max_new_tokens=max_new_tokens,
228
+ temperature=temperature,
229
+ verbose=verbose,
230
+ stream=False,
231
+ **kwrs
232
+ )
233
+
234
+ # Clean markdown
235
+ if self.output_mode == "markdown":
236
+ ocr_text = clean_markdown(ocr_text)
237
+ ocr_results.append(ocr_text)
238
+
239
+ return ocr_results
240
+
241
+
242
+ async def _run_ocr_async(self, file_paths: Union[str, Iterable[str]], max_new_tokens:int=4096,
243
+ temperature:float=0.0, concurrent_batch_size:int=32, **kwrs) -> List[str]:
244
+ """
245
+ This is the async version of the _run_ocr method.
246
+ """
247
+ # flatten pages/images in file_paths
248
+ flat_page_list = []
249
+ for file_path in file_paths:
250
+ file_ext = os.path.splitext(file_path)[1].lower()
251
+ # PDF
252
+ if file_ext == '.pdf':
253
+ images = get_images_from_pdf(file_path)
254
+ if not images:
255
+ flat_page_list.append({'file_path': file_path, 'file_type': "PDF", "image": image, "page_num": 0, "total_page_count": 0})
256
+ for page_num, image in enumerate(images):
257
+ flat_page_list.append({'file_path': file_path, 'file_type': "PDF", "image": image, "page_num": page_num, "total_page_count": len(images)})
258
+ # Image
259
+ else:
260
+ image = get_image_from_file(file_path)
261
+ flat_page_list.append({'file_path': file_path, 'file_type': "image", "image": image})
262
+
263
+ # Process images with asyncio.Semaphore
264
+ semaphore = asyncio.Semaphore(concurrent_batch_size)
265
+ async def semaphore_helper(page:List[Dict[str,str]], max_new_tokens:int, temperature:float, **kwrs):
266
+ try:
267
+ messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, page["image"])
268
+ async with semaphore:
269
+ async_task = self.vlm_engine.chat_async(
270
+ messages,
271
+ max_new_tokens=max_new_tokens,
272
+ temperature=temperature,
273
+ **kwrs
274
+ )
275
+ return await async_task
276
+ except Exception as e:
277
+ print(f"Error processing image: {e}")
278
+ return f"[Error: {e}]"
279
+
280
+ tasks = []
281
+ for page in flat_page_list:
282
+ async_task = semaphore_helper(
283
+ page,
284
+ max_new_tokens=max_new_tokens,
285
+ temperature=temperature,
286
+ **kwrs
287
+ )
288
+ tasks.append(asyncio.create_task(async_task))
289
+
290
+ responses = await asyncio.gather(*tasks)
291
+
292
+ # Restructure the results
293
+ ocr_results = []
294
+ pdf_page_text_buffer = ""
295
+ for page, ocr_text in zip(flat_page_list, responses):
296
+ # PDF
297
+ if page['file_type'] == "PDF":
298
+ pdf_page_text_buffer += ocr_text + self.page_delimiter
299
+ if page['page_num'] == page['total_page_count'] - 1:
300
+ if self.output_mode == "markdown":
301
+ pdf_page_text_buffer = clean_markdown(pdf_page_text_buffer)
302
+ ocr_results.append(pdf_page_text_buffer)
303
+ pdf_page_text_buffer = ""
304
+ # Image
305
+ if page['file_type'] == "image":
306
+ if self.output_mode == "markdown":
307
+ ocr_text = clean_markdown(ocr_text)
308
+ ocr_results.append(ocr_text)
309
+
310
+ return ocr_results
311
+
@@ -0,0 +1,47 @@
1
+ import os
2
+ import io
3
+ import base64
4
+ from typing import List
5
+ from pdf2image import convert_from_path
6
+ from PIL import Image
7
+
8
+ def get_images_from_pdf(file_path: str) -> List[Image.Image]:
9
+ """ Extracts images from a PDF file. """
10
+ try:
11
+ images = convert_from_path(file_path)
12
+ if not images:
13
+ print(f"Warning: No images extracted from PDF: {file_path}")
14
+ return images
15
+ except Exception as e:
16
+ print(f"Error converting PDF to images: {e}")
17
+ raise ValueError(f"Failed to process PDF file '{os.path.basename(file_path)}'. Ensure poppler is installed and the file is valid.") from e
18
+
19
+
20
+ def get_image_from_file(file_path: str) -> Image.Image:
21
+ """ Loads a single image file. """
22
+ try:
23
+ image = Image.open(file_path)
24
+ image.load()
25
+ return image
26
+ except FileNotFoundError:
27
+ raise FileNotFoundError(f"Image file not found: {file_path}")
28
+ except Exception as e:
29
+ raise ValueError(f"Failed to load image file '{os.path.basename(file_path)}': {e}") from e
30
+
31
+
32
+ def image_to_base64(image:Image.Image, format:str="png") -> str:
33
+ """ Converts an image to a base64 string. """
34
+ try:
35
+ buffered = io.BytesIO()
36
+ image.save(buffered, format=format)
37
+ img_bytes = buffered.getvalue()
38
+ encoded_bytes = base64.b64encode(img_bytes)
39
+ base64_encoded_string = encoded_bytes.decode('utf-8')
40
+ return base64_encoded_string
41
+ except Exception as e:
42
+ print(f"Error converting image to base64: {e}")
43
+ raise ValueError(f"Failed to convert image to base64: {e}") from e
44
+
45
+ def clean_markdown(text:str) -> str:
46
+ cleaned_text = text.replace("```markdown", "").replace("```", "")
47
+ return cleaned_text
@@ -0,0 +1,444 @@
1
+ import abc
2
+ import importlib
3
+ from typing import List, Dict, Union, Generator
4
+ import warnings
5
+ from PIL import Image
6
+ from vlm4ocr.utils import image_to_base64
7
+
8
+
9
+ class VLMEngine:
10
+ @abc.abstractmethod
11
+ def __init__(self):
12
+ """
13
+ This is an abstract class to provide interfaces for VLM inference engines.
14
+ Children classes that inherts this class can be used in extrators. Must implement chat() method.
15
+ """
16
+ return NotImplemented
17
+
18
+ @abc.abstractmethod
19
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
20
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
21
+ """
22
+ This method inputs chat messages and outputs VLM generated text.
23
+
24
+ Parameters:
25
+ ----------
26
+ messages : List[Dict[str,str]]
27
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
28
+ max_new_tokens : str, Optional
29
+ the max number of new tokens VLM can generate.
30
+ temperature : float, Optional
31
+ the temperature for token sampling.
32
+ verbose : bool, Optional
33
+ if True, VLM generated text will be printed in terminal in real-time.
34
+ stream : bool, Optional
35
+ if True, returns a generator that yields the output in real-time.
36
+ """
37
+ return NotImplemented
38
+
39
+ @abc.abstractmethod
40
+ def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
41
+ """
42
+ The async version of chat method. Streaming is not supported.
43
+ """
44
+ return NotImplemented
45
+
46
+ @abc.abstractmethod
47
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image_path:str) -> List[Dict[str,str]]:
48
+ """
49
+ This method inputs an image and returns the correesponding chat messages for the inference engine.
50
+
51
+ Parameters:
52
+ ----------
53
+ system_prompt : str
54
+ the system prompt.
55
+ user_prompt : str
56
+ the user prompt.
57
+ image_path : str
58
+ the image path for OCR.
59
+ """
60
+ return NotImplemented
61
+
62
+
63
+ class OllamaVLMEngine(VLMEngine):
64
+ def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
65
+ """
66
+ The Ollama inference engine.
67
+
68
+ Parameters:
69
+ ----------
70
+ model_name : str
71
+ the model name exactly as shown in >> ollama ls
72
+ num_ctx : int, Optional
73
+ context length that VLM will evaluate.
74
+ keep_alive : int, Optional
75
+ seconds to hold the VLM after the last API call.
76
+ """
77
+ if importlib.util.find_spec("ollama") is None:
78
+ raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
79
+
80
+ from ollama import Client, AsyncClient
81
+ self.client = Client(**kwrs)
82
+ self.async_client = AsyncClient(**kwrs)
83
+ self.model_name = model_name
84
+ self.num_ctx = num_ctx
85
+ self.keep_alive = keep_alive
86
+
87
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
88
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
89
+ """
90
+ This method inputs chat messages and outputs VLM generated text.
91
+
92
+ Parameters:
93
+ ----------
94
+ messages : List[Dict[str,str]]
95
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
96
+ max_new_tokens : str, Optional
97
+ the max number of new tokens VLM can generate.
98
+ temperature : float, Optional
99
+ the temperature for token sampling.
100
+ verbose : bool, Optional
101
+ if True, VLM generated text will be printed in terminal in real-time.
102
+ stream : bool, Optional
103
+ if True, returns a generator that yields the output in real-time.
104
+ """
105
+ options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs}
106
+ if stream:
107
+ def _stream_generator():
108
+ response_stream = self.client.chat(
109
+ model=self.model_name,
110
+ messages=messages,
111
+ options=options,
112
+ stream=True,
113
+ keep_alive=self.keep_alive
114
+ )
115
+ for chunk in response_stream:
116
+ content_chunk = chunk.get('message', {}).get('content')
117
+ if content_chunk:
118
+ yield content_chunk
119
+
120
+ return _stream_generator()
121
+
122
+ elif verbose:
123
+ response = self.client.chat(
124
+ model=self.model_name,
125
+ messages=messages,
126
+ options=options,
127
+ stream=True,
128
+ keep_alive=self.keep_alive
129
+ )
130
+
131
+ res = ''
132
+ for chunk in response:
133
+ content_chunk = chunk.get('message', {}).get('content')
134
+ print(content_chunk, end='', flush=True)
135
+ res += content_chunk
136
+ print('\n')
137
+ return res
138
+
139
+ return response.get('message', {}).get('content', '')
140
+
141
+
142
+ async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
143
+ """
144
+ Async version of chat method. Streaming is not supported.
145
+ """
146
+ response = await self.async_client.chat(
147
+ model=self.model_name,
148
+ messages=messages,
149
+ options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
150
+ stream=False,
151
+ keep_alive=self.keep_alive
152
+ )
153
+
154
+ return response.get('message', {}).get('content', '')
155
+
156
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
157
+ """
158
+ This method inputs an image and returns the correesponding chat messages for the inference engine.
159
+
160
+ Parameters:
161
+ ----------
162
+ system_prompt : str
163
+ the system prompt.
164
+ user_prompt : str
165
+ the user prompt.
166
+ image_path : str
167
+ the image path for OCR.
168
+ """
169
+ base64_str = image_to_base64(image)
170
+ return [
171
+ {"role": "system", "content": system_prompt},
172
+ {
173
+ "role": "user",
174
+ "content": user_prompt,
175
+ "images": [base64_str]
176
+ }
177
+ ]
178
+
179
+
180
+ class OpenAIVLMEngine(VLMEngine):
181
+ def __init__(self, model:str, reasoning_model:bool=False, **kwrs):
182
+ """
183
+ The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
184
+ - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
185
+
186
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
187
+
188
+ Parameters:
189
+ ----------
190
+ model_name : str
191
+ model name as described in https://platform.openai.com/docs/models
192
+ reasoning_model : bool, Optional
193
+ indicator for OpenAI reasoning models ("o" series).
194
+ """
195
+ if importlib.util.find_spec("openai") is None:
196
+ raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
197
+
198
+ from openai import OpenAI, AsyncOpenAI
199
+ self.client = OpenAI(**kwrs)
200
+ self.async_client = AsyncOpenAI(**kwrs)
201
+ self.model = model
202
+ self.reasoning_model = reasoning_model
203
+
204
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
205
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
206
+ """
207
+ This method inputs chat messages and outputs VLM generated text.
208
+
209
+ Parameters:
210
+ ----------
211
+ messages : List[Dict[str,str]]
212
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
213
+ max_new_tokens : str, Optional
214
+ the max number of new tokens VLM can generate.
215
+ temperature : float, Optional
216
+ the temperature for token sampling.
217
+ verbose : bool, Optional
218
+ if True, VLM generated text will be printed in terminal in real-time.
219
+ stream : bool, Optional
220
+ if True, returns a generator that yields the output in real-time.
221
+ """
222
+ # For reasoning models
223
+ if self.reasoning_model:
224
+ # Reasoning models do not support temperature parameter
225
+ if temperature != 0.0:
226
+ warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
227
+
228
+ # Reasoning models do not support system prompts
229
+ if any(msg['role'] == 'system' for msg in messages):
230
+ warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
231
+ messages = [msg for msg in messages if msg['role'] != 'system']
232
+
233
+
234
+ if stream:
235
+ def _stream_generator():
236
+ response_stream = self.client.chat.completions.create(
237
+ model=self.model,
238
+ messages=messages,
239
+ max_completion_tokens=max_new_tokens,
240
+ stream=True,
241
+ **kwrs
242
+ )
243
+ for chunk in response_stream:
244
+ if len(chunk.choices) > 0:
245
+ if chunk.choices[0].delta.content is not None:
246
+ yield chunk.choices[0].delta.content
247
+ if chunk.choices[0].finish_reason == "length":
248
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
249
+ if self.reasoning_model:
250
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
251
+ return _stream_generator()
252
+
253
+ elif verbose:
254
+ response = self.client.chat.completions.create(
255
+ model=self.model,
256
+ messages=messages,
257
+ max_completion_tokens=max_new_tokens,
258
+ stream=True,
259
+ **kwrs
260
+ )
261
+ res = ''
262
+ for chunk in response:
263
+ if len(chunk.choices) > 0:
264
+ if chunk.choices[0].delta.content is not None:
265
+ res += chunk.choices[0].delta.content
266
+ print(chunk.choices[0].delta.content, end="", flush=True)
267
+ if chunk.choices[0].finish_reason == "length":
268
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
269
+ if self.reasoning_model:
270
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
271
+
272
+ print('\n')
273
+ return res
274
+ else:
275
+ response = self.client.chat.completions.create(
276
+ model=self.model,
277
+ messages=messages,
278
+ max_completion_tokens=max_new_tokens,
279
+ stream=False,
280
+ **kwrs
281
+ )
282
+ return response.choices[0].message.content
283
+
284
+ # For non-reasoning models
285
+ else:
286
+ if stream:
287
+ def _stream_generator():
288
+ response_stream = self.client.chat.completions.create(
289
+ model=self.model,
290
+ messages=messages,
291
+ max_tokens=max_new_tokens,
292
+ temperature=temperature,
293
+ stream=True,
294
+ **kwrs
295
+ )
296
+ for chunk in response_stream:
297
+ if len(chunk.choices) > 0:
298
+ if chunk.choices[0].delta.content is not None:
299
+ yield chunk.choices[0].delta.content
300
+ if chunk.choices[0].finish_reason == "length":
301
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
302
+ if self.reasoning_model:
303
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
304
+ return _stream_generator()
305
+
306
+ elif verbose:
307
+ response = self.client.chat.completions.create(
308
+ model=self.model,
309
+ messages=messages,
310
+ max_tokens=max_new_tokens,
311
+ temperature=temperature,
312
+ stream=True,
313
+ **kwrs
314
+ )
315
+ res = ''
316
+ for chunk in response:
317
+ if len(chunk.choices) > 0:
318
+ if chunk.choices[0].delta.content is not None:
319
+ res += chunk.choices[0].delta.content
320
+ print(chunk.choices[0].delta.content, end="", flush=True)
321
+ if chunk.choices[0].finish_reason == "length":
322
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
323
+ if self.reasoning_model:
324
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
325
+
326
+ print('\n')
327
+ return res
328
+
329
+ else:
330
+ response = self.client.chat.completions.create(
331
+ model=self.model,
332
+ messages=messages,
333
+ max_tokens=max_new_tokens,
334
+ temperature=temperature,
335
+ stream=False,
336
+ **kwrs
337
+ )
338
+
339
+ return response.choices[0].message.content
340
+
341
+
342
+ async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
343
+ """
344
+ Async version of chat method. Streaming is not supported.
345
+ """
346
+ if self.reasoning_model:
347
+ # Reasoning models do not support temperature parameter
348
+ if temperature != 0.0:
349
+ warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
350
+
351
+ # Reasoning models do not support system prompts
352
+ if any(msg['role'] == 'system' for msg in messages):
353
+ warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
354
+ messages = [msg for msg in messages if msg['role'] != 'system']
355
+
356
+ response = await self.async_client.chat.completions.create(
357
+ model=self.model,
358
+ messages=messages,
359
+ max_completion_tokens=max_new_tokens,
360
+ stream=False,
361
+ **kwrs
362
+ )
363
+
364
+ else:
365
+ response = await self.async_client.chat.completions.create(
366
+ model=self.model,
367
+ messages=messages,
368
+ max_tokens=max_new_tokens,
369
+ temperature=temperature,
370
+ stream=False,
371
+ **kwrs
372
+ )
373
+
374
+ if response.choices[0].finish_reason == "length":
375
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
376
+ if self.reasoning_model:
377
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
378
+
379
+ return response.choices[0].message.content
380
+
381
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
382
+ """
383
+ This method inputs an image and returns the correesponding chat messages for the inference engine.
384
+
385
+ Parameters:
386
+ ----------
387
+ system_prompt : str
388
+ the system prompt.
389
+ user_prompt : str
390
+ the user prompt.
391
+ image : Image.Image
392
+ the image for OCR.
393
+ format : str, Optional
394
+ the image format.
395
+ detail : str, Optional
396
+ the detail level of the image. Default is "high".
397
+ """
398
+ base64_str = image_to_base64(image)
399
+ return [
400
+ {"role": "system", "content": system_prompt},
401
+ {
402
+ "role": "user",
403
+ "content": [
404
+ {
405
+ "type": "image_url",
406
+ "image_url": {
407
+ "url": f"data:image/{format};base64,{base64_str}",
408
+ "detail": detail
409
+ },
410
+ },
411
+ {"type": "text", "text": user_prompt},
412
+ ],
413
+ },
414
+ ]
415
+
416
+
417
+ class AzureOpenAIVLMEngine(OpenAIVLMEngine):
418
+ def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
419
+ """
420
+ The Azure OpenAI API inference engine.
421
+ For parameters and documentation, refer to
422
+ - https://azure.microsoft.com/en-us/products/ai-services/openai-service
423
+ - https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
424
+
425
+ Parameters:
426
+ ----------
427
+ model : str
428
+ model name as described in https://platform.openai.com/docs/models
429
+ api_version : str
430
+ the Azure OpenAI API version
431
+ reasoning_model : bool, Optional
432
+ indicator for OpenAI reasoning models ("o" series).
433
+ """
434
+ if importlib.util.find_spec("openai") is None:
435
+ raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
436
+
437
+ from openai import AzureOpenAI, AsyncAzureOpenAI
438
+ self.model = model
439
+ self.api_version = api_version
440
+ self.client = AzureOpenAI(api_version=self.api_version,
441
+ **kwrs)
442
+ self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
443
+ **kwrs)
444
+ self.reasoning_model = reasoning_model