vlm4ocr 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +7 -1
- vlm4ocr/cli.py +73 -28
- vlm4ocr/data_types.py +57 -7
- vlm4ocr/ocr_engines.py +85 -42
- vlm4ocr/vlm_engines.py +747 -71
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.4.0.dist-info}/METADATA +1 -1
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.4.0.dist-info}/RECORD +9 -9
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.4.0.dist-info}/WHEEL +0 -0
- {vlm4ocr-0.3.0.dist-info → vlm4ocr-0.4.0.dist-info}/entry_points.txt +0 -0
vlm4ocr/vlm_engines.py
CHANGED
|
@@ -2,8 +2,11 @@ import abc
|
|
|
2
2
|
import importlib.util
|
|
3
3
|
from typing import Any, List, Dict, Union, Generator
|
|
4
4
|
import warnings
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
5
7
|
from PIL import Image
|
|
6
8
|
from vlm4ocr.utils import image_to_base64
|
|
9
|
+
from vlm4ocr.data_types import FewShotExample
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class VLMConfig(abc.ABC):
|
|
@@ -33,7 +36,7 @@ class VLMConfig(abc.ABC):
|
|
|
33
36
|
return NotImplemented
|
|
34
37
|
|
|
35
38
|
@abc.abstractmethod
|
|
36
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
39
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
37
40
|
"""
|
|
38
41
|
This method postprocesses the VLM response after it is generated.
|
|
39
42
|
|
|
@@ -77,7 +80,7 @@ class BasicVLMConfig(VLMConfig):
|
|
|
77
80
|
"""
|
|
78
81
|
return messages
|
|
79
82
|
|
|
80
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
83
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
81
84
|
"""
|
|
82
85
|
This method postprocesses the VLM response after it is generated.
|
|
83
86
|
|
|
@@ -88,19 +91,121 @@ class BasicVLMConfig(VLMConfig):
|
|
|
88
91
|
|
|
89
92
|
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
90
93
|
the postprocessed VLM response.
|
|
91
|
-
if input is a generator, the output will be a generator {"data": <content>}.
|
|
94
|
+
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
92
95
|
"""
|
|
93
96
|
if isinstance(response, str):
|
|
94
|
-
return response
|
|
97
|
+
return {"response": response}
|
|
98
|
+
|
|
99
|
+
elif isinstance(response, dict):
|
|
100
|
+
if "response" in response:
|
|
101
|
+
return response
|
|
102
|
+
else:
|
|
103
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
104
|
+
return {"response": ""}
|
|
95
105
|
|
|
96
106
|
def _process_stream():
|
|
97
107
|
for chunk in response:
|
|
98
|
-
|
|
108
|
+
if isinstance(chunk, dict):
|
|
109
|
+
yield chunk
|
|
110
|
+
elif isinstance(chunk, str):
|
|
111
|
+
yield {"type": "response", "data": chunk}
|
|
99
112
|
|
|
100
113
|
return _process_stream()
|
|
114
|
+
|
|
115
|
+
class ReasoningVLMConfig(VLMConfig):
|
|
116
|
+
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
117
|
+
"""
|
|
118
|
+
The general configuration for reasoning vision models.
|
|
119
|
+
"""
|
|
120
|
+
super().__init__(**kwargs)
|
|
121
|
+
self.thinking_token_start = thinking_token_start
|
|
122
|
+
self.thinking_token_end = thinking_token_end
|
|
123
|
+
|
|
124
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
125
|
+
"""
|
|
126
|
+
This method preprocesses the input messages before passing them to the VLM.
|
|
127
|
+
|
|
128
|
+
Parameters:
|
|
129
|
+
----------
|
|
130
|
+
messages : List[Dict[str,str]]
|
|
131
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
-------
|
|
135
|
+
messages : List[Dict[str,str]]
|
|
136
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
137
|
+
"""
|
|
138
|
+
return messages.copy()
|
|
101
139
|
|
|
140
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
141
|
+
"""
|
|
142
|
+
This method postprocesses the VLM response after it is generated.
|
|
143
|
+
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
144
|
+
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
145
|
+
3. If input is a generator,
|
|
146
|
+
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
147
|
+
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
102
148
|
|
|
103
|
-
|
|
149
|
+
Parameters:
|
|
150
|
+
----------
|
|
151
|
+
response : Union[str, Generator[str, None, None]]
|
|
152
|
+
the VLM response. Can be a string or a generator.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
-------
|
|
156
|
+
response : Union[str, Generator[str, None, None]]
|
|
157
|
+
the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
|
|
158
|
+
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
159
|
+
"""
|
|
160
|
+
if isinstance(response, str):
|
|
161
|
+
# get contents between thinking_token_start and thinking_token_end
|
|
162
|
+
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
163
|
+
match = re.search(pattern, response, re.DOTALL)
|
|
164
|
+
reasoning = match.group(1) if match else ""
|
|
165
|
+
# get response AFTER thinking_token_end
|
|
166
|
+
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
167
|
+
return {"reasoning": reasoning, "response": response}
|
|
168
|
+
|
|
169
|
+
elif isinstance(response, dict):
|
|
170
|
+
if "reasoning" in response and "response" in response:
|
|
171
|
+
return response
|
|
172
|
+
else:
|
|
173
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
174
|
+
return {"reasoning": "", "response": ""}
|
|
175
|
+
|
|
176
|
+
elif isinstance(response, Generator):
|
|
177
|
+
def _process_stream():
|
|
178
|
+
think_flag = False
|
|
179
|
+
buffer = ""
|
|
180
|
+
for chunk in response:
|
|
181
|
+
if isinstance(chunk, dict):
|
|
182
|
+
yield chunk
|
|
183
|
+
|
|
184
|
+
elif isinstance(chunk, str):
|
|
185
|
+
buffer += chunk
|
|
186
|
+
# switch between reasoning and response
|
|
187
|
+
if self.thinking_token_start in buffer:
|
|
188
|
+
think_flag = True
|
|
189
|
+
buffer = buffer.replace(self.thinking_token_start, "")
|
|
190
|
+
elif self.thinking_token_end in buffer:
|
|
191
|
+
think_flag = False
|
|
192
|
+
buffer = buffer.replace(self.thinking_token_end, "")
|
|
193
|
+
|
|
194
|
+
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
195
|
+
if chunk not in [self.thinking_token_start, self.thinking_token_end]:
|
|
196
|
+
if think_flag:
|
|
197
|
+
yield {"type": "reasoning", "data": chunk}
|
|
198
|
+
else:
|
|
199
|
+
yield {"type": "response", "data": chunk}
|
|
200
|
+
|
|
201
|
+
return _process_stream()
|
|
202
|
+
|
|
203
|
+
else:
|
|
204
|
+
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
205
|
+
return {"reasoning": "", "response": ""}
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
|
|
104
209
|
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
105
210
|
"""
|
|
106
211
|
The OpenAI "o" series configuration.
|
|
@@ -160,27 +265,31 @@ class OpenAIReasoningVLMConfig(VLMConfig):
|
|
|
160
265
|
|
|
161
266
|
return new_messages
|
|
162
267
|
|
|
163
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
164
|
-
"""
|
|
165
|
-
This method postprocesses the VLM response after it is generated.
|
|
166
268
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
response : Union[str, Generator[str, None, None]]
|
|
170
|
-
the VLM response. Can be a string or a generator.
|
|
171
|
-
|
|
172
|
-
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
173
|
-
the postprocessed VLM response.
|
|
174
|
-
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
269
|
+
class MessagesLogger:
|
|
270
|
+
def __init__(self):
|
|
175
271
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
272
|
+
This class is used to log the messages for InferenceEngine.chat().
|
|
273
|
+
"""
|
|
274
|
+
self.messages_log = []
|
|
178
275
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
276
|
+
def log_messages(self, messages : List[Dict[str,str]]):
|
|
277
|
+
"""
|
|
278
|
+
This method logs the messages to a list.
|
|
279
|
+
"""
|
|
280
|
+
self.messages_log.append(messages)
|
|
182
281
|
|
|
183
|
-
|
|
282
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
283
|
+
"""
|
|
284
|
+
This method returns a copy of the current messages log
|
|
285
|
+
"""
|
|
286
|
+
return self.messages_log.copy()
|
|
287
|
+
|
|
288
|
+
def clear_messages_log(self):
|
|
289
|
+
"""
|
|
290
|
+
This method clears the current messages log
|
|
291
|
+
"""
|
|
292
|
+
self.messages_log.clear()
|
|
184
293
|
|
|
185
294
|
|
|
186
295
|
class VLMEngine:
|
|
@@ -198,7 +307,8 @@ class VLMEngine:
|
|
|
198
307
|
return NotImplemented
|
|
199
308
|
|
|
200
309
|
@abc.abstractmethod
|
|
201
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False
|
|
310
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
311
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
202
312
|
"""
|
|
203
313
|
This method inputs chat messages and outputs VLM generated text.
|
|
204
314
|
|
|
@@ -210,18 +320,20 @@ class VLMEngine:
|
|
|
210
320
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
211
321
|
stream : bool, Optional
|
|
212
322
|
if True, returns a generator that yields the output in real-time.
|
|
323
|
+
Messages_logger : MessagesLogger, Optional
|
|
324
|
+
the message logger that logs the chat messages.
|
|
213
325
|
"""
|
|
214
326
|
return NotImplemented
|
|
215
327
|
|
|
216
328
|
@abc.abstractmethod
|
|
217
|
-
def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
329
|
+
def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str, str]:
|
|
218
330
|
"""
|
|
219
331
|
The async version of chat method. Streaming is not supported.
|
|
220
332
|
"""
|
|
221
333
|
return NotImplemented
|
|
222
334
|
|
|
223
335
|
@abc.abstractmethod
|
|
224
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
336
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
225
337
|
"""
|
|
226
338
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
227
339
|
|
|
@@ -233,6 +345,8 @@ class VLMEngine:
|
|
|
233
345
|
the user prompt.
|
|
234
346
|
image : Image.Image
|
|
235
347
|
the image for OCR.
|
|
348
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
349
|
+
list of few-shot examples.
|
|
236
350
|
"""
|
|
237
351
|
return NotImplemented
|
|
238
352
|
|
|
@@ -285,7 +399,8 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
285
399
|
|
|
286
400
|
return formatted_params
|
|
287
401
|
|
|
288
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False
|
|
402
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
403
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
289
404
|
"""
|
|
290
405
|
This method inputs chat messages and outputs VLM generated text.
|
|
291
406
|
|
|
@@ -297,6 +412,13 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
297
412
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
298
413
|
stream : bool, Optional
|
|
299
414
|
if True, returns a generator that yields the output in real-time.
|
|
415
|
+
Messages_logger : MessagesLogger, Optional
|
|
416
|
+
the message logger that logs the chat messages.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
-------
|
|
420
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
421
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
300
422
|
"""
|
|
301
423
|
processed_messages = self.config.preprocess_messages(messages)
|
|
302
424
|
|
|
@@ -310,10 +432,33 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
310
432
|
stream=True,
|
|
311
433
|
keep_alive=self.keep_alive
|
|
312
434
|
)
|
|
435
|
+
res = {"reasoning": "", "response": ""}
|
|
313
436
|
for chunk in response_stream:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
437
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
438
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
439
|
+
res["reasoning"] += content_chunk
|
|
440
|
+
yield {"type": "reasoning", "data": content_chunk}
|
|
441
|
+
else:
|
|
442
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
443
|
+
res["response"] += content_chunk
|
|
444
|
+
yield {"type": "response", "data": content_chunk}
|
|
445
|
+
|
|
446
|
+
if chunk.done_reason == "length":
|
|
447
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
448
|
+
|
|
449
|
+
# Postprocess response
|
|
450
|
+
res_dict = self.config.postprocess_response(res)
|
|
451
|
+
# Write to messages log
|
|
452
|
+
if messages_logger:
|
|
453
|
+
# replace images content with a placeholder "[image]" to save space
|
|
454
|
+
for messages in processed_messages:
|
|
455
|
+
if "images" in messages:
|
|
456
|
+
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
457
|
+
|
|
458
|
+
processed_messages.append({"role": "assistant",
|
|
459
|
+
"content": res_dict.get("response", ""),
|
|
460
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
461
|
+
messages_logger.log_messages(processed_messages)
|
|
317
462
|
|
|
318
463
|
return self.config.postprocess_response(_stream_generator())
|
|
319
464
|
|
|
@@ -326,14 +471,29 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
326
471
|
keep_alive=self.keep_alive
|
|
327
472
|
)
|
|
328
473
|
|
|
329
|
-
res =
|
|
474
|
+
res = {"reasoning": "", "response": ""}
|
|
475
|
+
phase = ""
|
|
330
476
|
for chunk in response:
|
|
331
|
-
|
|
477
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
478
|
+
if phase != "reasoning":
|
|
479
|
+
print("\n--- Reasoning ---")
|
|
480
|
+
phase = "reasoning"
|
|
481
|
+
|
|
482
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
483
|
+
res["reasoning"] += content_chunk
|
|
484
|
+
else:
|
|
485
|
+
if phase != "response":
|
|
486
|
+
print("\n--- Response ---")
|
|
487
|
+
phase = "response"
|
|
488
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
489
|
+
res["response"] += content_chunk
|
|
490
|
+
|
|
332
491
|
print(content_chunk, end='', flush=True)
|
|
333
|
-
|
|
492
|
+
|
|
493
|
+
if chunk.done_reason == "length":
|
|
494
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
334
495
|
print('\n')
|
|
335
|
-
|
|
336
|
-
|
|
496
|
+
|
|
337
497
|
else:
|
|
338
498
|
response = self.client.chat(
|
|
339
499
|
model=self.model_name,
|
|
@@ -342,11 +502,30 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
342
502
|
stream=False,
|
|
343
503
|
keep_alive=self.keep_alive
|
|
344
504
|
)
|
|
345
|
-
res = response
|
|
346
|
-
|
|
505
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
506
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
507
|
+
|
|
508
|
+
if response.done_reason == "length":
|
|
509
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
510
|
+
|
|
511
|
+
# Postprocess response
|
|
512
|
+
res_dict = self.config.postprocess_response(res)
|
|
513
|
+
# Write to messages log
|
|
514
|
+
if messages_logger:
|
|
515
|
+
# replace images content with a placeholder "[image]" to save space
|
|
516
|
+
for messages in processed_messages:
|
|
517
|
+
if "images" in messages:
|
|
518
|
+
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
519
|
+
|
|
520
|
+
processed_messages.append({"role": "assistant",
|
|
521
|
+
"content": res_dict.get("response", ""),
|
|
522
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
523
|
+
messages_logger.log_messages(processed_messages)
|
|
524
|
+
|
|
525
|
+
return res_dict
|
|
347
526
|
|
|
348
527
|
|
|
349
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
528
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
350
529
|
"""
|
|
351
530
|
Async version of chat method. Streaming is not supported.
|
|
352
531
|
"""
|
|
@@ -360,10 +539,28 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
360
539
|
keep_alive=self.keep_alive
|
|
361
540
|
)
|
|
362
541
|
|
|
363
|
-
res = response
|
|
364
|
-
|
|
542
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
543
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
544
|
+
|
|
545
|
+
if response.done_reason == "length":
|
|
546
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
547
|
+
# Postprocess response
|
|
548
|
+
res_dict = self.config.postprocess_response(res)
|
|
549
|
+
# Write to messages log
|
|
550
|
+
if messages_logger:
|
|
551
|
+
# replace images content with a placeholder "[image]" to save space
|
|
552
|
+
for messages in processed_messages:
|
|
553
|
+
if "images" in messages:
|
|
554
|
+
messages["images"] = ["[image]" for _ in messages["images"]]
|
|
555
|
+
|
|
556
|
+
processed_messages.append({"role": "assistant",
|
|
557
|
+
"content": res_dict.get("response", ""),
|
|
558
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
559
|
+
messages_logger.log_messages(processed_messages)
|
|
560
|
+
|
|
561
|
+
return res_dict
|
|
365
562
|
|
|
366
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
563
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
367
564
|
"""
|
|
368
565
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
369
566
|
|
|
@@ -375,16 +572,404 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
375
572
|
the user prompt.
|
|
376
573
|
image : Image.Image
|
|
377
574
|
the image for OCR.
|
|
575
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
576
|
+
list of few-shot examples.
|
|
378
577
|
"""
|
|
379
578
|
base64_str = image_to_base64(image)
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
579
|
+
output_messages = []
|
|
580
|
+
# system message
|
|
581
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
582
|
+
output_messages.append(system_message)
|
|
583
|
+
|
|
584
|
+
# few-shot examples
|
|
585
|
+
if few_shot_examples is not None:
|
|
586
|
+
for example in few_shot_examples:
|
|
587
|
+
if not isinstance(example, FewShotExample):
|
|
588
|
+
raise ValueError("Few-shot example must be a FewShotExample object.")
|
|
589
|
+
|
|
590
|
+
example_image_b64 = image_to_base64(example.image)
|
|
591
|
+
example_user_message = {"role": "user", "content": user_prompt, "images": [example_image_b64]}
|
|
592
|
+
example_agent_message = {"role": "assistant", "content": example.text}
|
|
593
|
+
output_messages.append(example_user_message)
|
|
594
|
+
output_messages.append(example_agent_message)
|
|
595
|
+
|
|
596
|
+
# user message
|
|
597
|
+
user_message = {"role": "user", "content": user_prompt, "images": [base64_str]}
|
|
598
|
+
output_messages.append(user_message)
|
|
599
|
+
|
|
600
|
+
return output_messages
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class OpenAICompatibleVLMEngine(VLMEngine):
|
|
604
|
+
def __init__(self, model:str, api_key:str, base_url:str, config:VLMConfig=None, **kwrs):
|
|
605
|
+
"""
|
|
606
|
+
General OpenAI-compatible server inference engine.
|
|
607
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
608
|
+
|
|
609
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
610
|
+
|
|
611
|
+
Parameters:
|
|
612
|
+
----------
|
|
613
|
+
model_name : str
|
|
614
|
+
model name as shown in the vLLM server
|
|
615
|
+
api_key : str
|
|
616
|
+
the API key for the vLLM server.
|
|
617
|
+
base_url : str
|
|
618
|
+
the base url for the vLLM server.
|
|
619
|
+
config : LLMConfig
|
|
620
|
+
the LLM configuration.
|
|
621
|
+
"""
|
|
622
|
+
if importlib.util.find_spec("openai") is None:
|
|
623
|
+
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
624
|
+
|
|
625
|
+
from openai import OpenAI, AsyncOpenAI
|
|
626
|
+
from openai.types.chat import ChatCompletionChunk
|
|
627
|
+
self.ChatCompletionChunk = ChatCompletionChunk
|
|
628
|
+
super().__init__(config)
|
|
629
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
630
|
+
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
631
|
+
self.model = model
|
|
632
|
+
self.config = config if config else BasicVLMConfig()
|
|
633
|
+
self.formatted_params = self._format_config()
|
|
634
|
+
|
|
635
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
636
|
+
"""
|
|
637
|
+
This method format the VLM configuration with the correct key for the inference engine.
|
|
638
|
+
"""
|
|
639
|
+
formatted_params = self.config.params.copy()
|
|
640
|
+
if "max_new_tokens" in formatted_params:
|
|
641
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
642
|
+
formatted_params.pop("max_new_tokens")
|
|
643
|
+
|
|
644
|
+
return formatted_params
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
648
|
+
"""
|
|
649
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
650
|
+
|
|
651
|
+
Parameters:
|
|
652
|
+
----------
|
|
653
|
+
response : Any
|
|
654
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
655
|
+
"""
|
|
656
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
657
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
658
|
+
if chunk_text is None:
|
|
659
|
+
chunk_text = ""
|
|
660
|
+
return {"type": "response", "data": chunk_text}
|
|
661
|
+
|
|
662
|
+
return {"response": getattr(response.choices[0].message, "content", "")}
|
|
663
|
+
|
|
664
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
665
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
666
|
+
"""
|
|
667
|
+
This method inputs chat messages and outputs LLM generated text.
|
|
668
|
+
|
|
669
|
+
Parameters:
|
|
670
|
+
----------
|
|
671
|
+
messages : List[Dict[str,str]]
|
|
672
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
673
|
+
verbose : bool, Optional
|
|
674
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
675
|
+
stream : bool, Optional
|
|
676
|
+
if True, returns a generator that yields the output in real-time.
|
|
677
|
+
messages_logger : MessagesLogger, Optional
|
|
678
|
+
the message logger that logs the chat messages.
|
|
679
|
+
|
|
680
|
+
Returns:
|
|
681
|
+
-------
|
|
682
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
683
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
684
|
+
"""
|
|
685
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
686
|
+
|
|
687
|
+
if stream:
|
|
688
|
+
def _stream_generator():
|
|
689
|
+
response_stream = self.client.chat.completions.create(
|
|
690
|
+
model=self.model,
|
|
691
|
+
messages=processed_messages,
|
|
692
|
+
stream=True,
|
|
693
|
+
**self.formatted_params
|
|
694
|
+
)
|
|
695
|
+
res_text = ""
|
|
696
|
+
for chunk in response_stream:
|
|
697
|
+
if len(chunk.choices) > 0:
|
|
698
|
+
chunk_dict = self._format_response(chunk)
|
|
699
|
+
yield chunk_dict
|
|
700
|
+
|
|
701
|
+
res_text += chunk_dict["data"]
|
|
702
|
+
if chunk.choices[0].finish_reason == "length":
|
|
703
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
704
|
+
|
|
705
|
+
# Postprocess response
|
|
706
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
707
|
+
# Write to messages log
|
|
708
|
+
if messages_logger:
|
|
709
|
+
# replace images content with a placeholder "[image]" to save space
|
|
710
|
+
for messages in processed_messages:
|
|
711
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
712
|
+
for content in messages["content"]:
|
|
713
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
714
|
+
content["image_url"]["url"] = "[image]"
|
|
715
|
+
|
|
716
|
+
processed_messages.append({"role": "assistant",
|
|
717
|
+
"content": res_dict.get("response", ""),
|
|
718
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
719
|
+
messages_logger.log_messages(processed_messages)
|
|
720
|
+
|
|
721
|
+
return self.config.postprocess_response(_stream_generator())
|
|
722
|
+
|
|
723
|
+
elif verbose:
|
|
724
|
+
response = self.client.chat.completions.create(
|
|
725
|
+
model=self.model,
|
|
726
|
+
messages=processed_messages,
|
|
727
|
+
stream=True,
|
|
728
|
+
**self.formatted_params
|
|
729
|
+
)
|
|
730
|
+
res = {"reasoning": "", "response": ""}
|
|
731
|
+
phase = ""
|
|
732
|
+
for chunk in response:
|
|
733
|
+
if len(chunk.choices) > 0:
|
|
734
|
+
chunk_dict = self._format_response(chunk)
|
|
735
|
+
chunk_text = chunk_dict["data"]
|
|
736
|
+
res[chunk_dict["type"]] += chunk_text
|
|
737
|
+
if phase != chunk_dict["type"] and chunk_text != "":
|
|
738
|
+
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
739
|
+
phase = chunk_dict["type"]
|
|
740
|
+
|
|
741
|
+
print(chunk_text, end="", flush=True)
|
|
742
|
+
if chunk.choices[0].finish_reason == "length":
|
|
743
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
744
|
+
|
|
745
|
+
print('\n')
|
|
746
|
+
|
|
747
|
+
else:
|
|
748
|
+
response = self.client.chat.completions.create(
|
|
749
|
+
model=self.model,
|
|
750
|
+
messages=processed_messages,
|
|
751
|
+
stream=False,
|
|
752
|
+
**self.formatted_params
|
|
753
|
+
)
|
|
754
|
+
res = self._format_response(response)
|
|
755
|
+
|
|
756
|
+
if response.choices[0].finish_reason == "length":
|
|
757
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
758
|
+
|
|
759
|
+
# Postprocess response
|
|
760
|
+
res_dict = self.config.postprocess_response(res)
|
|
761
|
+
# Write to messages log
|
|
762
|
+
if messages_logger:
|
|
763
|
+
# replace images content with a placeholder "[image]" to save space
|
|
764
|
+
for messages in processed_messages:
|
|
765
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
766
|
+
for content in messages["content"]:
|
|
767
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
768
|
+
content["image_url"]["url"] = "[image]"
|
|
769
|
+
|
|
770
|
+
processed_messages.append({"role": "assistant",
|
|
771
|
+
"content": res_dict.get("response", ""),
|
|
772
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
773
|
+
messages_logger.log_messages(processed_messages)
|
|
774
|
+
|
|
775
|
+
return res_dict
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
779
|
+
"""
|
|
780
|
+
Async version of chat method. Streaming is not supported.
|
|
781
|
+
"""
|
|
782
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
783
|
+
|
|
784
|
+
response = await self.async_client.chat.completions.create(
|
|
785
|
+
model=self.model,
|
|
786
|
+
messages=processed_messages,
|
|
787
|
+
stream=False,
|
|
788
|
+
**self.formatted_params
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
if response.choices[0].finish_reason == "length":
|
|
792
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
793
|
+
|
|
794
|
+
res = self._format_response(response)
|
|
795
|
+
|
|
796
|
+
# Postprocess response
|
|
797
|
+
res_dict = self.config.postprocess_response(res)
|
|
798
|
+
# Write to messages log
|
|
799
|
+
if messages_logger:
|
|
800
|
+
# replace images content with a placeholder "[image]" to save space
|
|
801
|
+
for messages in processed_messages:
|
|
802
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
803
|
+
for content in messages["content"]:
|
|
804
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
805
|
+
content["image_url"]["url"] = "[image]"
|
|
806
|
+
|
|
807
|
+
processed_messages.append({"role": "assistant",
|
|
808
|
+
"content": res_dict.get("response", ""),
|
|
809
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
810
|
+
messages_logger.log_messages(processed_messages)
|
|
811
|
+
|
|
812
|
+
return res_dict
|
|
813
|
+
|
|
814
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
815
|
+
detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
816
|
+
"""
|
|
817
|
+
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
818
|
+
|
|
819
|
+
Parameters:
|
|
820
|
+
----------
|
|
821
|
+
system_prompt : str
|
|
822
|
+
the system prompt.
|
|
823
|
+
user_prompt : str
|
|
824
|
+
the user prompt.
|
|
825
|
+
image : Image.Image
|
|
826
|
+
the image for OCR.
|
|
827
|
+
format : str, Optional
|
|
828
|
+
the image format.
|
|
829
|
+
detail : str, Optional
|
|
830
|
+
the detail level of the image. Default is "high".
|
|
831
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
832
|
+
list of few-shot examples.
|
|
833
|
+
"""
|
|
834
|
+
base64_str = image_to_base64(image)
|
|
835
|
+
output_messages = []
|
|
836
|
+
# system message
|
|
837
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
838
|
+
output_messages.append(system_message)
|
|
839
|
+
|
|
840
|
+
# few-shot examples
|
|
841
|
+
if few_shot_examples is not None:
|
|
842
|
+
for example in few_shot_examples:
|
|
843
|
+
if not isinstance(example, FewShotExample):
|
|
844
|
+
raise ValueError("Few-shot example must be a FewShotExample object.")
|
|
845
|
+
|
|
846
|
+
example_image_b64 = image_to_base64(example.image)
|
|
847
|
+
example_user_message = {
|
|
848
|
+
"role": "user",
|
|
849
|
+
"content": [
|
|
850
|
+
{
|
|
851
|
+
"type": "image_url",
|
|
852
|
+
"image_url": {
|
|
853
|
+
"url": f"data:image/{format};base64,{example_image_b64}",
|
|
854
|
+
"detail": detail
|
|
855
|
+
},
|
|
856
|
+
},
|
|
857
|
+
{"type": "text", "text": user_prompt},
|
|
858
|
+
],
|
|
859
|
+
}
|
|
860
|
+
example_agent_message = {"role": "assistant", "content": example.text}
|
|
861
|
+
output_messages.append(example_user_message)
|
|
862
|
+
output_messages.append(example_agent_message)
|
|
863
|
+
|
|
864
|
+
# user message
|
|
865
|
+
user_message = {
|
|
866
|
+
"role": "user",
|
|
867
|
+
"content": [
|
|
868
|
+
{
|
|
869
|
+
"type": "image_url",
|
|
870
|
+
"image_url": {
|
|
871
|
+
"url": f"data:image/{format};base64,{base64_str}",
|
|
872
|
+
"detail": detail
|
|
873
|
+
},
|
|
874
|
+
},
|
|
875
|
+
{"type": "text", "text": user_prompt},
|
|
876
|
+
],
|
|
877
|
+
}
|
|
878
|
+
output_messages.append(user_message)
|
|
879
|
+
return output_messages
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
class VLLMVLMEngine(OpenAICompatibleVLMEngine):
|
|
883
|
+
def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:VLMConfig=None, **kwrs):
|
|
884
|
+
"""
|
|
885
|
+
vLLM OpenAI compatible server inference engine.
|
|
886
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
887
|
+
|
|
888
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
889
|
+
|
|
890
|
+
Parameters:
|
|
891
|
+
----------
|
|
892
|
+
model_name : str
|
|
893
|
+
model name as shown in the vLLM server
|
|
894
|
+
api_key : str, Optional
|
|
895
|
+
the API key for the vLLM server.
|
|
896
|
+
base_url : str, Optional
|
|
897
|
+
the base url for the vLLM server.
|
|
898
|
+
config : LLMConfig
|
|
899
|
+
the LLM configuration.
|
|
900
|
+
"""
|
|
901
|
+
super().__init__(model, api_key, base_url, config, **kwrs)
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
905
|
+
"""
|
|
906
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
907
|
+
|
|
908
|
+
Parameters:
|
|
909
|
+
----------
|
|
910
|
+
response : Any
|
|
911
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
912
|
+
"""
|
|
913
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
914
|
+
if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
|
|
915
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
|
|
916
|
+
if chunk_text is None:
|
|
917
|
+
chunk_text = ""
|
|
918
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
919
|
+
else:
|
|
920
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
921
|
+
if chunk_text is None:
|
|
922
|
+
chunk_text = ""
|
|
923
|
+
return {"type": "response", "data": chunk_text}
|
|
924
|
+
|
|
925
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
926
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
|
|
930
|
+
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
|
|
931
|
+
"""
|
|
932
|
+
OpenRouter OpenAI-compatible server inference engine.
|
|
933
|
+
|
|
934
|
+
Parameters:
|
|
935
|
+
----------
|
|
936
|
+
model_name : str
|
|
937
|
+
model name as shown in the vLLM server
|
|
938
|
+
api_key : str, Optional
|
|
939
|
+
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
940
|
+
base_url : str, Optional
|
|
941
|
+
the base url for the vLLM server.
|
|
942
|
+
config : LLMConfig
|
|
943
|
+
the LLM configuration.
|
|
944
|
+
"""
|
|
945
|
+
self.api_key = api_key
|
|
946
|
+
if self.api_key is None:
|
|
947
|
+
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
948
|
+
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
949
|
+
|
|
950
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
951
|
+
"""
|
|
952
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
953
|
+
|
|
954
|
+
Parameters:
|
|
955
|
+
----------
|
|
956
|
+
response : Any
|
|
957
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
958
|
+
"""
|
|
959
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
960
|
+
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
961
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
962
|
+
if chunk_text is None:
|
|
963
|
+
chunk_text = ""
|
|
964
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
965
|
+
else:
|
|
966
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
967
|
+
if chunk_text is None:
|
|
968
|
+
chunk_text = ""
|
|
969
|
+
return {"type": "response", "data": chunk_text}
|
|
970
|
+
|
|
971
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
972
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
388
973
|
|
|
389
974
|
|
|
390
975
|
class OpenAIVLMEngine(VLMEngine):
|
|
@@ -423,7 +1008,7 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
423
1008
|
|
|
424
1009
|
return formatted_params
|
|
425
1010
|
|
|
426
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
|
|
1011
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
427
1012
|
"""
|
|
428
1013
|
This method inputs chat messages and outputs LLM generated text.
|
|
429
1014
|
|
|
@@ -435,6 +1020,13 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
435
1020
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
436
1021
|
stream : bool, Optional
|
|
437
1022
|
if True, returns a generator that yields the output in real-time.
|
|
1023
|
+
messages_logger : MessagesLogger, Optional
|
|
1024
|
+
the message logger that logs the chat messages.
|
|
1025
|
+
|
|
1026
|
+
Returns:
|
|
1027
|
+
-------
|
|
1028
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
1029
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
438
1030
|
"""
|
|
439
1031
|
processed_messages = self.config.preprocess_messages(messages)
|
|
440
1032
|
|
|
@@ -446,13 +1038,32 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
446
1038
|
stream=True,
|
|
447
1039
|
**self.formatted_params
|
|
448
1040
|
)
|
|
1041
|
+
res_text = ""
|
|
449
1042
|
for chunk in response_stream:
|
|
450
1043
|
if len(chunk.choices) > 0:
|
|
451
|
-
|
|
452
|
-
|
|
1044
|
+
chunk_text = chunk.choices[0].delta.content
|
|
1045
|
+
if chunk_text is not None:
|
|
1046
|
+
res_text += chunk_text
|
|
1047
|
+
yield chunk_text
|
|
453
1048
|
if chunk.choices[0].finish_reason == "length":
|
|
454
1049
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
455
1050
|
|
|
1051
|
+
# Postprocess response
|
|
1052
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
1053
|
+
# Write to messages log
|
|
1054
|
+
if messages_logger:
|
|
1055
|
+
# replace images content with a placeholder "[image]" to save space
|
|
1056
|
+
for messages in processed_messages:
|
|
1057
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
1058
|
+
for content in messages["content"]:
|
|
1059
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1060
|
+
content["image_url"]["url"] = "[image]"
|
|
1061
|
+
|
|
1062
|
+
processed_messages.append({"role": "assistant",
|
|
1063
|
+
"content": res_dict.get("response", ""),
|
|
1064
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1065
|
+
messages_logger.log_messages(processed_messages)
|
|
1066
|
+
|
|
456
1067
|
return self.config.postprocess_response(_stream_generator())
|
|
457
1068
|
|
|
458
1069
|
elif verbose:
|
|
@@ -472,7 +1083,7 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
472
1083
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
473
1084
|
|
|
474
1085
|
print('\n')
|
|
475
|
-
|
|
1086
|
+
|
|
476
1087
|
else:
|
|
477
1088
|
response = self.client.chat.completions.create(
|
|
478
1089
|
model=self.model,
|
|
@@ -481,10 +1092,27 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
481
1092
|
**self.formatted_params
|
|
482
1093
|
)
|
|
483
1094
|
res = response.choices[0].message.content
|
|
484
|
-
|
|
1095
|
+
|
|
1096
|
+
# Postprocess response
|
|
1097
|
+
res_dict = self.config.postprocess_response(res)
|
|
1098
|
+
# Write to messages log
|
|
1099
|
+
if messages_logger:
|
|
1100
|
+
# replace images content with a placeholder "[image]" to save space
|
|
1101
|
+
for messages in processed_messages:
|
|
1102
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
1103
|
+
for content in messages["content"]:
|
|
1104
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1105
|
+
content["image_url"]["url"] = "[image]"
|
|
1106
|
+
|
|
1107
|
+
processed_messages.append({"role": "assistant",
|
|
1108
|
+
"content": res_dict.get("response", ""),
|
|
1109
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1110
|
+
messages_logger.log_messages(processed_messages)
|
|
1111
|
+
|
|
1112
|
+
return res_dict
|
|
485
1113
|
|
|
486
1114
|
|
|
487
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
1115
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
488
1116
|
"""
|
|
489
1117
|
Async version of chat method. Streaming is not supported.
|
|
490
1118
|
"""
|
|
@@ -501,9 +1129,26 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
501
1129
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
502
1130
|
|
|
503
1131
|
res = response.choices[0].message.content
|
|
504
|
-
|
|
1132
|
+
# Postprocess response
|
|
1133
|
+
res_dict = self.config.postprocess_response(res)
|
|
1134
|
+
# Write to messages log
|
|
1135
|
+
if messages_logger:
|
|
1136
|
+
# replace images content with a placeholder "[image]" to save space
|
|
1137
|
+
for messages in processed_messages:
|
|
1138
|
+
if "content" in messages and isinstance(messages["content"], list):
|
|
1139
|
+
for content in messages["content"]:
|
|
1140
|
+
if isinstance(content, dict) and content.get("type") == "image_url":
|
|
1141
|
+
content["image_url"]["url"] = "[image]"
|
|
1142
|
+
|
|
1143
|
+
processed_messages.append({"role": "assistant",
|
|
1144
|
+
"content": res_dict.get("response", ""),
|
|
1145
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1146
|
+
messages_logger.log_messages(processed_messages)
|
|
1147
|
+
|
|
1148
|
+
return res_dict
|
|
505
1149
|
|
|
506
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
1150
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
|
|
1151
|
+
detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
|
|
507
1152
|
"""
|
|
508
1153
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
509
1154
|
|
|
@@ -519,24 +1164,55 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
519
1164
|
the image format.
|
|
520
1165
|
detail : str, Optional
|
|
521
1166
|
the detail level of the image. Default is "high".
|
|
1167
|
+
few_shot_examples : List[FewShotExample], Optional
|
|
1168
|
+
list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
|
|
522
1169
|
"""
|
|
523
1170
|
base64_str = image_to_base64(image)
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
1171
|
+
output_messages = []
|
|
1172
|
+
# system message
|
|
1173
|
+
system_message = {"role": "system", "content": system_prompt}
|
|
1174
|
+
output_messages.append(system_message)
|
|
1175
|
+
|
|
1176
|
+
# few-shot examples
|
|
1177
|
+
if few_shot_examples is not None:
|
|
1178
|
+
for example in few_shot_examples:
|
|
1179
|
+
if not isinstance(example, FewShotExample):
|
|
1180
|
+
raise ValueError("Few-shot example must be a FewShotExample object.")
|
|
1181
|
+
|
|
1182
|
+
example_image_b64 = image_to_base64(example.image)
|
|
1183
|
+
example_user_message = {
|
|
1184
|
+
"role": "user",
|
|
1185
|
+
"content": [
|
|
1186
|
+
{
|
|
1187
|
+
"type": "image_url",
|
|
1188
|
+
"image_url": {
|
|
1189
|
+
"url": f"data:image/{format};base64,{example_image_b64}",
|
|
1190
|
+
"detail": detail
|
|
1191
|
+
},
|
|
534
1192
|
},
|
|
1193
|
+
{"type": "text", "text": user_prompt},
|
|
1194
|
+
],
|
|
1195
|
+
}
|
|
1196
|
+
example_agent_message = {"role": "assistant", "content": example.text}
|
|
1197
|
+
output_messages.append(example_user_message)
|
|
1198
|
+
output_messages.append(example_agent_message)
|
|
1199
|
+
|
|
1200
|
+
# user message
|
|
1201
|
+
user_message = {
|
|
1202
|
+
"role": "user",
|
|
1203
|
+
"content": [
|
|
1204
|
+
{
|
|
1205
|
+
"type": "image_url",
|
|
1206
|
+
"image_url": {
|
|
1207
|
+
"url": f"data:image/{format};base64,{base64_str}",
|
|
1208
|
+
"detail": detail
|
|
535
1209
|
},
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
1210
|
+
},
|
|
1211
|
+
{"type": "text", "text": user_prompt},
|
|
1212
|
+
],
|
|
1213
|
+
}
|
|
1214
|
+
output_messages.append(user_message)
|
|
1215
|
+
return output_messages
|
|
540
1216
|
|
|
541
1217
|
|
|
542
1218
|
class AzureOpenAIVLMEngine(OpenAIVLMEngine):
|