vlm4ocr 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlm4ocr/__init__.py +3 -1
- vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt +1 -0
- vlm4ocr/cli.py +276 -287
- vlm4ocr/data_types.py +109 -0
- vlm4ocr/ocr_engines.py +363 -195
- vlm4ocr/utils.py +386 -39
- vlm4ocr/vlm_engines.py +316 -190
- {vlm4ocr-0.1.0.dist-info → vlm4ocr-0.3.0.dist-info}/METADATA +5 -1
- vlm4ocr-0.3.0.dist-info/RECORD +17 -0
- vlm4ocr-0.1.0.dist-info/RECORD +0 -15
- {vlm4ocr-0.1.0.dist-info → vlm4ocr-0.3.0.dist-info}/WHEEL +0 -0
- {vlm4ocr-0.1.0.dist-info → vlm4ocr-0.3.0.dist-info}/entry_points.txt +0 -0
vlm4ocr/vlm_engines.py
CHANGED
|
@@ -1,23 +1,204 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import importlib.util
|
|
3
|
-
from typing import List, Dict, Union, Generator
|
|
3
|
+
from typing import Any, List, Dict, Union, Generator
|
|
4
4
|
import warnings
|
|
5
5
|
from PIL import Image
|
|
6
6
|
from vlm4ocr.utils import image_to_base64
|
|
7
7
|
|
|
8
8
|
|
|
9
|
+
class VLMConfig(abc.ABC):
|
|
10
|
+
def __init__(self, **kwargs):
|
|
11
|
+
"""
|
|
12
|
+
This is an abstract class to provide interfaces for VLM configuration.
|
|
13
|
+
Children classes that inherts this class can be used in extrators and prompt editor.
|
|
14
|
+
Common VLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
|
|
15
|
+
"""
|
|
16
|
+
self.params = kwargs.copy()
|
|
17
|
+
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
20
|
+
"""
|
|
21
|
+
This method preprocesses the input messages before passing them to the VLM.
|
|
22
|
+
|
|
23
|
+
Parameters:
|
|
24
|
+
----------
|
|
25
|
+
messages : List[Dict[str,str]]
|
|
26
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
-------
|
|
30
|
+
messages : List[Dict[str,str]]
|
|
31
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
32
|
+
"""
|
|
33
|
+
return NotImplemented
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
37
|
+
"""
|
|
38
|
+
This method postprocesses the VLM response after it is generated.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
----------
|
|
42
|
+
response : Union[str, Generator[str, None, None]]
|
|
43
|
+
the VLM response. Can be a string or a generator.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
-------
|
|
47
|
+
response : str
|
|
48
|
+
the postprocessed VLM response
|
|
49
|
+
"""
|
|
50
|
+
return NotImplemented
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class BasicVLMConfig(VLMConfig):
|
|
54
|
+
def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
|
|
55
|
+
"""
|
|
56
|
+
The basic VLM configuration for most non-reasoning models.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
self.max_new_tokens = max_new_tokens
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
self.params["max_new_tokens"] = self.max_new_tokens
|
|
62
|
+
self.params["temperature"] = self.temperature
|
|
63
|
+
|
|
64
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
65
|
+
"""
|
|
66
|
+
This method preprocesses the input messages before passing them to the VLM.
|
|
67
|
+
|
|
68
|
+
Parameters:
|
|
69
|
+
----------
|
|
70
|
+
messages : List[Dict[str,str]]
|
|
71
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
-------
|
|
75
|
+
messages : List[Dict[str,str]]
|
|
76
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
77
|
+
"""
|
|
78
|
+
return messages
|
|
79
|
+
|
|
80
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
81
|
+
"""
|
|
82
|
+
This method postprocesses the VLM response after it is generated.
|
|
83
|
+
|
|
84
|
+
Parameters:
|
|
85
|
+
----------
|
|
86
|
+
response : Union[str, Generator[str, None, None]]
|
|
87
|
+
the VLM response. Can be a string or a generator.
|
|
88
|
+
|
|
89
|
+
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
90
|
+
the postprocessed VLM response.
|
|
91
|
+
if input is a generator, the output will be a generator {"data": <content>}.
|
|
92
|
+
"""
|
|
93
|
+
if isinstance(response, str):
|
|
94
|
+
return response
|
|
95
|
+
|
|
96
|
+
def _process_stream():
|
|
97
|
+
for chunk in response:
|
|
98
|
+
yield chunk
|
|
99
|
+
|
|
100
|
+
return _process_stream()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class OpenAIReasoningVLMConfig(VLMConfig):
|
|
104
|
+
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
105
|
+
"""
|
|
106
|
+
The OpenAI "o" series configuration.
|
|
107
|
+
1. The reasoning effort is set to "low" by default.
|
|
108
|
+
2. The temperature parameter is not supported and will be ignored.
|
|
109
|
+
3. The system prompt is not supported and will be concatenated to the next user prompt.
|
|
110
|
+
|
|
111
|
+
Parameters:
|
|
112
|
+
----------
|
|
113
|
+
reasoning_effort : str, Optional
|
|
114
|
+
the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
|
|
115
|
+
"""
|
|
116
|
+
super().__init__(**kwargs)
|
|
117
|
+
if reasoning_effort not in ["low", "medium", "high"]:
|
|
118
|
+
raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
|
|
119
|
+
|
|
120
|
+
self.reasoning_effort = reasoning_effort
|
|
121
|
+
self.params["reasoning_effort"] = self.reasoning_effort
|
|
122
|
+
|
|
123
|
+
if "temperature" in self.params:
|
|
124
|
+
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
125
|
+
self.params.pop("temperature")
|
|
126
|
+
|
|
127
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
128
|
+
"""
|
|
129
|
+
Concatenate system prompts to the next user prompt.
|
|
130
|
+
|
|
131
|
+
Parameters:
|
|
132
|
+
----------
|
|
133
|
+
messages : List[Dict[str,str]]
|
|
134
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
-------
|
|
138
|
+
messages : List[Dict[str,str]]
|
|
139
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
140
|
+
"""
|
|
141
|
+
system_prompt_holder = ""
|
|
142
|
+
new_messages = []
|
|
143
|
+
for i, message in enumerate(messages):
|
|
144
|
+
# if system prompt, store it in system_prompt_holder
|
|
145
|
+
if message['role'] == 'system':
|
|
146
|
+
system_prompt_holder = message['content']
|
|
147
|
+
# if user prompt, concatenate it with system_prompt_holder
|
|
148
|
+
elif message['role'] == 'user':
|
|
149
|
+
if system_prompt_holder:
|
|
150
|
+
new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
|
|
151
|
+
system_prompt_holder = ""
|
|
152
|
+
else:
|
|
153
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
154
|
+
|
|
155
|
+
new_messages.append(new_message)
|
|
156
|
+
# if assistant/other prompt, do nothing
|
|
157
|
+
else:
|
|
158
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
159
|
+
new_messages.append(new_message)
|
|
160
|
+
|
|
161
|
+
return new_messages
|
|
162
|
+
|
|
163
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
164
|
+
"""
|
|
165
|
+
This method postprocesses the VLM response after it is generated.
|
|
166
|
+
|
|
167
|
+
Parameters:
|
|
168
|
+
----------
|
|
169
|
+
response : Union[str, Generator[str, None, None]]
|
|
170
|
+
the VLM response. Can be a string or a generator.
|
|
171
|
+
|
|
172
|
+
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
173
|
+
the postprocessed VLM response.
|
|
174
|
+
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
175
|
+
"""
|
|
176
|
+
if isinstance(response, str):
|
|
177
|
+
return response
|
|
178
|
+
|
|
179
|
+
def _process_stream():
|
|
180
|
+
for chunk in response:
|
|
181
|
+
yield {"type": "response", "data": chunk}
|
|
182
|
+
|
|
183
|
+
return _process_stream()
|
|
184
|
+
|
|
185
|
+
|
|
9
186
|
class VLMEngine:
|
|
10
187
|
@abc.abstractmethod
|
|
11
|
-
def __init__(self):
|
|
188
|
+
def __init__(self, config:VLMConfig, **kwrs):
|
|
12
189
|
"""
|
|
13
190
|
This is an abstract class to provide interfaces for VLM inference engines.
|
|
14
191
|
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
192
|
+
|
|
193
|
+
Parameters:
|
|
194
|
+
----------
|
|
195
|
+
config : VLMConfig
|
|
196
|
+
the VLM configuration. Must be a child class of VLMConfig.
|
|
15
197
|
"""
|
|
16
198
|
return NotImplemented
|
|
17
199
|
|
|
18
200
|
@abc.abstractmethod
|
|
19
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
20
|
-
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
201
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
|
|
21
202
|
"""
|
|
22
203
|
This method inputs chat messages and outputs VLM generated text.
|
|
23
204
|
|
|
@@ -25,10 +206,6 @@ class VLMEngine:
|
|
|
25
206
|
----------
|
|
26
207
|
messages : List[Dict[str,str]]
|
|
27
208
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
28
|
-
max_new_tokens : str, Optional
|
|
29
|
-
the max number of new tokens VLM can generate.
|
|
30
|
-
temperature : float, Optional
|
|
31
|
-
the temperature for token sampling.
|
|
32
209
|
verbose : bool, Optional
|
|
33
210
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
34
211
|
stream : bool, Optional
|
|
@@ -37,14 +214,14 @@ class VLMEngine:
|
|
|
37
214
|
return NotImplemented
|
|
38
215
|
|
|
39
216
|
@abc.abstractmethod
|
|
40
|
-
def chat_async(self, messages:List[Dict[str,str]]
|
|
217
|
+
def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
41
218
|
"""
|
|
42
219
|
The async version of chat method. Streaming is not supported.
|
|
43
220
|
"""
|
|
44
221
|
return NotImplemented
|
|
45
222
|
|
|
46
223
|
@abc.abstractmethod
|
|
47
|
-
def get_ocr_messages(self, system_prompt:str, user_prompt:str,
|
|
224
|
+
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
48
225
|
"""
|
|
49
226
|
This method inputs an image and returns the correesponding chat messages for the inference engine.
|
|
50
227
|
|
|
@@ -54,14 +231,23 @@ class VLMEngine:
|
|
|
54
231
|
the system prompt.
|
|
55
232
|
user_prompt : str
|
|
56
233
|
the user prompt.
|
|
57
|
-
|
|
58
|
-
the image
|
|
234
|
+
image : Image.Image
|
|
235
|
+
the image for OCR.
|
|
236
|
+
"""
|
|
237
|
+
return NotImplemented
|
|
238
|
+
|
|
239
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
240
|
+
"""
|
|
241
|
+
This method format the VLM configuration with the correct key for the inference engine.
|
|
242
|
+
|
|
243
|
+
Return : Dict[str, Any]
|
|
244
|
+
the config parameters.
|
|
59
245
|
"""
|
|
60
246
|
return NotImplemented
|
|
61
247
|
|
|
62
248
|
|
|
63
249
|
class OllamaVLMEngine(VLMEngine):
|
|
64
|
-
def __init__(self, model_name:str, num_ctx:int=
|
|
250
|
+
def __init__(self, model_name:str, num_ctx:int=8192, keep_alive:int=300, config:VLMConfig=None, **kwrs):
|
|
65
251
|
"""
|
|
66
252
|
The Ollama inference engine.
|
|
67
253
|
|
|
@@ -70,9 +256,11 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
70
256
|
model_name : str
|
|
71
257
|
the model name exactly as shown in >> ollama ls
|
|
72
258
|
num_ctx : int, Optional
|
|
73
|
-
context length that
|
|
259
|
+
context length that LLM will evaluate.
|
|
74
260
|
keep_alive : int, Optional
|
|
75
|
-
seconds to hold the
|
|
261
|
+
seconds to hold the LLM after the last API call.
|
|
262
|
+
config : LLMConfig
|
|
263
|
+
the LLM configuration.
|
|
76
264
|
"""
|
|
77
265
|
if importlib.util.find_spec("ollama") is None:
|
|
78
266
|
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
@@ -83,9 +271,21 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
83
271
|
self.model_name = model_name
|
|
84
272
|
self.num_ctx = num_ctx
|
|
85
273
|
self.keep_alive = keep_alive
|
|
274
|
+
self.config = config if config else BasicVLMConfig()
|
|
275
|
+
self.formatted_params = self._format_config()
|
|
276
|
+
|
|
277
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
278
|
+
"""
|
|
279
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
280
|
+
"""
|
|
281
|
+
formatted_params = self.config.params.copy()
|
|
282
|
+
if "max_new_tokens" in formatted_params:
|
|
283
|
+
formatted_params["num_predict"] = formatted_params["max_new_tokens"]
|
|
284
|
+
formatted_params.pop("max_new_tokens")
|
|
86
285
|
|
|
87
|
-
|
|
88
|
-
|
|
286
|
+
return formatted_params
|
|
287
|
+
|
|
288
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
|
|
89
289
|
"""
|
|
90
290
|
This method inputs chat messages and outputs VLM generated text.
|
|
91
291
|
|
|
@@ -93,21 +293,19 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
93
293
|
----------
|
|
94
294
|
messages : List[Dict[str,str]]
|
|
95
295
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
96
|
-
max_new_tokens : str, Optional
|
|
97
|
-
the max number of new tokens VLM can generate.
|
|
98
|
-
temperature : float, Optional
|
|
99
|
-
the temperature for token sampling.
|
|
100
296
|
verbose : bool, Optional
|
|
101
297
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
102
298
|
stream : bool, Optional
|
|
103
299
|
if True, returns a generator that yields the output in real-time.
|
|
104
300
|
"""
|
|
105
|
-
|
|
301
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
302
|
+
|
|
303
|
+
options={'num_ctx': self.num_ctx, **self.formatted_params}
|
|
106
304
|
if stream:
|
|
107
305
|
def _stream_generator():
|
|
108
306
|
response_stream = self.client.chat(
|
|
109
307
|
model=self.model_name,
|
|
110
|
-
messages=
|
|
308
|
+
messages=processed_messages,
|
|
111
309
|
options=options,
|
|
112
310
|
stream=True,
|
|
113
311
|
keep_alive=self.keep_alive
|
|
@@ -117,12 +315,12 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
117
315
|
if content_chunk:
|
|
118
316
|
yield content_chunk
|
|
119
317
|
|
|
120
|
-
return _stream_generator()
|
|
318
|
+
return self.config.postprocess_response(_stream_generator())
|
|
121
319
|
|
|
122
320
|
elif verbose:
|
|
123
321
|
response = self.client.chat(
|
|
124
322
|
model=self.model_name,
|
|
125
|
-
messages=
|
|
323
|
+
messages=processed_messages,
|
|
126
324
|
options=options,
|
|
127
325
|
stream=True,
|
|
128
326
|
keep_alive=self.keep_alive
|
|
@@ -134,24 +332,36 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
134
332
|
print(content_chunk, end='', flush=True)
|
|
135
333
|
res += content_chunk
|
|
136
334
|
print('\n')
|
|
137
|
-
return res
|
|
335
|
+
return self.config.postprocess_response(res)
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
response = self.client.chat(
|
|
339
|
+
model=self.model_name,
|
|
340
|
+
messages=processed_messages,
|
|
341
|
+
options=options,
|
|
342
|
+
stream=False,
|
|
343
|
+
keep_alive=self.keep_alive
|
|
344
|
+
)
|
|
345
|
+
res = response.get('message', {}).get('content')
|
|
346
|
+
return self.config.postprocess_response(res)
|
|
138
347
|
|
|
139
|
-
return response.get('message', {}).get('content', '')
|
|
140
|
-
|
|
141
348
|
|
|
142
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
349
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
143
350
|
"""
|
|
144
351
|
Async version of chat method. Streaming is not supported.
|
|
145
352
|
"""
|
|
353
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
354
|
+
|
|
146
355
|
response = await self.async_client.chat(
|
|
147
356
|
model=self.model_name,
|
|
148
|
-
messages=
|
|
149
|
-
options={'
|
|
357
|
+
messages=processed_messages,
|
|
358
|
+
options={'num_ctx': self.num_ctx, **self.formatted_params},
|
|
150
359
|
stream=False,
|
|
151
360
|
keep_alive=self.keep_alive
|
|
152
361
|
)
|
|
153
362
|
|
|
154
|
-
|
|
363
|
+
res = response['message']['content']
|
|
364
|
+
return self.config.postprocess_response(res)
|
|
155
365
|
|
|
156
366
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
|
|
157
367
|
"""
|
|
@@ -163,8 +373,8 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
163
373
|
the system prompt.
|
|
164
374
|
user_prompt : str
|
|
165
375
|
the user prompt.
|
|
166
|
-
|
|
167
|
-
the image
|
|
376
|
+
image : Image.Image
|
|
377
|
+
the image for OCR.
|
|
168
378
|
"""
|
|
169
379
|
base64_str = image_to_base64(image)
|
|
170
380
|
return [
|
|
@@ -178,7 +388,7 @@ class OllamaVLMEngine(VLMEngine):
|
|
|
178
388
|
|
|
179
389
|
|
|
180
390
|
class OpenAIVLMEngine(VLMEngine):
|
|
181
|
-
def __init__(self, model:str,
|
|
391
|
+
def __init__(self, model:str, config:VLMConfig=None, **kwrs):
|
|
182
392
|
"""
|
|
183
393
|
The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
|
|
184
394
|
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
@@ -189,8 +399,8 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
189
399
|
----------
|
|
190
400
|
model_name : str
|
|
191
401
|
model name as described in https://platform.openai.com/docs/models
|
|
192
|
-
|
|
193
|
-
|
|
402
|
+
config : VLMConfig, Optional
|
|
403
|
+
the VLM configuration. Must be a child class of VLMConfig.
|
|
194
404
|
"""
|
|
195
405
|
if importlib.util.find_spec("openai") is None:
|
|
196
406
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
@@ -199,184 +409,99 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
199
409
|
self.client = OpenAI(**kwrs)
|
|
200
410
|
self.async_client = AsyncOpenAI(**kwrs)
|
|
201
411
|
self.model = model
|
|
202
|
-
self.
|
|
412
|
+
self.config = config if config else BasicVLMConfig()
|
|
413
|
+
self.formatted_params = self._format_config()
|
|
203
414
|
|
|
204
|
-
def
|
|
205
|
-
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
415
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
206
416
|
"""
|
|
207
|
-
This method
|
|
417
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
418
|
+
"""
|
|
419
|
+
formatted_params = self.config.params.copy()
|
|
420
|
+
if "max_new_tokens" in formatted_params:
|
|
421
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
422
|
+
formatted_params.pop("max_new_tokens")
|
|
423
|
+
|
|
424
|
+
return formatted_params
|
|
425
|
+
|
|
426
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
|
|
427
|
+
"""
|
|
428
|
+
This method inputs chat messages and outputs LLM generated text.
|
|
208
429
|
|
|
209
430
|
Parameters:
|
|
210
431
|
----------
|
|
211
432
|
messages : List[Dict[str,str]]
|
|
212
433
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
213
|
-
max_new_tokens : str, Optional
|
|
214
|
-
the max number of new tokens VLM can generate.
|
|
215
|
-
temperature : float, Optional
|
|
216
|
-
the temperature for token sampling.
|
|
217
434
|
verbose : bool, Optional
|
|
218
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
435
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
219
436
|
stream : bool, Optional
|
|
220
437
|
if True, returns a generator that yields the output in real-time.
|
|
221
438
|
"""
|
|
222
|
-
|
|
223
|
-
if self.reasoning_model:
|
|
224
|
-
# Reasoning models do not support temperature parameter
|
|
225
|
-
if temperature != 0.0:
|
|
226
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
439
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
227
440
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
model=self.model,
|
|
238
|
-
messages=messages,
|
|
239
|
-
max_completion_tokens=max_new_tokens,
|
|
240
|
-
stream=True,
|
|
241
|
-
**kwrs
|
|
242
|
-
)
|
|
243
|
-
for chunk in response_stream:
|
|
244
|
-
if len(chunk.choices) > 0:
|
|
245
|
-
if chunk.choices[0].delta.content is not None:
|
|
246
|
-
yield chunk.choices[0].delta.content
|
|
247
|
-
if chunk.choices[0].finish_reason == "length":
|
|
248
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
249
|
-
if self.reasoning_model:
|
|
250
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
251
|
-
return _stream_generator()
|
|
252
|
-
|
|
253
|
-
elif verbose:
|
|
254
|
-
response = self.client.chat.completions.create(
|
|
255
|
-
model=self.model,
|
|
256
|
-
messages=messages,
|
|
257
|
-
max_completion_tokens=max_new_tokens,
|
|
258
|
-
stream=True,
|
|
259
|
-
**kwrs
|
|
260
|
-
)
|
|
261
|
-
res = ''
|
|
262
|
-
for chunk in response:
|
|
263
|
-
if len(chunk.choices) > 0:
|
|
264
|
-
if chunk.choices[0].delta.content is not None:
|
|
265
|
-
res += chunk.choices[0].delta.content
|
|
266
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
267
|
-
if chunk.choices[0].finish_reason == "length":
|
|
268
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
269
|
-
if self.reasoning_model:
|
|
270
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
271
|
-
|
|
272
|
-
print('\n')
|
|
273
|
-
return res
|
|
274
|
-
else:
|
|
275
|
-
response = self.client.chat.completions.create(
|
|
276
|
-
model=self.model,
|
|
277
|
-
messages=messages,
|
|
278
|
-
max_completion_tokens=max_new_tokens,
|
|
279
|
-
stream=False,
|
|
280
|
-
**kwrs
|
|
281
|
-
)
|
|
282
|
-
return response.choices[0].message.content
|
|
283
|
-
|
|
284
|
-
# For non-reasoning models
|
|
285
|
-
else:
|
|
286
|
-
if stream:
|
|
287
|
-
def _stream_generator():
|
|
288
|
-
response_stream = self.client.chat.completions.create(
|
|
289
|
-
model=self.model,
|
|
290
|
-
messages=messages,
|
|
291
|
-
max_tokens=max_new_tokens,
|
|
292
|
-
temperature=temperature,
|
|
293
|
-
stream=True,
|
|
294
|
-
**kwrs
|
|
295
|
-
)
|
|
296
|
-
for chunk in response_stream:
|
|
297
|
-
if len(chunk.choices) > 0:
|
|
298
|
-
if chunk.choices[0].delta.content is not None:
|
|
299
|
-
yield chunk.choices[0].delta.content
|
|
300
|
-
if chunk.choices[0].finish_reason == "length":
|
|
301
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
302
|
-
if self.reasoning_model:
|
|
303
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
304
|
-
return _stream_generator()
|
|
305
|
-
|
|
306
|
-
elif verbose:
|
|
307
|
-
response = self.client.chat.completions.create(
|
|
308
|
-
model=self.model,
|
|
309
|
-
messages=messages,
|
|
310
|
-
max_tokens=max_new_tokens,
|
|
311
|
-
temperature=temperature,
|
|
312
|
-
stream=True,
|
|
313
|
-
**kwrs
|
|
314
|
-
)
|
|
315
|
-
res = ''
|
|
316
|
-
for chunk in response:
|
|
441
|
+
if stream:
|
|
442
|
+
def _stream_generator():
|
|
443
|
+
response_stream = self.client.chat.completions.create(
|
|
444
|
+
model=self.model,
|
|
445
|
+
messages=processed_messages,
|
|
446
|
+
stream=True,
|
|
447
|
+
**self.formatted_params
|
|
448
|
+
)
|
|
449
|
+
for chunk in response_stream:
|
|
317
450
|
if len(chunk.choices) > 0:
|
|
318
451
|
if chunk.choices[0].delta.content is not None:
|
|
319
|
-
|
|
320
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
452
|
+
yield chunk.choices[0].delta.content
|
|
321
453
|
if chunk.choices[0].finish_reason == "length":
|
|
322
454
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
323
|
-
if self.reasoning_model:
|
|
324
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
325
455
|
|
|
326
|
-
|
|
327
|
-
return res
|
|
456
|
+
return self.config.postprocess_response(_stream_generator())
|
|
328
457
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
model=self.model,
|
|
332
|
-
messages=messages,
|
|
333
|
-
max_tokens=max_new_tokens,
|
|
334
|
-
temperature=temperature,
|
|
335
|
-
stream=False,
|
|
336
|
-
**kwrs
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
return response.choices[0].message.content
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
|
|
343
|
-
"""
|
|
344
|
-
Async version of chat method. Streaming is not supported.
|
|
345
|
-
"""
|
|
346
|
-
if self.reasoning_model:
|
|
347
|
-
# Reasoning models do not support temperature parameter
|
|
348
|
-
if temperature != 0.0:
|
|
349
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
350
|
-
|
|
351
|
-
# Reasoning models do not support system prompts
|
|
352
|
-
if any(msg['role'] == 'system' for msg in messages):
|
|
353
|
-
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
354
|
-
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
355
|
-
|
|
356
|
-
response = await self.async_client.chat.completions.create(
|
|
458
|
+
elif verbose:
|
|
459
|
+
response = self.client.chat.completions.create(
|
|
357
460
|
model=self.model,
|
|
358
|
-
messages=
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
**kwrs
|
|
461
|
+
messages=processed_messages,
|
|
462
|
+
stream=True,
|
|
463
|
+
**self.formatted_params
|
|
362
464
|
)
|
|
465
|
+
res = ''
|
|
466
|
+
for chunk in response:
|
|
467
|
+
if len(chunk.choices) > 0:
|
|
468
|
+
if chunk.choices[0].delta.content is not None:
|
|
469
|
+
res += chunk.choices[0].delta.content
|
|
470
|
+
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
471
|
+
if chunk.choices[0].finish_reason == "length":
|
|
472
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
363
473
|
|
|
474
|
+
print('\n')
|
|
475
|
+
return self.config.postprocess_response(res)
|
|
364
476
|
else:
|
|
365
|
-
response =
|
|
477
|
+
response = self.client.chat.completions.create(
|
|
366
478
|
model=self.model,
|
|
367
|
-
messages=
|
|
368
|
-
max_tokens=max_new_tokens,
|
|
369
|
-
temperature=temperature,
|
|
479
|
+
messages=processed_messages,
|
|
370
480
|
stream=False,
|
|
371
|
-
**
|
|
481
|
+
**self.formatted_params
|
|
372
482
|
)
|
|
483
|
+
res = response.choices[0].message.content
|
|
484
|
+
return self.config.postprocess_response(res)
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
488
|
+
"""
|
|
489
|
+
Async version of chat method. Streaming is not supported.
|
|
490
|
+
"""
|
|
491
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
492
|
+
|
|
493
|
+
response = await self.async_client.chat.completions.create(
|
|
494
|
+
model=self.model,
|
|
495
|
+
messages=processed_messages,
|
|
496
|
+
stream=False,
|
|
497
|
+
**self.formatted_params
|
|
498
|
+
)
|
|
373
499
|
|
|
374
500
|
if response.choices[0].finish_reason == "length":
|
|
375
501
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
376
|
-
if self.reasoning_model:
|
|
377
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
378
502
|
|
|
379
|
-
|
|
503
|
+
res = response.choices[0].message.content
|
|
504
|
+
return self.config.postprocess_response(res)
|
|
380
505
|
|
|
381
506
|
def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
|
|
382
507
|
"""
|
|
@@ -415,7 +540,7 @@ class OpenAIVLMEngine(VLMEngine):
|
|
|
415
540
|
|
|
416
541
|
|
|
417
542
|
class AzureOpenAIVLMEngine(OpenAIVLMEngine):
|
|
418
|
-
def __init__(self, model:str, api_version:str,
|
|
543
|
+
def __init__(self, model:str, api_version:str, config:VLMConfig=None, **kwrs):
|
|
419
544
|
"""
|
|
420
545
|
The Azure OpenAI API inference engine.
|
|
421
546
|
For parameters and documentation, refer to
|
|
@@ -428,8 +553,8 @@ class AzureOpenAIVLMEngine(OpenAIVLMEngine):
|
|
|
428
553
|
model name as described in https://platform.openai.com/docs/models
|
|
429
554
|
api_version : str
|
|
430
555
|
the Azure OpenAI API version
|
|
431
|
-
|
|
432
|
-
|
|
556
|
+
config : LLMConfig
|
|
557
|
+
the LLM configuration.
|
|
433
558
|
"""
|
|
434
559
|
if importlib.util.find_spec("openai") is None:
|
|
435
560
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
@@ -441,4 +566,5 @@ class AzureOpenAIVLMEngine(OpenAIVLMEngine):
|
|
|
441
566
|
**kwrs)
|
|
442
567
|
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
443
568
|
**kwrs)
|
|
444
|
-
self.
|
|
569
|
+
self.config = config if config else BasicVLMConfig()
|
|
570
|
+
self.formatted_params = self._format_config()
|