vlm4ocr 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr/vlm_engines.py CHANGED
@@ -2,8 +2,11 @@ import abc
2
2
  import importlib.util
3
3
  from typing import Any, List, Dict, Union, Generator
4
4
  import warnings
5
+ import os
6
+ import re
5
7
  from PIL import Image
6
8
  from vlm4ocr.utils import image_to_base64
9
+ from vlm4ocr.data_types import FewShotExample
7
10
 
8
11
 
9
12
  class VLMConfig(abc.ABC):
@@ -33,7 +36,7 @@ class VLMConfig(abc.ABC):
33
36
  return NotImplemented
34
37
 
35
38
  @abc.abstractmethod
36
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
39
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
37
40
  """
38
41
  This method postprocesses the VLM response after it is generated.
39
42
 
@@ -77,7 +80,7 @@ class BasicVLMConfig(VLMConfig):
77
80
  """
78
81
  return messages
79
82
 
80
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
83
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
81
84
  """
82
85
  This method postprocesses the VLM response after it is generated.
83
86
 
@@ -88,19 +91,121 @@ class BasicVLMConfig(VLMConfig):
88
91
 
89
92
  Returns: Union[str, Generator[Dict[str, str], None, None]]
90
93
  the postprocessed VLM response.
91
- if input is a generator, the output will be a generator {"data": <content>}.
94
+ if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
92
95
  """
93
96
  if isinstance(response, str):
94
- return response
97
+ return {"response": response}
98
+
99
+ elif isinstance(response, dict):
100
+ if "response" in response:
101
+ return response
102
+ else:
103
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
104
+ return {"response": ""}
95
105
 
96
106
  def _process_stream():
97
107
  for chunk in response:
98
- yield chunk
108
+ if isinstance(chunk, dict):
109
+ yield chunk
110
+ elif isinstance(chunk, str):
111
+ yield {"type": "response", "data": chunk}
99
112
 
100
113
  return _process_stream()
114
+
115
+ class ReasoningVLMConfig(VLMConfig):
116
+ def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
117
+ """
118
+ The general configuration for reasoning vision models.
119
+ """
120
+ super().__init__(**kwargs)
121
+ self.thinking_token_start = thinking_token_start
122
+ self.thinking_token_end = thinking_token_end
123
+
124
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
125
+ """
126
+ This method preprocesses the input messages before passing them to the VLM.
127
+
128
+ Parameters:
129
+ ----------
130
+ messages : List[Dict[str,str]]
131
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
132
+
133
+ Returns:
134
+ -------
135
+ messages : List[Dict[str,str]]
136
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
137
+ """
138
+ return messages.copy()
101
139
 
140
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
141
+ """
142
+ This method postprocesses the VLM response after it is generated.
143
+ 1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
144
+ 2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
145
+ 3. If input is a generator,
146
+ a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
147
+ b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
102
148
 
103
- class OpenAIReasoningVLMConfig(VLMConfig):
149
+ Parameters:
150
+ ----------
151
+ response : Union[str, Generator[str, None, None]]
152
+ the VLM response. Can be a string or a generator.
153
+
154
+ Returns:
155
+ -------
156
+ response : Union[str, Generator[str, None, None]]
157
+ the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
158
+ if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
159
+ """
160
+ if isinstance(response, str):
161
+ # get contents between thinking_token_start and thinking_token_end
162
+ pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
163
+ match = re.search(pattern, response, re.DOTALL)
164
+ reasoning = match.group(1) if match else ""
165
+ # get response AFTER thinking_token_end
166
+ response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
167
+ return {"reasoning": reasoning, "response": response}
168
+
169
+ elif isinstance(response, dict):
170
+ if "reasoning" in response and "response" in response:
171
+ return response
172
+ else:
173
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
174
+ return {"reasoning": "", "response": ""}
175
+
176
+ elif isinstance(response, Generator):
177
+ def _process_stream():
178
+ think_flag = False
179
+ buffer = ""
180
+ for chunk in response:
181
+ if isinstance(chunk, dict):
182
+ yield chunk
183
+
184
+ elif isinstance(chunk, str):
185
+ buffer += chunk
186
+ # switch between reasoning and response
187
+ if self.thinking_token_start in buffer:
188
+ think_flag = True
189
+ buffer = buffer.replace(self.thinking_token_start, "")
190
+ elif self.thinking_token_end in buffer:
191
+ think_flag = False
192
+ buffer = buffer.replace(self.thinking_token_end, "")
193
+
194
+ # if chunk is in thinking block, tag it as reasoning; else tag it as response
195
+ if chunk not in [self.thinking_token_start, self.thinking_token_end]:
196
+ if think_flag:
197
+ yield {"type": "reasoning", "data": chunk}
198
+ else:
199
+ yield {"type": "response", "data": chunk}
200
+
201
+ return _process_stream()
202
+
203
+ else:
204
+ warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
205
+ return {"reasoning": "", "response": ""}
206
+
207
+
208
+ class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
104
209
  def __init__(self, reasoning_effort:str="low", **kwargs):
105
210
  """
106
211
  The OpenAI "o" series configuration.
@@ -160,27 +265,31 @@ class OpenAIReasoningVLMConfig(VLMConfig):
160
265
 
161
266
  return new_messages
162
267
 
163
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
164
- """
165
- This method postprocesses the VLM response after it is generated.
166
268
 
167
- Parameters:
168
- ----------
169
- response : Union[str, Generator[str, None, None]]
170
- the VLM response. Can be a string or a generator.
171
-
172
- Returns: Union[str, Generator[Dict[str, str], None, None]]
173
- the postprocessed VLM response.
174
- if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
269
+ class MessagesLogger:
270
+ def __init__(self):
175
271
  """
176
- if isinstance(response, str):
177
- return response
272
+ This class is used to log the messages for InferenceEngine.chat().
273
+ """
274
+ self.messages_log = []
178
275
 
179
- def _process_stream():
180
- for chunk in response:
181
- yield {"type": "response", "data": chunk}
276
+ def log_messages(self, messages : List[Dict[str,str]]):
277
+ """
278
+ This method logs the messages to a list.
279
+ """
280
+ self.messages_log.append(messages)
182
281
 
183
- return _process_stream()
282
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
283
+ """
284
+ This method returns a copy of the current messages log
285
+ """
286
+ return self.messages_log.copy()
287
+
288
+ def clear_messages_log(self):
289
+ """
290
+ This method clears the current messages log
291
+ """
292
+ self.messages_log.clear()
184
293
 
185
294
 
186
295
  class VLMEngine:
@@ -198,7 +307,8 @@ class VLMEngine:
198
307
  return NotImplemented
199
308
 
200
309
  @abc.abstractmethod
201
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
310
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
311
+ messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
202
312
  """
203
313
  This method inputs chat messages and outputs VLM generated text.
204
314
 
@@ -210,18 +320,20 @@ class VLMEngine:
210
320
  if True, VLM generated text will be printed in terminal in real-time.
211
321
  stream : bool, Optional
212
322
  if True, returns a generator that yields the output in real-time.
323
+ Messages_logger : MessagesLogger, Optional
324
+ the message logger that logs the chat messages.
213
325
  """
214
326
  return NotImplemented
215
327
 
216
328
  @abc.abstractmethod
217
- def chat_async(self, messages:List[Dict[str,str]]) -> str:
329
+ def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str, str]:
218
330
  """
219
331
  The async version of chat method. Streaming is not supported.
220
332
  """
221
333
  return NotImplemented
222
334
 
223
335
  @abc.abstractmethod
224
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
336
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
225
337
  """
226
338
  This method inputs an image and returns the correesponding chat messages for the inference engine.
227
339
 
@@ -233,6 +345,8 @@ class VLMEngine:
233
345
  the user prompt.
234
346
  image : Image.Image
235
347
  the image for OCR.
348
+ few_shot_examples : List[FewShotExample], Optional
349
+ list of few-shot examples.
236
350
  """
237
351
  return NotImplemented
238
352
 
@@ -285,7 +399,8 @@ class OllamaVLMEngine(VLMEngine):
285
399
 
286
400
  return formatted_params
287
401
 
288
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
402
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
403
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
289
404
  """
290
405
  This method inputs chat messages and outputs VLM generated text.
291
406
 
@@ -297,6 +412,13 @@ class OllamaVLMEngine(VLMEngine):
297
412
  if True, VLM generated text will be printed in terminal in real-time.
298
413
  stream : bool, Optional
299
414
  if True, returns a generator that yields the output in real-time.
415
+ Messages_logger : MessagesLogger, Optional
416
+ the message logger that logs the chat messages.
417
+
418
+ Returns:
419
+ -------
420
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
421
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
300
422
  """
301
423
  processed_messages = self.config.preprocess_messages(messages)
302
424
 
@@ -310,10 +432,33 @@ class OllamaVLMEngine(VLMEngine):
310
432
  stream=True,
311
433
  keep_alive=self.keep_alive
312
434
  )
435
+ res = {"reasoning": "", "response": ""}
313
436
  for chunk in response_stream:
314
- content_chunk = chunk.get('message', {}).get('content')
315
- if content_chunk:
316
- yield content_chunk
437
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
438
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
439
+ res["reasoning"] += content_chunk
440
+ yield {"type": "reasoning", "data": content_chunk}
441
+ else:
442
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
443
+ res["response"] += content_chunk
444
+ yield {"type": "response", "data": content_chunk}
445
+
446
+ if chunk.done_reason == "length":
447
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
448
+
449
+ # Postprocess response
450
+ res_dict = self.config.postprocess_response(res)
451
+ # Write to messages log
452
+ if messages_logger:
453
+ # replace images content with a placeholder "[image]" to save space
454
+ for messages in processed_messages:
455
+ if "images" in messages:
456
+ messages["images"] = ["[image]" for _ in messages["images"]]
457
+
458
+ processed_messages.append({"role": "assistant",
459
+ "content": res_dict.get("response", ""),
460
+ "reasoning": res_dict.get("reasoning", "")})
461
+ messages_logger.log_messages(processed_messages)
317
462
 
318
463
  return self.config.postprocess_response(_stream_generator())
319
464
 
@@ -326,14 +471,29 @@ class OllamaVLMEngine(VLMEngine):
326
471
  keep_alive=self.keep_alive
327
472
  )
328
473
 
329
- res = ''
474
+ res = {"reasoning": "", "response": ""}
475
+ phase = ""
330
476
  for chunk in response:
331
- content_chunk = chunk.get('message', {}).get('content')
477
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
478
+ if phase != "reasoning":
479
+ print("\n--- Reasoning ---")
480
+ phase = "reasoning"
481
+
482
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
483
+ res["reasoning"] += content_chunk
484
+ else:
485
+ if phase != "response":
486
+ print("\n--- Response ---")
487
+ phase = "response"
488
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
489
+ res["response"] += content_chunk
490
+
332
491
  print(content_chunk, end='', flush=True)
333
- res += content_chunk
492
+
493
+ if chunk.done_reason == "length":
494
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
334
495
  print('\n')
335
- return self.config.postprocess_response(res)
336
-
496
+
337
497
  else:
338
498
  response = self.client.chat(
339
499
  model=self.model_name,
@@ -342,11 +502,30 @@ class OllamaVLMEngine(VLMEngine):
342
502
  stream=False,
343
503
  keep_alive=self.keep_alive
344
504
  )
345
- res = response.get('message', {}).get('content')
346
- return self.config.postprocess_response(res)
505
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
506
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
507
+
508
+ if response.done_reason == "length":
509
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
510
+
511
+ # Postprocess response
512
+ res_dict = self.config.postprocess_response(res)
513
+ # Write to messages log
514
+ if messages_logger:
515
+ # replace images content with a placeholder "[image]" to save space
516
+ for messages in processed_messages:
517
+ if "images" in messages:
518
+ messages["images"] = ["[image]" for _ in messages["images"]]
519
+
520
+ processed_messages.append({"role": "assistant",
521
+ "content": res_dict.get("response", ""),
522
+ "reasoning": res_dict.get("reasoning", "")})
523
+ messages_logger.log_messages(processed_messages)
524
+
525
+ return res_dict
347
526
 
348
527
 
349
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
528
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
350
529
  """
351
530
  Async version of chat method. Streaming is not supported.
352
531
  """
@@ -360,10 +539,28 @@ class OllamaVLMEngine(VLMEngine):
360
539
  keep_alive=self.keep_alive
361
540
  )
362
541
 
363
- res = response['message']['content']
364
- return self.config.postprocess_response(res)
542
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
543
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
544
+
545
+ if response.done_reason == "length":
546
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
547
+ # Postprocess response
548
+ res_dict = self.config.postprocess_response(res)
549
+ # Write to messages log
550
+ if messages_logger:
551
+ # replace images content with a placeholder "[image]" to save space
552
+ for messages in processed_messages:
553
+ if "images" in messages:
554
+ messages["images"] = ["[image]" for _ in messages["images"]]
555
+
556
+ processed_messages.append({"role": "assistant",
557
+ "content": res_dict.get("response", ""),
558
+ "reasoning": res_dict.get("reasoning", "")})
559
+ messages_logger.log_messages(processed_messages)
560
+
561
+ return res_dict
365
562
 
366
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
563
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
367
564
  """
368
565
  This method inputs an image and returns the correesponding chat messages for the inference engine.
369
566
 
@@ -375,16 +572,404 @@ class OllamaVLMEngine(VLMEngine):
375
572
  the user prompt.
376
573
  image : Image.Image
377
574
  the image for OCR.
575
+ few_shot_examples : List[FewShotExample], Optional
576
+ list of few-shot examples.
378
577
  """
379
578
  base64_str = image_to_base64(image)
380
- return [
381
- {"role": "system", "content": system_prompt},
382
- {
383
- "role": "user",
384
- "content": user_prompt,
385
- "images": [base64_str]
386
- }
387
- ]
579
+ output_messages = []
580
+ # system message
581
+ system_message = {"role": "system", "content": system_prompt}
582
+ output_messages.append(system_message)
583
+
584
+ # few-shot examples
585
+ if few_shot_examples is not None:
586
+ for example in few_shot_examples:
587
+ if not isinstance(example, FewShotExample):
588
+ raise ValueError("Few-shot example must be a FewShotExample object.")
589
+
590
+ example_image_b64 = image_to_base64(example.image)
591
+ example_user_message = {"role": "user", "content": user_prompt, "images": [example_image_b64]}
592
+ example_agent_message = {"role": "assistant", "content": example.text}
593
+ output_messages.append(example_user_message)
594
+ output_messages.append(example_agent_message)
595
+
596
+ # user message
597
+ user_message = {"role": "user", "content": user_prompt, "images": [base64_str]}
598
+ output_messages.append(user_message)
599
+
600
+ return output_messages
601
+
602
+
603
+ class OpenAICompatibleVLMEngine(VLMEngine):
604
+ def __init__(self, model:str, api_key:str, base_url:str, config:VLMConfig=None, **kwrs):
605
+ """
606
+ General OpenAI-compatible server inference engine.
607
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
608
+
609
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
610
+
611
+ Parameters:
612
+ ----------
613
+ model_name : str
614
+ model name as shown in the vLLM server
615
+ api_key : str
616
+ the API key for the vLLM server.
617
+ base_url : str
618
+ the base url for the vLLM server.
619
+ config : LLMConfig
620
+ the LLM configuration.
621
+ """
622
+ if importlib.util.find_spec("openai") is None:
623
+ raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
624
+
625
+ from openai import OpenAI, AsyncOpenAI
626
+ from openai.types.chat import ChatCompletionChunk
627
+ self.ChatCompletionChunk = ChatCompletionChunk
628
+ super().__init__(config)
629
+ self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
630
+ self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
631
+ self.model = model
632
+ self.config = config if config else BasicVLMConfig()
633
+ self.formatted_params = self._format_config()
634
+
635
+ def _format_config(self) -> Dict[str, Any]:
636
+ """
637
+ This method format the VLM configuration with the correct key for the inference engine.
638
+ """
639
+ formatted_params = self.config.params.copy()
640
+ if "max_new_tokens" in formatted_params:
641
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
642
+ formatted_params.pop("max_new_tokens")
643
+
644
+ return formatted_params
645
+
646
+
647
+ def _format_response(self, response: Any) -> Dict[str, str]:
648
+ """
649
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
650
+
651
+ Parameters:
652
+ ----------
653
+ response : Any
654
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
655
+ """
656
+ if isinstance(response, self.ChatCompletionChunk):
657
+ chunk_text = getattr(response.choices[0].delta, "content", "")
658
+ if chunk_text is None:
659
+ chunk_text = ""
660
+ return {"type": "response", "data": chunk_text}
661
+
662
+ return {"response": getattr(response.choices[0].message, "content", "")}
663
+
664
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
665
+ messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
666
+ """
667
+ This method inputs chat messages and outputs LLM generated text.
668
+
669
+ Parameters:
670
+ ----------
671
+ messages : List[Dict[str,str]]
672
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
673
+ verbose : bool, Optional
674
+ if True, VLM generated text will be printed in terminal in real-time.
675
+ stream : bool, Optional
676
+ if True, returns a generator that yields the output in real-time.
677
+ messages_logger : MessagesLogger, Optional
678
+ the message logger that logs the chat messages.
679
+
680
+ Returns:
681
+ -------
682
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
683
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
684
+ """
685
+ processed_messages = self.config.preprocess_messages(messages)
686
+
687
+ if stream:
688
+ def _stream_generator():
689
+ response_stream = self.client.chat.completions.create(
690
+ model=self.model,
691
+ messages=processed_messages,
692
+ stream=True,
693
+ **self.formatted_params
694
+ )
695
+ res_text = ""
696
+ for chunk in response_stream:
697
+ if len(chunk.choices) > 0:
698
+ chunk_dict = self._format_response(chunk)
699
+ yield chunk_dict
700
+
701
+ res_text += chunk_dict["data"]
702
+ if chunk.choices[0].finish_reason == "length":
703
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
704
+
705
+ # Postprocess response
706
+ res_dict = self.config.postprocess_response(res_text)
707
+ # Write to messages log
708
+ if messages_logger:
709
+ # replace images content with a placeholder "[image]" to save space
710
+ for messages in processed_messages:
711
+ if "content" in messages and isinstance(messages["content"], list):
712
+ for content in messages["content"]:
713
+ if isinstance(content, dict) and content.get("type") == "image_url":
714
+ content["image_url"]["url"] = "[image]"
715
+
716
+ processed_messages.append({"role": "assistant",
717
+ "content": res_dict.get("response", ""),
718
+ "reasoning": res_dict.get("reasoning", "")})
719
+ messages_logger.log_messages(processed_messages)
720
+
721
+ return self.config.postprocess_response(_stream_generator())
722
+
723
+ elif verbose:
724
+ response = self.client.chat.completions.create(
725
+ model=self.model,
726
+ messages=processed_messages,
727
+ stream=True,
728
+ **self.formatted_params
729
+ )
730
+ res = {"reasoning": "", "response": ""}
731
+ phase = ""
732
+ for chunk in response:
733
+ if len(chunk.choices) > 0:
734
+ chunk_dict = self._format_response(chunk)
735
+ chunk_text = chunk_dict["data"]
736
+ res[chunk_dict["type"]] += chunk_text
737
+ if phase != chunk_dict["type"] and chunk_text != "":
738
+ print(f"\n--- {chunk_dict['type'].capitalize()} ---")
739
+ phase = chunk_dict["type"]
740
+
741
+ print(chunk_text, end="", flush=True)
742
+ if chunk.choices[0].finish_reason == "length":
743
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
744
+
745
+ print('\n')
746
+
747
+ else:
748
+ response = self.client.chat.completions.create(
749
+ model=self.model,
750
+ messages=processed_messages,
751
+ stream=False,
752
+ **self.formatted_params
753
+ )
754
+ res = self._format_response(response)
755
+
756
+ if response.choices[0].finish_reason == "length":
757
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
758
+
759
+ # Postprocess response
760
+ res_dict = self.config.postprocess_response(res)
761
+ # Write to messages log
762
+ if messages_logger:
763
+ # replace images content with a placeholder "[image]" to save space
764
+ for messages in processed_messages:
765
+ if "content" in messages and isinstance(messages["content"], list):
766
+ for content in messages["content"]:
767
+ if isinstance(content, dict) and content.get("type") == "image_url":
768
+ content["image_url"]["url"] = "[image]"
769
+
770
+ processed_messages.append({"role": "assistant",
771
+ "content": res_dict.get("response", ""),
772
+ "reasoning": res_dict.get("reasoning", "")})
773
+ messages_logger.log_messages(processed_messages)
774
+
775
+ return res_dict
776
+
777
+
778
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
779
+ """
780
+ Async version of chat method. Streaming is not supported.
781
+ """
782
+ processed_messages = self.config.preprocess_messages(messages)
783
+
784
+ response = await self.async_client.chat.completions.create(
785
+ model=self.model,
786
+ messages=processed_messages,
787
+ stream=False,
788
+ **self.formatted_params
789
+ )
790
+
791
+ if response.choices[0].finish_reason == "length":
792
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
793
+
794
+ res = self._format_response(response)
795
+
796
+ # Postprocess response
797
+ res_dict = self.config.postprocess_response(res)
798
+ # Write to messages log
799
+ if messages_logger:
800
+ # replace images content with a placeholder "[image]" to save space
801
+ for messages in processed_messages:
802
+ if "content" in messages and isinstance(messages["content"], list):
803
+ for content in messages["content"]:
804
+ if isinstance(content, dict) and content.get("type") == "image_url":
805
+ content["image_url"]["url"] = "[image]"
806
+
807
+ processed_messages.append({"role": "assistant",
808
+ "content": res_dict.get("response", ""),
809
+ "reasoning": res_dict.get("reasoning", "")})
810
+ messages_logger.log_messages(processed_messages)
811
+
812
+ return res_dict
813
+
814
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
815
+ detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
816
+ """
817
+ This method inputs an image and returns the correesponding chat messages for the inference engine.
818
+
819
+ Parameters:
820
+ ----------
821
+ system_prompt : str
822
+ the system prompt.
823
+ user_prompt : str
824
+ the user prompt.
825
+ image : Image.Image
826
+ the image for OCR.
827
+ format : str, Optional
828
+ the image format.
829
+ detail : str, Optional
830
+ the detail level of the image. Default is "high".
831
+ few_shot_examples : List[FewShotExample], Optional
832
+ list of few-shot examples.
833
+ """
834
+ base64_str = image_to_base64(image)
835
+ output_messages = []
836
+ # system message
837
+ system_message = {"role": "system", "content": system_prompt}
838
+ output_messages.append(system_message)
839
+
840
+ # few-shot examples
841
+ if few_shot_examples is not None:
842
+ for example in few_shot_examples:
843
+ if not isinstance(example, FewShotExample):
844
+ raise ValueError("Few-shot example must be a FewShotExample object.")
845
+
846
+ example_image_b64 = image_to_base64(example.image)
847
+ example_user_message = {
848
+ "role": "user",
849
+ "content": [
850
+ {
851
+ "type": "image_url",
852
+ "image_url": {
853
+ "url": f"data:image/{format};base64,{example_image_b64}",
854
+ "detail": detail
855
+ },
856
+ },
857
+ {"type": "text", "text": user_prompt},
858
+ ],
859
+ }
860
+ example_agent_message = {"role": "assistant", "content": example.text}
861
+ output_messages.append(example_user_message)
862
+ output_messages.append(example_agent_message)
863
+
864
+ # user message
865
+ user_message = {
866
+ "role": "user",
867
+ "content": [
868
+ {
869
+ "type": "image_url",
870
+ "image_url": {
871
+ "url": f"data:image/{format};base64,{base64_str}",
872
+ "detail": detail
873
+ },
874
+ },
875
+ {"type": "text", "text": user_prompt},
876
+ ],
877
+ }
878
+ output_messages.append(user_message)
879
+ return output_messages
880
+
881
+
882
+ class VLLMVLMEngine(OpenAICompatibleVLMEngine):
883
+ def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:VLMConfig=None, **kwrs):
884
+ """
885
+ vLLM OpenAI compatible server inference engine.
886
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
887
+
888
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
889
+
890
+ Parameters:
891
+ ----------
892
+ model_name : str
893
+ model name as shown in the vLLM server
894
+ api_key : str, Optional
895
+ the API key for the vLLM server.
896
+ base_url : str, Optional
897
+ the base url for the vLLM server.
898
+ config : LLMConfig
899
+ the LLM configuration.
900
+ """
901
+ super().__init__(model, api_key, base_url, config, **kwrs)
902
+
903
+
904
+ def _format_response(self, response: Any) -> Dict[str, str]:
905
+ """
906
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
907
+
908
+ Parameters:
909
+ ----------
910
+ response : Any
911
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
912
+ """
913
+ if isinstance(response, self.ChatCompletionChunk):
914
+ if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
915
+ chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
916
+ if chunk_text is None:
917
+ chunk_text = ""
918
+ return {"type": "reasoning", "data": chunk_text}
919
+ else:
920
+ chunk_text = getattr(response.choices[0].delta, "content", "")
921
+ if chunk_text is None:
922
+ chunk_text = ""
923
+ return {"type": "response", "data": chunk_text}
924
+
925
+ return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
926
+ "response": getattr(response.choices[0].message, "content", "")}
927
+
928
+
929
+ class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
930
+ def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
931
+ """
932
+ OpenRouter OpenAI-compatible server inference engine.
933
+
934
+ Parameters:
935
+ ----------
936
+ model_name : str
937
+ model name as shown in the vLLM server
938
+ api_key : str, Optional
939
+ the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
940
+ base_url : str, Optional
941
+ the base url for the vLLM server.
942
+ config : LLMConfig
943
+ the LLM configuration.
944
+ """
945
+ self.api_key = api_key
946
+ if self.api_key is None:
947
+ self.api_key = os.getenv("OPENROUTER_API_KEY")
948
+ super().__init__(model, self.api_key, base_url, config, **kwrs)
949
+
950
+ def _format_response(self, response: Any) -> Dict[str, str]:
951
+ """
952
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
953
+
954
+ Parameters:
955
+ ----------
956
+ response : Any
957
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
958
+ """
959
+ if isinstance(response, self.ChatCompletionChunk):
960
+ if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
961
+ chunk_text = getattr(response.choices[0].delta, "reasoning", "")
962
+ if chunk_text is None:
963
+ chunk_text = ""
964
+ return {"type": "reasoning", "data": chunk_text}
965
+ else:
966
+ chunk_text = getattr(response.choices[0].delta, "content", "")
967
+ if chunk_text is None:
968
+ chunk_text = ""
969
+ return {"type": "response", "data": chunk_text}
970
+
971
+ return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
972
+ "response": getattr(response.choices[0].message, "content", "")}
388
973
 
389
974
 
390
975
  class OpenAIVLMEngine(VLMEngine):
@@ -423,7 +1008,7 @@ class OpenAIVLMEngine(VLMEngine):
423
1008
 
424
1009
  return formatted_params
425
1010
 
426
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
1011
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
427
1012
  """
428
1013
  This method inputs chat messages and outputs LLM generated text.
429
1014
 
@@ -435,6 +1020,13 @@ class OpenAIVLMEngine(VLMEngine):
435
1020
  if True, VLM generated text will be printed in terminal in real-time.
436
1021
  stream : bool, Optional
437
1022
  if True, returns a generator that yields the output in real-time.
1023
+ messages_logger : MessagesLogger, Optional
1024
+ the message logger that logs the chat messages.
1025
+
1026
+ Returns:
1027
+ -------
1028
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
1029
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
438
1030
  """
439
1031
  processed_messages = self.config.preprocess_messages(messages)
440
1032
 
@@ -446,13 +1038,32 @@ class OpenAIVLMEngine(VLMEngine):
446
1038
  stream=True,
447
1039
  **self.formatted_params
448
1040
  )
1041
+ res_text = ""
449
1042
  for chunk in response_stream:
450
1043
  if len(chunk.choices) > 0:
451
- if chunk.choices[0].delta.content is not None:
452
- yield chunk.choices[0].delta.content
1044
+ chunk_text = chunk.choices[0].delta.content
1045
+ if chunk_text is not None:
1046
+ res_text += chunk_text
1047
+ yield chunk_text
453
1048
  if chunk.choices[0].finish_reason == "length":
454
1049
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
455
1050
 
1051
+ # Postprocess response
1052
+ res_dict = self.config.postprocess_response(res_text)
1053
+ # Write to messages log
1054
+ if messages_logger:
1055
+ # replace images content with a placeholder "[image]" to save space
1056
+ for messages in processed_messages:
1057
+ if "content" in messages and isinstance(messages["content"], list):
1058
+ for content in messages["content"]:
1059
+ if isinstance(content, dict) and content.get("type") == "image_url":
1060
+ content["image_url"]["url"] = "[image]"
1061
+
1062
+ processed_messages.append({"role": "assistant",
1063
+ "content": res_dict.get("response", ""),
1064
+ "reasoning": res_dict.get("reasoning", "")})
1065
+ messages_logger.log_messages(processed_messages)
1066
+
456
1067
  return self.config.postprocess_response(_stream_generator())
457
1068
 
458
1069
  elif verbose:
@@ -472,7 +1083,7 @@ class OpenAIVLMEngine(VLMEngine):
472
1083
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
473
1084
 
474
1085
  print('\n')
475
- return self.config.postprocess_response(res)
1086
+
476
1087
  else:
477
1088
  response = self.client.chat.completions.create(
478
1089
  model=self.model,
@@ -481,10 +1092,27 @@ class OpenAIVLMEngine(VLMEngine):
481
1092
  **self.formatted_params
482
1093
  )
483
1094
  res = response.choices[0].message.content
484
- return self.config.postprocess_response(res)
1095
+
1096
+ # Postprocess response
1097
+ res_dict = self.config.postprocess_response(res)
1098
+ # Write to messages log
1099
+ if messages_logger:
1100
+ # replace images content with a placeholder "[image]" to save space
1101
+ for messages in processed_messages:
1102
+ if "content" in messages and isinstance(messages["content"], list):
1103
+ for content in messages["content"]:
1104
+ if isinstance(content, dict) and content.get("type") == "image_url":
1105
+ content["image_url"]["url"] = "[image]"
1106
+
1107
+ processed_messages.append({"role": "assistant",
1108
+ "content": res_dict.get("response", ""),
1109
+ "reasoning": res_dict.get("reasoning", "")})
1110
+ messages_logger.log_messages(processed_messages)
1111
+
1112
+ return res_dict
485
1113
 
486
1114
 
487
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
1115
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
488
1116
  """
489
1117
  Async version of chat method. Streaming is not supported.
490
1118
  """
@@ -501,9 +1129,26 @@ class OpenAIVLMEngine(VLMEngine):
501
1129
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
502
1130
 
503
1131
  res = response.choices[0].message.content
504
- return self.config.postprocess_response(res)
1132
+ # Postprocess response
1133
+ res_dict = self.config.postprocess_response(res)
1134
+ # Write to messages log
1135
+ if messages_logger:
1136
+ # replace images content with a placeholder "[image]" to save space
1137
+ for messages in processed_messages:
1138
+ if "content" in messages and isinstance(messages["content"], list):
1139
+ for content in messages["content"]:
1140
+ if isinstance(content, dict) and content.get("type") == "image_url":
1141
+ content["image_url"]["url"] = "[image]"
1142
+
1143
+ processed_messages.append({"role": "assistant",
1144
+ "content": res_dict.get("response", ""),
1145
+ "reasoning": res_dict.get("reasoning", "")})
1146
+ messages_logger.log_messages(processed_messages)
1147
+
1148
+ return res_dict
505
1149
 
506
- def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
1150
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png',
1151
+ detail:str="high", few_shot_examples:List[FewShotExample]=None) -> List[Dict[str,str]]:
507
1152
  """
508
1153
  This method inputs an image and returns the correesponding chat messages for the inference engine.
509
1154
 
@@ -519,24 +1164,55 @@ class OpenAIVLMEngine(VLMEngine):
519
1164
  the image format.
520
1165
  detail : str, Optional
521
1166
  the detail level of the image. Default is "high".
1167
+ few_shot_examples : List[FewShotExample], Optional
1168
+ list of few-shot examples. Each example is a dict with keys "image" (PIL.Image.Image) and "text" (str).
522
1169
  """
523
1170
  base64_str = image_to_base64(image)
524
- return [
525
- {"role": "system", "content": system_prompt},
526
- {
527
- "role": "user",
528
- "content": [
529
- {
530
- "type": "image_url",
531
- "image_url": {
532
- "url": f"data:image/{format};base64,{base64_str}",
533
- "detail": detail
1171
+ output_messages = []
1172
+ # system message
1173
+ system_message = {"role": "system", "content": system_prompt}
1174
+ output_messages.append(system_message)
1175
+
1176
+ # few-shot examples
1177
+ if few_shot_examples is not None:
1178
+ for example in few_shot_examples:
1179
+ if not isinstance(example, FewShotExample):
1180
+ raise ValueError("Few-shot example must be a FewShotExample object.")
1181
+
1182
+ example_image_b64 = image_to_base64(example.image)
1183
+ example_user_message = {
1184
+ "role": "user",
1185
+ "content": [
1186
+ {
1187
+ "type": "image_url",
1188
+ "image_url": {
1189
+ "url": f"data:image/{format};base64,{example_image_b64}",
1190
+ "detail": detail
1191
+ },
534
1192
  },
1193
+ {"type": "text", "text": user_prompt},
1194
+ ],
1195
+ }
1196
+ example_agent_message = {"role": "assistant", "content": example.text}
1197
+ output_messages.append(example_user_message)
1198
+ output_messages.append(example_agent_message)
1199
+
1200
+ # user message
1201
+ user_message = {
1202
+ "role": "user",
1203
+ "content": [
1204
+ {
1205
+ "type": "image_url",
1206
+ "image_url": {
1207
+ "url": f"data:image/{format};base64,{base64_str}",
1208
+ "detail": detail
535
1209
  },
536
- {"type": "text", "text": user_prompt},
537
- ],
538
- },
539
- ]
1210
+ },
1211
+ {"type": "text", "text": user_prompt},
1212
+ ],
1213
+ }
1214
+ output_messages.append(user_message)
1215
+ return output_messages
540
1216
 
541
1217
 
542
1218
  class AzureOpenAIVLMEngine(OpenAIVLMEngine):