vlm4ocr 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +2 -0
- vlm4ocr/cli.py +69 -15
- vlm4ocr/data_types.py +40 -2
- vlm4ocr/ocr_engines.py +56 -30
- vlm4ocr/vlm_engines.py +202 -1089
- {vlm4ocr-0.3.1.dist-info → vlm4ocr-0.4.1.dist-info}/METADATA +2 -1
- {vlm4ocr-0.3.1.dist-info → vlm4ocr-0.4.1.dist-info}/RECORD +9 -9
- {vlm4ocr-0.3.1.dist-info → vlm4ocr-0.4.1.dist-info}/WHEEL +0 -0
- {vlm4ocr-0.3.1.dist-info → vlm4ocr-0.4.1.dist-info}/entry_points.txt +0 -0
vlm4ocr/vlm_engines.py
CHANGED
|
@@ -1,338 +1,29 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
import
|
|
3
|
-
from typing import Any, List, Dict, Union, Generator
|
|
4
|
-
import warnings
|
|
5
|
-
import os
|
|
6
|
-
import re
|
|
2
|
+
from typing import List, Dict
|
|
7
3
|
from PIL import Image
|
|
8
4
|
from vlm4ocr.utils import image_to_base64
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
Returns:
|
|
31
|
-
-------
|
|
32
|
-
messages : List[Dict[str,str]]
|
|
33
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
34
|
-
"""
|
|
35
|
-
return NotImplemented
|
|
36
|
-
|
|
37
|
-
@abc.abstractmethod
|
|
38
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
39
|
-
"""
|
|
40
|
-
This method postprocesses the VLM response after it is generated.
|
|
41
|
-
|
|
42
|
-
Parameters:
|
|
43
|
-
----------
|
|
44
|
-
response : Union[str, Generator[str, None, None]]
|
|
45
|
-
the VLM response. Can be a string or a generator.
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
-------
|
|
49
|
-
response : str
|
|
50
|
-
the postprocessed VLM response
|
|
51
|
-
"""
|
|
52
|
-
return NotImplemented
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class BasicVLMConfig(VLMConfig):
|
|
56
|
-
def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
|
|
57
|
-
"""
|
|
58
|
-
The basic VLM configuration for most non-reasoning models.
|
|
59
|
-
"""
|
|
60
|
-
super().__init__(**kwargs)
|
|
61
|
-
self.max_new_tokens = max_new_tokens
|
|
62
|
-
self.temperature = temperature
|
|
63
|
-
self.params["max_new_tokens"] = self.max_new_tokens
|
|
64
|
-
self.params["temperature"] = self.temperature
|
|
65
|
-
|
|
66
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
67
|
-
"""
|
|
68
|
-
This method preprocesses the input messages before passing them to the VLM.
|
|
69
|
-
|
|
70
|
-
Parameters:
|
|
71
|
-
----------
|
|
72
|
-
messages : List[Dict[str,str]]
|
|
73
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
-------
|
|
77
|
-
messages : List[Dict[str,str]]
|
|
78
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
79
|
-
"""
|
|
80
|
-
return messages
|
|
81
|
-
|
|
82
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
83
|
-
"""
|
|
84
|
-
This method postprocesses the VLM response after it is generated.
|
|
85
|
-
|
|
86
|
-
Parameters:
|
|
87
|
-
----------
|
|
88
|
-
response : Union[str, Generator[str, None, None]]
|
|
89
|
-
the VLM response. Can be a string or a generator.
|
|
90
|
-
|
|
91
|
-
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
92
|
-
the postprocessed VLM response.
|
|
93
|
-
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
94
|
-
"""
|
|
95
|
-
if isinstance(response, str):
|
|
96
|
-
return {"response": response}
|
|
97
|
-
|
|
98
|
-
elif isinstance(response, dict):
|
|
99
|
-
if "response" in response:
|
|
100
|
-
return response
|
|
101
|
-
else:
|
|
102
|
-
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
103
|
-
return {"response": ""}
|
|
104
|
-
|
|
105
|
-
def _process_stream():
|
|
106
|
-
for chunk in response:
|
|
107
|
-
if isinstance(chunk, dict):
|
|
108
|
-
yield chunk
|
|
109
|
-
elif isinstance(chunk, str):
|
|
110
|
-
yield {"type": "response", "data": chunk}
|
|
111
|
-
|
|
112
|
-
return _process_stream()
|
|
113
|
-
|
|
114
|
-
class ReasoningVLMConfig(VLMConfig):
|
|
115
|
-
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
116
|
-
"""
|
|
117
|
-
The general configuration for reasoning vision models.
|
|
118
|
-
"""
|
|
119
|
-
super().__init__(**kwargs)
|
|
120
|
-
self.thinking_token_start = thinking_token_start
|
|
121
|
-
self.thinking_token_end = thinking_token_end
|
|
122
|
-
|
|
123
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
124
|
-
"""
|
|
125
|
-
This method preprocesses the input messages before passing them to the VLM.
|
|
126
|
-
|
|
127
|
-
Parameters:
|
|
128
|
-
----------
|
|
129
|
-
messages : List[Dict[str,str]]
|
|
130
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
-------
|
|
134
|
-
messages : List[Dict[str,str]]
|
|
135
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
136
|
-
"""
|
|
137
|
-
return messages.copy()
|
|
138
|
-
|
|
139
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
140
|
-
"""
|
|
141
|
-
This method postprocesses the VLM response after it is generated.
|
|
142
|
-
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
143
|
-
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
144
|
-
3. If input is a generator,
|
|
145
|
-
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
146
|
-
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
147
|
-
|
|
148
|
-
Parameters:
|
|
149
|
-
----------
|
|
150
|
-
response : Union[str, Generator[str, None, None]]
|
|
151
|
-
the VLM response. Can be a string or a generator.
|
|
152
|
-
|
|
153
|
-
Returns:
|
|
154
|
-
-------
|
|
155
|
-
response : Union[str, Generator[str, None, None]]
|
|
156
|
-
the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
|
|
157
|
-
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
158
|
-
"""
|
|
159
|
-
if isinstance(response, str):
|
|
160
|
-
# get contents between thinking_token_start and thinking_token_end
|
|
161
|
-
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
162
|
-
match = re.search(pattern, response, re.DOTALL)
|
|
163
|
-
reasoning = match.group(1) if match else ""
|
|
164
|
-
# get response AFTER thinking_token_end
|
|
165
|
-
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
166
|
-
return {"reasoning": reasoning, "response": response}
|
|
167
|
-
|
|
168
|
-
elif isinstance(response, dict):
|
|
169
|
-
if "reasoning" in response and "response" in response:
|
|
170
|
-
return response
|
|
171
|
-
else:
|
|
172
|
-
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
173
|
-
return {"reasoning": "", "response": ""}
|
|
174
|
-
|
|
175
|
-
elif isinstance(response, Generator):
|
|
176
|
-
def _process_stream():
|
|
177
|
-
think_flag = False
|
|
178
|
-
buffer = ""
|
|
179
|
-
for chunk in response:
|
|
180
|
-
if isinstance(chunk, dict):
|
|
181
|
-
yield chunk
|
|
182
|
-
|
|
183
|
-
elif isinstance(chunk, str):
|
|
184
|
-
buffer += chunk
|
|
185
|
-
# switch between reasoning and response
|
|
186
|
-
if self.thinking_token_start in buffer:
|
|
187
|
-
think_flag = True
|
|
188
|
-
buffer = buffer.replace(self.thinking_token_start, "")
|
|
189
|
-
elif self.thinking_token_end in buffer:
|
|
190
|
-
think_flag = False
|
|
191
|
-
buffer = buffer.replace(self.thinking_token_end, "")
|
|
192
|
-
|
|
193
|
-
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
194
|
-
if chunk not in [self.thinking_token_start, self.thinking_token_end]:
|
|
195
|
-
if think_flag:
|
|
196
|
-
yield {"type": "reasoning", "data": chunk}
|
|
197
|
-
else:
|
|
198
|
-
yield {"type": "response", "data": chunk}
|
|
199
|
-
|
|
200
|
-
return _process_stream()
|
|
201
|
-
|
|
202
|
-
else:
|
|
203
|
-
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
204
|
-
return {"reasoning": "", "response": ""}
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
|
|
208
|
-
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
209
|
-
"""
|
|
210
|
-
The OpenAI "o" series configuration.
|
|
211
|
-
1. The reasoning effort is set to "low" by default.
|
|
212
|
-
2. The temperature parameter is not supported and will be ignored.
|
|
213
|
-
3. The system prompt is not supported and will be concatenated to the next user prompt.
|
|
214
|
-
|
|
215
|
-
Parameters:
|
|
216
|
-
----------
|
|
217
|
-
reasoning_effort : str, Optional
|
|
218
|
-
the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
|
|
219
|
-
"""
|
|
220
|
-
super().__init__(**kwargs)
|
|
221
|
-
if reasoning_effort not in ["low", "medium", "high"]:
|
|
222
|
-
raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
|
|
223
|
-
|
|
224
|
-
self.reasoning_effort = reasoning_effort
|
|
225
|
-
self.params["reasoning_effort"] = self.reasoning_effort
|
|
226
|
-
|
|
227
|
-
if "temperature" in self.params:
|
|
228
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
229
|
-
self.params.pop("temperature")
|
|
230
|
-
|
|
231
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
232
|
-
"""
|
|
233
|
-
Concatenate system prompts to the next user prompt.
|
|
234
|
-
|
|
235
|
-
Parameters:
|
|
236
|
-
----------
|
|
237
|
-
messages : List[Dict[str,str]]
|
|
238
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
239
|
-
|
|
240
|
-
Returns:
|
|
241
|
-
-------
|
|
242
|
-
messages : List[Dict[str,str]]
|
|
243
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
244
|
-
"""
|
|
245
|
-
system_prompt_holder = ""
|
|
246
|
-
new_messages = []
|
|
247
|
-
for i, message in enumerate(messages):
|
|
248
|
-
# if system prompt, store it in system_prompt_holder
|
|
249
|
-
if message['role'] == 'system':
|
|
250
|
-
system_prompt_holder = message['content']
|
|
251
|
-
# if user prompt, concatenate it with system_prompt_holder
|
|
252
|
-
elif message['role'] == 'user':
|
|
253
|
-
if system_prompt_holder:
|
|
254
|
-
new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
|
|
255
|
-
system_prompt_holder = ""
|
|
256
|
-
else:
|
|
257
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
258
|
-
|
|
259
|
-
new_messages.append(new_message)
|
|
260
|
-
# if assistant/other prompt, do nothing
|
|
261
|
-
else:
|
|
262
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
263
|
-
new_messages.append(new_message)
|
|
264
|
-
|
|
265
|
-
return new_messages
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
class MessagesLogger:
|
|
269
|
-
def __init__(self):
|
|
270
|
-
"""
|
|
271
|
-
This class is used to log the messages for InferenceEngine.chat().
|
|
272
|
-
"""
|
|
273
|
-
self.messages_log = []
|
|
274
|
-
|
|
275
|
-
def log_messages(self, messages : List[Dict[str,str]]):
|
|
276
|
-
"""
|
|
277
|
-
This method logs the messages to a list.
|
|
278
|
-
"""
|
|
279
|
-
self.messages_log.append(messages)
|
|
280
|
-
|
|
281
|
-
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
282
|
-
"""
|
|
283
|
-
This method returns a copy of the current messages log
|
|
284
|
-
"""
|
|
285
|
-
return self.messages_log.copy()
|
|
286
|
-
|
|
287
|
-
def clear_messages_log(self):
|
|
288
|
-
"""
|
|
289
|
-
This method clears the current messages log
|
|
290
|
-
"""
|
|
291
|
-
self.messages_log.clear()
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
class VLMEngine:
|
|
295
|
-
@abc.abstractmethod
|
|
296
|
-
def __init__(self, config:VLMConfig, **kwrs):
|
|
297
|
-
"""
|
|
298
|
-
This is an abstract class to provide interfaces for VLM inference engines.
|
|
299
|
-
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
300
|
-
|
|
301
|
-
Parameters:
|
|
302
|
-
----------
|
|
303
|
-
config : VLMConfig
|
|
304
|
-
the VLM configuration. Must be a child class of VLMConfig.
|
|
305
|
-
"""
|
|
306
|
-
return NotImplemented
|
|
307
|
-
|
|
308
|
-
@abc.abstractmethod
|
|
309
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
310
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
311
|
-
"""
|
|
312
|
-
This method inputs chat messages and outputs VLM generated text.
|
|
313
|
-
|
|
314
|
-
Parameters:
|
|
315
|
-
----------
|
|
316
|
-
messages : List[Dict[str,str]]
|
|
317
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
318
|
-
verbose : bool, Optional
|
|
319
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
320
|
-
stream : bool, Optional
|
|
321
|
-
if True, returns a generator that yields the output in real-time.
|
|
322
|
-
Messages_logger : MessagesLogger, Optional
|
|
323
|
-
the message logger that logs the chat messages.
|
|
324
|
-
"""
|
|
325
|
-
return NotImplemented
|
|
326
|
-
|
|
5
|
+
from vlm4ocr.data_types import FewShotExample
|
|
6
|
+
from llm_inference_engine.llm_configs import (
|
|
7
|
+
LLMConfig as VLMConfig,
|
|
8
|
+
BasicLLMConfig as BasicVLMConfig,
|
|
9
|
+
ReasoningLLMConfig as ReasoningVLMConfig,
|
|
10
|
+
OpenAIReasoningLLMConfig as OpenAIReasoningVLMConfig
|
|
11
|
+
)
|
|
12
|
+
from llm_inference_engine.utils import MessagesLogger
|
|
13
|
+
from llm_inference_engine.engines import (
|
|
14
|
+
InferenceEngine,
|
|
15
|
+
OllamaInferenceEngine,
|
|
16
|
+
OpenAICompatibleInferenceEngine,
|
|
17
|
+
VLLMInferenceEngine,
|
|
18
|
+
OpenRouterInferenceEngine,
|
|
19
|
+
OpenAIInferenceEngine,
|
|
20
|
+
AzureOpenAIInferenceEngine,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VLMEngine(InferenceEngine):
|
|
327
25
|
@abc.abstractmethod
|
|
328
|
-
def
|
|
329
|
-
"""
|
|
330
|
-
The async version of chat method. Streaming is not supported.
|
|
331
|
-
"""
|
|
332
|
-
return NotImplemented
|
|
333
|
-
|
|
334
|
-
@abc.abstractmethod
|
|
335
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
26
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
336
27
|
"""
|
|
337
28
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
338
29
|
|
|
@@ -344,220 +35,14 @@ class VLMEngine:
|
|
|
344
35
|
the user prompt.
|
|
345
36
|
image : Image.Image
|
|
346
37
|
the image for OCR.
|
|
38
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
39
|
+
list of few-shot examples.
|
|
347
40
|
"""
|
|
348
41
|
return NotImplemented
|
|
349
|
-
|
|
350
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
351
|
-
"""
|
|
352
|
-
This method format the VLM configuration with the correct key for the inference engine.
|
|
353
|
-
|
|
354
|
-
Return : Dict[str, Any]
|
|
355
|
-
the config parameters.
|
|
356
|
-
"""
|
|
357
|
-
return NotImplemented
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
class OllamaVLMEngine(VLMEngine):
|
|
361
|
-
def __init__(self, model_name:str, num_ctx:int=8192, keep_alive:int=300, config:VLMConfig=None, **kwrs):
|
|
362
|
-
"""
|
|
363
|
-
The Ollama inference engine.
|
|
364
|
-
|
|
365
|
-
Parameters:
|
|
366
|
-
----------
|
|
367
|
-
model_name : str
|
|
368
|
-
the model name exactly as shown in >> ollama ls
|
|
369
|
-
num_ctx : int, Optional
|
|
370
|
-
context length that LLM will evaluate.
|
|
371
|
-
keep_alive : int, Optional
|
|
372
|
-
seconds to hold the LLM after the last API call.
|
|
373
|
-
config : LLMConfig
|
|
374
|
-
the LLM configuration.
|
|
375
|
-
"""
|
|
376
|
-
if importlib.util.find_spec("ollama") is None:
|
|
377
|
-
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
378
|
-
|
|
379
|
-
from ollama import Client, AsyncClient
|
|
380
|
-
self.client = Client(**kwrs)
|
|
381
|
-
self.async_client = AsyncClient(**kwrs)
|
|
382
|
-
self.model_name = model_name
|
|
383
|
-
self.num_ctx = num_ctx
|
|
384
|
-
self.keep_alive = keep_alive
|
|
385
|
-
self.config = config if config else BasicVLMConfig()
|
|
386
|
-
self.formatted_params = self._format_config()
|
|
387
|
-
|
|
388
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
389
|
-
"""
|
|
390
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
391
|
-
"""
|
|
392
|
-
formatted_params = self.config.params.copy()
|
|
393
|
-
if "max_new_tokens" in formatted_params:
|
|
394
|
-
formatted_params["num_predict"] = formatted_params["max_new_tokens"]
|
|
395
|
-
formatted_params.pop("max_new_tokens")
|
|
396
42
|
|
|
397
|
-
return formatted_params
|
|
398
43
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
"""
|
|
402
|
-
This method inputs chat messages and outputs VLM generated text.
|
|
403
|
-
|
|
404
|
-
Parameters:
|
|
405
|
-
----------
|
|
406
|
-
messages : List[Dict[str,str]]
|
|
407
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
408
|
-
verbose : bool, Optional
|
|
409
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
410
|
-
stream : bool, Optional
|
|
411
|
-
if True, returns a generator that yields the output in real-time.
|
|
412
|
-
Messages_logger : MessagesLogger, Optional
|
|
413
|
-
the message logger that logs the chat messages.
|
|
414
|
-
|
|
415
|
-
Returns:
|
|
416
|
-
-------
|
|
417
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
418
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
419
|
-
"""
|
|
420
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
421
|
-
|
|
422
|
-
options={'num_ctx': self.num_ctx, **self.formatted_params}
|
|
423
|
-
if stream:
|
|
424
|
-
def _stream_generator():
|
|
425
|
-
response_stream = self.client.chat(
|
|
426
|
-
model=self.model_name,
|
|
427
|
-
messages=processed_messages,
|
|
428
|
-
options=options,
|
|
429
|
-
stream=True,
|
|
430
|
-
keep_alive=self.keep_alive
|
|
431
|
-
)
|
|
432
|
-
res = {"reasoning": "", "response": ""}
|
|
433
|
-
for chunk in response_stream:
|
|
434
|
-
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
435
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
436
|
-
res["reasoning"] += content_chunk
|
|
437
|
-
yield {"type": "reasoning", "data": content_chunk}
|
|
438
|
-
else:
|
|
439
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
440
|
-
res["response"] += content_chunk
|
|
441
|
-
yield {"type": "response", "data": content_chunk}
|
|
442
|
-
|
|
443
|
-
if chunk.done_reason == "length":
|
|
444
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
445
|
-
|
|
446
|
-
# Postprocess response
|
|
447
|
-
res_dict = self.config.postprocess_response(res)
|
|
448
|
-
# Write to messages log
|
|
449
|
-
if messages_logger:
|
|
450
|
-
# replace images content with a placeholder "[image]" to save space
|
|
451
|
-
for messages in processed_messages:
|
|
452
|
-
if "images" in messages:
|
|
453
|
-
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
454
|
-
|
|
455
|
-
processed_messages.append({"role": "assistant",
|
|
456
|
-
"content": res_dict.get("response", ""),
|
|
457
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
458
|
-
messages_logger.log_messages(processed_messages)
|
|
459
|
-
|
|
460
|
-
return self.config.postprocess_response(_stream_generator())
|
|
461
|
-
|
|
462
|
-
elif verbose:
|
|
463
|
-
response = self.client.chat(
|
|
464
|
-
model=self.model_name,
|
|
465
|
-
messages=processed_messages,
|
|
466
|
-
options=options,
|
|
467
|
-
stream=True,
|
|
468
|
-
keep_alive=self.keep_alive
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
res = {"reasoning": "", "response": ""}
|
|
472
|
-
phase = ""
|
|
473
|
-
for chunk in response:
|
|
474
|
-
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
475
|
-
if phase != "reasoning":
|
|
476
|
-
print("\n--- Reasoning ---")
|
|
477
|
-
phase = "reasoning"
|
|
478
|
-
|
|
479
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
480
|
-
res["reasoning"] += content_chunk
|
|
481
|
-
else:
|
|
482
|
-
if phase != "response":
|
|
483
|
-
print("\n--- Response ---")
|
|
484
|
-
phase = "response"
|
|
485
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
486
|
-
res["response"] += content_chunk
|
|
487
|
-
|
|
488
|
-
print(content_chunk, end='', flush=True)
|
|
489
|
-
|
|
490
|
-
if chunk.done_reason == "length":
|
|
491
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
492
|
-
print('\n')
|
|
493
|
-
|
|
494
|
-
else:
|
|
495
|
-
response = self.client.chat(
|
|
496
|
-
model=self.model_name,
|
|
497
|
-
messages=processed_messages,
|
|
498
|
-
options=options,
|
|
499
|
-
stream=False,
|
|
500
|
-
keep_alive=self.keep_alive
|
|
501
|
-
)
|
|
502
|
-
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
503
|
-
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
504
|
-
|
|
505
|
-
if response.done_reason == "length":
|
|
506
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
507
|
-
|
|
508
|
-
# Postprocess response
|
|
509
|
-
res_dict = self.config.postprocess_response(res)
|
|
510
|
-
# Write to messages log
|
|
511
|
-
if messages_logger:
|
|
512
|
-
# replace images content with a placeholder "[image]" to save space
|
|
513
|
-
for messages in processed_messages:
|
|
514
|
-
if "images" in messages:
|
|
515
|
-
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
516
|
-
|
|
517
|
-
processed_messages.append({"role": "assistant",
|
|
518
|
-
"content": res_dict.get("response", ""),
|
|
519
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
520
|
-
messages_logger.log_messages(processed_messages)
|
|
521
|
-
|
|
522
|
-
return res_dict
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
526
|
-
"""
|
|
527
|
-
Async version of chat method. Streaming is not supported.
|
|
528
|
-
"""
|
|
529
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
530
|
-
|
|
531
|
-
response = await self.async_client.chat(
|
|
532
|
-
model=self.model_name,
|
|
533
|
-
messages=processed_messages,
|
|
534
|
-
options={'num_ctx': self.num_ctx, **self.formatted_params},
|
|
535
|
-
stream=False,
|
|
536
|
-
keep_alive=self.keep_alive
|
|
537
|
-
)
|
|
538
|
-
|
|
539
|
-
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
540
|
-
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
541
|
-
|
|
542
|
-
if response.done_reason == "length":
|
|
543
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
544
|
-
# Postprocess response
|
|
545
|
-
res_dict = self.config.postprocess_response(res)
|
|
546
|
-
# Write to messages log
|
|
547
|
-
if messages_logger:
|
|
548
|
-
# replace images content with a placeholder "[image]" to save space
|
|
549
|
-
for messages in processed_messages:
|
|
550
|
-
if "images" in messages:
|
|
551
|
-
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
552
|
-
|
|
553
|
-
processed_messages.append({"role": "assistant",
|
|
554
|
-
"content": res_dict.get("response", ""),
|
|
555
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
556
|
-
messages_logger.log_messages(processed_messages)
|
|
557
|
-
|
|
558
|
-
return res_dict
|
|
559
|
-
|
|
560
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
44
|
+
class OllamaVLMEngine(OllamaInferenceEngine, VLMEngine):
|
|
45
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
561
46
|
"""
|
|
562
47
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
563
48
|
|
|
@@ -569,230 +54,37 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
569
54
|
the user prompt.
|
|
570
55
|
image : Image.Image
|
|
571
56
|
the image for OCR.
|
|
57
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
58
|
+
list of few-shot examples.
|
|
572
59
|
"""
|
|
573
60
|
base64_str = image_to_base64(image)
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
591
|
-
|
|
592
|
-
Parameters:
|
|
593
|
-
----------
|
|
594
|
-
model_name : str
|
|
595
|
-
model name as shown in the vLLM server
|
|
596
|
-
api_key : str
|
|
597
|
-
the API key for the vLLM server.
|
|
598
|
-
base_url : str
|
|
599
|
-
the base url for the vLLM server.
|
|
600
|
-
config : LLMConfig
|
|
601
|
-
the LLM configuration.
|
|
602
|
-
"""
|
|
603
|
-
if importlib.util.find_spec("openai") is None:
|
|
604
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
605
|
-
|
|
606
|
-
from openai import OpenAI, AsyncOpenAI
|
|
607
|
-
from openai.types.chat import ChatCompletionChunk
|
|
608
|
-
self.ChatCompletionChunk = ChatCompletionChunk
|
|
609
|
-
super().__init__(config)
|
|
610
|
-
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
611
|
-
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
612
|
-
self.model = model
|
|
613
|
-
self.config = config if config else BasicVLMConfig()
|
|
614
|
-
self.formatted_params = self._format_config()
|
|
615
|
-
|
|
616
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
617
|
-
"""
|
|
618
|
-
This method format the VLM configuration with the correct key for the inference engine.
|
|
619
|
-
"""
|
|
620
|
-
formatted_params = self.config.params.copy()
|
|
621
|
-
if "max_new_tokens" in formatted_params:
|
|
622
|
-
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
623
|
-
formatted_params.pop("max_new_tokens")
|
|
624
|
-
|
|
625
|
-
return formatted_params
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
629
|
-
"""
|
|
630
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
631
|
-
|
|
632
|
-
Parameters:
|
|
633
|
-
----------
|
|
634
|
-
response : Any
|
|
635
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
636
|
-
"""
|
|
637
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
638
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
639
|
-
if chunk_text is None:
|
|
640
|
-
chunk_text = ""
|
|
641
|
-
return {"type": "response", "data": chunk_text}
|
|
642
|
-
|
|
643
|
-
return {"response": getattr(response.choices[0].message, "content", "")}
|
|
644
|
-
|
|
645
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
646
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
647
|
-
"""
|
|
648
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
649
|
-
|
|
650
|
-
Parameters:
|
|
651
|
-
----------
|
|
652
|
-
messages : List[Dict[str,str]]
|
|
653
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
654
|
-
verbose : bool, Optional
|
|
655
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
656
|
-
stream : bool, Optional
|
|
657
|
-
if True, returns a generator that yields the output in real-time.
|
|
658
|
-
messages_logger : MessagesLogger, Optional
|
|
659
|
-
the message logger that logs the chat messages.
|
|
660
|
-
|
|
661
|
-
Returns:
|
|
662
|
-
-------
|
|
663
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
664
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
665
|
-
"""
|
|
666
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
667
|
-
|
|
668
|
-
if stream:
|
|
669
|
-
def _stream_generator():
|
|
670
|
-
response_stream = self.client.chat.completions.create(
|
|
671
|
-
model=self.model,
|
|
672
|
-
messages=processed_messages,
|
|
673
|
-
stream=True,
|
|
674
|
-
**self.formatted_params
|
|
675
|
-
)
|
|
676
|
-
res_text = ""
|
|
677
|
-
for chunk in response_stream:
|
|
678
|
-
if len(chunk.choices) > 0:
|
|
679
|
-
chunk_dict = self._format_response(chunk)
|
|
680
|
-
yield chunk_dict
|
|
681
|
-
|
|
682
|
-
res_text += chunk_dict["data"]
|
|
683
|
-
if chunk.choices[0].finish_reason == "length":
|
|
684
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
685
|
-
|
|
686
|
-
# Postprocess response
|
|
687
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
688
|
-
# Write to messages log
|
|
689
|
-
if messages_logger:
|
|
690
|
-
# replace images content with a placeholder "[image]" to save space
|
|
691
|
-
for messages in processed_messages:
|
|
692
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
693
|
-
for content in messages["content"]:
|
|
694
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
695
|
-
content["image_url"]["url"] = "[image]"
|
|
696
|
-
|
|
697
|
-
processed_messages.append({"role": "assistant",
|
|
698
|
-
"content": res_dict.get("response", ""),
|
|
699
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
700
|
-
messages_logger.log_messages(processed_messages)
|
|
701
|
-
|
|
702
|
-
return self.config.postprocess_response(_stream_generator())
|
|
703
|
-
|
|
704
|
-
elif verbose:
|
|
705
|
-
response = self.client.chat.completions.create(
|
|
706
|
-
model=self.model,
|
|
707
|
-
messages=processed_messages,
|
|
708
|
-
stream=True,
|
|
709
|
-
**self.formatted_params
|
|
710
|
-
)
|
|
711
|
-
res = {"reasoning": "", "response": ""}
|
|
712
|
-
phase = ""
|
|
713
|
-
for chunk in response:
|
|
714
|
-
if len(chunk.choices) > 0:
|
|
715
|
-
chunk_dict = self._format_response(chunk)
|
|
716
|
-
chunk_text = chunk_dict["data"]
|
|
717
|
-
res[chunk_dict["type"]] += chunk_text
|
|
718
|
-
if phase != chunk_dict["type"] and chunk_text != "":
|
|
719
|
-
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
720
|
-
phase = chunk_dict["type"]
|
|
721
|
-
|
|
722
|
-
print(chunk_text, end="", flush=True)
|
|
723
|
-
if chunk.choices[0].finish_reason == "length":
|
|
724
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
725
|
-
|
|
726
|
-
print('\n')
|
|
727
|
-
|
|
728
|
-
else:
|
|
729
|
-
response = self.client.chat.completions.create(
|
|
730
|
-
model=self.model,
|
|
731
|
-
messages=processed_messages,
|
|
732
|
-
stream=False,
|
|
733
|
-
**self.formatted_params
|
|
734
|
-
)
|
|
735
|
-
res = self._format_response(response)
|
|
61
|
+
output_messages = []
|
|
62
|
+
# system message
|
|
63
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
64
|
+
output_messages.append(system_message)
|
|
65
|
+
|
|
66
|
+
# few-shot examples
|
|
67
|
+
if few_shot_examples is not None:
|
|
68
|
+
for example in few_shot_examples:
|
|
69
|
+
if not isinstance(example, FewShotExample):
|
|
70
|
+
raise ValueError("Few-shot example must be a FewShotExample object.")
|
|
71
|
+
|
|
72
|
+
example_image_b64 = image_to_base64(example.image)
|
|
73
|
+
example_user_message = {"role": "user", "content": user_prompt, "images": [example_image_b64]}
|
|
74
|
+
example_agent_message = {"role": "assistant", "content": example.text}
|
|
75
|
+
output_messages.append(example_user_message)
|
|
76
|
+
output_messages.append(example_agent_message)
|
|
736
77
|
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
# Postprocess response
|
|
741
|
-
res_dict = self.config.postprocess_response(res)
|
|
742
|
-
# Write to messages log
|
|
743
|
-
if messages_logger:
|
|
744
|
-
# replace images content with a placeholder "[image]" to save space
|
|
745
|
-
for messages in processed_messages:
|
|
746
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
747
|
-
for content in messages["content"]:
|
|
748
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
749
|
-
content["image_url"]["url"] = "[image]"
|
|
78
|
+
# user message
|
|
79
|
+
user_message = {"role": "user", "content": user_prompt, "images": [base64_str]}
|
|
80
|
+
output_messages.append(user_message)
|
|
750
81
|
|
|
751
|
-
|
|
752
|
-
"content": res_dict.get("response", ""),
|
|
753
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
754
|
-
messages_logger.log_messages(processed_messages)
|
|
82
|
+
return output_messages
|
|
755
83
|
|
|
756
|
-
return res_dict
|
|
757
|
-
|
|
758
84
|
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
"""
|
|
763
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
764
|
-
|
|
765
|
-
response = await self.async_client.chat.completions.create(
|
|
766
|
-
model=self.model,
|
|
767
|
-
messages=processed_messages,
|
|
768
|
-
stream=False,
|
|
769
|
-
**self.formatted_params
|
|
770
|
-
)
|
|
771
|
-
|
|
772
|
-
if response.choices[0].finish_reason == "length":
|
|
773
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
774
|
-
|
|
775
|
-
res = self._format_response(response)
|
|
776
|
-
|
|
777
|
-
# Postprocess response
|
|
778
|
-
res_dict = self.config.postprocess_response(res)
|
|
779
|
-
# Write to messages log
|
|
780
|
-
if messages_logger:
|
|
781
|
-
# replace images content with a placeholder "[image]" to save space
|
|
782
|
-
for messages in processed_messages:
|
|
783
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
784
|
-
for content in messages["content"]:
|
|
785
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
786
|
-
content["image_url"]["url"] = "[image]"
|
|
787
|
-
|
|
788
|
-
processed_messages.append({"role": "assistant",
|
|
789
|
-
"content": res_dict.get("response", ""),
|
|
790
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
791
|
-
messages_logger.log_messages(processed_messages)
|
|
792
|
-
|
|
793
|
-
return res_dict
|
|
794
|
-
|
|
795
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
|
|
85
|
+
class OpenAICompatibleVLMEngine(OpenAICompatibleInferenceEngine, VLMEngine):
|
|
86
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
87
|
+
detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
796
88
|
"""
|
|
797
89
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
798
90
|
|
|
@@ -808,295 +100,97 @@ class OpenAICompatibleVLMEngine(VLMEngine):
|
|
|
808
100
|
the image format.
|
|
809
101
|
detail : str, Optional
|
|
810
102
|
the detail level of the image. Default is "high".
|
|
103
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
104
|
+
list of few-shot examples.
|
|
811
105
|
"""
|
|
812
106
|
base64_str = image_to_base64(image)
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
107
|
+
output_messages = []
|
|
108
|
+
# system message
|
|
109
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
110
|
+
output_messages.append(system_message)
|
|
111
|
+
|
|
112
|
+
# few-shot examples
|
|
113
|
+
if few_shot_examples is not None:
|
|
114
|
+
for example in few_shot_examples:
|
|
115
|
+
if not isinstance(example, FewShotExample):
|
|
116
|
+
raise ValueError("Few-shot example must be a FewShotExample object.")
|
|
117
|
+
|
|
118
|
+
example_image_b64 = image_to_base64(example.image)
|
|
119
|
+
example_user_message = {
|
|
120
|
+
"role": "user",
|
|
121
|
+
"content": [
|
|
122
|
+
{
|
|
123
|
+
"type": "image_url",
|
|
124
|
+
"image_url": {
|
|
125
|
+
"url": f"data:image/{format};base64,{example_image_b64}",
|
|
126
|
+
"detail": detail
|
|
127
|
+
},
|
|
823
128
|
},
|
|
129
|
+
{"type": "text", "text": user_prompt},
|
|
130
|
+
],
|
|
131
|
+
}
|
|
132
|
+
example_agent_message = {"role": "assistant", "content": example.text}
|
|
133
|
+
output_messages.append(example_user_message)
|
|
134
|
+
output_messages.append(example_agent_message)
|
|
135
|
+
|
|
136
|
+
# user message
|
|
137
|
+
user_message = {
|
|
138
|
+
"role": "user",
|
|
139
|
+
"content": [
|
|
140
|
+
{
|
|
141
|
+
"type": "image_url",
|
|
142
|
+
"image_url": {
|
|
143
|
+
"url": f"data:image/{format};base64,{base64_str}",
|
|
144
|
+
"detail": detail
|
|
824
145
|
},
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
875
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
|
|
879
|
-
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
|
|
880
|
-
"""
|
|
881
|
-
OpenRouter OpenAI-compatible server inference engine.
|
|
882
|
-
|
|
883
|
-
Parameters:
|
|
884
|
-
----------
|
|
885
|
-
model_name : str
|
|
886
|
-
model name as shown in the vLLM server
|
|
887
|
-
api_key : str, Optional
|
|
888
|
-
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
889
|
-
base_url : str, Optional
|
|
890
|
-
the base url for the vLLM server.
|
|
891
|
-
config : LLMConfig
|
|
892
|
-
the LLM configuration.
|
|
893
|
-
"""
|
|
894
|
-
self.api_key = api_key
|
|
895
|
-
if self.api_key is None:
|
|
896
|
-
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
897
|
-
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
898
|
-
|
|
899
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
900
|
-
"""
|
|
901
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
902
|
-
|
|
903
|
-
Parameters:
|
|
904
|
-
----------
|
|
905
|
-
response : Any
|
|
906
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
907
|
-
"""
|
|
908
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
909
|
-
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
910
|
-
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
911
|
-
if chunk_text is None:
|
|
912
|
-
chunk_text = ""
|
|
913
|
-
return {"type": "reasoning", "data": chunk_text}
|
|
914
|
-
else:
|
|
915
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
916
|
-
if chunk_text is None:
|
|
917
|
-
chunk_text = ""
|
|
918
|
-
return {"type": "response", "data": chunk_text}
|
|
919
|
-
|
|
920
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
921
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
class OpenAIVLMEngine(VLMEngine):
|
|
925
|
-
def __init__(self, model:str, config:VLMConfig=None, **kwrs):
|
|
926
|
-
"""
|
|
927
|
-
The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
|
|
928
|
-
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
929
|
-
|
|
930
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
931
|
-
|
|
932
|
-
Parameters:
|
|
933
|
-
----------
|
|
934
|
-
model_name : str
|
|
935
|
-
model name as described in https://platform.openai.com/docs/models
|
|
936
|
-
config : VLMConfig, Optional
|
|
937
|
-
the VLM configuration. Must be a child class of VLMConfig.
|
|
938
|
-
"""
|
|
939
|
-
if importlib.util.find_spec("openai") is None:
|
|
940
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
941
|
-
|
|
942
|
-
from openai import OpenAI, AsyncOpenAI
|
|
943
|
-
self.client = OpenAI(**kwrs)
|
|
944
|
-
self.async_client = AsyncOpenAI(**kwrs)
|
|
945
|
-
self.model = model
|
|
946
|
-
self.config = config if config else BasicVLMConfig()
|
|
947
|
-
self.formatted_params = self._format_config()
|
|
948
|
-
|
|
949
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
950
|
-
"""
|
|
951
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
952
|
-
"""
|
|
953
|
-
formatted_params = self.config.params.copy()
|
|
954
|
-
if "max_new_tokens" in formatted_params:
|
|
955
|
-
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
956
|
-
formatted_params.pop("max_new_tokens")
|
|
957
|
-
|
|
958
|
-
return formatted_params
|
|
959
|
-
|
|
960
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
961
|
-
"""
|
|
962
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
963
|
-
|
|
964
|
-
Parameters:
|
|
965
|
-
----------
|
|
966
|
-
messages : List[Dict[str,str]]
|
|
967
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
968
|
-
verbose : bool, Optional
|
|
969
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
970
|
-
stream : bool, Optional
|
|
971
|
-
if True, returns a generator that yields the output in real-time.
|
|
972
|
-
messages_logger : MessagesLogger, Optional
|
|
973
|
-
the message logger that logs the chat messages.
|
|
974
|
-
|
|
975
|
-
Returns:
|
|
976
|
-
-------
|
|
977
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
978
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
979
|
-
"""
|
|
980
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
981
|
-
|
|
982
|
-
if stream:
|
|
983
|
-
def _stream_generator():
|
|
984
|
-
response_stream = self.client.chat.completions.create(
|
|
985
|
-
model=self.model,
|
|
986
|
-
messages=processed_messages,
|
|
987
|
-
stream=True,
|
|
988
|
-
**self.formatted_params
|
|
989
|
-
)
|
|
990
|
-
res_text = ""
|
|
991
|
-
for chunk in response_stream:
|
|
992
|
-
if len(chunk.choices) > 0:
|
|
993
|
-
chunk_text = chunk.choices[0].delta.content
|
|
994
|
-
if chunk_text is not None:
|
|
995
|
-
res_text += chunk_text
|
|
996
|
-
yield chunk_text
|
|
997
|
-
if chunk.choices[0].finish_reason == "length":
|
|
998
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
999
|
-
|
|
1000
|
-
# Postprocess response
|
|
1001
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
1002
|
-
# Write to messages log
|
|
1003
|
-
if messages_logger:
|
|
1004
|
-
# replace images content with a placeholder "[image]" to save space
|
|
1005
|
-
for messages in processed_messages:
|
|
1006
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
1007
|
-
for content in messages["content"]:
|
|
1008
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1009
|
-
content["image_url"]["url"] = "[image]"
|
|
1010
|
-
|
|
1011
|
-
processed_messages.append({"role": "assistant",
|
|
1012
|
-
"content": res_dict.get("response", ""),
|
|
1013
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1014
|
-
messages_logger.log_messages(processed_messages)
|
|
1015
|
-
|
|
1016
|
-
return self.config.postprocess_response(_stream_generator())
|
|
1017
|
-
|
|
1018
|
-
elif verbose:
|
|
1019
|
-
response = self.client.chat.completions.create(
|
|
1020
|
-
model=self.model,
|
|
1021
|
-
messages=processed_messages,
|
|
1022
|
-
stream=True,
|
|
1023
|
-
**self.formatted_params
|
|
1024
|
-
)
|
|
1025
|
-
res = ''
|
|
1026
|
-
for chunk in response:
|
|
1027
|
-
if len(chunk.choices) > 0:
|
|
1028
|
-
if chunk.choices[0].delta.content is not None:
|
|
1029
|
-
res += chunk.choices[0].delta.content
|
|
1030
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
1031
|
-
if chunk.choices[0].finish_reason == "length":
|
|
1032
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1033
|
-
|
|
1034
|
-
print('\n')
|
|
1035
|
-
|
|
1036
|
-
else:
|
|
1037
|
-
response = self.client.chat.completions.create(
|
|
1038
|
-
model=self.model,
|
|
1039
|
-
messages=processed_messages,
|
|
1040
|
-
stream=False,
|
|
1041
|
-
**self.formatted_params
|
|
1042
|
-
)
|
|
1043
|
-
res = response.choices[0].message.content
|
|
1044
|
-
|
|
1045
|
-
# Postprocess response
|
|
1046
|
-
res_dict = self.config.postprocess_response(res)
|
|
1047
|
-
# Write to messages log
|
|
1048
|
-
if messages_logger:
|
|
1049
|
-
# replace images content with a placeholder "[image]" to save space
|
|
1050
|
-
for messages in processed_messages:
|
|
1051
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
1052
|
-
for content in messages["content"]:
|
|
1053
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1054
|
-
content["image_url"]["url"] = "[image]"
|
|
1055
|
-
|
|
1056
|
-
processed_messages.append({"role": "assistant",
|
|
1057
|
-
"content": res_dict.get("response", ""),
|
|
1058
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1059
|
-
messages_logger.log_messages(processed_messages)
|
|
1060
|
-
|
|
1061
|
-
return res_dict
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
1065
|
-
"""
|
|
1066
|
-
Async version of chat method. Streaming is not supported.
|
|
1067
|
-
"""
|
|
1068
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1069
|
-
|
|
1070
|
-
response = await self.async_client.chat.completions.create(
|
|
1071
|
-
model=self.model,
|
|
1072
|
-
messages=processed_messages,
|
|
1073
|
-
stream=False,
|
|
1074
|
-
**self.formatted_params
|
|
1075
|
-
)
|
|
1076
|
-
|
|
1077
|
-
if response.choices[0].finish_reason == "length":
|
|
1078
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1079
|
-
|
|
1080
|
-
res = response.choices[0].message.content
|
|
1081
|
-
# Postprocess response
|
|
1082
|
-
res_dict = self.config.postprocess_response(res)
|
|
1083
|
-
# Write to messages log
|
|
1084
|
-
if messages_logger:
|
|
1085
|
-
# replace images content with a placeholder "[image]" to save space
|
|
1086
|
-
for messages in processed_messages:
|
|
1087
|
-
if "content" in messages and isinstance(messages["content"], list):
|
|
1088
|
-
for content in messages["content"]:
|
|
1089
|
-
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1090
|
-
content["image_url"]["url"] = "[image]"
|
|
1091
|
-
|
|
1092
|
-
processed_messages.append({"role": "assistant",
|
|
1093
|
-
"content": res_dict.get("response", ""),
|
|
1094
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1095
|
-
messages_logger.log_messages(processed_messages)
|
|
1096
|
-
|
|
1097
|
-
return res_dict
|
|
1098
|
-
|
|
1099
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
|
|
146
|
+
},
|
|
147
|
+
{"type": "text", "text": user_prompt},
|
|
148
|
+
],
|
|
149
|
+
}
|
|
150
|
+
output_messages.append(user_message)
|
|
151
|
+
return output_messages
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class VLLMVLMEngine(VLLMInferenceEngine, OpenAICompatibleVLMEngine):
|
|
155
|
+
"""
|
|
156
|
+
vLLM OpenAI compatible server inference engine.
|
|
157
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
158
|
+
|
|
159
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
160
|
+
|
|
161
|
+
Parameters:
|
|
162
|
+
----------
|
|
163
|
+
model_name : str
|
|
164
|
+
model name as shown in the vLLM server
|
|
165
|
+
api_key : str, Optional
|
|
166
|
+
the API key for the vLLM server.
|
|
167
|
+
base_url : str, Optional
|
|
168
|
+
the base url for the vLLM server.
|
|
169
|
+
config : LLMConfig
|
|
170
|
+
the LLM configuration.
|
|
171
|
+
"""
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
class OpenRouterVLMEngine(OpenRouterInferenceEngine, OpenAICompatibleVLMEngine):
|
|
175
|
+
"""
|
|
176
|
+
OpenRouter OpenAI-compatible server inference engine.
|
|
177
|
+
|
|
178
|
+
Parameters:
|
|
179
|
+
----------
|
|
180
|
+
model_name : str
|
|
181
|
+
model name as shown in the vLLM server
|
|
182
|
+
api_key : str, Optional
|
|
183
|
+
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
184
|
+
base_url : str, Optional
|
|
185
|
+
the base url for the vLLM server.
|
|
186
|
+
config : LLMConfig
|
|
187
|
+
the LLM configuration.
|
|
188
|
+
"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
class OpenAIVLMEngine(OpenAIInferenceEngine, VLMEngine):
|
|
192
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
193
|
+
detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
1100
194
|
"""
|
|
1101
195
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
1102
196
|
|
|
@@ -1112,52 +206,71 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
1112
206
|
the image format.
|
|
1113
207
|
detail : str, Optional
|
|
1114
208
|
the detail level of the image. Default is "high".
|
|
209
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
210
|
+
list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
|
|
1115
211
|
"""
|
|
1116
212
|
base64_str = image_to_base64(image)
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
213
|
+
output_messages = []
|
|
214
|
+
# system message
|
|
215
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
216
|
+
output_messages.append(system_message)
|
|
217
|
+
|
|
218
|
+
# few-shot examples
|
|
219
|
+
if few_shot_examples is not None:
|
|
220
|
+
for example in few_shot_examples:
|
|
221
|
+
if not isinstance(example, FewShotExample):
|
|
222
|
+
raise ValueError("Few-shot example must be a FewShotExample object.")
|
|
223
|
+
|
|
224
|
+
example_image_b64 = image_to_base64(example.image)
|
|
225
|
+
example_user_message = {
|
|
226
|
+
"role": "user",
|
|
227
|
+
"content": [
|
|
228
|
+
{
|
|
229
|
+
"type": "image_url",
|
|
230
|
+
"image_url": {
|
|
231
|
+
"url": f"data:image/{format};base64,{example_image_b64}",
|
|
232
|
+
"detail": detail
|
|
233
|
+
},
|
|
1127
234
|
},
|
|
235
|
+
{"type": "text", "text": user_prompt},
|
|
236
|
+
],
|
|
237
|
+
}
|
|
238
|
+
example_agent_message = {"role": "assistant", "content": example.text}
|
|
239
|
+
output_messages.append(example_user_message)
|
|
240
|
+
output_messages.append(example_agent_message)
|
|
241
|
+
|
|
242
|
+
# user message
|
|
243
|
+
user_message = {
|
|
244
|
+
"role": "user",
|
|
245
|
+
"content": [
|
|
246
|
+
{
|
|
247
|
+
"type": "image_url",
|
|
248
|
+
"image_url": {
|
|
249
|
+
"url": f"data:image/{format};base64,{base64_str}",
|
|
250
|
+
"detail": detail
|
|
1128
251
|
},
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
1156
|
-
self.model = model
|
|
1157
|
-
self.api_version = api_version
|
|
1158
|
-
self.client = AzureOpenAI(api_version=self.api_version,
|
|
1159
|
-
**kwrs)
|
|
1160
|
-
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
1161
|
-
**kwrs)
|
|
1162
|
-
self.config = config if config else BasicVLMConfig()
|
|
1163
|
-
self.formatted_params = self._format_config()
|
|
252
|
+
},
|
|
253
|
+
{"type": "text", "text": user_prompt},
|
|
254
|
+
],
|
|
255
|
+
}
|
|
256
|
+
output_messages.append(user_message)
|
|
257
|
+
return output_messages
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class AzureOpenAIVLMEngine(AzureOpenAIInferenceEngine, OpenAIVLMEngine):
|
|
261
|
+
"""
|
|
262
|
+
The Azure OpenAI API inference engine.
|
|
263
|
+
For parameters and documentation, refer to
|
|
264
|
+
- https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
|
265
|
+
- https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
|
|
266
|
+
|
|
267
|
+
Parameters:
|
|
268
|
+
----------
|
|
269
|
+
model : str
|
|
270
|
+
model name as described in https://platform.openai.com/docs/models
|
|
271
|
+
api_version : str
|
|
272
|
+
the Azure OpenAI API version
|
|
273
|
+
config : LLMConfig
|
|
274
|
+
the LLM configuration.
|
|
275
|
+
"""
|
|
276
|
+
pass
|