vlm4ocr 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +5 -1
- vlm4ocr/cli.py +4 -13
- vlm4ocr/data_types.py +17 -5
- vlm4ocr/ocr_engines.py +30 -13
- vlm4ocr/vlm_engines.py +639 -46
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.3.1.dist-info}/METADATA +1 -1
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.3.1.dist-info}/RECORD +9 -9
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.3.1.dist-info}/WHEEL +0 -0
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.3.1.dist-info}/entry_points.txt +0 -0
vlm4ocr/__init__.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
from .ocr_engines import OCREngine
|
|
2
|
-
from .vlm_engines import BasicVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
2
|
+
from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
"BasicVLMConfig",
|
|
6
|
+
"ReasoningVLMConfig",
|
|
6
7
|
"OpenAIReasoningVLMConfig",
|
|
7
8
|
"OCREngine",
|
|
8
9
|
"OllamaVLMEngine",
|
|
10
|
+
"OpenAICompatibleVLMEngine",
|
|
11
|
+
"VLLMVLMEngine",
|
|
12
|
+
"OpenRouterVLMEngine",
|
|
9
13
|
"OpenAIVLMEngine",
|
|
10
14
|
"AzureOpenAIVLMEngine"
|
|
11
15
|
]
|
vlm4ocr/cli.py
CHANGED
|
@@ -4,18 +4,9 @@ import sys
|
|
|
4
4
|
import logging
|
|
5
5
|
import asyncio
|
|
6
6
|
import time
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from .ocr_engines import OCREngine
|
|
11
|
-
from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
12
|
-
from .data_types import OCRResult
|
|
13
|
-
except ImportError:
|
|
14
|
-
# Fallback for when the package is installed
|
|
15
|
-
from vlm4ocr.ocr_engines import OCREngine
|
|
16
|
-
from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
17
|
-
from vlm4ocr.data_types import OCRResult
|
|
18
|
-
|
|
7
|
+
from .ocr_engines import OCREngine
|
|
8
|
+
from .vlm_engines import OpenAICompatibleVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
|
|
9
|
+
from .data_types import OCRResult
|
|
19
10
|
import tqdm.asyncio
|
|
20
11
|
|
|
21
12
|
# --- Global logger setup (console) ---
|
|
@@ -208,7 +199,7 @@ def main():
|
|
|
208
199
|
vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
|
|
209
200
|
elif args.vlm_engine == "openai_compatible":
|
|
210
201
|
if not args.base_url: parser.error("--base_url is required for openai_compatible.")
|
|
211
|
-
vlm_engine_instance =
|
|
202
|
+
vlm_engine_instance = OpenAICompatibleVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
|
|
212
203
|
elif args.vlm_engine == "azure_openai":
|
|
213
204
|
if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
|
|
214
205
|
if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
|
vlm4ocr/data_types.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import List, Literal
|
|
2
|
+
from typing import List, Dict, Literal
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from vlm4ocr.utils import get_default_page_delimiter
|
|
5
5
|
|
|
@@ -24,6 +24,7 @@ class OCRResult:
|
|
|
24
24
|
pages: List[dict] = field(default_factory=list)
|
|
25
25
|
filename: str = field(init=False)
|
|
26
26
|
status: str = field(init=False, default="processing")
|
|
27
|
+
messages_log: List[List[Dict[str,str]]] = field(default_factory=list)
|
|
27
28
|
|
|
28
29
|
def __post_init__(self):
|
|
29
30
|
"""
|
|
@@ -67,10 +68,6 @@ class OCRResult:
|
|
|
67
68
|
}
|
|
68
69
|
self.pages.append(page)
|
|
69
70
|
|
|
70
|
-
|
|
71
|
-
def __len__(self):
|
|
72
|
-
return len(self.pages)
|
|
73
|
-
|
|
74
71
|
def get_page(self, idx):
|
|
75
72
|
if not isinstance(idx, int):
|
|
76
73
|
raise ValueError("Index must be an integer")
|
|
@@ -78,6 +75,21 @@ class OCRResult:
|
|
|
78
75
|
raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
|
|
79
76
|
|
|
80
77
|
return self.pages[idx]
|
|
78
|
+
|
|
79
|
+
def clear_messages_log(self):
|
|
80
|
+
self.messages_log = []
|
|
81
|
+
|
|
82
|
+
def add_messages_to_log(self, messages: List[Dict[str,str]]):
|
|
83
|
+
if not isinstance(messages, list):
|
|
84
|
+
raise ValueError("messages must be a list of dict")
|
|
85
|
+
|
|
86
|
+
self.messages_log.extend(messages)
|
|
87
|
+
|
|
88
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
89
|
+
return self.messages_log.copy()
|
|
90
|
+
|
|
91
|
+
def __len__(self):
|
|
92
|
+
return len(self.pages)
|
|
81
93
|
|
|
82
94
|
def __iter__(self):
|
|
83
95
|
return iter(self.pages)
|
vlm4ocr/ocr_engines.py
CHANGED
|
@@ -6,7 +6,7 @@ from colorama import Fore, Style
|
|
|
6
6
|
import json
|
|
7
7
|
from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
|
|
8
8
|
from vlm4ocr.data_types import OCRResult
|
|
9
|
-
from vlm4ocr.vlm_engines import VLMEngine
|
|
9
|
+
from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
|
|
10
10
|
|
|
11
11
|
SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
|
|
12
12
|
|
|
@@ -126,7 +126,8 @@ class OCREngine:
|
|
|
126
126
|
stream=True
|
|
127
127
|
)
|
|
128
128
|
for chunk in response_stream:
|
|
129
|
-
|
|
129
|
+
if chunk["type"] == "response":
|
|
130
|
+
yield {"type": "ocr_chunk", "data": chunk["data"]}
|
|
130
131
|
|
|
131
132
|
if i < len(images) - 1:
|
|
132
133
|
yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
|
|
@@ -157,7 +158,8 @@ class OCREngine:
|
|
|
157
158
|
stream=True
|
|
158
159
|
)
|
|
159
160
|
for chunk in response_stream:
|
|
160
|
-
|
|
161
|
+
if chunk["type"] == "response":
|
|
162
|
+
yield {"type": "ocr_chunk", "data": chunk["data"]}
|
|
161
163
|
|
|
162
164
|
|
|
163
165
|
def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
|
|
@@ -271,24 +273,32 @@ class OCREngine:
|
|
|
271
273
|
|
|
272
274
|
try:
|
|
273
275
|
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
276
|
+
# Define a messages logger to capture messages
|
|
277
|
+
messages_logger = MessagesLogger()
|
|
278
|
+
# Generate response
|
|
274
279
|
response = self.vlm_engine.chat(
|
|
275
280
|
messages,
|
|
276
281
|
verbose=verbose,
|
|
277
|
-
stream=False
|
|
282
|
+
stream=False,
|
|
283
|
+
messages_logger=messages_logger
|
|
278
284
|
)
|
|
285
|
+
ocr_text = response["response"]
|
|
279
286
|
# Clean the response if output mode is markdown
|
|
280
287
|
if self.output_mode == "markdown":
|
|
281
|
-
|
|
288
|
+
ocr_text = clean_markdown(ocr_text)
|
|
282
289
|
|
|
283
290
|
# Parse the response if output mode is JSON
|
|
284
|
-
|
|
285
|
-
json_list = extract_json(
|
|
291
|
+
elif self.output_mode == "JSON":
|
|
292
|
+
json_list = extract_json(ocr_text)
|
|
286
293
|
# Serialize the JSON list to a string
|
|
287
|
-
|
|
294
|
+
ocr_text = json.dumps(json_list, indent=4)
|
|
288
295
|
|
|
289
296
|
# Add the page to the OCR result
|
|
290
|
-
ocr_result.add_page(text=
|
|
297
|
+
ocr_result.add_page(text=ocr_text,
|
|
291
298
|
image_processing_status=image_processing_status)
|
|
299
|
+
|
|
300
|
+
# Add messages log to the OCR result
|
|
301
|
+
ocr_result.add_messages_to_log(messages_logger.get_messages_log())
|
|
292
302
|
|
|
293
303
|
except Exception as page_e:
|
|
294
304
|
ocr_result.status = "error"
|
|
@@ -387,6 +397,7 @@ class OCREngine:
|
|
|
387
397
|
filename = os.path.basename(file_path)
|
|
388
398
|
file_ext = os.path.splitext(file_path)[1].lower()
|
|
389
399
|
result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
|
|
400
|
+
messages_logger = MessagesLogger()
|
|
390
401
|
# check file extension
|
|
391
402
|
if file_ext not in SUPPORTED_IMAGE_EXTS:
|
|
392
403
|
result.status = "error"
|
|
@@ -416,7 +427,8 @@ class OCREngine:
|
|
|
416
427
|
data_loader=data_loader,
|
|
417
428
|
page_index=page_index,
|
|
418
429
|
rotate_correction=rotate_correction,
|
|
419
|
-
max_dimension_pixels=max_dimension_pixels
|
|
430
|
+
max_dimension_pixels=max_dimension_pixels,
|
|
431
|
+
messages_logger=messages_logger
|
|
420
432
|
)
|
|
421
433
|
page_processing_tasks.append(task)
|
|
422
434
|
|
|
@@ -428,14 +440,17 @@ class OCREngine:
|
|
|
428
440
|
except Exception as e:
|
|
429
441
|
result.status = "error"
|
|
430
442
|
result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
|
|
443
|
+
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
431
444
|
return result
|
|
432
445
|
|
|
433
446
|
# Set status to success if no errors occurred
|
|
434
447
|
result.status = "success"
|
|
448
|
+
result.add_messages_to_log(messages_logger.get_messages_log())
|
|
435
449
|
return result
|
|
436
450
|
|
|
437
451
|
async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
|
|
438
|
-
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None
|
|
452
|
+
page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
|
|
453
|
+
messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
|
|
439
454
|
"""
|
|
440
455
|
This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
|
|
441
456
|
|
|
@@ -476,15 +491,17 @@ class OCREngine:
|
|
|
476
491
|
}
|
|
477
492
|
|
|
478
493
|
messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
|
|
479
|
-
|
|
494
|
+
response = await self.vlm_engine.chat_async(
|
|
480
495
|
messages,
|
|
496
|
+
messages_logger=messages_logger
|
|
481
497
|
)
|
|
498
|
+
ocr_text = response["response"]
|
|
482
499
|
# Clean the OCR text if output mode is markdown
|
|
483
500
|
if self.output_mode == "markdown":
|
|
484
501
|
ocr_text = clean_markdown(ocr_text)
|
|
485
502
|
|
|
486
503
|
# Parse the response if output mode is JSON
|
|
487
|
-
|
|
504
|
+
elif self.output_mode == "JSON":
|
|
488
505
|
json_list = extract_json(ocr_text)
|
|
489
506
|
# Serialize the JSON list to a string
|
|
490
507
|
ocr_text = json.dumps(json_list, indent=4)
|
vlm4ocr/vlm_engines.py
CHANGED
|
@@ -2,6 +2,8 @@ import abc
|
|
|
2
2
|
import importlib.util
|
|
3
3
|
from typing import Any, List, Dict, Union, Generator
|
|
4
4
|
import warnings
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
5
7
|
from PIL import Image
|
|
6
8
|
from vlm4ocr.utils import image_to_base64
|
|
7
9
|
|
|
@@ -33,7 +35,7 @@ class VLMConfig(abc.ABC):
|
|
|
33
35
|
return NotImplemented
|
|
34
36
|
|
|
35
37
|
@abc.abstractmethod
|
|
36
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
38
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
37
39
|
"""
|
|
38
40
|
This method postprocesses the VLM response after it is generated.
|
|
39
41
|
|
|
@@ -77,7 +79,7 @@ class BasicVLMConfig(VLMConfig):
|
|
|
77
79
|
"""
|
|
78
80
|
return messages
|
|
79
81
|
|
|
80
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
82
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
81
83
|
"""
|
|
82
84
|
This method postprocesses the VLM response after it is generated.
|
|
83
85
|
|
|
@@ -88,19 +90,121 @@ class BasicVLMConfig(VLMConfig):
|
|
|
88
90
|
|
|
89
91
|
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
90
92
|
the postprocessed VLM response.
|
|
91
|
-
if input is a generator, the output will be a generator {"data": <content>}.
|
|
93
|
+
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
92
94
|
"""
|
|
93
95
|
if isinstance(response, str):
|
|
94
|
-
return response
|
|
96
|
+
return {"response": response}
|
|
97
|
+
|
|
98
|
+
elif isinstance(response, dict):
|
|
99
|
+
if "response" in response:
|
|
100
|
+
return response
|
|
101
|
+
else:
|
|
102
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
103
|
+
return {"response": ""}
|
|
95
104
|
|
|
96
105
|
def _process_stream():
|
|
97
106
|
for chunk in response:
|
|
98
|
-
|
|
107
|
+
if isinstance(chunk, dict):
|
|
108
|
+
yield chunk
|
|
109
|
+
elif isinstance(chunk, str):
|
|
110
|
+
yield {"type": "response", "data": chunk}
|
|
99
111
|
|
|
100
112
|
return _process_stream()
|
|
113
|
+
|
|
114
|
+
class ReasoningVLMConfig(VLMConfig):
|
|
115
|
+
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
116
|
+
"""
|
|
117
|
+
The general configuration for reasoning vision models.
|
|
118
|
+
"""
|
|
119
|
+
super().__init__(**kwargs)
|
|
120
|
+
self.thinking_token_start = thinking_token_start
|
|
121
|
+
self.thinking_token_end = thinking_token_end
|
|
122
|
+
|
|
123
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
124
|
+
"""
|
|
125
|
+
This method preprocesses the input messages before passing them to the VLM.
|
|
126
|
+
|
|
127
|
+
Parameters:
|
|
128
|
+
----------
|
|
129
|
+
messages : List[Dict[str,str]]
|
|
130
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
-------
|
|
134
|
+
messages : List[Dict[str,str]]
|
|
135
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
136
|
+
"""
|
|
137
|
+
return messages.copy()
|
|
101
138
|
|
|
139
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
140
|
+
"""
|
|
141
|
+
This method postprocesses the VLM response after it is generated.
|
|
142
|
+
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
143
|
+
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
144
|
+
3. If input is a generator,
|
|
145
|
+
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
146
|
+
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
102
147
|
|
|
103
|
-
|
|
148
|
+
Parameters:
|
|
149
|
+
----------
|
|
150
|
+
response : Union[str, Generator[str, None, None]]
|
|
151
|
+
the VLM response. Can be a string or a generator.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
-------
|
|
155
|
+
response : Union[str, Generator[str, None, None]]
|
|
156
|
+
the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
|
|
157
|
+
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
158
|
+
"""
|
|
159
|
+
if isinstance(response, str):
|
|
160
|
+
# get contents between thinking_token_start and thinking_token_end
|
|
161
|
+
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
162
|
+
match = re.search(pattern, response, re.DOTALL)
|
|
163
|
+
reasoning = match.group(1) if match else ""
|
|
164
|
+
# get response AFTER thinking_token_end
|
|
165
|
+
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
166
|
+
return {"reasoning": reasoning, "response": response}
|
|
167
|
+
|
|
168
|
+
elif isinstance(response, dict):
|
|
169
|
+
if "reasoning" in response and "response" in response:
|
|
170
|
+
return response
|
|
171
|
+
else:
|
|
172
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
173
|
+
return {"reasoning": "", "response": ""}
|
|
174
|
+
|
|
175
|
+
elif isinstance(response, Generator):
|
|
176
|
+
def _process_stream():
|
|
177
|
+
think_flag = False
|
|
178
|
+
buffer = ""
|
|
179
|
+
for chunk in response:
|
|
180
|
+
if isinstance(chunk, dict):
|
|
181
|
+
yield chunk
|
|
182
|
+
|
|
183
|
+
elif isinstance(chunk, str):
|
|
184
|
+
buffer += chunk
|
|
185
|
+
# switch between reasoning and response
|
|
186
|
+
if self.thinking_token_start in buffer:
|
|
187
|
+
think_flag = True
|
|
188
|
+
buffer = buffer.replace(self.thinking_token_start, "")
|
|
189
|
+
elif self.thinking_token_end in buffer:
|
|
190
|
+
think_flag = False
|
|
191
|
+
buffer = buffer.replace(self.thinking_token_end, "")
|
|
192
|
+
|
|
193
|
+
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
194
|
+
if chunk not in [self.thinking_token_start, self.thinking_token_end]:
|
|
195
|
+
if think_flag:
|
|
196
|
+
yield {"type": "reasoning", "data": chunk}
|
|
197
|
+
else:
|
|
198
|
+
yield {"type": "response", "data": chunk}
|
|
199
|
+
|
|
200
|
+
return _process_stream()
|
|
201
|
+
|
|
202
|
+
else:
|
|
203
|
+
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
204
|
+
return {"reasoning": "", "response": ""}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
|
|
104
208
|
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
105
209
|
"""
|
|
106
210
|
The OpenAI "o" series configuration.
|
|
@@ -160,27 +264,31 @@ class OpenAIReasoningVLMConfig(VLMConfig):
|
|
|
160
264
|
|
|
161
265
|
return new_messages
|
|
162
266
|
|
|
163
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
164
|
-
"""
|
|
165
|
-
This method postprocesses the VLM response after it is generated.
|
|
166
267
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
response : Union[str, Generator[str, None, None]]
|
|
170
|
-
the VLM response. Can be a string or a generator.
|
|
171
|
-
|
|
172
|
-
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
173
|
-
the postprocessed VLM response.
|
|
174
|
-
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
268
|
+
class MessagesLogger:
|
|
269
|
+
def __init__(self):
|
|
175
270
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
271
|
+
This class is used to log the messages for InferenceEngine.chat().
|
|
272
|
+
"""
|
|
273
|
+
self.messages_log = []
|
|
178
274
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
275
|
+
def log_messages(self, messages : List[Dict[str,str]]):
|
|
276
|
+
"""
|
|
277
|
+
This method logs the messages to a list.
|
|
278
|
+
"""
|
|
279
|
+
self.messages_log.append(messages)
|
|
182
280
|
|
|
183
|
-
|
|
281
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
282
|
+
"""
|
|
283
|
+
This method returns a copy of the current messages log
|
|
284
|
+
"""
|
|
285
|
+
return self.messages_log.copy()
|
|
286
|
+
|
|
287
|
+
def clear_messages_log(self):
|
|
288
|
+
"""
|
|
289
|
+
This method clears the current messages log
|
|
290
|
+
"""
|
|
291
|
+
self.messages_log.clear()
|
|
184
292
|
|
|
185
293
|
|
|
186
294
|
class VLMEngine:
|
|
@@ -198,7 +306,8 @@ class VLMEngine:
|
|
|
198
306
|
return NotImplemented
|
|
199
307
|
|
|
200
308
|
@abc.abstractmethod
|
|
201
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False
|
|
309
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
310
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
202
311
|
"""
|
|
203
312
|
This method inputs chat messages and outputs VLM generated text.
|
|
204
313
|
|
|
@@ -210,11 +319,13 @@ class VLMEngine:
|
|
|
210
319
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
211
320
|
stream : bool, Optional
|
|
212
321
|
if True, returns a generator that yields the output in real-time.
|
|
322
|
+
Messages_logger : MessagesLogger, Optional
|
|
323
|
+
the message logger that logs the chat messages.
|
|
213
324
|
"""
|
|
214
325
|
return NotImplemented
|
|
215
326
|
|
|
216
327
|
@abc.abstractmethod
|
|
217
|
-
def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
328
|
+
def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str, str]:
|
|
218
329
|
"""
|
|
219
330
|
The async version of chat method. Streaming is not supported.
|
|
220
331
|
"""
|
|
@@ -285,7 +396,8 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
285
396
|
|
|
286
397
|
return formatted_params
|
|
287
398
|
|
|
288
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False
|
|
399
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
400
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
289
401
|
"""
|
|
290
402
|
This method inputs chat messages and outputs VLM generated text.
|
|
291
403
|
|
|
@@ -297,6 +409,13 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
297
409
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
298
410
|
stream : bool, Optional
|
|
299
411
|
if True, returns a generator that yields the output in real-time.
|
|
412
|
+
Messages_logger : MessagesLogger, Optional
|
|
413
|
+
the message logger that logs the chat messages.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
-------
|
|
417
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
418
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
300
419
|
"""
|
|
301
420
|
processed_messages = self.config.preprocess_messages(messages)
|
|
302
421
|
|
|
@@ -310,10 +429,33 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
310
429
|
stream=True,
|
|
311
430
|
keep_alive=self.keep_alive
|
|
312
431
|
)
|
|
432
|
+
res = {"reasoning": "", "response": ""}
|
|
313
433
|
for chunk in response_stream:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
434
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
435
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
436
|
+
res["reasoning"] += content_chunk
|
|
437
|
+
yield {"type": "reasoning", "data": content_chunk}
|
|
438
|
+
else:
|
|
439
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
440
|
+
res["response"] += content_chunk
|
|
441
|
+
yield {"type": "response", "data": content_chunk}
|
|
442
|
+
|
|
443
|
+
if chunk.done_reason == "length":
|
|
444
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
445
|
+
|
|
446
|
+
# Postprocess response
|
|
447
|
+
res_dict = self.config.postprocess_response(res)
|
|
448
|
+
# Write to messages log
|
|
449
|
+
if messages_logger:
|
|
450
|
+
# replace images content with a placeholder "[image]" to save space
|
|
451
|
+
for messages in processed_messages:
|
|
452
|
+
if "images" in messages:
|
|
453
|
+
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
454
|
+
|
|
455
|
+
processed_messages.append({"role": "assistant",
|
|
456
|
+
"content": res_dict.get("response", ""),
|
|
457
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
458
|
+
messages_logger.log_messages(processed_messages)
|
|
317
459
|
|
|
318
460
|
return self.config.postprocess_response(_stream_generator())
|
|
319
461
|
|
|
@@ -326,14 +468,29 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
326
468
|
keep_alive=self.keep_alive
|
|
327
469
|
)
|
|
328
470
|
|
|
329
|
-
res =
|
|
471
|
+
res = {"reasoning": "", "response": ""}
|
|
472
|
+
phase = ""
|
|
330
473
|
for chunk in response:
|
|
331
|
-
|
|
474
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
475
|
+
if phase != "reasoning":
|
|
476
|
+
print("\n--- Reasoning ---")
|
|
477
|
+
phase = "reasoning"
|
|
478
|
+
|
|
479
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
480
|
+
res["reasoning"] += content_chunk
|
|
481
|
+
else:
|
|
482
|
+
if phase != "response":
|
|
483
|
+
print("\n--- Response ---")
|
|
484
|
+
phase = "response"
|
|
485
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
486
|
+
res["response"] += content_chunk
|
|
487
|
+
|
|
332
488
|
print(content_chunk, end='', flush=True)
|
|
333
|
-
|
|
489
|
+
|
|
490
|
+
if chunk.done_reason == "length":
|
|
491
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
334
492
|
print('\n')
|
|
335
|
-
|
|
336
|
-
|
|
493
|
+
|
|
337
494
|
else:
|
|
338
495
|
response = self.client.chat(
|
|
339
496
|
model=self.model_name,
|
|
@@ -342,11 +499,30 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
342
499
|
stream=False,
|
|
343
500
|
keep_alive=self.keep_alive
|
|
344
501
|
)
|
|
345
|
-
res = response
|
|
346
|
-
|
|
502
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
503
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
504
|
+
|
|
505
|
+
if response.done_reason == "length":
|
|
506
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
507
|
+
|
|
508
|
+
# Postprocess response
|
|
509
|
+
res_dict = self.config.postprocess_response(res)
|
|
510
|
+
# Write to messages log
|
|
511
|
+
if messages_logger:
|
|
512
|
+
# replace images content with a placeholder "[image]" to save space
|
|
513
|
+
for messages in processed_messages:
|
|
514
|
+
if "images" in messages:
|
|
515
|
+
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
516
|
+
|
|
517
|
+
processed_messages.append({"role": "assistant",
|
|
518
|
+
"content": res_dict.get("response", ""),
|
|
519
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
520
|
+
messages_logger.log_messages(processed_messages)
|
|
521
|
+
|
|
522
|
+
return res_dict
|
|
347
523
|
|
|
348
524
|
|
|
349
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
525
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
350
526
|
"""
|
|
351
527
|
Async version of chat method. Streaming is not supported.
|
|
352
528
|
"""
|
|
@@ -360,8 +536,26 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
360
536
|
keep_alive=self.keep_alive
|
|
361
537
|
)
|
|
362
538
|
|
|
363
|
-
res = response
|
|
364
|
-
|
|
539
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
540
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
541
|
+
|
|
542
|
+
if response.done_reason == "length":
|
|
543
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
544
|
+
# Postprocess response
|
|
545
|
+
res_dict = self.config.postprocess_response(res)
|
|
546
|
+
# Write to messages log
|
|
547
|
+
if messages_logger:
|
|
548
|
+
# replace images content with a placeholder "[image]" to save space
|
|
549
|
+
for messages in processed_messages:
|
|
550
|
+
if "images" in messages:
|
|
551
|
+
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
552
|
+
|
|
553
|
+
processed_messages.append({"role": "assistant",
|
|
554
|
+
"content": res_dict.get("response", ""),
|
|
555
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
556
|
+
messages_logger.log_messages(processed_messages)
|
|
557
|
+
|
|
558
|
+
return res_dict
|
|
365
559
|
|
|
366
560
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
367
561
|
"""
|
|
@@ -387,6 +581,346 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
387
581
|
]
|
|
388
582
|
|
|
389
583
|
|
|
584
|
+
class OpenAICompatibleVLMEngine(VLMEngine):
|
|
585
|
+
def __init__(self, model:str, api_key:str, base_url:str, config:VLMConfig=None, **kwrs):
|
|
586
|
+
"""
|
|
587
|
+
General OpenAI-compatible server inference engine.
|
|
588
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
589
|
+
|
|
590
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
591
|
+
|
|
592
|
+
Parameters:
|
|
593
|
+
----------
|
|
594
|
+
model_name : str
|
|
595
|
+
model name as shown in the vLLM server
|
|
596
|
+
api_key : str
|
|
597
|
+
the API key for the vLLM server.
|
|
598
|
+
base_url : str
|
|
599
|
+
the base url for the vLLM server.
|
|
600
|
+
config : LLMConfig
|
|
601
|
+
the LLM configuration.
|
|
602
|
+
"""
|
|
603
|
+
if importlib.util.find_spec("openai") is None:
|
|
604
|
+
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
605
|
+
|
|
606
|
+
from openai import OpenAI, AsyncOpenAI
|
|
607
|
+
from openai.types.chat import ChatCompletionChunk
|
|
608
|
+
self.ChatCompletionChunk = ChatCompletionChunk
|
|
609
|
+
super().__init__(config)
|
|
610
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
611
|
+
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
612
|
+
self.model = model
|
|
613
|
+
self.config = config if config else BasicVLMConfig()
|
|
614
|
+
self.formatted_params = self._format_config()
|
|
615
|
+
|
|
616
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
617
|
+
"""
|
|
618
|
+
This method format the VLM configuration with the correct key for the inference engine.
|
|
619
|
+
"""
|
|
620
|
+
formatted_params = self.config.params.copy()
|
|
621
|
+
if "max_new_tokens" in formatted_params:
|
|
622
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
623
|
+
formatted_params.pop("max_new_tokens")
|
|
624
|
+
|
|
625
|
+
return formatted_params
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
629
|
+
"""
|
|
630
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
631
|
+
|
|
632
|
+
Parameters:
|
|
633
|
+
----------
|
|
634
|
+
response : Any
|
|
635
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
636
|
+
"""
|
|
637
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
638
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
639
|
+
if chunk_text is None:
|
|
640
|
+
chunk_text = ""
|
|
641
|
+
return {"type": "response", "data": chunk_text}
|
|
642
|
+
|
|
643
|
+
return {"response": getattr(response.choices[0].message, "content", "")}
|
|
644
|
+
|
|
645
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
646
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
647
|
+
"""
|
|
648
|
+
This method inputs chat messages and outputs LLM generated text.
|
|
649
|
+
|
|
650
|
+
Parameters:
|
|
651
|
+
----------
|
|
652
|
+
messages : List[Dict[str,str]]
|
|
653
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
654
|
+
verbose : bool, Optional
|
|
655
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
656
|
+
stream : bool, Optional
|
|
657
|
+
if True, returns a generator that yields the output in real-time.
|
|
658
|
+
messages_logger : MessagesLogger, Optional
|
|
659
|
+
the message logger that logs the chat messages.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
-------
|
|
663
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
664
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
665
|
+
"""
|
|
666
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
667
|
+
|
|
668
|
+
if stream:
|
|
669
|
+
def _stream_generator():
|
|
670
|
+
response_stream = self.client.chat.completions.create(
|
|
671
|
+
model=self.model,
|
|
672
|
+
messages=processed_messages,
|
|
673
|
+
stream=True,
|
|
674
|
+
**self.formatted_params
|
|
675
|
+
)
|
|
676
|
+
res_text = ""
|
|
677
|
+
for chunk in response_stream:
|
|
678
|
+
if len(chunk.choices) > 0:
|
|
679
|
+
chunk_dict = self._format_response(chunk)
|
|
680
|
+
yield chunk_dict
|
|
681
|
+
|
|
682
|
+
res_text += chunk_dict["data"]
|
|
683
|
+
if chunk.choices[0].finish_reason == "length":
|
|
684
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
685
|
+
|
|
686
|
+
# Postprocess response
|
|
687
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
688
|
+
# Write to messages log
|
|
689
|
+
if messages_logger:
|
|
690
|
+
# replace images content with a placeholder "[image]" to save space
|
|
691
|
+
for messages in processed_messages:
|
|
692
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
693
|
+
for content in messages["content"]:
|
|
694
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
695
|
+
content["image_url"]["url"] = "[image]"
|
|
696
|
+
|
|
697
|
+
processed_messages.append({"role": "assistant",
|
|
698
|
+
"content": res_dict.get("response", ""),
|
|
699
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
700
|
+
messages_logger.log_messages(processed_messages)
|
|
701
|
+
|
|
702
|
+
return self.config.postprocess_response(_stream_generator())
|
|
703
|
+
|
|
704
|
+
elif verbose:
|
|
705
|
+
response = self.client.chat.completions.create(
|
|
706
|
+
model=self.model,
|
|
707
|
+
messages=processed_messages,
|
|
708
|
+
stream=True,
|
|
709
|
+
**self.formatted_params
|
|
710
|
+
)
|
|
711
|
+
res = {"reasoning": "", "response": ""}
|
|
712
|
+
phase = ""
|
|
713
|
+
for chunk in response:
|
|
714
|
+
if len(chunk.choices) > 0:
|
|
715
|
+
chunk_dict = self._format_response(chunk)
|
|
716
|
+
chunk_text = chunk_dict["data"]
|
|
717
|
+
res[chunk_dict["type"]] += chunk_text
|
|
718
|
+
if phase != chunk_dict["type"] and chunk_text != "":
|
|
719
|
+
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
720
|
+
phase = chunk_dict["type"]
|
|
721
|
+
|
|
722
|
+
print(chunk_text, end="", flush=True)
|
|
723
|
+
if chunk.choices[0].finish_reason == "length":
|
|
724
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
725
|
+
|
|
726
|
+
print('\n')
|
|
727
|
+
|
|
728
|
+
else:
|
|
729
|
+
response = self.client.chat.completions.create(
|
|
730
|
+
model=self.model,
|
|
731
|
+
messages=processed_messages,
|
|
732
|
+
stream=False,
|
|
733
|
+
**self.formatted_params
|
|
734
|
+
)
|
|
735
|
+
res = self._format_response(response)
|
|
736
|
+
|
|
737
|
+
if response.choices[0].finish_reason == "length":
|
|
738
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
739
|
+
|
|
740
|
+
# Postprocess response
|
|
741
|
+
res_dict = self.config.postprocess_response(res)
|
|
742
|
+
# Write to messages log
|
|
743
|
+
if messages_logger:
|
|
744
|
+
# replace images content with a placeholder "[image]" to save space
|
|
745
|
+
for messages in processed_messages:
|
|
746
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
747
|
+
for content in messages["content"]:
|
|
748
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
749
|
+
content["image_url"]["url"] = "[image]"
|
|
750
|
+
|
|
751
|
+
processed_messages.append({"role": "assistant",
|
|
752
|
+
"content": res_dict.get("response", ""),
|
|
753
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
754
|
+
messages_logger.log_messages(processed_messages)
|
|
755
|
+
|
|
756
|
+
return res_dict
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
760
|
+
"""
|
|
761
|
+
Async version of chat method. Streaming is not supported.
|
|
762
|
+
"""
|
|
763
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
764
|
+
|
|
765
|
+
response = await self.async_client.chat.completions.create(
|
|
766
|
+
model=self.model,
|
|
767
|
+
messages=processed_messages,
|
|
768
|
+
stream=False,
|
|
769
|
+
**self.formatted_params
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
if response.choices[0].finish_reason == "length":
|
|
773
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
774
|
+
|
|
775
|
+
res = self._format_response(response)
|
|
776
|
+
|
|
777
|
+
# Postprocess response
|
|
778
|
+
res_dict = self.config.postprocess_response(res)
|
|
779
|
+
# Write to messages log
|
|
780
|
+
if messages_logger:
|
|
781
|
+
# replace images content with a placeholder "[image]" to save space
|
|
782
|
+
for messages in processed_messages:
|
|
783
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
784
|
+
for content in messages["content"]:
|
|
785
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
786
|
+
content["image_url"]["url"] = "[image]"
|
|
787
|
+
|
|
788
|
+
processed_messages.append({"role": "assistant",
|
|
789
|
+
"content": res_dict.get("response", ""),
|
|
790
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
791
|
+
messages_logger.log_messages(processed_messages)
|
|
792
|
+
|
|
793
|
+
return res_dict
|
|
794
|
+
|
|
795
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
|
|
796
|
+
"""
|
|
797
|
+
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
798
|
+
|
|
799
|
+
Parameters:
|
|
800
|
+
----------
|
|
801
|
+
system_prompt : str
|
|
802
|
+
the system prompt.
|
|
803
|
+
user_prompt : str
|
|
804
|
+
the user prompt.
|
|
805
|
+
image : Image.Image
|
|
806
|
+
the image for OCR.
|
|
807
|
+
format : str, Optional
|
|
808
|
+
the image format.
|
|
809
|
+
detail : str, Optional
|
|
810
|
+
the detail level of the image. Default is "high".
|
|
811
|
+
"""
|
|
812
|
+
base64_str = image_to_base64(image)
|
|
813
|
+
return [
|
|
814
|
+
{"role": "system", "content": system_prompt},
|
|
815
|
+
{
|
|
816
|
+
"role": "user",
|
|
817
|
+
"content": [
|
|
818
|
+
{
|
|
819
|
+
"type": "image_url",
|
|
820
|
+
"image_url": {
|
|
821
|
+
"url": f"data:image/{format};base64,{base64_str}",
|
|
822
|
+
"detail": detail
|
|
823
|
+
},
|
|
824
|
+
},
|
|
825
|
+
{"type": "text", "text": user_prompt},
|
|
826
|
+
],
|
|
827
|
+
},
|
|
828
|
+
]
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
class VLLMVLMEngine(OpenAICompatibleVLMEngine):
|
|
832
|
+
def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:VLMConfig=None, **kwrs):
|
|
833
|
+
"""
|
|
834
|
+
vLLM OpenAI compatible server inference engine.
|
|
835
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
836
|
+
|
|
837
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
838
|
+
|
|
839
|
+
Parameters:
|
|
840
|
+
----------
|
|
841
|
+
model_name : str
|
|
842
|
+
model name as shown in the vLLM server
|
|
843
|
+
api_key : str, Optional
|
|
844
|
+
the API key for the vLLM server.
|
|
845
|
+
base_url : str, Optional
|
|
846
|
+
the base url for the vLLM server.
|
|
847
|
+
config : LLMConfig
|
|
848
|
+
the LLM configuration.
|
|
849
|
+
"""
|
|
850
|
+
super().__init__(model, api_key, base_url, config, **kwrs)
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
854
|
+
"""
|
|
855
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
856
|
+
|
|
857
|
+
Parameters:
|
|
858
|
+
----------
|
|
859
|
+
response : Any
|
|
860
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
861
|
+
"""
|
|
862
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
863
|
+
if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
|
|
864
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
|
|
865
|
+
if chunk_text is None:
|
|
866
|
+
chunk_text = ""
|
|
867
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
868
|
+
else:
|
|
869
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
870
|
+
if chunk_text is None:
|
|
871
|
+
chunk_text = ""
|
|
872
|
+
return {"type": "response", "data": chunk_text}
|
|
873
|
+
|
|
874
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
875
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
|
|
879
|
+
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
|
|
880
|
+
"""
|
|
881
|
+
OpenRouter OpenAI-compatible server inference engine.
|
|
882
|
+
|
|
883
|
+
Parameters:
|
|
884
|
+
----------
|
|
885
|
+
model_name : str
|
|
886
|
+
model name as shown in the vLLM server
|
|
887
|
+
api_key : str, Optional
|
|
888
|
+
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
889
|
+
base_url : str, Optional
|
|
890
|
+
the base url for the vLLM server.
|
|
891
|
+
config : LLMConfig
|
|
892
|
+
the LLM configuration.
|
|
893
|
+
"""
|
|
894
|
+
self.api_key = api_key
|
|
895
|
+
if self.api_key is None:
|
|
896
|
+
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
897
|
+
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
898
|
+
|
|
899
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
900
|
+
"""
|
|
901
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
902
|
+
|
|
903
|
+
Parameters:
|
|
904
|
+
----------
|
|
905
|
+
response : Any
|
|
906
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
907
|
+
"""
|
|
908
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
909
|
+
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
910
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
911
|
+
if chunk_text is None:
|
|
912
|
+
chunk_text = ""
|
|
913
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
914
|
+
else:
|
|
915
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
916
|
+
if chunk_text is None:
|
|
917
|
+
chunk_text = ""
|
|
918
|
+
return {"type": "response", "data": chunk_text}
|
|
919
|
+
|
|
920
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
921
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
922
|
+
|
|
923
|
+
|
|
390
924
|
class OpenAIVLMEngine(VLMEngine):
|
|
391
925
|
def __init__(self, model:str, config:VLMConfig=None, **kwrs):
|
|
392
926
|
"""
|
|
@@ -423,7 +957,7 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
423
957
|
|
|
424
958
|
return formatted_params
|
|
425
959
|
|
|
426
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
|
|
960
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
427
961
|
"""
|
|
428
962
|
This method inputs chat messages and outputs LLM generated text.
|
|
429
963
|
|
|
@@ -435,6 +969,13 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
435
969
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
436
970
|
stream : bool, Optional
|
|
437
971
|
if True, returns a generator that yields the output in real-time.
|
|
972
|
+
messages_logger : MessagesLogger, Optional
|
|
973
|
+
the message logger that logs the chat messages.
|
|
974
|
+
|
|
975
|
+
Returns:
|
|
976
|
+
-------
|
|
977
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
978
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
438
979
|
"""
|
|
439
980
|
processed_messages = self.config.preprocess_messages(messages)
|
|
440
981
|
|
|
@@ -446,13 +987,32 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
446
987
|
stream=True,
|
|
447
988
|
**self.formatted_params
|
|
448
989
|
)
|
|
990
|
+
res_text = ""
|
|
449
991
|
for chunk in response_stream:
|
|
450
992
|
if len(chunk.choices) > 0:
|
|
451
|
-
|
|
452
|
-
|
|
993
|
+
chunk_text = chunk.choices[0].delta.content
|
|
994
|
+
if chunk_text is not None:
|
|
995
|
+
res_text += chunk_text
|
|
996
|
+
yield chunk_text
|
|
453
997
|
if chunk.choices[0].finish_reason == "length":
|
|
454
998
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
455
999
|
|
|
1000
|
+
# Postprocess response
|
|
1001
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
1002
|
+
# Write to messages log
|
|
1003
|
+
if messages_logger:
|
|
1004
|
+
# replace images content with a placeholder "[image]" to save space
|
|
1005
|
+
for messages in processed_messages:
|
|
1006
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
1007
|
+
for content in messages["content"]:
|
|
1008
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1009
|
+
content["image_url"]["url"] = "[image]"
|
|
1010
|
+
|
|
1011
|
+
processed_messages.append({"role": "assistant",
|
|
1012
|
+
"content": res_dict.get("response", ""),
|
|
1013
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1014
|
+
messages_logger.log_messages(processed_messages)
|
|
1015
|
+
|
|
456
1016
|
return self.config.postprocess_response(_stream_generator())
|
|
457
1017
|
|
|
458
1018
|
elif verbose:
|
|
@@ -472,7 +1032,7 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
472
1032
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
473
1033
|
|
|
474
1034
|
print('\n')
|
|
475
|
-
|
|
1035
|
+
|
|
476
1036
|
else:
|
|
477
1037
|
response = self.client.chat.completions.create(
|
|
478
1038
|
model=self.model,
|
|
@@ -481,10 +1041,27 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
481
1041
|
**self.formatted_params
|
|
482
1042
|
)
|
|
483
1043
|
res = response.choices[0].message.content
|
|
484
|
-
|
|
1044
|
+
|
|
1045
|
+
# Postprocess response
|
|
1046
|
+
res_dict = self.config.postprocess_response(res)
|
|
1047
|
+
# Write to messages log
|
|
1048
|
+
if messages_logger:
|
|
1049
|
+
# replace images content with a placeholder "[image]" to save space
|
|
1050
|
+
for messages in processed_messages:
|
|
1051
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
1052
|
+
for content in messages["content"]:
|
|
1053
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1054
|
+
content["image_url"]["url"] = "[image]"
|
|
1055
|
+
|
|
1056
|
+
processed_messages.append({"role": "assistant",
|
|
1057
|
+
"content": res_dict.get("response", ""),
|
|
1058
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1059
|
+
messages_logger.log_messages(processed_messages)
|
|
1060
|
+
|
|
1061
|
+
return res_dict
|
|
485
1062
|
|
|
486
1063
|
|
|
487
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
1064
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
488
1065
|
"""
|
|
489
1066
|
Async version of chat method. Streaming is not supported.
|
|
490
1067
|
"""
|
|
@@ -501,7 +1078,23 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
501
1078
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
502
1079
|
|
|
503
1080
|
res = response.choices[0].message.content
|
|
504
|
-
|
|
1081
|
+
# Postprocess response
|
|
1082
|
+
res_dict = self.config.postprocess_response(res)
|
|
1083
|
+
# Write to messages log
|
|
1084
|
+
if messages_logger:
|
|
1085
|
+
# replace images content with a placeholder "[image]" to save space
|
|
1086
|
+
for messages in processed_messages:
|
|
1087
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
1088
|
+
for content in messages["content"]:
|
|
1089
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1090
|
+
content["image_url"]["url"] = "[image]"
|
|
1091
|
+
|
|
1092
|
+
processed_messages.append({"role": "assistant",
|
|
1093
|
+
"content": res_dict.get("response", ""),
|
|
1094
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1095
|
+
messages_logger.log_messages(processed_messages)
|
|
1096
|
+
|
|
1097
|
+
return res_dict
|
|
505
1098
|
|
|
506
1099
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
|
|
507
1100
|
"""
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
vlm4ocr/__init__.py,sha256=
|
|
1
|
+
vlm4ocr/__init__.py,sha256=NpJ-jquqaXo-uHPcMOYUtqToLWLxixftPQn7epD2XbY,506
|
|
2
2
|
vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt,sha256=igPOntiLDZXTB71-QrTmMJveb6XC1TgArg1serPc9V8,547
|
|
3
3
|
vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt,sha256=cVn538JojZfCtIhfrcOPWt0dO7dtDqgB9xdS_5VvAqo,41
|
|
4
4
|
vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt,sha256=v-fUw53gkngc_dz9TMH2abALDsAEZfe-zJ2u3-SO4ck,417
|
|
@@ -6,12 +6,12 @@ vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt,sha256=pI
|
|
|
6
6
|
vlm4ocr/assets/default_prompt_templates/ocr_markdown_user_prompt.txt,sha256=61EJv8POsQGIIUVwCjDU73lMXJE7F3qhPIYl6zSbl1Q,45
|
|
7
7
|
vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt,sha256=WbLSOerqFjlYGaGWJ-w2enhky1WhnPl011s0fgRPgnQ,398
|
|
8
8
|
vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt,sha256=ftgNAIPy_UlrcY6m7-IkH2ApHkCzRnymra1w2wg60Ks,47
|
|
9
|
-
vlm4ocr/cli.py,sha256=
|
|
10
|
-
vlm4ocr/data_types.py,sha256=
|
|
11
|
-
vlm4ocr/ocr_engines.py,sha256=
|
|
9
|
+
vlm4ocr/cli.py,sha256=mq5fbJQvgUm89Vd9v2SIW9ARsGex-8V46-r3-evjYrs,19966
|
|
10
|
+
vlm4ocr/data_types.py,sha256=BOcq5KsZFJ_-Fxb9A4IJfOd0x5u-1tUQkYbWAJayuPM,4416
|
|
11
|
+
vlm4ocr/ocr_engines.py,sha256=up7p9xGIeBdwQgqChlr7lsTMWTVFtSWzwlFZp2wKAxk,25431
|
|
12
12
|
vlm4ocr/utils.py,sha256=nQhUskOze99wCVMKmvsen0dhq-9NdN4EPC_bdYfkjgA,13611
|
|
13
|
-
vlm4ocr/vlm_engines.py,sha256=
|
|
14
|
-
vlm4ocr-0.3.
|
|
15
|
-
vlm4ocr-0.3.
|
|
16
|
-
vlm4ocr-0.3.
|
|
17
|
-
vlm4ocr-0.3.
|
|
13
|
+
vlm4ocr/vlm_engines.py,sha256=rfb4P1fhpY6ClC27FMhYCWOaIjCipZCx3gPrNnDbF0w,50209
|
|
14
|
+
vlm4ocr-0.3.1.dist-info/METADATA,sha256=_l03maaznCHetgYPATohqd_yFJenWE57sdw_JLaVmc0,710
|
|
15
|
+
vlm4ocr-0.3.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
16
|
+
vlm4ocr-0.3.1.dist-info/entry_points.txt,sha256=qzWUk_QTZ12cH4DLjjfqce89EAlOydD85dreRRZF3K4,44
|
|
17
|
+
vlm4ocr-0.3.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|