vlm4ocr 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr/vlm_engines.py CHANGED
@@ -1,23 +1,204 @@
1
1
  import abc
2
2
  import importlib.util
3
- from typing import List, Dict, Union, Generator
3
+ from typing import Any, List, Dict, Union, Generator
4
4
  import warnings
5
5
  from PIL import Image
6
6
  from vlm4ocr.utils import image_to_base64
7
7
 
8
8
 
9
+ class VLMConfig(abc.ABC):
10
+ def __init__(self, **kwargs):
11
+ """
12
+ This is an abstract class to provide interfaces for VLM configuration.
13
+ Children classes that inherts this class can be used in extrators and prompt editor.
14
+ Common VLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
15
+ """
16
+ self.params = kwargs.copy()
17
+
18
+ @abc.abstractmethod
19
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
20
+ """
21
+ This method preprocesses the input messages before passing them to the VLM.
22
+
23
+ Parameters:
24
+ ----------
25
+ messages : List[Dict[str,str]]
26
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
27
+
28
+ Returns:
29
+ -------
30
+ messages : List[Dict[str,str]]
31
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
32
+ """
33
+ return NotImplemented
34
+
35
+ @abc.abstractmethod
36
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
37
+ """
38
+ This method postprocesses the VLM response after it is generated.
39
+
40
+ Parameters:
41
+ ----------
42
+ response : Union[str, Generator[str, None, None]]
43
+ the VLM response. Can be a string or a generator.
44
+
45
+ Returns:
46
+ -------
47
+ response : str
48
+ the postprocessed VLM response
49
+ """
50
+ return NotImplemented
51
+
52
+
53
+ class BasicVLMConfig(VLMConfig):
54
+ def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
55
+ """
56
+ The basic VLM configuration for most non-reasoning models.
57
+ """
58
+ super().__init__(**kwargs)
59
+ self.max_new_tokens = max_new_tokens
60
+ self.temperature = temperature
61
+ self.params["max_new_tokens"] = self.max_new_tokens
62
+ self.params["temperature"] = self.temperature
63
+
64
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
65
+ """
66
+ This method preprocesses the input messages before passing them to the VLM.
67
+
68
+ Parameters:
69
+ ----------
70
+ messages : List[Dict[str,str]]
71
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
72
+
73
+ Returns:
74
+ -------
75
+ messages : List[Dict[str,str]]
76
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
77
+ """
78
+ return messages
79
+
80
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
81
+ """
82
+ This method postprocesses the VLM response after it is generated.
83
+
84
+ Parameters:
85
+ ----------
86
+ response : Union[str, Generator[str, None, None]]
87
+ the VLM response. Can be a string or a generator.
88
+
89
+ Returns: Union[str, Generator[Dict[str, str], None, None]]
90
+ the postprocessed VLM response.
91
+ if input is a generator, the output will be a generator {"data": <content>}.
92
+ """
93
+ if isinstance(response, str):
94
+ return response
95
+
96
+ def _process_stream():
97
+ for chunk in response:
98
+ yield chunk
99
+
100
+ return _process_stream()
101
+
102
+
103
+ class OpenAIReasoningVLMConfig(VLMConfig):
104
+ def __init__(self, reasoning_effort:str="low", **kwargs):
105
+ """
106
+ The OpenAI "o" series configuration.
107
+ 1. The reasoning effort is set to "low" by default.
108
+ 2. The temperature parameter is not supported and will be ignored.
109
+ 3. The system prompt is not supported and will be concatenated to the next user prompt.
110
+
111
+ Parameters:
112
+ ----------
113
+ reasoning_effort : str, Optional
114
+ the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
115
+ """
116
+ super().__init__(**kwargs)
117
+ if reasoning_effort not in ["low", "medium", "high"]:
118
+ raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
119
+
120
+ self.reasoning_effort = reasoning_effort
121
+ self.params["reasoning_effort"] = self.reasoning_effort
122
+
123
+ if "temperature" in self.params:
124
+ warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
125
+ self.params.pop("temperature")
126
+
127
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
128
+ """
129
+ Concatenate system prompts to the next user prompt.
130
+
131
+ Parameters:
132
+ ----------
133
+ messages : List[Dict[str,str]]
134
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
135
+
136
+ Returns:
137
+ -------
138
+ messages : List[Dict[str,str]]
139
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
140
+ """
141
+ system_prompt_holder = ""
142
+ new_messages = []
143
+ for i, message in enumerate(messages):
144
+ # if system prompt, store it in system_prompt_holder
145
+ if message['role'] == 'system':
146
+ system_prompt_holder = message['content']
147
+ # if user prompt, concatenate it with system_prompt_holder
148
+ elif message['role'] == 'user':
149
+ if system_prompt_holder:
150
+ new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
151
+ system_prompt_holder = ""
152
+ else:
153
+ new_message = {'role': message['role'], 'content': message['content']}
154
+
155
+ new_messages.append(new_message)
156
+ # if assistant/other prompt, do nothing
157
+ else:
158
+ new_message = {'role': message['role'], 'content': message['content']}
159
+ new_messages.append(new_message)
160
+
161
+ return new_messages
162
+
163
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
164
+ """
165
+ This method postprocesses the VLM response after it is generated.
166
+
167
+ Parameters:
168
+ ----------
169
+ response : Union[str, Generator[str, None, None]]
170
+ the VLM response. Can be a string or a generator.
171
+
172
+ Returns: Union[str, Generator[Dict[str, str], None, None]]
173
+ the postprocessed VLM response.
174
+ if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
175
+ """
176
+ if isinstance(response, str):
177
+ return response
178
+
179
+ def _process_stream():
180
+ for chunk in response:
181
+ yield {"type": "response", "data": chunk}
182
+
183
+ return _process_stream()
184
+
185
+
9
186
  class VLMEngine:
10
187
  @abc.abstractmethod
11
- def __init__(self):
188
+ def __init__(self, config:VLMConfig, **kwrs):
12
189
  """
13
190
  This is an abstract class to provide interfaces for VLM inference engines.
14
191
  Children classes that inherts this class can be used in extrators. Must implement chat() method.
192
+
193
+ Parameters:
194
+ ----------
195
+ config : VLMConfig
196
+ the VLM configuration. Must be a child class of VLMConfig.
15
197
  """
16
198
  return NotImplemented
17
199
 
18
200
  @abc.abstractmethod
19
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
20
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
201
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
21
202
  """
22
203
  This method inputs chat messages and outputs VLM generated text.
23
204
 
@@ -25,10 +206,6 @@ class VLMEngine:
25
206
  ----------
26
207
  messages : List[Dict[str,str]]
27
208
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
28
- max_new_tokens : str, Optional
29
- the max number of new tokens VLM can generate.
30
- temperature : float, Optional
31
- the temperature for token sampling.
32
209
  verbose : bool, Optional
33
210
  if True, VLM generated text will be printed in terminal in real-time.
34
211
  stream : bool, Optional
@@ -37,14 +214,14 @@ class VLMEngine:
37
214
  return NotImplemented
38
215
 
39
216
  @abc.abstractmethod
40
- def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
217
+ def chat_async(self, messages:List[Dict[str,str]]) -> str:
41
218
  """
42
219
  The async version of chat method. Streaming is not supported.
43
220
  """
44
221
  return NotImplemented
45
222
 
46
223
  @abc.abstractmethod
47
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image_path:str) -> List[Dict[str,str]]:
224
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
48
225
  """
49
226
  This method inputs an image and returns the correesponding chat messages for the inference engine.
50
227
 
@@ -54,14 +231,23 @@ class VLMEngine:
54
231
  the system prompt.
55
232
  user_prompt : str
56
233
  the user prompt.
57
- image_path : str
58
- the image path for OCR.
234
+ image : Image.Image
235
+ the image for OCR.
236
+ """
237
+ return NotImplemented
238
+
239
+ def _format_config(self) -> Dict[str, Any]:
240
+ """
241
+ This method format the VLM configuration with the correct key for the inference engine.
242
+
243
+ Return : Dict[str, Any]
244
+ the config parameters.
59
245
  """
60
246
  return NotImplemented
61
247
 
62
248
 
63
249
  class OllamaVLMEngine(VLMEngine):
64
- def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
250
+ def __init__(self, model_name:str, num_ctx:int=8192, keep_alive:int=300, config:VLMConfig=None, **kwrs):
65
251
  """
66
252
  The Ollama inference engine.
67
253
 
@@ -70,9 +256,11 @@ class OllamaVLMEngine(VLMEngine):
70
256
  model_name : str
71
257
  the model name exactly as shown in >> ollama ls
72
258
  num_ctx : int, Optional
73
- context length that VLM will evaluate.
259
+ context length that LLM will evaluate.
74
260
  keep_alive : int, Optional
75
- seconds to hold the VLM after the last API call.
261
+ seconds to hold the LLM after the last API call.
262
+ config : LLMConfig
263
+ the LLM configuration.
76
264
  """
77
265
  if importlib.util.find_spec("ollama") is None:
78
266
  raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
@@ -83,9 +271,21 @@ class OllamaVLMEngine(VLMEngine):
83
271
  self.model_name = model_name
84
272
  self.num_ctx = num_ctx
85
273
  self.keep_alive = keep_alive
274
+ self.config = config if config else BasicVLMConfig()
275
+ self.formatted_params = self._format_config()
276
+
277
+ def _format_config(self) -> Dict[str, Any]:
278
+ """
279
+ This method format the LLM configuration with the correct key for the inference engine.
280
+ """
281
+ formatted_params = self.config.params.copy()
282
+ if "max_new_tokens" in formatted_params:
283
+ formatted_params["num_predict"] = formatted_params["max_new_tokens"]
284
+ formatted_params.pop("max_new_tokens")
86
285
 
87
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
88
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
286
+ return formatted_params
287
+
288
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
89
289
  """
90
290
  This method inputs chat messages and outputs VLM generated text.
91
291
 
@@ -93,21 +293,19 @@ class OllamaVLMEngine(VLMEngine):
93
293
  ----------
94
294
  messages : List[Dict[str,str]]
95
295
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
96
- max_new_tokens : str, Optional
97
- the max number of new tokens VLM can generate.
98
- temperature : float, Optional
99
- the temperature for token sampling.
100
296
  verbose : bool, Optional
101
297
  if True, VLM generated text will be printed in terminal in real-time.
102
298
  stream : bool, Optional
103
299
  if True, returns a generator that yields the output in real-time.
104
300
  """
105
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs}
301
+ processed_messages = self.config.preprocess_messages(messages)
302
+
303
+ options={'num_ctx': self.num_ctx, **self.formatted_params}
106
304
  if stream:
107
305
  def _stream_generator():
108
306
  response_stream = self.client.chat(
109
307
  model=self.model_name,
110
- messages=messages,
308
+ messages=processed_messages,
111
309
  options=options,
112
310
  stream=True,
113
311
  keep_alive=self.keep_alive
@@ -117,12 +315,12 @@ class OllamaVLMEngine(VLMEngine):
117
315
  if content_chunk:
118
316
  yield content_chunk
119
317
 
120
- return _stream_generator()
318
+ return self.config.postprocess_response(_stream_generator())
121
319
 
122
320
  elif verbose:
123
321
  response = self.client.chat(
124
322
  model=self.model_name,
125
- messages=messages,
323
+ messages=processed_messages,
126
324
  options=options,
127
325
  stream=True,
128
326
  keep_alive=self.keep_alive
@@ -134,24 +332,36 @@ class OllamaVLMEngine(VLMEngine):
134
332
  print(content_chunk, end='', flush=True)
135
333
  res += content_chunk
136
334
  print('\n')
137
- return res
335
+ return self.config.postprocess_response(res)
336
+
337
+ else:
338
+ response = self.client.chat(
339
+ model=self.model_name,
340
+ messages=processed_messages,
341
+ options=options,
342
+ stream=False,
343
+ keep_alive=self.keep_alive
344
+ )
345
+ res = response.get('message', {}).get('content')
346
+ return self.config.postprocess_response(res)
138
347
 
139
- return response.get('message', {}).get('content', '')
140
-
141
348
 
142
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
349
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
143
350
  """
144
351
  Async version of chat method. Streaming is not supported.
145
352
  """
353
+ processed_messages = self.config.preprocess_messages(messages)
354
+
146
355
  response = await self.async_client.chat(
147
356
  model=self.model_name,
148
- messages=messages,
149
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
357
+ messages=processed_messages,
358
+ options={'num_ctx': self.num_ctx, **self.formatted_params},
150
359
  stream=False,
151
360
  keep_alive=self.keep_alive
152
361
  )
153
362
 
154
- return response.get('message', {}).get('content', '')
363
+ res = response['message']['content']
364
+ return self.config.postprocess_response(res)
155
365
 
156
366
  def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
157
367
  """
@@ -163,8 +373,8 @@ class OllamaVLMEngine(VLMEngine):
163
373
  the system prompt.
164
374
  user_prompt : str
165
375
  the user prompt.
166
- image_path : str
167
- the image path for OCR.
376
+ image : Image.Image
377
+ the image for OCR.
168
378
  """
169
379
  base64_str = image_to_base64(image)
170
380
  return [
@@ -178,7 +388,7 @@ class OllamaVLMEngine(VLMEngine):
178
388
 
179
389
 
180
390
  class OpenAIVLMEngine(VLMEngine):
181
- def __init__(self, model:str, reasoning_model:bool=False, **kwrs):
391
+ def __init__(self, model:str, config:VLMConfig=None, **kwrs):
182
392
  """
183
393
  The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
184
394
  - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
@@ -189,8 +399,8 @@ class OpenAIVLMEngine(VLMEngine):
189
399
  ----------
190
400
  model_name : str
191
401
  model name as described in https://platform.openai.com/docs/models
192
- reasoning_model : bool, Optional
193
- indicator for OpenAI reasoning models ("o" series).
402
+ config : VLMConfig, Optional
403
+ the VLM configuration. Must be a child class of VLMConfig.
194
404
  """
195
405
  if importlib.util.find_spec("openai") is None:
196
406
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
@@ -199,184 +409,99 @@ class OpenAIVLMEngine(VLMEngine):
199
409
  self.client = OpenAI(**kwrs)
200
410
  self.async_client = AsyncOpenAI(**kwrs)
201
411
  self.model = model
202
- self.reasoning_model = reasoning_model
412
+ self.config = config if config else BasicVLMConfig()
413
+ self.formatted_params = self._format_config()
203
414
 
204
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0,
205
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
415
+ def _format_config(self) -> Dict[str, Any]:
206
416
  """
207
- This method inputs chat messages and outputs VLM generated text.
417
+ This method format the LLM configuration with the correct key for the inference engine.
418
+ """
419
+ formatted_params = self.config.params.copy()
420
+ if "max_new_tokens" in formatted_params:
421
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
422
+ formatted_params.pop("max_new_tokens")
423
+
424
+ return formatted_params
425
+
426
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
427
+ """
428
+ This method inputs chat messages and outputs LLM generated text.
208
429
 
209
430
  Parameters:
210
431
  ----------
211
432
  messages : List[Dict[str,str]]
212
433
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
213
- max_new_tokens : str, Optional
214
- the max number of new tokens VLM can generate.
215
- temperature : float, Optional
216
- the temperature for token sampling.
217
434
  verbose : bool, Optional
218
- if True, VLM generated text will be printed in terminal in real-time.
435
+ if True, VLM generated text will be printed in terminal in real-time.
219
436
  stream : bool, Optional
220
437
  if True, returns a generator that yields the output in real-time.
221
438
  """
222
- # For reasoning models
223
- if self.reasoning_model:
224
- # Reasoning models do not support temperature parameter
225
- if temperature != 0.0:
226
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
439
+ processed_messages = self.config.preprocess_messages(messages)
227
440
 
228
- # Reasoning models do not support system prompts
229
- if any(msg['role'] == 'system' for msg in messages):
230
- warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
231
- messages = [msg for msg in messages if msg['role'] != 'system']
232
-
233
-
234
- if stream:
235
- def _stream_generator():
236
- response_stream = self.client.chat.completions.create(
237
- model=self.model,
238
- messages=messages,
239
- max_completion_tokens=max_new_tokens,
240
- stream=True,
241
- **kwrs
242
- )
243
- for chunk in response_stream:
244
- if len(chunk.choices) > 0:
245
- if chunk.choices[0].delta.content is not None:
246
- yield chunk.choices[0].delta.content
247
- if chunk.choices[0].finish_reason == "length":
248
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
249
- if self.reasoning_model:
250
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
251
- return _stream_generator()
252
-
253
- elif verbose:
254
- response = self.client.chat.completions.create(
255
- model=self.model,
256
- messages=messages,
257
- max_completion_tokens=max_new_tokens,
258
- stream=True,
259
- **kwrs
260
- )
261
- res = ''
262
- for chunk in response:
263
- if len(chunk.choices) > 0:
264
- if chunk.choices[0].delta.content is not None:
265
- res += chunk.choices[0].delta.content
266
- print(chunk.choices[0].delta.content, end="", flush=True)
267
- if chunk.choices[0].finish_reason == "length":
268
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
269
- if self.reasoning_model:
270
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
271
-
272
- print('\n')
273
- return res
274
- else:
275
- response = self.client.chat.completions.create(
276
- model=self.model,
277
- messages=messages,
278
- max_completion_tokens=max_new_tokens,
279
- stream=False,
280
- **kwrs
281
- )
282
- return response.choices[0].message.content
283
-
284
- # For non-reasoning models
285
- else:
286
- if stream:
287
- def _stream_generator():
288
- response_stream = self.client.chat.completions.create(
289
- model=self.model,
290
- messages=messages,
291
- max_tokens=max_new_tokens,
292
- temperature=temperature,
293
- stream=True,
294
- **kwrs
295
- )
296
- for chunk in response_stream:
297
- if len(chunk.choices) > 0:
298
- if chunk.choices[0].delta.content is not None:
299
- yield chunk.choices[0].delta.content
300
- if chunk.choices[0].finish_reason == "length":
301
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
302
- if self.reasoning_model:
303
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
304
- return _stream_generator()
305
-
306
- elif verbose:
307
- response = self.client.chat.completions.create(
308
- model=self.model,
309
- messages=messages,
310
- max_tokens=max_new_tokens,
311
- temperature=temperature,
312
- stream=True,
313
- **kwrs
314
- )
315
- res = ''
316
- for chunk in response:
441
+ if stream:
442
+ def _stream_generator():
443
+ response_stream = self.client.chat.completions.create(
444
+ model=self.model,
445
+ messages=processed_messages,
446
+ stream=True,
447
+ **self.formatted_params
448
+ )
449
+ for chunk in response_stream:
317
450
  if len(chunk.choices) > 0:
318
451
  if chunk.choices[0].delta.content is not None:
319
- res += chunk.choices[0].delta.content
320
- print(chunk.choices[0].delta.content, end="", flush=True)
452
+ yield chunk.choices[0].delta.content
321
453
  if chunk.choices[0].finish_reason == "length":
322
454
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
323
- if self.reasoning_model:
324
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
325
455
 
326
- print('\n')
327
- return res
456
+ return self.config.postprocess_response(_stream_generator())
328
457
 
329
- else:
330
- response = self.client.chat.completions.create(
331
- model=self.model,
332
- messages=messages,
333
- max_tokens=max_new_tokens,
334
- temperature=temperature,
335
- stream=False,
336
- **kwrs
337
- )
338
-
339
- return response.choices[0].message.content
340
-
341
-
342
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
343
- """
344
- Async version of chat method. Streaming is not supported.
345
- """
346
- if self.reasoning_model:
347
- # Reasoning models do not support temperature parameter
348
- if temperature != 0.0:
349
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
350
-
351
- # Reasoning models do not support system prompts
352
- if any(msg['role'] == 'system' for msg in messages):
353
- warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
354
- messages = [msg for msg in messages if msg['role'] != 'system']
355
-
356
- response = await self.async_client.chat.completions.create(
458
+ elif verbose:
459
+ response = self.client.chat.completions.create(
357
460
  model=self.model,
358
- messages=messages,
359
- max_completion_tokens=max_new_tokens,
360
- stream=False,
361
- **kwrs
461
+ messages=processed_messages,
462
+ stream=True,
463
+ **self.formatted_params
362
464
  )
465
+ res = ''
466
+ for chunk in response:
467
+ if len(chunk.choices) > 0:
468
+ if chunk.choices[0].delta.content is not None:
469
+ res += chunk.choices[0].delta.content
470
+ print(chunk.choices[0].delta.content, end="", flush=True)
471
+ if chunk.choices[0].finish_reason == "length":
472
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
363
473
 
474
+ print('\n')
475
+ return self.config.postprocess_response(res)
364
476
  else:
365
- response = await self.async_client.chat.completions.create(
477
+ response = self.client.chat.completions.create(
366
478
  model=self.model,
367
- messages=messages,
368
- max_tokens=max_new_tokens,
369
- temperature=temperature,
479
+ messages=processed_messages,
370
480
  stream=False,
371
- **kwrs
481
+ **self.formatted_params
372
482
  )
483
+ res = response.choices[0].message.content
484
+ return self.config.postprocess_response(res)
485
+
486
+
487
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
488
+ """
489
+ Async version of chat method. Streaming is not supported.
490
+ """
491
+ processed_messages = self.config.preprocess_messages(messages)
492
+
493
+ response = await self.async_client.chat.completions.create(
494
+ model=self.model,
495
+ messages=processed_messages,
496
+ stream=False,
497
+ **self.formatted_params
498
+ )
373
499
 
374
500
  if response.choices[0].finish_reason == "length":
375
501
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
376
- if self.reasoning_model:
377
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
378
502
 
379
- return response.choices[0].message.content
503
+ res = response.choices[0].message.content
504
+ return self.config.postprocess_response(res)
380
505
 
381
506
  def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
382
507
  """
@@ -415,7 +540,7 @@ class OpenAIVLMEngine(VLMEngine):
415
540
 
416
541
 
417
542
  class AzureOpenAIVLMEngine(OpenAIVLMEngine):
418
- def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
543
+ def __init__(self, model:str, api_version:str, config:VLMConfig=None, **kwrs):
419
544
  """
420
545
  The Azure OpenAI API inference engine.
421
546
  For parameters and documentation, refer to
@@ -428,8 +553,8 @@ class AzureOpenAIVLMEngine(OpenAIVLMEngine):
428
553
  model name as described in https://platform.openai.com/docs/models
429
554
  api_version : str
430
555
  the Azure OpenAI API version
431
- reasoning_model : bool, Optional
432
- indicator for OpenAI reasoning models ("o" series).
556
+ config : LLMConfig
557
+ the LLM configuration.
433
558
  """
434
559
  if importlib.util.find_spec("openai") is None:
435
560
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
@@ -441,4 +566,5 @@ class AzureOpenAIVLMEngine(OpenAIVLMEngine):
441
566
  **kwrs)
442
567
  self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
443
568
  **kwrs)
444
- self.reasoning_model = reasoning_model
569
+ self.config = config if config else BasicVLMConfig()
570
+ self.formatted_params = self._format_config()