vlm4ocr 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/ocr_engines.py +4 -7
- vlm4ocr/vlm_engines.py +77 -1047
- {vlm4ocr-0.4.0.dist-info → vlm4ocr-0.4.2.dist-info}/METADATA +2 -1
- {vlm4ocr-0.4.0.dist-info → vlm4ocr-0.4.2.dist-info}/RECORD +6 -6
- {vlm4ocr-0.4.0.dist-info → vlm4ocr-0.4.2.dist-info}/WHEEL +0 -0
- {vlm4ocr-0.4.0.dist-info → vlm4ocr-0.4.2.dist-info}/entry_points.txt +0 -0
vlm4ocr/ocr_engines.py
CHANGED
|
@@ -126,9 +126,8 @@ class OCREngine:
|
|
|
126
126
|
few_shot_examples=few_shot_examples)
|
|
127
127
|
|
|
128
128
|
# Stream response
|
|
129
|
-
response_stream = self.vlm_engine.
|
|
130
|
-
messages
|
|
131
|
-
stream=True
|
|
129
|
+
response_stream = self.vlm_engine.chat_stream(
|
|
130
|
+
messages
|
|
132
131
|
)
|
|
133
132
|
for chunk in response_stream:
|
|
134
133
|
if chunk["type"] == "response":
|
|
@@ -163,9 +162,8 @@ class OCREngine:
|
|
|
163
162
|
image=image,
|
|
164
163
|
few_shot_examples=few_shot_examples)
|
|
165
164
|
# Stream response
|
|
166
|
-
response_stream = self.vlm_engine.
|
|
167
|
-
messages
|
|
168
|
-
stream=True
|
|
165
|
+
response_stream = self.vlm_engine.chat_stream(
|
|
166
|
+
messages
|
|
169
167
|
)
|
|
170
168
|
for chunk in response_stream:
|
|
171
169
|
if chunk["type"] == "response":
|
|
@@ -295,7 +293,6 @@ class OCREngine:
|
|
|
295
293
|
response = self.vlm_engine.chat(
|
|
296
294
|
messages,
|
|
297
295
|
verbose=verbose,
|
|
298
|
-
stream=False,
|
|
299
296
|
messages_logger=messages_logger
|
|
300
297
|
)
|
|
301
298
|
ocr_text = response["response"]
|
vlm4ocr/vlm_engines.py
CHANGED
|
@@ -1,337 +1,27 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
import
|
|
3
|
-
from typing import Any, List, Dict, Union, Generator
|
|
4
|
-
import warnings
|
|
5
|
-
import os
|
|
6
|
-
import re
|
|
2
|
+
from typing import List, Dict
|
|
7
3
|
from PIL import Image
|
|
8
4
|
from vlm4ocr.utils import image_to_base64
|
|
9
5
|
from vlm4ocr.data_types import FewShotExample
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
-------
|
|
33
|
-
messages : List[Dict[str,str]]
|
|
34
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
35
|
-
"""
|
|
36
|
-
return NotImplemented
|
|
37
|
-
|
|
38
|
-
@abc.abstractmethod
|
|
39
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
40
|
-
"""
|
|
41
|
-
This method postprocesses the VLM response after it is generated.
|
|
42
|
-
|
|
43
|
-
Parameters:
|
|
44
|
-
----------
|
|
45
|
-
response : Union[str, Generator[str, None, None]]
|
|
46
|
-
the VLM response. Can be a string or a generator.
|
|
47
|
-
|
|
48
|
-
Returns:
|
|
49
|
-
-------
|
|
50
|
-
response : str
|
|
51
|
-
the postprocessed VLM response
|
|
52
|
-
"""
|
|
53
|
-
return NotImplemented
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class BasicVLMConfig(VLMConfig):
|
|
57
|
-
def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
|
|
58
|
-
"""
|
|
59
|
-
The basic VLM configuration for most non-reasoning models.
|
|
60
|
-
"""
|
|
61
|
-
super().__init__(**kwargs)
|
|
62
|
-
self.max_new_tokens = max_new_tokens
|
|
63
|
-
self.temperature = temperature
|
|
64
|
-
self.params["max_new_tokens"] = self.max_new_tokens
|
|
65
|
-
self.params["temperature"] = self.temperature
|
|
66
|
-
|
|
67
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
68
|
-
"""
|
|
69
|
-
This method preprocesses the input messages before passing them to the VLM.
|
|
70
|
-
|
|
71
|
-
Parameters:
|
|
72
|
-
----------
|
|
73
|
-
messages : List[Dict[str,str]]
|
|
74
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
75
|
-
|
|
76
|
-
Returns:
|
|
77
|
-
-------
|
|
78
|
-
messages : List[Dict[str,str]]
|
|
79
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
80
|
-
"""
|
|
81
|
-
return messages
|
|
82
|
-
|
|
83
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
84
|
-
"""
|
|
85
|
-
This method postprocesses the VLM response after it is generated.
|
|
86
|
-
|
|
87
|
-
Parameters:
|
|
88
|
-
----------
|
|
89
|
-
response : Union[str, Generator[str, None, None]]
|
|
90
|
-
the VLM response. Can be a string or a generator.
|
|
91
|
-
|
|
92
|
-
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
93
|
-
the postprocessed VLM response.
|
|
94
|
-
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
95
|
-
"""
|
|
96
|
-
if isinstance(response, str):
|
|
97
|
-
return {"response": response}
|
|
98
|
-
|
|
99
|
-
elif isinstance(response, dict):
|
|
100
|
-
if "response" in response:
|
|
101
|
-
return response
|
|
102
|
-
else:
|
|
103
|
-
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
104
|
-
return {"response": ""}
|
|
105
|
-
|
|
106
|
-
def _process_stream():
|
|
107
|
-
for chunk in response:
|
|
108
|
-
if isinstance(chunk, dict):
|
|
109
|
-
yield chunk
|
|
110
|
-
elif isinstance(chunk, str):
|
|
111
|
-
yield {"type": "response", "data": chunk}
|
|
112
|
-
|
|
113
|
-
return _process_stream()
|
|
114
|
-
|
|
115
|
-
class ReasoningVLMConfig(VLMConfig):
|
|
116
|
-
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
117
|
-
"""
|
|
118
|
-
The general configuration for reasoning vision models.
|
|
119
|
-
"""
|
|
120
|
-
super().__init__(**kwargs)
|
|
121
|
-
self.thinking_token_start = thinking_token_start
|
|
122
|
-
self.thinking_token_end = thinking_token_end
|
|
123
|
-
|
|
124
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
125
|
-
"""
|
|
126
|
-
This method preprocesses the input messages before passing them to the VLM.
|
|
127
|
-
|
|
128
|
-
Parameters:
|
|
129
|
-
----------
|
|
130
|
-
messages : List[Dict[str,str]]
|
|
131
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
132
|
-
|
|
133
|
-
Returns:
|
|
134
|
-
-------
|
|
135
|
-
messages : List[Dict[str,str]]
|
|
136
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
137
|
-
"""
|
|
138
|
-
return messages.copy()
|
|
139
|
-
|
|
140
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
141
|
-
"""
|
|
142
|
-
This method postprocesses the VLM response after it is generated.
|
|
143
|
-
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
144
|
-
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
145
|
-
3. If input is a generator,
|
|
146
|
-
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
147
|
-
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
148
|
-
|
|
149
|
-
Parameters:
|
|
150
|
-
----------
|
|
151
|
-
response : Union[str, Generator[str, None, None]]
|
|
152
|
-
the VLM response. Can be a string or a generator.
|
|
153
|
-
|
|
154
|
-
Returns:
|
|
155
|
-
-------
|
|
156
|
-
response : Union[str, Generator[str, None, None]]
|
|
157
|
-
the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
|
|
158
|
-
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
159
|
-
"""
|
|
160
|
-
if isinstance(response, str):
|
|
161
|
-
# get contents between thinking_token_start and thinking_token_end
|
|
162
|
-
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
163
|
-
match = re.search(pattern, response, re.DOTALL)
|
|
164
|
-
reasoning = match.group(1) if match else ""
|
|
165
|
-
# get response AFTER thinking_token_end
|
|
166
|
-
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
167
|
-
return {"reasoning": reasoning, "response": response}
|
|
168
|
-
|
|
169
|
-
elif isinstance(response, dict):
|
|
170
|
-
if "reasoning" in response and "response" in response:
|
|
171
|
-
return response
|
|
172
|
-
else:
|
|
173
|
-
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
174
|
-
return {"reasoning": "", "response": ""}
|
|
175
|
-
|
|
176
|
-
elif isinstance(response, Generator):
|
|
177
|
-
def _process_stream():
|
|
178
|
-
think_flag = False
|
|
179
|
-
buffer = ""
|
|
180
|
-
for chunk in response:
|
|
181
|
-
if isinstance(chunk, dict):
|
|
182
|
-
yield chunk
|
|
183
|
-
|
|
184
|
-
elif isinstance(chunk, str):
|
|
185
|
-
buffer += chunk
|
|
186
|
-
# switch between reasoning and response
|
|
187
|
-
if self.thinking_token_start in buffer:
|
|
188
|
-
think_flag = True
|
|
189
|
-
buffer = buffer.replace(self.thinking_token_start, "")
|
|
190
|
-
elif self.thinking_token_end in buffer:
|
|
191
|
-
think_flag = False
|
|
192
|
-
buffer = buffer.replace(self.thinking_token_end, "")
|
|
193
|
-
|
|
194
|
-
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
195
|
-
if chunk not in [self.thinking_token_start, self.thinking_token_end]:
|
|
196
|
-
if think_flag:
|
|
197
|
-
yield {"type": "reasoning", "data": chunk}
|
|
198
|
-
else:
|
|
199
|
-
yield {"type": "response", "data": chunk}
|
|
200
|
-
|
|
201
|
-
return _process_stream()
|
|
202
|
-
|
|
203
|
-
else:
|
|
204
|
-
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
205
|
-
return {"reasoning": "", "response": ""}
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
|
|
209
|
-
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
210
|
-
"""
|
|
211
|
-
The OpenAI "o" series configuration.
|
|
212
|
-
1. The reasoning effort is set to "low" by default.
|
|
213
|
-
2. The temperature parameter is not supported and will be ignored.
|
|
214
|
-
3. The system prompt is not supported and will be concatenated to the next user prompt.
|
|
215
|
-
|
|
216
|
-
Parameters:
|
|
217
|
-
----------
|
|
218
|
-
reasoning_effort : str, Optional
|
|
219
|
-
the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
|
|
220
|
-
"""
|
|
221
|
-
super().__init__(**kwargs)
|
|
222
|
-
if reasoning_effort not in ["low", "medium", "high"]:
|
|
223
|
-
raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
|
|
224
|
-
|
|
225
|
-
self.reasoning_effort = reasoning_effort
|
|
226
|
-
self.params["reasoning_effort"] = self.reasoning_effort
|
|
227
|
-
|
|
228
|
-
if "temperature" in self.params:
|
|
229
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
230
|
-
self.params.pop("temperature")
|
|
231
|
-
|
|
232
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
233
|
-
"""
|
|
234
|
-
Concatenate system prompts to the next user prompt.
|
|
235
|
-
|
|
236
|
-
Parameters:
|
|
237
|
-
----------
|
|
238
|
-
messages : List[Dict[str,str]]
|
|
239
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
240
|
-
|
|
241
|
-
Returns:
|
|
242
|
-
-------
|
|
243
|
-
messages : List[Dict[str,str]]
|
|
244
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
245
|
-
"""
|
|
246
|
-
system_prompt_holder = ""
|
|
247
|
-
new_messages = []
|
|
248
|
-
for i, message in enumerate(messages):
|
|
249
|
-
# if system prompt, store it in system_prompt_holder
|
|
250
|
-
if message['role'] == 'system':
|
|
251
|
-
system_prompt_holder = message['content']
|
|
252
|
-
# if user prompt, concatenate it with system_prompt_holder
|
|
253
|
-
elif message['role'] == 'user':
|
|
254
|
-
if system_prompt_holder:
|
|
255
|
-
new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
|
|
256
|
-
system_prompt_holder = ""
|
|
257
|
-
else:
|
|
258
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
259
|
-
|
|
260
|
-
new_messages.append(new_message)
|
|
261
|
-
# if assistant/other prompt, do nothing
|
|
262
|
-
else:
|
|
263
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
264
|
-
new_messages.append(new_message)
|
|
265
|
-
|
|
266
|
-
return new_messages
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
class MessagesLogger:
|
|
270
|
-
def __init__(self):
|
|
271
|
-
"""
|
|
272
|
-
This class is used to log the messages for InferenceEngine.chat().
|
|
273
|
-
"""
|
|
274
|
-
self.messages_log = []
|
|
275
|
-
|
|
276
|
-
def log_messages(self, messages : List[Dict[str,str]]):
|
|
277
|
-
"""
|
|
278
|
-
This method logs the messages to a list.
|
|
279
|
-
"""
|
|
280
|
-
self.messages_log.append(messages)
|
|
281
|
-
|
|
282
|
-
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
283
|
-
"""
|
|
284
|
-
This method returns a copy of the current messages log
|
|
285
|
-
"""
|
|
286
|
-
return self.messages_log.copy()
|
|
287
|
-
|
|
288
|
-
def clear_messages_log(self):
|
|
289
|
-
"""
|
|
290
|
-
This method clears the current messages log
|
|
291
|
-
"""
|
|
292
|
-
self.messages_log.clear()
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
class VLMEngine:
|
|
296
|
-
@abc.abstractmethod
|
|
297
|
-
def __init__(self, config:VLMConfig, **kwrs):
|
|
298
|
-
"""
|
|
299
|
-
This is an abstract class to provide interfaces for VLM inference engines.
|
|
300
|
-
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
301
|
-
|
|
302
|
-
Parameters:
|
|
303
|
-
----------
|
|
304
|
-
config : VLMConfig
|
|
305
|
-
the VLM configuration. Must be a child class of VLMConfig.
|
|
306
|
-
"""
|
|
307
|
-
return NotImplemented
|
|
308
|
-
|
|
309
|
-
@abc.abstractmethod
|
|
310
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
311
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
312
|
-
"""
|
|
313
|
-
This method inputs chat messages and outputs VLM generated text.
|
|
314
|
-
|
|
315
|
-
Parameters:
|
|
316
|
-
----------
|
|
317
|
-
messages : List[Dict[str,str]]
|
|
318
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
319
|
-
verbose : bool, Optional
|
|
320
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
321
|
-
stream : bool, Optional
|
|
322
|
-
if True, returns a generator that yields the output in real-time.
|
|
323
|
-
Messages_logger : MessagesLogger, Optional
|
|
324
|
-
the message logger that logs the chat messages.
|
|
325
|
-
"""
|
|
326
|
-
return NotImplemented
|
|
327
|
-
|
|
328
|
-
@abc.abstractmethod
|
|
329
|
-
def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str, str]:
|
|
330
|
-
"""
|
|
331
|
-
The async version of chat method. Streaming is not supported.
|
|
332
|
-
"""
|
|
333
|
-
return NotImplemented
|
|
334
|
-
|
|
6
|
+
from llm_inference_engine.llm_configs import (
|
|
7
|
+
LLMConfig as VLMConfig,
|
|
8
|
+
BasicLLMConfig as BasicVLMConfig,
|
|
9
|
+
ReasoningLLMConfig as ReasoningVLMConfig,
|
|
10
|
+
OpenAIReasoningLLMConfig as OpenAIReasoningVLMConfig
|
|
11
|
+
)
|
|
12
|
+
from llm_inference_engine.utils import MessagesLogger
|
|
13
|
+
from llm_inference_engine.engines import (
|
|
14
|
+
InferenceEngine,
|
|
15
|
+
OllamaInferenceEngine,
|
|
16
|
+
OpenAICompatibleInferenceEngine,
|
|
17
|
+
VLLMInferenceEngine,
|
|
18
|
+
OpenRouterInferenceEngine,
|
|
19
|
+
OpenAIInferenceEngine,
|
|
20
|
+
AzureOpenAIInferenceEngine,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VLMEngine(InferenceEngine):
|
|
335
25
|
@abc.abstractmethod
|
|
336
26
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
337
27
|
"""
|
|
@@ -349,217 +39,9 @@ class VLMEngine:
|
|
|
349
39
|
list of few-shot examples.
|
|
350
40
|
"""
|
|
351
41
|
return NotImplemented
|
|
352
|
-
|
|
353
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
354
|
-
"""
|
|
355
|
-
This method format the VLM configuration with the correct key for the inference engine.
|
|
356
|
-
|
|
357
|
-
Return : Dict[str, Any]
|
|
358
|
-
the config parameters.
|
|
359
|
-
"""
|
|
360
|
-
return NotImplemented
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
class OllamaVLMEngine(VLMEngine):
|
|
364
|
-
def __init__(self, model_name:str, num_ctx:int=8192, keep_alive:int=300, config:VLMConfig=None, **kwrs):
|
|
365
|
-
"""
|
|
366
|
-
The Ollama inference engine.
|
|
367
|
-
|
|
368
|
-
Parameters:
|
|
369
|
-
----------
|
|
370
|
-
model_name : str
|
|
371
|
-
the model name exactly as shown in >> ollama ls
|
|
372
|
-
num_ctx : int, Optional
|
|
373
|
-
context length that LLM will evaluate.
|
|
374
|
-
keep_alive : int, Optional
|
|
375
|
-
seconds to hold the LLM after the last API call.
|
|
376
|
-
config : LLMConfig
|
|
377
|
-
the LLM configuration.
|
|
378
|
-
"""
|
|
379
|
-
if importlib.util.find_spec("ollama") is None:
|
|
380
|
-
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
381
|
-
|
|
382
|
-
from ollama import Client, AsyncClient
|
|
383
|
-
self.client = Client(**kwrs)
|
|
384
|
-
self.async_client = AsyncClient(**kwrs)
|
|
385
|
-
self.model_name = model_name
|
|
386
|
-
self.num_ctx = num_ctx
|
|
387
|
-
self.keep_alive = keep_alive
|
|
388
|
-
self.config = config if config else BasicVLMConfig()
|
|
389
|
-
self.formatted_params = self._format_config()
|
|
390
|
-
|
|
391
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
392
|
-
"""
|
|
393
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
394
|
-
"""
|
|
395
|
-
formatted_params = self.config.params.copy()
|
|
396
|
-
if "max_new_tokens" in formatted_params:
|
|
397
|
-
formatted_params["num_predict"] = formatted_params["max_new_tokens"]
|
|
398
|
-
formatted_params.pop("max_new_tokens")
|
|
399
|
-
|
|
400
|
-
return formatted_params
|
|
401
|
-
|
|
402
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
403
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
404
|
-
"""
|
|
405
|
-
This method inputs chat messages and outputs VLM generated text.
|
|
406
|
-
|
|
407
|
-
Parameters:
|
|
408
|
-
----------
|
|
409
|
-
messages : List[Dict[str,str]]
|
|
410
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
411
|
-
verbose : bool, Optional
|
|
412
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
413
|
-
stream : bool, Optional
|
|
414
|
-
if True, returns a generator that yields the output in real-time.
|
|
415
|
-
Messages_logger : MessagesLogger, Optional
|
|
416
|
-
the message logger that logs the chat messages.
|
|
417
|
-
|
|
418
|
-
Returns:
|
|
419
|
-
-------
|
|
420
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
421
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
422
|
-
"""
|
|
423
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
424
42
|
|
|
425
|
-
options={'num_ctx': self.num_ctx, **self.formatted_params}
|
|
426
|
-
if stream:
|
|
427
|
-
def _stream_generator():
|
|
428
|
-
response_stream = self.client.chat(
|
|
429
|
-
model=self.model_name,
|
|
430
|
-
messages=processed_messages,
|
|
431
|
-
options=options,
|
|
432
|
-
stream=True,
|
|
433
|
-
keep_alive=self.keep_alive
|
|
434
|
-
)
|
|
435
|
-
res = {"reasoning": "", "response": ""}
|
|
436
|
-
for chunk in response_stream:
|
|
437
|
-
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
438
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
439
|
-
res["reasoning"] += content_chunk
|
|
440
|
-
yield {"type": "reasoning", "data": content_chunk}
|
|
441
|
-
else:
|
|
442
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
443
|
-
res["response"] += content_chunk
|
|
444
|
-
yield {"type": "response", "data": content_chunk}
|
|
445
43
|
|
|
446
|
-
|
|
447
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
448
|
-
|
|
449
|
-
# Postprocess response
|
|
450
|
-
res_dict = self.config.postprocess_response(res)
|
|
451
|
-
# Write to messages log
|
|
452
|
-
if messages_logger:
|
|
453
|
-
# replace images content with a placeholder "[image]" to save space
|
|
454
|
-
for messages in processed_messages:
|
|
455
|
-
if "images" in messages:
|
|
456
|
-
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
457
|
-
|
|
458
|
-
processed_messages.append({"role": "assistant",
|
|
459
|
-
"content": res_dict.get("response", ""),
|
|
460
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
461
|
-
messages_logger.log_messages(processed_messages)
|
|
462
|
-
|
|
463
|
-
return self.config.postprocess_response(_stream_generator())
|
|
464
|
-
|
|
465
|
-
elif verbose:
|
|
466
|
-
response = self.client.chat(
|
|
467
|
-
model=self.model_name,
|
|
468
|
-
messages=processed_messages,
|
|
469
|
-
options=options,
|
|
470
|
-
stream=True,
|
|
471
|
-
keep_alive=self.keep_alive
|
|
472
|
-
)
|
|
473
|
-
|
|
474
|
-
res = {"reasoning": "", "response": ""}
|
|
475
|
-
phase = ""
|
|
476
|
-
for chunk in response:
|
|
477
|
-
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
478
|
-
if phase != "reasoning":
|
|
479
|
-
print("\n--- Reasoning ---")
|
|
480
|
-
phase = "reasoning"
|
|
481
|
-
|
|
482
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
483
|
-
res["reasoning"] += content_chunk
|
|
484
|
-
else:
|
|
485
|
-
if phase != "response":
|
|
486
|
-
print("\n--- Response ---")
|
|
487
|
-
phase = "response"
|
|
488
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
489
|
-
res["response"] += content_chunk
|
|
490
|
-
|
|
491
|
-
print(content_chunk, end='', flush=True)
|
|
492
|
-
|
|
493
|
-
if chunk.done_reason == "length":
|
|
494
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
495
|
-
print('\n')
|
|
496
|
-
|
|
497
|
-
else:
|
|
498
|
-
response = self.client.chat(
|
|
499
|
-
model=self.model_name,
|
|
500
|
-
messages=processed_messages,
|
|
501
|
-
options=options,
|
|
502
|
-
stream=False,
|
|
503
|
-
keep_alive=self.keep_alive
|
|
504
|
-
)
|
|
505
|
-
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
506
|
-
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
507
|
-
|
|
508
|
-
if response.done_reason == "length":
|
|
509
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
510
|
-
|
|
511
|
-
# Postprocess response
|
|
512
|
-
res_dict = self.config.postprocess_response(res)
|
|
513
|
-
# Write to messages log
|
|
514
|
-
if messages_logger:
|
|
515
|
-
# replace images content with a placeholder "[image]" to save space
|
|
516
|
-
for messages in processed_messages:
|
|
517
|
-
if "images" in messages:
|
|
518
|
-
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
519
|
-
|
|
520
|
-
processed_messages.append({"role": "assistant",
|
|
521
|
-
"content": res_dict.get("response", ""),
|
|
522
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
523
|
-
messages_logger.log_messages(processed_messages)
|
|
524
|
-
|
|
525
|
-
return res_dict
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
529
|
-
"""
|
|
530
|
-
Async version of chat method. Streaming is not supported.
|
|
531
|
-
"""
|
|
532
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
533
|
-
|
|
534
|
-
response = await self.async_client.chat(
|
|
535
|
-
model=self.model_name,
|
|
536
|
-
messages=processed_messages,
|
|
537
|
-
options={'num_ctx': self.num_ctx, **self.formatted_params},
|
|
538
|
-
stream=False,
|
|
539
|
-
keep_alive=self.keep_alive
|
|
540
|
-
)
|
|
541
|
-
|
|
542
|
-
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
543
|
-
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
544
|
-
|
|
545
|
-
if response.done_reason == "length":
|
|
546
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
547
|
-
# Postprocess response
|
|
548
|
-
res_dict = self.config.postprocess_response(res)
|
|
549
|
-
# Write to messages log
|
|
550
|
-
if messages_logger:
|
|
551
|
-
# replace images content with a placeholder "[image]" to save space
|
|
552
|
-
for messages in processed_messages:
|
|
553
|
-
if "images" in messages:
|
|
554
|
-
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
555
|
-
|
|
556
|
-
processed_messages.append({"role": "assistant",
|
|
557
|
-
"content": res_dict.get("response", ""),
|
|
558
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
559
|
-
messages_logger.log_messages(processed_messages)
|
|
560
|
-
|
|
561
|
-
return res_dict
|
|
562
|
-
|
|
44
|
+
class OllamaVLMEngine(OllamaInferenceEngine, VLMEngine):
|
|
563
45
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
564
46
|
"""
|
|
565
47
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
@@ -600,217 +82,7 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
600
82
|
return output_messages
|
|
601
83
|
|
|
602
84
|
|
|
603
|
-
class OpenAICompatibleVLMEngine(VLMEngine):
|
|
604
|
-
def __init__(self, model:str, api_key:str, base_url:str, config:VLMConfig=None, **kwrs):
|
|
605
|
-
"""
|
|
606
|
-
General OpenAI-compatible server inference engine.
|
|
607
|
-
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
608
|
-
|
|
609
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
610
|
-
|
|
611
|
-
Parameters:
|
|
612
|
-
----------
|
|
613
|
-
model_name : str
|
|
614
|
-
model name as shown in the vLLM server
|
|
615
|
-
api_key : str
|
|
616
|
-
the API key for the vLLM server.
|
|
617
|
-
base_url : str
|
|
618
|
-
the base url for the vLLM server.
|
|
619
|
-
config : LLMConfig
|
|
620
|
-
the LLM configuration.
|
|
621
|
-
"""
|
|
622
|
-
if importlib.util.find_spec("openai") is None:
|
|
623
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
624
|
-
|
|
625
|
-
from openai import OpenAI, AsyncOpenAI
|
|
626
|
-
from openai.types.chat import ChatCompletionChunk
|
|
627
|
-
self.ChatCompletionChunk = ChatCompletionChunk
|
|
628
|
-
super().__init__(config)
|
|
629
|
-
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
630
|
-
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
631
|
-
self.model = model
|
|
632
|
-
self.config = config if config else BasicVLMConfig()
|
|
633
|
-
self.formatted_params = self._format_config()
|
|
634
|
-
|
|
635
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
636
|
-
"""
|
|
637
|
-
This method format the VLM configuration with the correct key for the inference engine.
|
|
638
|
-
"""
|
|
639
|
-
formatted_params = self.config.params.copy()
|
|
640
|
-
if "max_new_tokens" in formatted_params:
|
|
641
|
-
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
642
|
-
formatted_params.pop("max_new_tokens")
|
|
643
|
-
|
|
644
|
-
return formatted_params
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
648
|
-
"""
|
|
649
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
650
|
-
|
|
651
|
-
Parameters:
|
|
652
|
-
----------
|
|
653
|
-
response : Any
|
|
654
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
655
|
-
"""
|
|
656
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
657
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
658
|
-
if chunk_text is None:
|
|
659
|
-
chunk_text = ""
|
|
660
|
-
return {"type": "response", "data": chunk_text}
|
|
661
|
-
|
|
662
|
-
return {"response": getattr(response.choices[0].message, "content", "")}
|
|
663
|
-
|
|
664
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
665
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
666
|
-
"""
|
|
667
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
668
|
-
|
|
669
|
-
Parameters:
|
|
670
|
-
----------
|
|
671
|
-
messages : List[Dict[str,str]]
|
|
672
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
673
|
-
verbose : bool, Optional
|
|
674
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
675
|
-
stream : bool, Optional
|
|
676
|
-
if True, returns a generator that yields the output in real-time.
|
|
677
|
-
messages_logger : MessagesLogger, Optional
|
|
678
|
-
the message logger that logs the chat messages.
|
|
679
|
-
|
|
680
|
-
Returns:
|
|
681
|
-
-------
|
|
682
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
683
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
684
|
-
"""
|
|
685
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
686
|
-
|
|
687
|
-
if stream:
|
|
688
|
-
def _stream_generator():
|
|
689
|
-
response_stream = self.client.chat.completions.create(
|
|
690
|
-
model=self.model,
|
|
691
|
-
messages=processed_messages,
|
|
692
|
-
stream=True,
|
|
693
|
-
**self.formatted_params
|
|
694
|
-
)
|
|
695
|
-
res_text = ""
|
|
696
|
-
for chunk in response_stream:
|
|
697
|
-
if len(chunk.choices) > 0:
|
|
698
|
-
chunk_dict = self._format_response(chunk)
|
|
699
|
-
yield chunk_dict
|
|
700
|
-
|
|
701
|
-
res_text += chunk_dict["data"]
|
|
702
|
-
if chunk.choices[0].finish_reason == "length":
|
|
703
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
704
|
-
|
|
705
|
-
# Postprocess response
|
|
706
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
707
|
-
# Write to messages log
|
|
708
|
-
if messages_logger:
|
|
709
|
-
# replace images content with a placeholder "[image]" to save space
|
|
710
|
-
for messages in processed_messages:
|
|
711
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
712
|
-
for content in messages["content"]:
|
|
713
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
714
|
-
content["image_url"]["url"] = "[image]"
|
|
715
|
-
|
|
716
|
-
processed_messages.append({"role": "assistant",
|
|
717
|
-
"content": res_dict.get("response", ""),
|
|
718
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
719
|
-
messages_logger.log_messages(processed_messages)
|
|
720
|
-
|
|
721
|
-
return self.config.postprocess_response(_stream_generator())
|
|
722
|
-
|
|
723
|
-
elif verbose:
|
|
724
|
-
response = self.client.chat.completions.create(
|
|
725
|
-
model=self.model,
|
|
726
|
-
messages=processed_messages,
|
|
727
|
-
stream=True,
|
|
728
|
-
**self.formatted_params
|
|
729
|
-
)
|
|
730
|
-
res = {"reasoning": "", "response": ""}
|
|
731
|
-
phase = ""
|
|
732
|
-
for chunk in response:
|
|
733
|
-
if len(chunk.choices) > 0:
|
|
734
|
-
chunk_dict = self._format_response(chunk)
|
|
735
|
-
chunk_text = chunk_dict["data"]
|
|
736
|
-
res[chunk_dict["type"]] += chunk_text
|
|
737
|
-
if phase != chunk_dict["type"] and chunk_text != "":
|
|
738
|
-
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
739
|
-
phase = chunk_dict["type"]
|
|
740
|
-
|
|
741
|
-
print(chunk_text, end="", flush=True)
|
|
742
|
-
if chunk.choices[0].finish_reason == "length":
|
|
743
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
744
|
-
|
|
745
|
-
print('\n')
|
|
746
|
-
|
|
747
|
-
else:
|
|
748
|
-
response = self.client.chat.completions.create(
|
|
749
|
-
model=self.model,
|
|
750
|
-
messages=processed_messages,
|
|
751
|
-
stream=False,
|
|
752
|
-
**self.formatted_params
|
|
753
|
-
)
|
|
754
|
-
res = self._format_response(response)
|
|
755
|
-
|
|
756
|
-
if response.choices[0].finish_reason == "length":
|
|
757
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
758
|
-
|
|
759
|
-
# Postprocess response
|
|
760
|
-
res_dict = self.config.postprocess_response(res)
|
|
761
|
-
# Write to messages log
|
|
762
|
-
if messages_logger:
|
|
763
|
-
# replace images content with a placeholder "[image]" to save space
|
|
764
|
-
for messages in processed_messages:
|
|
765
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
766
|
-
for content in messages["content"]:
|
|
767
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
768
|
-
content["image_url"]["url"] = "[image]"
|
|
769
|
-
|
|
770
|
-
processed_messages.append({"role": "assistant",
|
|
771
|
-
"content": res_dict.get("response", ""),
|
|
772
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
773
|
-
messages_logger.log_messages(processed_messages)
|
|
774
|
-
|
|
775
|
-
return res_dict
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
779
|
-
"""
|
|
780
|
-
Async version of chat method. Streaming is not supported.
|
|
781
|
-
"""
|
|
782
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
783
|
-
|
|
784
|
-
response = await self.async_client.chat.completions.create(
|
|
785
|
-
model=self.model,
|
|
786
|
-
messages=processed_messages,
|
|
787
|
-
stream=False,
|
|
788
|
-
**self.formatted_params
|
|
789
|
-
)
|
|
790
|
-
|
|
791
|
-
if response.choices[0].finish_reason == "length":
|
|
792
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
793
|
-
|
|
794
|
-
res = self._format_response(response)
|
|
795
|
-
|
|
796
|
-
# Postprocess response
|
|
797
|
-
res_dict = self.config.postprocess_response(res)
|
|
798
|
-
# Write to messages log
|
|
799
|
-
if messages_logger:
|
|
800
|
-
# replace images content with a placeholder "[image]" to save space
|
|
801
|
-
for messages in processed_messages:
|
|
802
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
803
|
-
for content in messages["content"]:
|
|
804
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
805
|
-
content["image_url"]["url"] = "[image]"
|
|
806
|
-
|
|
807
|
-
processed_messages.append({"role": "assistant",
|
|
808
|
-
"content": res_dict.get("response", ""),
|
|
809
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
810
|
-
messages_logger.log_messages(processed_messages)
|
|
811
|
-
|
|
812
|
-
return res_dict
|
|
813
|
-
|
|
85
|
+
class OpenAICompatibleVLMEngine(OpenAICompatibleInferenceEngine, VLMEngine):
|
|
814
86
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
815
87
|
detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
816
88
|
"""
|
|
@@ -879,274 +151,44 @@ class OpenAICompatibleVLMEngine(VLMEngine):
|
|
|
879
151
|
return output_messages
|
|
880
152
|
|
|
881
153
|
|
|
882
|
-
class VLLMVLMEngine(OpenAICompatibleVLMEngine):
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
921
|
-
if chunk_text is None:
|
|
922
|
-
chunk_text = ""
|
|
923
|
-
return {"type": "response", "data": chunk_text}
|
|
924
|
-
|
|
925
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
926
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
|
|
930
|
-
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
|
|
931
|
-
"""
|
|
932
|
-
OpenRouter OpenAI-compatible server inference engine.
|
|
933
|
-
|
|
934
|
-
Parameters:
|
|
935
|
-
----------
|
|
936
|
-
model_name : str
|
|
937
|
-
model name as shown in the vLLM server
|
|
938
|
-
api_key : str, Optional
|
|
939
|
-
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
940
|
-
base_url : str, Optional
|
|
941
|
-
the base url for the vLLM server.
|
|
942
|
-
config : LLMConfig
|
|
943
|
-
the LLM configuration.
|
|
944
|
-
"""
|
|
945
|
-
self.api_key = api_key
|
|
946
|
-
if self.api_key is None:
|
|
947
|
-
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
948
|
-
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
949
|
-
|
|
950
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
951
|
-
"""
|
|
952
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
953
|
-
|
|
954
|
-
Parameters:
|
|
955
|
-
----------
|
|
956
|
-
response : Any
|
|
957
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
958
|
-
"""
|
|
959
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
960
|
-
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
961
|
-
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
962
|
-
if chunk_text is None:
|
|
963
|
-
chunk_text = ""
|
|
964
|
-
return {"type": "reasoning", "data": chunk_text}
|
|
965
|
-
else:
|
|
966
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
967
|
-
if chunk_text is None:
|
|
968
|
-
chunk_text = ""
|
|
969
|
-
return {"type": "response", "data": chunk_text}
|
|
970
|
-
|
|
971
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
972
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
class OpenAIVLMEngine(VLMEngine):
|
|
976
|
-
def __init__(self, model:str, config:VLMConfig=None, **kwrs):
|
|
977
|
-
"""
|
|
978
|
-
The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
|
|
979
|
-
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
980
|
-
|
|
981
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
982
|
-
|
|
983
|
-
Parameters:
|
|
984
|
-
----------
|
|
985
|
-
model_name : str
|
|
986
|
-
model name as described in https://platform.openai.com/docs/models
|
|
987
|
-
config : VLMConfig, Optional
|
|
988
|
-
the VLM configuration. Must be a child class of VLMConfig.
|
|
989
|
-
"""
|
|
990
|
-
if importlib.util.find_spec("openai") is None:
|
|
991
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
992
|
-
|
|
993
|
-
from openai import OpenAI, AsyncOpenAI
|
|
994
|
-
self.client = OpenAI(**kwrs)
|
|
995
|
-
self.async_client = AsyncOpenAI(**kwrs)
|
|
996
|
-
self.model = model
|
|
997
|
-
self.config = config if config else BasicVLMConfig()
|
|
998
|
-
self.formatted_params = self._format_config()
|
|
999
|
-
|
|
1000
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
1001
|
-
"""
|
|
1002
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
1003
|
-
"""
|
|
1004
|
-
formatted_params = self.config.params.copy()
|
|
1005
|
-
if "max_new_tokens" in formatted_params:
|
|
1006
|
-
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
1007
|
-
formatted_params.pop("max_new_tokens")
|
|
1008
|
-
|
|
1009
|
-
return formatted_params
|
|
1010
|
-
|
|
1011
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
1012
|
-
"""
|
|
1013
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
1014
|
-
|
|
1015
|
-
Parameters:
|
|
1016
|
-
----------
|
|
1017
|
-
messages : List[Dict[str,str]]
|
|
1018
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
1019
|
-
verbose : bool, Optional
|
|
1020
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
1021
|
-
stream : bool, Optional
|
|
1022
|
-
if True, returns a generator that yields the output in real-time.
|
|
1023
|
-
messages_logger : MessagesLogger, Optional
|
|
1024
|
-
the message logger that logs the chat messages.
|
|
1025
|
-
|
|
1026
|
-
Returns:
|
|
1027
|
-
-------
|
|
1028
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
1029
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
1030
|
-
"""
|
|
1031
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1032
|
-
|
|
1033
|
-
if stream:
|
|
1034
|
-
def _stream_generator():
|
|
1035
|
-
response_stream = self.client.chat.completions.create(
|
|
1036
|
-
model=self.model,
|
|
1037
|
-
messages=processed_messages,
|
|
1038
|
-
stream=True,
|
|
1039
|
-
**self.formatted_params
|
|
1040
|
-
)
|
|
1041
|
-
res_text = ""
|
|
1042
|
-
for chunk in response_stream:
|
|
1043
|
-
if len(chunk.choices) > 0:
|
|
1044
|
-
chunk_text = chunk.choices[0].delta.content
|
|
1045
|
-
if chunk_text is not None:
|
|
1046
|
-
res_text += chunk_text
|
|
1047
|
-
yield chunk_text
|
|
1048
|
-
if chunk.choices[0].finish_reason == "length":
|
|
1049
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1050
|
-
|
|
1051
|
-
# Postprocess response
|
|
1052
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
1053
|
-
# Write to messages log
|
|
1054
|
-
if messages_logger:
|
|
1055
|
-
# replace images content with a placeholder "[image]" to save space
|
|
1056
|
-
for messages in processed_messages:
|
|
1057
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
1058
|
-
for content in messages["content"]:
|
|
1059
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1060
|
-
content["image_url"]["url"] = "[image]"
|
|
1061
|
-
|
|
1062
|
-
processed_messages.append({"role": "assistant",
|
|
1063
|
-
"content": res_dict.get("response", ""),
|
|
1064
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1065
|
-
messages_logger.log_messages(processed_messages)
|
|
1066
|
-
|
|
1067
|
-
return self.config.postprocess_response(_stream_generator())
|
|
1068
|
-
|
|
1069
|
-
elif verbose:
|
|
1070
|
-
response = self.client.chat.completions.create(
|
|
1071
|
-
model=self.model,
|
|
1072
|
-
messages=processed_messages,
|
|
1073
|
-
stream=True,
|
|
1074
|
-
**self.formatted_params
|
|
1075
|
-
)
|
|
1076
|
-
res = ''
|
|
1077
|
-
for chunk in response:
|
|
1078
|
-
if len(chunk.choices) > 0:
|
|
1079
|
-
if chunk.choices[0].delta.content is not None:
|
|
1080
|
-
res += chunk.choices[0].delta.content
|
|
1081
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
1082
|
-
if chunk.choices[0].finish_reason == "length":
|
|
1083
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1084
|
-
|
|
1085
|
-
print('\n')
|
|
1086
|
-
|
|
1087
|
-
else:
|
|
1088
|
-
response = self.client.chat.completions.create(
|
|
1089
|
-
model=self.model,
|
|
1090
|
-
messages=processed_messages,
|
|
1091
|
-
stream=False,
|
|
1092
|
-
**self.formatted_params
|
|
1093
|
-
)
|
|
1094
|
-
res = response.choices[0].message.content
|
|
1095
|
-
|
|
1096
|
-
# Postprocess response
|
|
1097
|
-
res_dict = self.config.postprocess_response(res)
|
|
1098
|
-
# Write to messages log
|
|
1099
|
-
if messages_logger:
|
|
1100
|
-
# replace images content with a placeholder "[image]" to save space
|
|
1101
|
-
for messages in processed_messages:
|
|
1102
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
1103
|
-
for content in messages["content"]:
|
|
1104
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1105
|
-
content["image_url"]["url"] = "[image]"
|
|
1106
|
-
|
|
1107
|
-
processed_messages.append({"role": "assistant",
|
|
1108
|
-
"content": res_dict.get("response", ""),
|
|
1109
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1110
|
-
messages_logger.log_messages(processed_messages)
|
|
1111
|
-
|
|
1112
|
-
return res_dict
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
1116
|
-
"""
|
|
1117
|
-
Async version of chat method. Streaming is not supported.
|
|
1118
|
-
"""
|
|
1119
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1120
|
-
|
|
1121
|
-
response = await self.async_client.chat.completions.create(
|
|
1122
|
-
model=self.model,
|
|
1123
|
-
messages=processed_messages,
|
|
1124
|
-
stream=False,
|
|
1125
|
-
**self.formatted_params
|
|
1126
|
-
)
|
|
1127
|
-
|
|
1128
|
-
if response.choices[0].finish_reason == "length":
|
|
1129
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1130
|
-
|
|
1131
|
-
res = response.choices[0].message.content
|
|
1132
|
-
# Postprocess response
|
|
1133
|
-
res_dict = self.config.postprocess_response(res)
|
|
1134
|
-
# Write to messages log
|
|
1135
|
-
if messages_logger:
|
|
1136
|
-
# replace images content with a placeholder "[image]" to save space
|
|
1137
|
-
for messages in processed_messages:
|
|
1138
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
1139
|
-
for content in messages["content"]:
|
|
1140
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1141
|
-
content["image_url"]["url"] = "[image]"
|
|
1142
|
-
|
|
1143
|
-
processed_messages.append({"role": "assistant",
|
|
1144
|
-
"content": res_dict.get("response", ""),
|
|
1145
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1146
|
-
messages_logger.log_messages(processed_messages)
|
|
1147
|
-
|
|
1148
|
-
return res_dict
|
|
1149
|
-
|
|
154
|
+
class VLLMVLMEngine(VLLMInferenceEngine, OpenAICompatibleVLMEngine):
|
|
155
|
+
"""
|
|
156
|
+
vLLM OpenAI compatible server inference engine.
|
|
157
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
158
|
+
|
|
159
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
160
|
+
|
|
161
|
+
Parameters:
|
|
162
|
+
----------
|
|
163
|
+
model_name : str
|
|
164
|
+
model name as shown in the vLLM server
|
|
165
|
+
api_key : str, Optional
|
|
166
|
+
the API key for the vLLM server.
|
|
167
|
+
base_url : str, Optional
|
|
168
|
+
the base url for the vLLM server.
|
|
169
|
+
config : LLMConfig
|
|
170
|
+
the LLM configuration.
|
|
171
|
+
"""
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
class OpenRouterVLMEngine(OpenRouterInferenceEngine, OpenAICompatibleVLMEngine):
|
|
175
|
+
"""
|
|
176
|
+
OpenRouter OpenAI-compatible server inference engine.
|
|
177
|
+
|
|
178
|
+
Parameters:
|
|
179
|
+
----------
|
|
180
|
+
model_name : str
|
|
181
|
+
model name as shown in the vLLM server
|
|
182
|
+
api_key : str, Optional
|
|
183
|
+
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
184
|
+
base_url : str, Optional
|
|
185
|
+
the base url for the vLLM server.
|
|
186
|
+
config : LLMConfig
|
|
187
|
+
the LLM configuration.
|
|
188
|
+
"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
class OpenAIVLMEngine(OpenAIInferenceEngine, VLMEngine):
|
|
1150
192
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
1151
193
|
detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
1152
194
|
"""
|
|
@@ -1215,32 +257,20 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
1215
257
|
return output_messages
|
|
1216
258
|
|
|
1217
259
|
|
|
1218
|
-
class AzureOpenAIVLMEngine(OpenAIVLMEngine):
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
model
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
if importlib.util.find_spec("openai") is None:
|
|
1236
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
1237
|
-
|
|
1238
|
-
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
1239
|
-
self.model = model
|
|
1240
|
-
self.api_version = api_version
|
|
1241
|
-
self.client = AzureOpenAI(api_version=self.api_version,
|
|
1242
|
-
**kwrs)
|
|
1243
|
-
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
1244
|
-
**kwrs)
|
|
1245
|
-
self.config = config if config else BasicVLMConfig()
|
|
1246
|
-
self.formatted_params = self._format_config()
|
|
260
|
+
class AzureOpenAIVLMEngine(AzureOpenAIInferenceEngine, OpenAIVLMEngine):
|
|
261
|
+
"""
|
|
262
|
+
The Azure OpenAI API inference engine.
|
|
263
|
+
For parameters and documentation, refer to
|
|
264
|
+
- https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
|
265
|
+
- https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
|
|
266
|
+
|
|
267
|
+
Parameters:
|
|
268
|
+
----------
|
|
269
|
+
model : str
|
|
270
|
+
model name as described in https://platform.openai.com/docs/models
|
|
271
|
+
api_version : str
|
|
272
|
+
the Azure OpenAI API version
|
|
273
|
+
config : LLMConfig
|
|
274
|
+
the LLM configuration.
|
|
275
|
+
"""
|
|
276
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: vlm4ocr
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary: Python package and Web App for OCR with vision language models.
|
|
5
5
|
License: MIT
|
|
6
6
|
Author: Enshuo (David) Hsu
|
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
12
12
|
Provides-Extra: tesseract
|
|
13
13
|
Requires-Dist: colorama (>=0.4.4)
|
|
14
14
|
Requires-Dist: json-repair (>=0.30.0)
|
|
15
|
+
Requires-Dist: llm-inference-engine (>=0.1.5)
|
|
15
16
|
Requires-Dist: pdf2image (>=1.16.0)
|
|
16
17
|
Requires-Dist: pillow (>=10.0.0)
|
|
17
18
|
Requires-Dist: pytesseract (>=0.3.13) ; extra == "tesseract"
|
|
@@ -8,10 +8,10 @@ vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt,sha256=WbLSOe
|
|
|
8
8
|
vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt,sha256=ftgNAIPy_UlrcY6m7-IkH2ApHkCzRnymra1w2wg60Ks,47
|
|
9
9
|
vlm4ocr/cli.py,sha256=qFFIynex4sQSmT9ryjO2fPrkWfMyy3Aefp0p99rD3lU,22741
|
|
10
10
|
vlm4ocr/data_types.py,sha256=DAlMl6UsfajVs7tIVjl2E8hT8BQg2fcOW7SDG12uIaA,5922
|
|
11
|
-
vlm4ocr/ocr_engines.py,sha256=
|
|
11
|
+
vlm4ocr/ocr_engines.py,sha256=R02usLqJtyNLf6kHu90LHote7Dw-wrh7dWWkqgfW8aE,26954
|
|
12
12
|
vlm4ocr/utils.py,sha256=nQhUskOze99wCVMKmvsen0dhq-9NdN4EPC_bdYfkjgA,13611
|
|
13
|
-
vlm4ocr/vlm_engines.py,sha256=
|
|
14
|
-
vlm4ocr-0.4.
|
|
15
|
-
vlm4ocr-0.4.
|
|
16
|
-
vlm4ocr-0.4.
|
|
17
|
-
vlm4ocr-0.4.
|
|
13
|
+
vlm4ocr/vlm_engines.py,sha256=Rv8-QcOBJYgjjaBLMHfAanQr9aTbH7rSCtuf9dH3lTc,10298
|
|
14
|
+
vlm4ocr-0.4.2.dist-info/METADATA,sha256=GswZXF6XUU2W8xEx61mxeRfYW98Tv9L-WSE2XDB7zy0,756
|
|
15
|
+
vlm4ocr-0.4.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
16
|
+
vlm4ocr-0.4.2.dist-info/entry_points.txt,sha256=qzWUk_QTZ12cH4DLjjfqce89EAlOydD85dreRRZF3K4,44
|
|
17
|
+
vlm4ocr-0.4.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|