vlm4ocr 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr/vlm_engines.py CHANGED
@@ -1,338 +1,29 @@
1
1
  import abc
2
- import importlib.util
3
- from typing import Any, List, Dict, Union, Generator
4
- import warnings
5
- import os
6
- import re
2
+ from typing import List, Dict
7
3
  from PIL import Image
8
4
  from vlm4ocr.utils import image_to_base64
9
-
10
-
11
- class VLMConfig(abc.ABC):
12
- def __init__(self, **kwargs):
13
- """
14
- This is an abstract class to provide interfaces for VLM configuration.
15
- Children classes that inherts this class can be used in extrators and prompt editor.
16
- Common VLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
17
- """
18
- self.params = kwargs.copy()
19
-
20
- @abc.abstractmethod
21
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
22
- """
23
- This method preprocesses the input messages before passing them to the VLM.
24
-
25
- Parameters:
26
- ----------
27
- messages : List[Dict[str,str]]
28
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
29
-
30
- Returns:
31
- -------
32
- messages : List[Dict[str,str]]
33
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
34
- """
35
- return NotImplemented
36
-
37
- @abc.abstractmethod
38
- def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
39
- """
40
- This method postprocesses the VLM response after it is generated.
41
-
42
- Parameters:
43
- ----------
44
- response : Union[str, Generator[str, None, None]]
45
- the VLM response. Can be a string or a generator.
46
-
47
- Returns:
48
- -------
49
- response : str
50
- the postprocessed VLM response
51
- """
52
- return NotImplemented
53
-
54
-
55
- class BasicVLMConfig(VLMConfig):
56
- def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
57
- """
58
- The basic VLM configuration for most non-reasoning models.
59
- """
60
- super().__init__(**kwargs)
61
- self.max_new_tokens = max_new_tokens
62
- self.temperature = temperature
63
- self.params["max_new_tokens"] = self.max_new_tokens
64
- self.params["temperature"] = self.temperature
65
-
66
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
67
- """
68
- This method preprocesses the input messages before passing them to the VLM.
69
-
70
- Parameters:
71
- ----------
72
- messages : List[Dict[str,str]]
73
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
74
-
75
- Returns:
76
- -------
77
- messages : List[Dict[str,str]]
78
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
79
- """
80
- return messages
81
-
82
- def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
83
- """
84
- This method postprocesses the VLM response after it is generated.
85
-
86
- Parameters:
87
- ----------
88
- response : Union[str, Generator[str, None, None]]
89
- the VLM response. Can be a string or a generator.
90
-
91
- Returns: Union[str, Generator[Dict[str, str], None, None]]
92
- the postprocessed VLM response.
93
- if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
94
- """
95
- if isinstance(response, str):
96
- return {"response": response}
97
-
98
- elif isinstance(response, dict):
99
- if "response" in response:
100
- return response
101
- else:
102
- warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
103
- return {"response": ""}
104
-
105
- def _process_stream():
106
- for chunk in response:
107
- if isinstance(chunk, dict):
108
- yield chunk
109
- elif isinstance(chunk, str):
110
- yield {"type": "response", "data": chunk}
111
-
112
- return _process_stream()
113
-
114
- class ReasoningVLMConfig(VLMConfig):
115
- def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
116
- """
117
- The general configuration for reasoning vision models.
118
- """
119
- super().__init__(**kwargs)
120
- self.thinking_token_start = thinking_token_start
121
- self.thinking_token_end = thinking_token_end
122
-
123
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
124
- """
125
- This method preprocesses the input messages before passing them to the VLM.
126
-
127
- Parameters:
128
- ----------
129
- messages : List[Dict[str,str]]
130
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
131
-
132
- Returns:
133
- -------
134
- messages : List[Dict[str,str]]
135
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
136
- """
137
- return messages.copy()
138
-
139
- def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
140
- """
141
- This method postprocesses the VLM response after it is generated.
142
- 1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
143
- 2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
144
- 3. If input is a generator,
145
- a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
146
- b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
147
-
148
- Parameters:
149
- ----------
150
- response : Union[str, Generator[str, None, None]]
151
- the VLM response. Can be a string or a generator.
152
-
153
- Returns:
154
- -------
155
- response : Union[str, Generator[str, None, None]]
156
- the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
157
- if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
158
- """
159
- if isinstance(response, str):
160
- # get contents between thinking_token_start and thinking_token_end
161
- pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
162
- match = re.search(pattern, response, re.DOTALL)
163
- reasoning = match.group(1) if match else ""
164
- # get response AFTER thinking_token_end
165
- response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
166
- return {"reasoning": reasoning, "response": response}
167
-
168
- elif isinstance(response, dict):
169
- if "reasoning" in response and "response" in response:
170
- return response
171
- else:
172
- warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
173
- return {"reasoning": "", "response": ""}
174
-
175
- elif isinstance(response, Generator):
176
- def _process_stream():
177
- think_flag = False
178
- buffer = ""
179
- for chunk in response:
180
- if isinstance(chunk, dict):
181
- yield chunk
182
-
183
- elif isinstance(chunk, str):
184
- buffer += chunk
185
- # switch between reasoning and response
186
- if self.thinking_token_start in buffer:
187
- think_flag = True
188
- buffer = buffer.replace(self.thinking_token_start, "")
189
- elif self.thinking_token_end in buffer:
190
- think_flag = False
191
- buffer = buffer.replace(self.thinking_token_end, "")
192
-
193
- # if chunk is in thinking block, tag it as reasoning; else tag it as response
194
- if chunk not in [self.thinking_token_start, self.thinking_token_end]:
195
- if think_flag:
196
- yield {"type": "reasoning", "data": chunk}
197
- else:
198
- yield {"type": "response", "data": chunk}
199
-
200
- return _process_stream()
201
-
202
- else:
203
- warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
204
- return {"reasoning": "", "response": ""}
205
-
206
-
207
- class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
208
- def __init__(self, reasoning_effort:str="low", **kwargs):
209
- """
210
- The OpenAI "o" series configuration.
211
- 1. The reasoning effort is set to "low" by default.
212
- 2. The temperature parameter is not supported and will be ignored.
213
- 3. The system prompt is not supported and will be concatenated to the next user prompt.
214
-
215
- Parameters:
216
- ----------
217
- reasoning_effort : str, Optional
218
- the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
219
- """
220
- super().__init__(**kwargs)
221
- if reasoning_effort not in ["low", "medium", "high"]:
222
- raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
223
-
224
- self.reasoning_effort = reasoning_effort
225
- self.params["reasoning_effort"] = self.reasoning_effort
226
-
227
- if "temperature" in self.params:
228
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
229
- self.params.pop("temperature")
230
-
231
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
232
- """
233
- Concatenate system prompts to the next user prompt.
234
-
235
- Parameters:
236
- ----------
237
- messages : List[Dict[str,str]]
238
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
239
-
240
- Returns:
241
- -------
242
- messages : List[Dict[str,str]]
243
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
244
- """
245
- system_prompt_holder = ""
246
- new_messages = []
247
- for i, message in enumerate(messages):
248
- # if system prompt, store it in system_prompt_holder
249
- if message['role'] == 'system':
250
- system_prompt_holder = message['content']
251
- # if user prompt, concatenate it with system_prompt_holder
252
- elif message['role'] == 'user':
253
- if system_prompt_holder:
254
- new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
255
- system_prompt_holder = ""
256
- else:
257
- new_message = {'role': message['role'], 'content': message['content']}
258
-
259
- new_messages.append(new_message)
260
- # if assistant/other prompt, do nothing
261
- else:
262
- new_message = {'role': message['role'], 'content': message['content']}
263
- new_messages.append(new_message)
264
-
265
- return new_messages
266
-
267
-
268
- class MessagesLogger:
269
- def __init__(self):
270
- """
271
- This class is used to log the messages for InferenceEngine.chat().
272
- """
273
- self.messages_log = []
274
-
275
- def log_messages(self, messages : List[Dict[str,str]]):
276
- """
277
- This method logs the messages to a list.
278
- """
279
- self.messages_log.append(messages)
280
-
281
- def get_messages_log(self) -> List[List[Dict[str,str]]]:
282
- """
283
- This method returns a copy of the current messages log
284
- """
285
- return self.messages_log.copy()
286
-
287
- def clear_messages_log(self):
288
- """
289
- This method clears the current messages log
290
- """
291
- self.messages_log.clear()
292
-
293
-
294
- class VLMEngine:
295
- @abc.abstractmethod
296
- def __init__(self, config:VLMConfig, **kwrs):
297
- """
298
- This is an abstract class to provide interfaces for VLM inference engines.
299
- Children classes that inherts this class can be used in extrators. Must implement chat() method.
300
-
301
- Parameters:
302
- ----------
303
- config : VLMConfig
304
- the VLM configuration. Must be a child class of VLMConfig.
305
- """
306
- return NotImplemented
307
-
308
- @abc.abstractmethod
309
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
310
- messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
311
- """
312
- This method inputs chat messages and outputs VLM generated text.
313
-
314
- Parameters:
315
- ----------
316
- messages : List[Dict[str,str]]
317
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
318
- verbose : bool, Optional
319
- if True, VLM generated text will be printed in terminal in real-time.
320
- stream : bool, Optional
321
- if True, returns a generator that yields the output in real-time.
322
- Messages_logger : MessagesLogger, Optional
323
- the message logger that logs the chat messages.
324
- """
325
- return NotImplemented
326
-
5
+ from vlm4ocr.data_types import FewShotExample
6
+ from llm_inference_engine.llm_configs import (
7
+ LLMConfig as VLMConfig,
8
+ BasicLLMConfig as BasicVLMConfig,
9
+ ReasoningLLMConfig as ReasoningVLMConfig,
10
+ OpenAIReasoningLLMConfig as OpenAIReasoningVLMConfig
11
+ )
12
+ from llm_inference_engine.utils import MessagesLogger
13
+ from llm_inference_engine.engines import (
14
+ InferenceEngine,
15
+ OllamaInferenceEngine,
16
+ OpenAICompatibleInferenceEngine,
17
+ VLLMInferenceEngine,
18
+ OpenRouterInferenceEngine,
19
+ OpenAIInferenceEngine,
20
+ AzureOpenAIInferenceEngine,
21
+ )
22
+
23
+
24
+ class VLMEngine(InferenceEngine):
327
25
  @abc.abstractmethod
328
- def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str, str]:
329
- """
330
- The async version of chat method. Streaming is not supported.
331
- """
332
- return NotImplemented
333
-
334
- @abc.abstractmethod
335
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
26
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
336
27
  """
337
28
  This method inputs an image and returns the correesponding chat messages for the inference engine.
338
29
 
@@ -344,220 +35,14 @@ class VLMEngine:
344
35
  the user prompt.
345
36
  image : Image.Image
346
37
  the image for OCR.
38
+ few_shot_examples : List[FewShotExample], Optional
39
+ list of few-shot examples.
347
40
  """
348
41
  return NotImplemented
349
-
350
- def _format_config(self) -> Dict[str, Any]:
351
- """
352
- This method format the VLM configuration with the correct key for the inference engine.
353
-
354
- Return : Dict[str, Any]
355
- the config parameters.
356
- """
357
- return NotImplemented
358
-
359
-
360
- class OllamaVLMEngine(VLMEngine):
361
- def __init__(self, model_name:str, num_ctx:int=8192, keep_alive:int=300, config:VLMConfig=None, **kwrs):
362
- """
363
- The Ollama inference engine.
364
-
365
- Parameters:
366
- ----------
367
- model_name : str
368
- the model name exactly as shown in >> ollama ls
369
- num_ctx : int, Optional
370
- context length that LLM will evaluate.
371
- keep_alive : int, Optional
372
- seconds to hold the LLM after the last API call.
373
- config : LLMConfig
374
- the LLM configuration.
375
- """
376
- if importlib.util.find_spec("ollama") is None:
377
- raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
378
-
379
- from ollama import Client, AsyncClient
380
- self.client = Client(**kwrs)
381
- self.async_client = AsyncClient(**kwrs)
382
- self.model_name = model_name
383
- self.num_ctx = num_ctx
384
- self.keep_alive = keep_alive
385
- self.config = config if config else BasicVLMConfig()
386
- self.formatted_params = self._format_config()
387
-
388
- def _format_config(self) -> Dict[str, Any]:
389
- """
390
- This method format the LLM configuration with the correct key for the inference engine.
391
- """
392
- formatted_params = self.config.params.copy()
393
- if "max_new_tokens" in formatted_params:
394
- formatted_params["num_predict"] = formatted_params["max_new_tokens"]
395
- formatted_params.pop("max_new_tokens")
396
42
 
397
- return formatted_params
398
43
 
399
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
400
- messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
401
- """
402
- This method inputs chat messages and outputs VLM generated text.
403
-
404
- Parameters:
405
- ----------
406
- messages : List[Dict[str,str]]
407
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
408
- verbose : bool, Optional
409
- if True, VLM generated text will be printed in terminal in real-time.
410
- stream : bool, Optional
411
- if True, returns a generator that yields the output in real-time.
412
- Messages_logger : MessagesLogger, Optional
413
- the message logger that logs the chat messages.
414
-
415
- Returns:
416
- -------
417
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
418
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
419
- """
420
- processed_messages = self.config.preprocess_messages(messages)
421
-
422
- options={'num_ctx': self.num_ctx, **self.formatted_params}
423
- if stream:
424
- def _stream_generator():
425
- response_stream = self.client.chat(
426
- model=self.model_name,
427
- messages=processed_messages,
428
- options=options,
429
- stream=True,
430
- keep_alive=self.keep_alive
431
- )
432
- res = {"reasoning": "", "response": ""}
433
- for chunk in response_stream:
434
- if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
435
- content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
436
- res["reasoning"] += content_chunk
437
- yield {"type": "reasoning", "data": content_chunk}
438
- else:
439
- content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
440
- res["response"] += content_chunk
441
- yield {"type": "response", "data": content_chunk}
442
-
443
- if chunk.done_reason == "length":
444
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
445
-
446
- # Postprocess response
447
- res_dict = self.config.postprocess_response(res)
448
- # Write to messages log
449
- if messages_logger:
450
- # replace images content with a placeholder "[image]" to save space
451
- for messages in processed_messages:
452
- if "images" in messages:
453
- messages["images"] = ["[image]" for _ in messages["images"]]
454
-
455
- processed_messages.append({"role": "assistant",
456
- "content": res_dict.get("response", ""),
457
- "reasoning": res_dict.get("reasoning", "")})
458
- messages_logger.log_messages(processed_messages)
459
-
460
- return self.config.postprocess_response(_stream_generator())
461
-
462
- elif verbose:
463
- response = self.client.chat(
464
- model=self.model_name,
465
- messages=processed_messages,
466
- options=options,
467
- stream=True,
468
- keep_alive=self.keep_alive
469
- )
470
-
471
- res = {"reasoning": "", "response": ""}
472
- phase = ""
473
- for chunk in response:
474
- if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
475
- if phase != "reasoning":
476
- print("\n--- Reasoning ---")
477
- phase = "reasoning"
478
-
479
- content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
480
- res["reasoning"] += content_chunk
481
- else:
482
- if phase != "response":
483
- print("\n--- Response ---")
484
- phase = "response"
485
- content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
486
- res["response"] += content_chunk
487
-
488
- print(content_chunk, end='', flush=True)
489
-
490
- if chunk.done_reason == "length":
491
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
492
- print('\n')
493
-
494
- else:
495
- response = self.client.chat(
496
- model=self.model_name,
497
- messages=processed_messages,
498
- options=options,
499
- stream=False,
500
- keep_alive=self.keep_alive
501
- )
502
- res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
503
- "response": getattr(getattr(response, 'message', {}), 'content', '')}
504
-
505
- if response.done_reason == "length":
506
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
507
-
508
- # Postprocess response
509
- res_dict = self.config.postprocess_response(res)
510
- # Write to messages log
511
- if messages_logger:
512
- # replace images content with a placeholder "[image]" to save space
513
- for messages in processed_messages:
514
- if "images" in messages:
515
- messages["images"] = ["[image]" for _ in messages["images"]]
516
-
517
- processed_messages.append({"role": "assistant",
518
- "content": res_dict.get("response", ""),
519
- "reasoning": res_dict.get("reasoning", "")})
520
- messages_logger.log_messages(processed_messages)
521
-
522
- return res_dict
523
-
524
-
525
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
526
- """
527
- Async version of chat method. Streaming is not supported.
528
- """
529
- processed_messages = self.config.preprocess_messages(messages)
530
-
531
- response = await self.async_client.chat(
532
- model=self.model_name,
533
- messages=processed_messages,
534
- options={'num_ctx': self.num_ctx, **self.formatted_params},
535
- stream=False,
536
- keep_alive=self.keep_alive
537
- )
538
-
539
- res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
540
- "response": getattr(getattr(response, 'message', {}), 'content', '')}
541
-
542
- if response.done_reason == "length":
543
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
544
- # Postprocess response
545
- res_dict = self.config.postprocess_response(res)
546
- # Write to messages log
547
- if messages_logger:
548
- # replace images content with a placeholder "[image]" to save space
549
- for messages in processed_messages:
550
- if "images" in messages:
551
- messages["images"] = ["[image]" for _ in messages["images"]]
552
-
553
- processed_messages.append({"role": "assistant",
554
- "content": res_dict.get("response", ""),
555
- "reasoning": res_dict.get("reasoning", "")})
556
- messages_logger.log_messages(processed_messages)
557
-
558
- return res_dict
559
-
560
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
44
+ class OllamaVLMEngine(OllamaInferenceEngine, VLMEngine):
45
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
561
46
  """
562
47
  This method inputs an image and returns the correesponding chat messages for the inference engine.
563
48
 
@@ -569,230 +54,37 @@ class OllamaVLMEngine(VLMEngine):
569
54
  the user prompt.
570
55
  image : Image.Image
571
56
  the image for OCR.
57
+ few_shot_examples : List[FewShotExample], Optional
58
+ list of few-shot examples.
572
59
  """
573
60
  base64_str = image_to_base64(image)
574
- return [
575
- {"role": "system", "content": system_prompt},
576
- {
577
- "role": "user",
578
- "content": user_prompt,
579
- "images": [base64_str]
580
- }
581
- ]
582
-
583
-
584
- class OpenAICompatibleVLMEngine(VLMEngine):
585
- def __init__(self, model:str, api_key:str, base_url:str, config:VLMConfig=None, **kwrs):
586
- """
587
- General OpenAI-compatible server inference engine.
588
- https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
589
-
590
- For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
591
-
592
- Parameters:
593
- ----------
594
- model_name : str
595
- model name as shown in the vLLM server
596
- api_key : str
597
- the API key for the vLLM server.
598
- base_url : str
599
- the base url for the vLLM server.
600
- config : LLMConfig
601
- the LLM configuration.
602
- """
603
- if importlib.util.find_spec("openai") is None:
604
- raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
605
-
606
- from openai import OpenAI, AsyncOpenAI
607
- from openai.types.chat import ChatCompletionChunk
608
- self.ChatCompletionChunk = ChatCompletionChunk
609
- super().__init__(config)
610
- self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
611
- self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
612
- self.model = model
613
- self.config = config if config else BasicVLMConfig()
614
- self.formatted_params = self._format_config()
615
-
616
- def _format_config(self) -> Dict[str, Any]:
617
- """
618
- This method format the VLM configuration with the correct key for the inference engine.
619
- """
620
- formatted_params = self.config.params.copy()
621
- if "max_new_tokens" in formatted_params:
622
- formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
623
- formatted_params.pop("max_new_tokens")
624
-
625
- return formatted_params
626
-
627
-
628
- def _format_response(self, response: Any) -> Dict[str, str]:
629
- """
630
- This method format the response from OpenAI API to a dict with keys "type" and "data".
631
-
632
- Parameters:
633
- ----------
634
- response : Any
635
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
636
- """
637
- if isinstance(response, self.ChatCompletionChunk):
638
- chunk_text = getattr(response.choices[0].delta, "content", "")
639
- if chunk_text is None:
640
- chunk_text = ""
641
- return {"type": "response", "data": chunk_text}
642
-
643
- return {"response": getattr(response.choices[0].message, "content", "")}
644
-
645
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
646
- messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
647
- """
648
- This method inputs chat messages and outputs LLM generated text.
649
-
650
- Parameters:
651
- ----------
652
- messages : List[Dict[str,str]]
653
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
654
- verbose : bool, Optional
655
- if True, VLM generated text will be printed in terminal in real-time.
656
- stream : bool, Optional
657
- if True, returns a generator that yields the output in real-time.
658
- messages_logger : MessagesLogger, Optional
659
- the message logger that logs the chat messages.
660
-
661
- Returns:
662
- -------
663
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
664
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
665
- """
666
- processed_messages = self.config.preprocess_messages(messages)
667
-
668
- if stream:
669
- def _stream_generator():
670
- response_stream = self.client.chat.completions.create(
671
- model=self.model,
672
- messages=processed_messages,
673
- stream=True,
674
- **self.formatted_params
675
- )
676
- res_text = ""
677
- for chunk in response_stream:
678
- if len(chunk.choices) > 0:
679
- chunk_dict = self._format_response(chunk)
680
- yield chunk_dict
681
-
682
- res_text += chunk_dict["data"]
683
- if chunk.choices[0].finish_reason == "length":
684
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
685
-
686
- # Postprocess response
687
- res_dict = self.config.postprocess_response(res_text)
688
- # Write to messages log
689
- if messages_logger:
690
- # replace images content with a placeholder "[image]" to save space
691
- for messages in processed_messages:
692
- if "content" in messages and isinstance(messages["content"], list):
693
- for content in messages["content"]:
694
- if isinstance(content, dict) and content.get("type") == "image_url":
695
- content["image_url"]["url"] = "[image]"
696
-
697
- processed_messages.append({"role": "assistant",
698
- "content": res_dict.get("response", ""),
699
- "reasoning": res_dict.get("reasoning", "")})
700
- messages_logger.log_messages(processed_messages)
701
-
702
- return self.config.postprocess_response(_stream_generator())
703
-
704
- elif verbose:
705
- response = self.client.chat.completions.create(
706
- model=self.model,
707
- messages=processed_messages,
708
- stream=True,
709
- **self.formatted_params
710
- )
711
- res = {"reasoning": "", "response": ""}
712
- phase = ""
713
- for chunk in response:
714
- if len(chunk.choices) > 0:
715
- chunk_dict = self._format_response(chunk)
716
- chunk_text = chunk_dict["data"]
717
- res[chunk_dict["type"]] += chunk_text
718
- if phase != chunk_dict["type"] and chunk_text != "":
719
- print(f"\n--- {chunk_dict['type'].capitalize()} ---")
720
- phase = chunk_dict["type"]
721
-
722
- print(chunk_text, end="", flush=True)
723
- if chunk.choices[0].finish_reason == "length":
724
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
725
-
726
- print('\n')
727
-
728
- else:
729
- response = self.client.chat.completions.create(
730
- model=self.model,
731
- messages=processed_messages,
732
- stream=False,
733
- **self.formatted_params
734
- )
735
- res = self._format_response(response)
61
+ output_messages = []
62
+ # system message
63
+ system_message = {"role": "system", "content": system_prompt}
64
+ output_messages.append(system_message)
65
+
66
+ # few-shot examples
67
+ if few_shot_examples is not None:
68
+ for example in few_shot_examples:
69
+ if not isinstance(example, FewShotExample):
70
+ raise ValueError("Few-shot example must be a FewShotExample object.")
71
+
72
+ example_image_b64 = image_to_base64(example.image)
73
+ example_user_message = {"role": "user", "content": user_prompt, "images": [example_image_b64]}
74
+ example_agent_message = {"role": "assistant", "content": example.text}
75
+ output_messages.append(example_user_message)
76
+ output_messages.append(example_agent_message)
736
77
 
737
- if response.choices[0].finish_reason == "length":
738
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
739
-
740
- # Postprocess response
741
- res_dict = self.config.postprocess_response(res)
742
- # Write to messages log
743
- if messages_logger:
744
- # replace images content with a placeholder "[image]" to save space
745
- for messages in processed_messages:
746
- if "content" in messages and isinstance(messages["content"], list):
747
- for content in messages["content"]:
748
- if isinstance(content, dict) and content.get("type") == "image_url":
749
- content["image_url"]["url"] = "[image]"
78
+ # user message
79
+ user_message = {"role": "user", "content": user_prompt, "images": [base64_str]}
80
+ output_messages.append(user_message)
750
81
 
751
- processed_messages.append({"role": "assistant",
752
- "content": res_dict.get("response", ""),
753
- "reasoning": res_dict.get("reasoning", "")})
754
- messages_logger.log_messages(processed_messages)
82
+ return output_messages
755
83
 
756
- return res_dict
757
-
758
84
 
759
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
760
- """
761
- Async version of chat method. Streaming is not supported.
762
- """
763
- processed_messages = self.config.preprocess_messages(messages)
764
-
765
- response = await self.async_client.chat.completions.create(
766
- model=self.model,
767
- messages=processed_messages,
768
- stream=False,
769
- **self.formatted_params
770
- )
771
-
772
- if response.choices[0].finish_reason == "length":
773
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
774
-
775
- res = self._format_response(response)
776
-
777
- # Postprocess response
778
- res_dict = self.config.postprocess_response(res)
779
- # Write to messages log
780
- if messages_logger:
781
- # replace images content with a placeholder "[image]" to save space
782
- for messages in processed_messages:
783
- if "content" in messages and isinstance(messages["content"], list):
784
- for content in messages["content"]:
785
- if isinstance(content, dict) and content.get("type") == "image_url":
786
- content["image_url"]["url"] = "[image]"
787
-
788
- processed_messages.append({"role": "assistant",
789
- "content": res_dict.get("response", ""),
790
- "reasoning": res_dict.get("reasoning", "")})
791
- messages_logger.log_messages(processed_messages)
792
-
793
- return res_dict
794
-
795
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
85
+ class OpenAICompatibleVLMEngine(OpenAICompatibleInferenceEngine, VLMEngine):
86
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
87
+ detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
796
88
  """
797
89
  This method inputs an image and returns the correesponding chat messages for the inference engine.
798
90
 
@@ -808,295 +100,97 @@ class OpenAICompatibleVLMEngine(VLMEngine):
808
100
  the image format.
809
101
  detail : str, Optional
810
102
  the detail level of the image. Default is "high".
103
+ few_shot_examples : List[FewShotExample], Optional
104
+ list of few-shot examples.
811
105
  """
812
106
  base64_str = image_to_base64(image)
813
- return [
814
- {"role": "system", "content": system_prompt},
815
- {
816
- "role": "user",
817
- "content": [
818
- {
819
- "type": "image_url",
820
- "image_url": {
821
- "url": f"data:image/{format};base64,{base64_str}",
822
- "detail": detail
107
+ output_messages = []
108
+ # system message
109
+ system_message = {"role": "system", "content": system_prompt}
110
+ output_messages.append(system_message)
111
+
112
+ # few-shot examples
113
+ if few_shot_examples is not None:
114
+ for example in few_shot_examples:
115
+ if not isinstance(example, FewShotExample):
116
+ raise ValueError("Few-shot example must be a FewShotExample object.")
117
+
118
+ example_image_b64 = image_to_base64(example.image)
119
+ example_user_message = {
120
+ "role": "user",
121
+ "content": [
122
+ {
123
+ "type": "image_url",
124
+ "image_url": {
125
+ "url": f"data:image/{format};base64,{example_image_b64}",
126
+ "detail": detail
127
+ },
823
128
  },
129
+ {"type": "text", "text": user_prompt},
130
+ ],
131
+ }
132
+ example_agent_message = {"role": "assistant", "content": example.text}
133
+ output_messages.append(example_user_message)
134
+ output_messages.append(example_agent_message)
135
+
136
+ # user message
137
+ user_message = {
138
+ "role": "user",
139
+ "content": [
140
+ {
141
+ "type": "image_url",
142
+ "image_url": {
143
+ "url": f"data:image/{format};base64,{base64_str}",
144
+ "detail": detail
824
145
  },
825
- {"type": "text", "text": user_prompt},
826
- ],
827
- },
828
- ]
829
-
830
-
831
- class VLLMVLMEngine(OpenAICompatibleVLMEngine):
832
- def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:VLMConfig=None, **kwrs):
833
- """
834
- vLLM OpenAI compatible server inference engine.
835
- https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
836
-
837
- For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
838
-
839
- Parameters:
840
- ----------
841
- model_name : str
842
- model name as shown in the vLLM server
843
- api_key : str, Optional
844
- the API key for the vLLM server.
845
- base_url : str, Optional
846
- the base url for the vLLM server.
847
- config : LLMConfig
848
- the LLM configuration.
849
- """
850
- super().__init__(model, api_key, base_url, config, **kwrs)
851
-
852
-
853
- def _format_response(self, response: Any) -> Dict[str, str]:
854
- """
855
- This method format the response from OpenAI API to a dict with keys "type" and "data".
856
-
857
- Parameters:
858
- ----------
859
- response : Any
860
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
861
- """
862
- if isinstance(response, self.ChatCompletionChunk):
863
- if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
864
- chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
865
- if chunk_text is None:
866
- chunk_text = ""
867
- return {"type": "reasoning", "data": chunk_text}
868
- else:
869
- chunk_text = getattr(response.choices[0].delta, "content", "")
870
- if chunk_text is None:
871
- chunk_text = ""
872
- return {"type": "response", "data": chunk_text}
873
-
874
- return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
875
- "response": getattr(response.choices[0].message, "content", "")}
876
-
877
-
878
- class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
879
- def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
880
- """
881
- OpenRouter OpenAI-compatible server inference engine.
882
-
883
- Parameters:
884
- ----------
885
- model_name : str
886
- model name as shown in the vLLM server
887
- api_key : str, Optional
888
- the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
889
- base_url : str, Optional
890
- the base url for the vLLM server.
891
- config : LLMConfig
892
- the LLM configuration.
893
- """
894
- self.api_key = api_key
895
- if self.api_key is None:
896
- self.api_key = os.getenv("OPENROUTER_API_KEY")
897
- super().__init__(model, self.api_key, base_url, config, **kwrs)
898
-
899
- def _format_response(self, response: Any) -> Dict[str, str]:
900
- """
901
- This method format the response from OpenAI API to a dict with keys "type" and "data".
902
-
903
- Parameters:
904
- ----------
905
- response : Any
906
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
907
- """
908
- if isinstance(response, self.ChatCompletionChunk):
909
- if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
910
- chunk_text = getattr(response.choices[0].delta, "reasoning", "")
911
- if chunk_text is None:
912
- chunk_text = ""
913
- return {"type": "reasoning", "data": chunk_text}
914
- else:
915
- chunk_text = getattr(response.choices[0].delta, "content", "")
916
- if chunk_text is None:
917
- chunk_text = ""
918
- return {"type": "response", "data": chunk_text}
919
-
920
- return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
921
- "response": getattr(response.choices[0].message, "content", "")}
922
-
923
-
924
- class OpenAIVLMEngine(VLMEngine):
925
- def __init__(self, model:str, config:VLMConfig=None, **kwrs):
926
- """
927
- The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
928
- - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
929
-
930
- For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
931
-
932
- Parameters:
933
- ----------
934
- model_name : str
935
- model name as described in https://platform.openai.com/docs/models
936
- config : VLMConfig, Optional
937
- the VLM configuration. Must be a child class of VLMConfig.
938
- """
939
- if importlib.util.find_spec("openai") is None:
940
- raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
941
-
942
- from openai import OpenAI, AsyncOpenAI
943
- self.client = OpenAI(**kwrs)
944
- self.async_client = AsyncOpenAI(**kwrs)
945
- self.model = model
946
- self.config = config if config else BasicVLMConfig()
947
- self.formatted_params = self._format_config()
948
-
949
- def _format_config(self) -> Dict[str, Any]:
950
- """
951
- This method format the LLM configuration with the correct key for the inference engine.
952
- """
953
- formatted_params = self.config.params.copy()
954
- if "max_new_tokens" in formatted_params:
955
- formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
956
- formatted_params.pop("max_new_tokens")
957
-
958
- return formatted_params
959
-
960
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
961
- """
962
- This method inputs chat messages and outputs LLM generated text.
963
-
964
- Parameters:
965
- ----------
966
- messages : List[Dict[str,str]]
967
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
968
- verbose : bool, Optional
969
- if True, VLM generated text will be printed in terminal in real-time.
970
- stream : bool, Optional
971
- if True, returns a generator that yields the output in real-time.
972
- messages_logger : MessagesLogger, Optional
973
- the message logger that logs the chat messages.
974
-
975
- Returns:
976
- -------
977
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
978
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
979
- """
980
- processed_messages = self.config.preprocess_messages(messages)
981
-
982
- if stream:
983
- def _stream_generator():
984
- response_stream = self.client.chat.completions.create(
985
- model=self.model,
986
- messages=processed_messages,
987
- stream=True,
988
- **self.formatted_params
989
- )
990
- res_text = ""
991
- for chunk in response_stream:
992
- if len(chunk.choices) > 0:
993
- chunk_text = chunk.choices[0].delta.content
994
- if chunk_text is not None:
995
- res_text += chunk_text
996
- yield chunk_text
997
- if chunk.choices[0].finish_reason == "length":
998
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
999
-
1000
- # Postprocess response
1001
- res_dict = self.config.postprocess_response(res_text)
1002
- # Write to messages log
1003
- if messages_logger:
1004
- # replace images content with a placeholder "[image]" to save space
1005
- for messages in processed_messages:
1006
- if "content" in messages and isinstance(messages["content"], list):
1007
- for content in messages["content"]:
1008
- if isinstance(content, dict) and content.get("type") == "image_url":
1009
- content["image_url"]["url"] = "[image]"
1010
-
1011
- processed_messages.append({"role": "assistant",
1012
- "content": res_dict.get("response", ""),
1013
- "reasoning": res_dict.get("reasoning", "")})
1014
- messages_logger.log_messages(processed_messages)
1015
-
1016
- return self.config.postprocess_response(_stream_generator())
1017
-
1018
- elif verbose:
1019
- response = self.client.chat.completions.create(
1020
- model=self.model,
1021
- messages=processed_messages,
1022
- stream=True,
1023
- **self.formatted_params
1024
- )
1025
- res = ''
1026
- for chunk in response:
1027
- if len(chunk.choices) > 0:
1028
- if chunk.choices[0].delta.content is not None:
1029
- res += chunk.choices[0].delta.content
1030
- print(chunk.choices[0].delta.content, end="", flush=True)
1031
- if chunk.choices[0].finish_reason == "length":
1032
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1033
-
1034
- print('\n')
1035
-
1036
- else:
1037
- response = self.client.chat.completions.create(
1038
- model=self.model,
1039
- messages=processed_messages,
1040
- stream=False,
1041
- **self.formatted_params
1042
- )
1043
- res = response.choices[0].message.content
1044
-
1045
- # Postprocess response
1046
- res_dict = self.config.postprocess_response(res)
1047
- # Write to messages log
1048
- if messages_logger:
1049
- # replace images content with a placeholder "[image]" to save space
1050
- for messages in processed_messages:
1051
- if "content" in messages and isinstance(messages["content"], list):
1052
- for content in messages["content"]:
1053
- if isinstance(content, dict) and content.get("type") == "image_url":
1054
- content["image_url"]["url"] = "[image]"
1055
-
1056
- processed_messages.append({"role": "assistant",
1057
- "content": res_dict.get("response", ""),
1058
- "reasoning": res_dict.get("reasoning", "")})
1059
- messages_logger.log_messages(processed_messages)
1060
-
1061
- return res_dict
1062
-
1063
-
1064
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
1065
- """
1066
- Async version of chat method. Streaming is not supported.
1067
- """
1068
- processed_messages = self.config.preprocess_messages(messages)
1069
-
1070
- response = await self.async_client.chat.completions.create(
1071
- model=self.model,
1072
- messages=processed_messages,
1073
- stream=False,
1074
- **self.formatted_params
1075
- )
1076
-
1077
- if response.choices[0].finish_reason == "length":
1078
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1079
-
1080
- res = response.choices[0].message.content
1081
- # Postprocess response
1082
- res_dict = self.config.postprocess_response(res)
1083
- # Write to messages log
1084
- if messages_logger:
1085
- # replace images content with a placeholder "[image]" to save space
1086
- for messages in processed_messages:
1087
- if "content" in messages and isinstance(messages["content"], list):
1088
- for content in messages["content"]:
1089
- if isinstance(content, dict) and content.get("type") == "image_url":
1090
- content["image_url"]["url"] = "[image]"
1091
-
1092
- processed_messages.append({"role": "assistant",
1093
- "content": res_dict.get("response", ""),
1094
- "reasoning": res_dict.get("reasoning", "")})
1095
- messages_logger.log_messages(processed_messages)
1096
-
1097
- return res_dict
1098
-
1099
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
146
+ },
147
+ {"type": "text", "text": user_prompt},
148
+ ],
149
+ }
150
+ output_messages.append(user_message)
151
+ return output_messages
152
+
153
+
154
+ class VLLMVLMEngine(VLLMInferenceEngine, OpenAICompatibleVLMEngine):
155
+ """
156
+ vLLM OpenAI compatible server inference engine.
157
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
158
+
159
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
160
+
161
+ Parameters:
162
+ ----------
163
+ model_name : str
164
+ model name as shown in the vLLM server
165
+ api_key : str, Optional
166
+ the API key for the vLLM server.
167
+ base_url : str, Optional
168
+ the base url for the vLLM server.
169
+ config : LLMConfig
170
+ the LLM configuration.
171
+ """
172
+ pass
173
+
174
+ class OpenRouterVLMEngine(OpenRouterInferenceEngine, OpenAICompatibleVLMEngine):
175
+ """
176
+ OpenRouter OpenAI-compatible server inference engine.
177
+
178
+ Parameters:
179
+ ----------
180
+ model_name : str
181
+ model name as shown in the vLLM server
182
+ api_key : str, Optional
183
+ the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
184
+ base_url : str, Optional
185
+ the base url for the vLLM server.
186
+ config : LLMConfig
187
+ the LLM configuration.
188
+ """
189
+ pass
190
+
191
+ class OpenAIVLMEngine(OpenAIInferenceEngine, VLMEngine):
192
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
193
+ detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
1100
194
  """
1101
195
  This method inputs an image and returns the correesponding chat messages for the inference engine.
1102
196
 
@@ -1112,52 +206,71 @@ class OpenAIVLMEngine(VLMEngine):
1112
206
  the image format.
1113
207
  detail : str, Optional
1114
208
  the detail level of the image. Default is "high".
209
+ few_shot_examples : List[FewShotExample], Optional
210
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
1115
211
  """
1116
212
  base64_str = image_to_base64(image)
1117
- return [
1118
- {"role": "system", "content": system_prompt},
1119
- {
1120
- "role": "user",
1121
- "content": [
1122
- {
1123
- "type": "image_url",
1124
- "image_url": {
1125
- "url": f"data:image/{format};base64,{base64_str}",
1126
- "detail": detail
213
+ output_messages = []
214
+ # system message
215
+ system_message = {"role": "system", "content": system_prompt}
216
+ output_messages.append(system_message)
217
+
218
+ # few-shot examples
219
+ if few_shot_examples is not None:
220
+ for example in few_shot_examples:
221
+ if not isinstance(example, FewShotExample):
222
+ raise ValueError("Few-shot example must be a FewShotExample object.")
223
+
224
+ example_image_b64 = image_to_base64(example.image)
225
+ example_user_message = {
226
+ "role": "user",
227
+ "content": [
228
+ {
229
+ "type": "image_url",
230
+ "image_url": {
231
+ "url": f"data:image/{format};base64,{example_image_b64}",
232
+ "detail": detail
233
+ },
1127
234
  },
235
+ {"type": "text", "text": user_prompt},
236
+ ],
237
+ }
238
+ example_agent_message = {"role": "assistant", "content": example.text}
239
+ output_messages.append(example_user_message)
240
+ output_messages.append(example_agent_message)
241
+
242
+ # user message
243
+ user_message = {
244
+ "role": "user",
245
+ "content": [
246
+ {
247
+ "type": "image_url",
248
+ "image_url": {
249
+ "url": f"data:image/{format};base64,{base64_str}",
250
+ "detail": detail
1128
251
  },
1129
- {"type": "text", "text": user_prompt},
1130
- ],
1131
- },
1132
- ]
1133
-
1134
-
1135
- class AzureOpenAIVLMEngine(OpenAIVLMEngine):
1136
- def __init__(self, model:str, api_version:str, config:VLMConfig=None, **kwrs):
1137
- """
1138
- The Azure OpenAI API inference engine.
1139
- For parameters and documentation, refer to
1140
- - https://azure.microsoft.com/en-us/products/ai-services/openai-service
1141
- - https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
1142
-
1143
- Parameters:
1144
- ----------
1145
- model : str
1146
- model name as described in https://platform.openai.com/docs/models
1147
- api_version : str
1148
- the Azure OpenAI API version
1149
- config : LLMConfig
1150
- the LLM configuration.
1151
- """
1152
- if importlib.util.find_spec("openai") is None:
1153
- raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
1154
-
1155
- from openai import AzureOpenAI, AsyncAzureOpenAI
1156
- self.model = model
1157
- self.api_version = api_version
1158
- self.client = AzureOpenAI(api_version=self.api_version,
1159
- **kwrs)
1160
- self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
1161
- **kwrs)
1162
- self.config = config if config else BasicVLMConfig()
1163
- self.formatted_params = self._format_config()
252
+ },
253
+ {"type": "text", "text": user_prompt},
254
+ ],
255
+ }
256
+ output_messages.append(user_message)
257
+ return output_messages
258
+
259
+
260
+ class AzureOpenAIVLMEngine(AzureOpenAIInferenceEngine, OpenAIVLMEngine):
261
+ """
262
+ The Azure OpenAI API inference engine.
263
+ For parameters and documentation, refer to
264
+ - https://azure.microsoft.com/en-us/products/ai-services/openai-service
265
+ - https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
266
+
267
+ Parameters:
268
+ ----------
269
+ model : str
270
+ model name as described in https://platform.openai.com/docs/models
271
+ api_version : str
272
+ the Azure OpenAI API version
273
+ config : LLMConfig
274
+ the LLM configuration.
275
+ """
276
+ pass