vlm4ocr 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlm4ocr/__init__.py CHANGED
@@ -1,11 +1,15 @@
1
1
  from .ocr_engines import OCREngine
2
- from .vlm_engines import BasicVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
2
+ from .vlm_engines import BasicVLMConfig, ReasoningVLMConfig, OpenAIReasoningVLMConfig, OllamaVLMEngine, OpenAICompatibleVLMEngine, VLLMVLMEngine, OpenRouterVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine
3
3
 
4
4
  __all__ = [
5
5
  "BasicVLMConfig",
6
+ "ReasoningVLMConfig",
6
7
  "OpenAIReasoningVLMConfig",
7
8
  "OCREngine",
8
9
  "OllamaVLMEngine",
10
+ "OpenAICompatibleVLMEngine",
11
+ "VLLMVLMEngine",
12
+ "OpenRouterVLMEngine",
9
13
  "OpenAIVLMEngine",
10
14
  "AzureOpenAIVLMEngine"
11
15
  ]
@@ -0,0 +1 @@
1
+ You are a helpful assistant that can convert scanned documents into JSON format. Your output is accurate and well-formatted, starting with ```json and ending with ```. You will only output the JSON text without any additional explanations or comments. The JSON should include all text, tables, and lists with appropriate keys and values. You will ignore images, icons, or anything that can not be converted into text.
vlm4ocr/cli.py CHANGED
@@ -4,18 +4,9 @@ import sys
4
4
  import logging
5
5
  import asyncio
6
6
  import time
7
-
8
- # Attempt to import from the local package structure
9
- try:
10
- from .ocr_engines import OCREngine
11
- from .vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
12
- from .data_types import OCRResult
13
- except ImportError:
14
- # Fallback for when the package is installed
15
- from vlm4ocr.ocr_engines import OCREngine
16
- from vlm4ocr.vlm_engines import OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
17
- from vlm4ocr.data_types import OCRResult
18
-
7
+ from .ocr_engines import OCREngine
8
+ from .vlm_engines import OpenAICompatibleVLMEngine, OpenAIVLMEngine, AzureOpenAIVLMEngine, OllamaVLMEngine, BasicVLMConfig
9
+ from .data_types import OCRResult
19
10
  import tqdm.asyncio
20
11
 
21
12
  # --- Global logger setup (console) ---
@@ -208,7 +199,7 @@ def main():
208
199
  vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, config=config)
209
200
  elif args.vlm_engine == "openai_compatible":
210
201
  if not args.base_url: parser.error("--base_url is required for openai_compatible.")
211
- vlm_engine_instance = OpenAIVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
202
+ vlm_engine_instance = OpenAICompatibleVLMEngine(model=args.model, api_key=args.api_key, base_url=args.base_url, config=config)
212
203
  elif args.vlm_engine == "azure_openai":
213
204
  if not args.azure_api_key: parser.error("--azure_api_key (or AZURE_OPENAI_API_KEY) is required.")
214
205
  if not args.azure_endpoint: parser.error("--azure_endpoint (or AZURE_OPENAI_ENDPOINT) is required.")
vlm4ocr/data_types.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import os
2
- from typing import List, Literal
2
+ from typing import List, Dict, Literal
3
3
  from dataclasses import dataclass, field
4
4
  from vlm4ocr.utils import get_default_page_delimiter
5
5
 
6
- OutputMode = Literal["markdown", "HTML", "text"]
6
+ OutputMode = Literal["markdown", "HTML", "text", "JSON"]
7
7
 
8
8
  @dataclass
9
9
  class OCRResult:
@@ -24,6 +24,7 @@ class OCRResult:
24
24
  pages: List[dict] = field(default_factory=list)
25
25
  filename: str = field(init=False)
26
26
  status: str = field(init=False, default="processing")
27
+ messages_log: List[List[Dict[str,str]]] = field(default_factory=list)
27
28
 
28
29
  def __post_init__(self):
29
30
  """
@@ -33,8 +34,8 @@ class OCRResult:
33
34
  self.filename = os.path.basename(self.input_dir)
34
35
 
35
36
  # output_mode validation
36
- if self.output_mode not in ["markdown", "HTML", "text"]:
37
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
37
+ if self.output_mode not in ["markdown", "HTML", "text", "JSON"]:
38
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'")
38
39
 
39
40
  # pages validation
40
41
  if not isinstance(self.pages, list):
@@ -67,10 +68,6 @@ class OCRResult:
67
68
  }
68
69
  self.pages.append(page)
69
70
 
70
-
71
- def __len__(self):
72
- return len(self.pages)
73
-
74
71
  def get_page(self, idx):
75
72
  if not isinstance(idx, int):
76
73
  raise ValueError("Index must be an integer")
@@ -78,6 +75,21 @@ class OCRResult:
78
75
  raise IndexError(f"Index out of range. The OCRResult has {len(self.pages)} pages, but index {idx} was requested.")
79
76
 
80
77
  return self.pages[idx]
78
+
79
+ def clear_messages_log(self):
80
+ self.messages_log = []
81
+
82
+ def add_messages_to_log(self, messages: List[Dict[str,str]]):
83
+ if not isinstance(messages, list):
84
+ raise ValueError("messages must be a list of dict")
85
+
86
+ self.messages_log.extend(messages)
87
+
88
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
89
+ return self.messages_log.copy()
90
+
91
+ def __len__(self):
92
+ return len(self.pages)
81
93
 
82
94
  def __iter__(self):
83
95
  return iter(self.pages)
vlm4ocr/ocr_engines.py CHANGED
@@ -3,10 +3,10 @@ from typing import Tuple, List, Dict, Union, Generator, AsyncGenerator, Iterable
3
3
  import importlib
4
4
  import asyncio
5
5
  from colorama import Fore, Style
6
- from PIL import Image
7
- from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, get_default_page_delimiter
6
+ import json
7
+ from vlm4ocr.utils import DataLoader, PDFDataLoader, TIFFDataLoader, ImageDataLoader, ImageProcessor, clean_markdown, extract_json, get_default_page_delimiter
8
8
  from vlm4ocr.data_types import OCRResult
9
- from vlm4ocr.vlm_engines import VLMEngine
9
+ from vlm4ocr.vlm_engines import VLMEngine, MessagesLogger
10
10
 
11
11
  SUPPORTED_IMAGE_EXTS = ['.pdf', '.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']
12
12
 
@@ -21,7 +21,7 @@ class OCREngine:
21
21
  inference_engine : InferenceEngine
22
22
  The inference engine to use for OCR.
23
23
  output_mode : str, Optional
24
- The output format. Must be 'markdown', 'HTML', or 'text'.
24
+ The output format. Must be 'markdown', 'HTML', 'text', or 'JSON'.
25
25
  system_prompt : str, Optional
26
26
  Custom system prompt. We recommend use a default system prompt by leaving this blank.
27
27
  user_prompt : str, Optional
@@ -33,8 +33,8 @@ class OCREngine:
33
33
  self.vlm_engine = vlm_engine
34
34
 
35
35
  # Check output mode
36
- if output_mode not in ["markdown", "HTML", "text"]:
37
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
36
+ if output_mode not in ["markdown", "HTML", "text", "JSON"]:
37
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'.")
38
38
  self.output_mode = output_mode
39
39
 
40
40
  # System prompt
@@ -49,6 +49,9 @@ class OCREngine:
49
49
  if isinstance(user_prompt, str) and user_prompt:
50
50
  self.user_prompt = user_prompt
51
51
  else:
52
+ if self.output_mode == "JSON":
53
+ raise ValueError("user_prompt must be provided when output_mode is 'JSON' to define the JSON structure.")
54
+
52
55
  prompt_template_path = importlib.resources.files('vlm4ocr.assets.default_prompt_templates').joinpath(f'ocr_{self.output_mode}_user_prompt.txt')
53
56
  with prompt_template_path.open('r', encoding='utf-8') as f:
54
57
  self.user_prompt = f.read()
@@ -123,7 +126,8 @@ class OCREngine:
123
126
  stream=True
124
127
  )
125
128
  for chunk in response_stream:
126
- yield {"type": "ocr_chunk", "data": chunk}
129
+ if chunk["type"] == "response":
130
+ yield {"type": "ocr_chunk", "data": chunk["data"]}
127
131
 
128
132
  if i < len(images) - 1:
129
133
  yield {"type": "page_delimiter", "data": get_default_page_delimiter(self.output_mode)}
@@ -154,7 +158,8 @@ class OCREngine:
154
158
  stream=True
155
159
  )
156
160
  for chunk in response_stream:
157
- yield {"type": "ocr_chunk", "data": chunk}
161
+ if chunk["type"] == "response":
162
+ yield {"type": "ocr_chunk", "data": chunk["data"]}
158
163
 
159
164
 
160
165
  def sequential_ocr(self, file_paths: Union[str, Iterable[str]], rotate_correction:bool=False,
@@ -268,18 +273,32 @@ class OCREngine:
268
273
 
269
274
  try:
270
275
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
276
+ # Define a messages logger to capture messages
277
+ messages_logger = MessagesLogger()
278
+ # Generate response
271
279
  response = self.vlm_engine.chat(
272
280
  messages,
273
281
  verbose=verbose,
274
- stream=False
282
+ stream=False,
283
+ messages_logger=messages_logger
275
284
  )
285
+ ocr_text = response["response"]
276
286
  # Clean the response if output mode is markdown
277
287
  if self.output_mode == "markdown":
278
- response = clean_markdown(response)
288
+ ocr_text = clean_markdown(ocr_text)
289
+
290
+ # Parse the response if output mode is JSON
291
+ elif self.output_mode == "JSON":
292
+ json_list = extract_json(ocr_text)
293
+ # Serialize the JSON list to a string
294
+ ocr_text = json.dumps(json_list, indent=4)
279
295
 
280
296
  # Add the page to the OCR result
281
- ocr_result.add_page(text=response,
297
+ ocr_result.add_page(text=ocr_text,
282
298
  image_processing_status=image_processing_status)
299
+
300
+ # Add messages log to the OCR result
301
+ ocr_result.add_messages_to_log(messages_logger.get_messages_log())
283
302
 
284
303
  except Exception as page_e:
285
304
  ocr_result.status = "error"
@@ -378,6 +397,7 @@ class OCREngine:
378
397
  filename = os.path.basename(file_path)
379
398
  file_ext = os.path.splitext(file_path)[1].lower()
380
399
  result = OCRResult(input_dir=file_path, output_mode=self.output_mode)
400
+ messages_logger = MessagesLogger()
381
401
  # check file extension
382
402
  if file_ext not in SUPPORTED_IMAGE_EXTS:
383
403
  result.status = "error"
@@ -407,7 +427,8 @@ class OCREngine:
407
427
  data_loader=data_loader,
408
428
  page_index=page_index,
409
429
  rotate_correction=rotate_correction,
410
- max_dimension_pixels=max_dimension_pixels
430
+ max_dimension_pixels=max_dimension_pixels,
431
+ messages_logger=messages_logger
411
432
  )
412
433
  page_processing_tasks.append(task)
413
434
 
@@ -419,14 +440,17 @@ class OCREngine:
419
440
  except Exception as e:
420
441
  result.status = "error"
421
442
  result.add_page(text=f"Error during OCR for {filename}: {str(e)}", image_processing_status={})
443
+ result.add_messages_to_log(messages_logger.get_messages_log())
422
444
  return result
423
445
 
424
446
  # Set status to success if no errors occurred
425
447
  result.status = "success"
448
+ result.add_messages_to_log(messages_logger.get_messages_log())
426
449
  return result
427
450
 
428
451
  async def _ocr_page_with_semaphore(self, vlm_call_semaphore: asyncio.Semaphore, data_loader: DataLoader,
429
- page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None) -> Tuple[str, Dict[str, str]]:
452
+ page_index:int, rotate_correction:bool=False, max_dimension_pixels:int=None,
453
+ messages_logger:MessagesLogger=None) -> Tuple[str, Dict[str, str]]:
430
454
  """
431
455
  This internal method takes a semaphore and OCR a single image/page using the VLM inference engine.
432
456
 
@@ -467,9 +491,19 @@ class OCREngine:
467
491
  }
468
492
 
469
493
  messages = self.vlm_engine.get_ocr_messages(self.system_prompt, self.user_prompt, image)
470
- ocr_text = await self.vlm_engine.chat_async(
494
+ response = await self.vlm_engine.chat_async(
471
495
  messages,
496
+ messages_logger=messages_logger
472
497
  )
498
+ ocr_text = response["response"]
499
+ # Clean the OCR text if output mode is markdown
473
500
  if self.output_mode == "markdown":
474
501
  ocr_text = clean_markdown(ocr_text)
502
+
503
+ # Parse the response if output mode is JSON
504
+ elif self.output_mode == "JSON":
505
+ json_list = extract_json(ocr_text)
506
+ # Serialize the JSON list to a string
507
+ ocr_text = json.dumps(json_list, indent=4)
508
+
475
509
  return ocr_text, image_processing_status
vlm4ocr/utils.py CHANGED
@@ -2,11 +2,14 @@ import abc
2
2
  import os
3
3
  import io
4
4
  import base64
5
- from typing import Union, List, Tuple
5
+ from typing import Dict, List, Tuple
6
+ import json
7
+ import json_repair
6
8
  import importlib.util
7
9
  from pdf2image import convert_from_path, pdfinfo_from_path
8
10
  from PIL import Image
9
11
  import asyncio
12
+ import warnings
10
13
 
11
14
 
12
15
  class DataLoader(abc.ABC):
@@ -229,6 +232,55 @@ def clean_markdown(text:str) -> str:
229
232
  cleaned_text = text.replace("```markdown", "").replace("```", "")
230
233
  return cleaned_text
231
234
 
235
+ def _find_dict_strings( text: str) -> List[str]:
236
+ """
237
+ Extracts balanced JSON-like dictionaries from a string, even if nested.
238
+
239
+ Parameters:
240
+ -----------
241
+ text : str
242
+ the input text containing JSON-like structures.
243
+
244
+ Returns : List[str]
245
+ A list of valid JSON-like strings representing dictionaries.
246
+ """
247
+ open_brace = 0
248
+ start = -1
249
+ json_objects = []
250
+
251
+ for i, char in enumerate(text):
252
+ if char == '{':
253
+ if open_brace == 0:
254
+ # start of a new JSON object
255
+ start = i
256
+ open_brace += 1
257
+ elif char == '}':
258
+ open_brace -= 1
259
+ if open_brace == 0 and start != -1:
260
+ json_objects.append(text[start:i + 1])
261
+ start = -1
262
+
263
+ return json_objects
264
+
265
+ def extract_json(gen_text:str) -> List[Dict[str, str]]:
266
+ """
267
+ This method inputs a generated text and output a JSON of information tuples
268
+ """
269
+ out = []
270
+ dict_str_list = _find_dict_strings(gen_text)
271
+ for dict_str in dict_str_list:
272
+ try:
273
+ dict_obj = json.loads(dict_str)
274
+ out.append(dict_obj)
275
+ except json.JSONDecodeError:
276
+ dict_obj = json_repair.repair_json(dict_str, skip_json_loads=True, return_objects=True)
277
+ if dict_obj:
278
+ warnings.warn(f'JSONDecodeError detected, fixed with repair_json:\n{dict_str}', RuntimeWarning)
279
+ out.append(dict_obj)
280
+ else:
281
+ warnings.warn(f'JSONDecodeError could not be fixed:\n{dict_str}', RuntimeWarning)
282
+ return out
283
+
232
284
  def get_default_page_delimiter(output_mode:str) -> str:
233
285
  """
234
286
  Returns the default page delimiter based on the environment variable.
@@ -243,8 +295,8 @@ def get_default_page_delimiter(output_mode:str) -> str:
243
295
  str
244
296
  The default page delimiter.
245
297
  """
246
- if output_mode not in ["markdown", "HTML", "text"]:
247
- raise ValueError("output_mode must be 'markdown', 'HTML', or 'text'")
298
+ if output_mode not in ["markdown", "HTML", "text", "JSON"]:
299
+ raise ValueError("output_mode must be 'markdown', 'HTML', 'text', or 'JSON'")
248
300
 
249
301
  if output_mode == "markdown":
250
302
  return "\n\n---\n\n"
@@ -252,6 +304,8 @@ def get_default_page_delimiter(output_mode:str) -> str:
252
304
  return "<br><br>"
253
305
  elif output_mode == "text":
254
306
  return "\n\n---\n\n"
307
+ elif output_mode == "JSON":
308
+ return "\n\n---\n\n"
255
309
 
256
310
 
257
311
  class ImageProcessor:
vlm4ocr/vlm_engines.py CHANGED
@@ -2,6 +2,8 @@ import abc
2
2
  import importlib.util
3
3
  from typing import Any, List, Dict, Union, Generator
4
4
  import warnings
5
+ import os
6
+ import re
5
7
  from PIL import Image
6
8
  from vlm4ocr.utils import image_to_base64
7
9
 
@@ -33,7 +35,7 @@ class VLMConfig(abc.ABC):
33
35
  return NotImplemented
34
36
 
35
37
  @abc.abstractmethod
36
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
38
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
37
39
  """
38
40
  This method postprocesses the VLM response after it is generated.
39
41
 
@@ -77,7 +79,7 @@ class BasicVLMConfig(VLMConfig):
77
79
  """
78
80
  return messages
79
81
 
80
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
82
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
81
83
  """
82
84
  This method postprocesses the VLM response after it is generated.
83
85
 
@@ -88,19 +90,121 @@ class BasicVLMConfig(VLMConfig):
88
90
 
89
91
  Returns: Union[str, Generator[Dict[str, str], None, None]]
90
92
  the postprocessed VLM response.
91
- if input is a generator, the output will be a generator {"data": <content>}.
93
+ if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
92
94
  """
93
95
  if isinstance(response, str):
94
- return response
96
+ return {"response": response}
97
+
98
+ elif isinstance(response, dict):
99
+ if "response" in response:
100
+ return response
101
+ else:
102
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
103
+ return {"response": ""}
95
104
 
96
105
  def _process_stream():
97
106
  for chunk in response:
98
- yield chunk
107
+ if isinstance(chunk, dict):
108
+ yield chunk
109
+ elif isinstance(chunk, str):
110
+ yield {"type": "response", "data": chunk}
99
111
 
100
112
  return _process_stream()
113
+
114
+ class ReasoningVLMConfig(VLMConfig):
115
+ def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
116
+ """
117
+ The general configuration for reasoning vision models.
118
+ """
119
+ super().__init__(**kwargs)
120
+ self.thinking_token_start = thinking_token_start
121
+ self.thinking_token_end = thinking_token_end
122
+
123
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
124
+ """
125
+ This method preprocesses the input messages before passing them to the VLM.
126
+
127
+ Parameters:
128
+ ----------
129
+ messages : List[Dict[str,str]]
130
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
131
+
132
+ Returns:
133
+ -------
134
+ messages : List[Dict[str,str]]
135
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
136
+ """
137
+ return messages.copy()
101
138
 
139
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
140
+ """
141
+ This method postprocesses the VLM response after it is generated.
142
+ 1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
143
+ 2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
144
+ 3. If input is a generator,
145
+ a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
146
+ b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
102
147
 
103
- class OpenAIReasoningVLMConfig(VLMConfig):
148
+ Parameters:
149
+ ----------
150
+ response : Union[str, Generator[str, None, None]]
151
+ the VLM response. Can be a string or a generator.
152
+
153
+ Returns:
154
+ -------
155
+ response : Union[str, Generator[str, None, None]]
156
+ the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
157
+ if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
158
+ """
159
+ if isinstance(response, str):
160
+ # get contents between thinking_token_start and thinking_token_end
161
+ pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
162
+ match = re.search(pattern, response, re.DOTALL)
163
+ reasoning = match.group(1) if match else ""
164
+ # get response AFTER thinking_token_end
165
+ response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
166
+ return {"reasoning": reasoning, "response": response}
167
+
168
+ elif isinstance(response, dict):
169
+ if "reasoning" in response and "response" in response:
170
+ return response
171
+ else:
172
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
173
+ return {"reasoning": "", "response": ""}
174
+
175
+ elif isinstance(response, Generator):
176
+ def _process_stream():
177
+ think_flag = False
178
+ buffer = ""
179
+ for chunk in response:
180
+ if isinstance(chunk, dict):
181
+ yield chunk
182
+
183
+ elif isinstance(chunk, str):
184
+ buffer += chunk
185
+ # switch between reasoning and response
186
+ if self.thinking_token_start in buffer:
187
+ think_flag = True
188
+ buffer = buffer.replace(self.thinking_token_start, "")
189
+ elif self.thinking_token_end in buffer:
190
+ think_flag = False
191
+ buffer = buffer.replace(self.thinking_token_end, "")
192
+
193
+ # if chunk is in thinking block, tag it as reasoning; else tag it as response
194
+ if chunk not in [self.thinking_token_start, self.thinking_token_end]:
195
+ if think_flag:
196
+ yield {"type": "reasoning", "data": chunk}
197
+ else:
198
+ yield {"type": "response", "data": chunk}
199
+
200
+ return _process_stream()
201
+
202
+ else:
203
+ warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
204
+ return {"reasoning": "", "response": ""}
205
+
206
+
207
+ class OpenAIReasoningVLMConfig(ReasoningVLMConfig):
104
208
  def __init__(self, reasoning_effort:str="low", **kwargs):
105
209
  """
106
210
  The OpenAI "o" series configuration.
@@ -160,27 +264,31 @@ class OpenAIReasoningVLMConfig(VLMConfig):
160
264
 
161
265
  return new_messages
162
266
 
163
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
164
- """
165
- This method postprocesses the VLM response after it is generated.
166
267
 
167
- Parameters:
168
- ----------
169
- response : Union[str, Generator[str, None, None]]
170
- the VLM response. Can be a string or a generator.
171
-
172
- Returns: Union[str, Generator[Dict[str, str], None, None]]
173
- the postprocessed VLM response.
174
- if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
268
+ class MessagesLogger:
269
+ def __init__(self):
175
270
  """
176
- if isinstance(response, str):
177
- return response
271
+ This class is used to log the messages for InferenceEngine.chat().
272
+ """
273
+ self.messages_log = []
178
274
 
179
- def _process_stream():
180
- for chunk in response:
181
- yield {"type": "response", "data": chunk}
275
+ def log_messages(self, messages : List[Dict[str,str]]):
276
+ """
277
+ This method logs the messages to a list.
278
+ """
279
+ self.messages_log.append(messages)
182
280
 
183
- return _process_stream()
281
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
282
+ """
283
+ This method returns a copy of the current messages log
284
+ """
285
+ return self.messages_log.copy()
286
+
287
+ def clear_messages_log(self):
288
+ """
289
+ This method clears the current messages log
290
+ """
291
+ self.messages_log.clear()
184
292
 
185
293
 
186
294
  class VLMEngine:
@@ -198,7 +306,8 @@ class VLMEngine:
198
306
  return NotImplemented
199
307
 
200
308
  @abc.abstractmethod
201
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
309
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
310
+ messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
202
311
  """
203
312
  This method inputs chat messages and outputs VLM generated text.
204
313
 
@@ -210,11 +319,13 @@ class VLMEngine:
210
319
  if True, VLM generated text will be printed in terminal in real-time.
211
320
  stream : bool, Optional
212
321
  if True, returns a generator that yields the output in real-time.
322
+ Messages_logger : MessagesLogger, Optional
323
+ the message logger that logs the chat messages.
213
324
  """
214
325
  return NotImplemented
215
326
 
216
327
  @abc.abstractmethod
217
- def chat_async(self, messages:List[Dict[str,str]]) -> str:
328
+ def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str, str]:
218
329
  """
219
330
  The async version of chat method. Streaming is not supported.
220
331
  """
@@ -285,7 +396,8 @@ class OllamaVLMEngine(VLMEngine):
285
396
 
286
397
  return formatted_params
287
398
 
288
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
399
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
400
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
289
401
  """
290
402
  This method inputs chat messages and outputs VLM generated text.
291
403
 
@@ -297,6 +409,13 @@ class OllamaVLMEngine(VLMEngine):
297
409
  if True, VLM generated text will be printed in terminal in real-time.
298
410
  stream : bool, Optional
299
411
  if True, returns a generator that yields the output in real-time.
412
+ Messages_logger : MessagesLogger, Optional
413
+ the message logger that logs the chat messages.
414
+
415
+ Returns:
416
+ -------
417
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
418
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
300
419
  """
301
420
  processed_messages = self.config.preprocess_messages(messages)
302
421
 
@@ -310,10 +429,33 @@ class OllamaVLMEngine(VLMEngine):
310
429
  stream=True,
311
430
  keep_alive=self.keep_alive
312
431
  )
432
+ res = {"reasoning": "", "response": ""}
313
433
  for chunk in response_stream:
314
- content_chunk = chunk.get('message', {}).get('content')
315
- if content_chunk:
316
- yield content_chunk
434
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
435
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
436
+ res["reasoning"] += content_chunk
437
+ yield {"type": "reasoning", "data": content_chunk}
438
+ else:
439
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
440
+ res["response"] += content_chunk
441
+ yield {"type": "response", "data": content_chunk}
442
+
443
+ if chunk.done_reason == "length":
444
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
445
+
446
+ # Postprocess response
447
+ res_dict = self.config.postprocess_response(res)
448
+ # Write to messages log
449
+ if messages_logger:
450
+ # replace images content with a placeholder "[image]" to save space
451
+ for messages in processed_messages:
452
+ if "images" in messages:
453
+ messages["images"] = ["[image]" for _ in messages["images"]]
454
+
455
+ processed_messages.append({"role": "assistant",
456
+ "content": res_dict.get("response", ""),
457
+ "reasoning": res_dict.get("reasoning", "")})
458
+ messages_logger.log_messages(processed_messages)
317
459
 
318
460
  return self.config.postprocess_response(_stream_generator())
319
461
 
@@ -326,14 +468,29 @@ class OllamaVLMEngine(VLMEngine):
326
468
  keep_alive=self.keep_alive
327
469
  )
328
470
 
329
- res = ''
471
+ res = {"reasoning": "", "response": ""}
472
+ phase = ""
330
473
  for chunk in response:
331
- content_chunk = chunk.get('message', {}).get('content')
474
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
475
+ if phase != "reasoning":
476
+ print("\n--- Reasoning ---")
477
+ phase = "reasoning"
478
+
479
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
480
+ res["reasoning"] += content_chunk
481
+ else:
482
+ if phase != "response":
483
+ print("\n--- Response ---")
484
+ phase = "response"
485
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
486
+ res["response"] += content_chunk
487
+
332
488
  print(content_chunk, end='', flush=True)
333
- res += content_chunk
489
+
490
+ if chunk.done_reason == "length":
491
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
334
492
  print('\n')
335
- return self.config.postprocess_response(res)
336
-
493
+
337
494
  else:
338
495
  response = self.client.chat(
339
496
  model=self.model_name,
@@ -342,11 +499,30 @@ class OllamaVLMEngine(VLMEngine):
342
499
  stream=False,
343
500
  keep_alive=self.keep_alive
344
501
  )
345
- res = response.get('message', {}).get('content')
346
- return self.config.postprocess_response(res)
502
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
503
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
504
+
505
+ if response.done_reason == "length":
506
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
507
+
508
+ # Postprocess response
509
+ res_dict = self.config.postprocess_response(res)
510
+ # Write to messages log
511
+ if messages_logger:
512
+ # replace images content with a placeholder "[image]" to save space
513
+ for messages in processed_messages:
514
+ if "images" in messages:
515
+ messages["images"] = ["[image]" for _ in messages["images"]]
516
+
517
+ processed_messages.append({"role": "assistant",
518
+ "content": res_dict.get("response", ""),
519
+ "reasoning": res_dict.get("reasoning", "")})
520
+ messages_logger.log_messages(processed_messages)
521
+
522
+ return res_dict
347
523
 
348
524
 
349
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
525
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
350
526
  """
351
527
  Async version of chat method. Streaming is not supported.
352
528
  """
@@ -360,8 +536,26 @@ class OllamaVLMEngine(VLMEngine):
360
536
  keep_alive=self.keep_alive
361
537
  )
362
538
 
363
- res = response['message']['content']
364
- return self.config.postprocess_response(res)
539
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
540
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
541
+
542
+ if response.done_reason == "length":
543
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
544
+ # Postprocess response
545
+ res_dict = self.config.postprocess_response(res)
546
+ # Write to messages log
547
+ if messages_logger:
548
+ # replace images content with a placeholder "[image]" to save space
549
+ for messages in processed_messages:
550
+ if "images" in messages:
551
+ messages["images"] = ["[image]" for _ in messages["images"]]
552
+
553
+ processed_messages.append({"role": "assistant",
554
+ "content": res_dict.get("response", ""),
555
+ "reasoning": res_dict.get("reasoning", "")})
556
+ messages_logger.log_messages(processed_messages)
557
+
558
+ return res_dict
365
559
 
366
560
  def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image) -> List[Dict[str,str]]:
367
561
  """
@@ -387,6 +581,346 @@ class OllamaVLMEngine(VLMEngine):
387
581
  ]
388
582
 
389
583
 
584
+ class OpenAICompatibleVLMEngine(VLMEngine):
585
+ def __init__(self, model:str, api_key:str, base_url:str, config:VLMConfig=None, **kwrs):
586
+ """
587
+ General OpenAI-compatible server inference engine.
588
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
589
+
590
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
591
+
592
+ Parameters:
593
+ ----------
594
+ model_name : str
595
+ model name as shown in the vLLM server
596
+ api_key : str
597
+ the API key for the vLLM server.
598
+ base_url : str
599
+ the base url for the vLLM server.
600
+ config : LLMConfig
601
+ the LLM configuration.
602
+ """
603
+ if importlib.util.find_spec("openai") is None:
604
+ raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
605
+
606
+ from openai import OpenAI, AsyncOpenAI
607
+ from openai.types.chat import ChatCompletionChunk
608
+ self.ChatCompletionChunk = ChatCompletionChunk
609
+ super().__init__(config)
610
+ self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
611
+ self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
612
+ self.model = model
613
+ self.config = config if config else BasicVLMConfig()
614
+ self.formatted_params = self._format_config()
615
+
616
+ def _format_config(self) -> Dict[str, Any]:
617
+ """
618
+ This method format the VLM configuration with the correct key for the inference engine.
619
+ """
620
+ formatted_params = self.config.params.copy()
621
+ if "max_new_tokens" in formatted_params:
622
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
623
+ formatted_params.pop("max_new_tokens")
624
+
625
+ return formatted_params
626
+
627
+
628
+ def _format_response(self, response: Any) -> Dict[str, str]:
629
+ """
630
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
631
+
632
+ Parameters:
633
+ ----------
634
+ response : Any
635
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
636
+ """
637
+ if isinstance(response, self.ChatCompletionChunk):
638
+ chunk_text = getattr(response.choices[0].delta, "content", "")
639
+ if chunk_text is None:
640
+ chunk_text = ""
641
+ return {"type": "response", "data": chunk_text}
642
+
643
+ return {"response": getattr(response.choices[0].message, "content", "")}
644
+
645
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
646
+ messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
647
+ """
648
+ This method inputs chat messages and outputs LLM generated text.
649
+
650
+ Parameters:
651
+ ----------
652
+ messages : List[Dict[str,str]]
653
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
654
+ verbose : bool, Optional
655
+ if True, VLM generated text will be printed in terminal in real-time.
656
+ stream : bool, Optional
657
+ if True, returns a generator that yields the output in real-time.
658
+ messages_logger : MessagesLogger, Optional
659
+ the message logger that logs the chat messages.
660
+
661
+ Returns:
662
+ -------
663
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
664
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
665
+ """
666
+ processed_messages = self.config.preprocess_messages(messages)
667
+
668
+ if stream:
669
+ def _stream_generator():
670
+ response_stream = self.client.chat.completions.create(
671
+ model=self.model,
672
+ messages=processed_messages,
673
+ stream=True,
674
+ **self.formatted_params
675
+ )
676
+ res_text = ""
677
+ for chunk in response_stream:
678
+ if len(chunk.choices) > 0:
679
+ chunk_dict = self._format_response(chunk)
680
+ yield chunk_dict
681
+
682
+ res_text += chunk_dict["data"]
683
+ if chunk.choices[0].finish_reason == "length":
684
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
685
+
686
+ # Postprocess response
687
+ res_dict = self.config.postprocess_response(res_text)
688
+ # Write to messages log
689
+ if messages_logger:
690
+ # replace images content with a placeholder "[image]" to save space
691
+ for messages in processed_messages:
692
+ if "content" in messages and isinstance(messages["content"], list):
693
+ for content in messages["content"]:
694
+ if isinstance(content, dict) and content.get("type") == "image_url":
695
+ content["image_url"]["url"] = "[image]"
696
+
697
+ processed_messages.append({"role": "assistant",
698
+ "content": res_dict.get("response", ""),
699
+ "reasoning": res_dict.get("reasoning", "")})
700
+ messages_logger.log_messages(processed_messages)
701
+
702
+ return self.config.postprocess_response(_stream_generator())
703
+
704
+ elif verbose:
705
+ response = self.client.chat.completions.create(
706
+ model=self.model,
707
+ messages=processed_messages,
708
+ stream=True,
709
+ **self.formatted_params
710
+ )
711
+ res = {"reasoning": "", "response": ""}
712
+ phase = ""
713
+ for chunk in response:
714
+ if len(chunk.choices) > 0:
715
+ chunk_dict = self._format_response(chunk)
716
+ chunk_text = chunk_dict["data"]
717
+ res[chunk_dict["type"]] += chunk_text
718
+ if phase != chunk_dict["type"] and chunk_text != "":
719
+ print(f"\n--- {chunk_dict['type'].capitalize()} ---")
720
+ phase = chunk_dict["type"]
721
+
722
+ print(chunk_text, end="", flush=True)
723
+ if chunk.choices[0].finish_reason == "length":
724
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
725
+
726
+ print('\n')
727
+
728
+ else:
729
+ response = self.client.chat.completions.create(
730
+ model=self.model,
731
+ messages=processed_messages,
732
+ stream=False,
733
+ **self.formatted_params
734
+ )
735
+ res = self._format_response(response)
736
+
737
+ if response.choices[0].finish_reason == "length":
738
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
739
+
740
+ # Postprocess response
741
+ res_dict = self.config.postprocess_response(res)
742
+ # Write to messages log
743
+ if messages_logger:
744
+ # replace images content with a placeholder "[image]" to save space
745
+ for messages in processed_messages:
746
+ if "content" in messages and isinstance(messages["content"], list):
747
+ for content in messages["content"]:
748
+ if isinstance(content, dict) and content.get("type") == "image_url":
749
+ content["image_url"]["url"] = "[image]"
750
+
751
+ processed_messages.append({"role": "assistant",
752
+ "content": res_dict.get("response", ""),
753
+ "reasoning": res_dict.get("reasoning", "")})
754
+ messages_logger.log_messages(processed_messages)
755
+
756
+ return res_dict
757
+
758
+
759
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
760
+ """
761
+ Async version of chat method. Streaming is not supported.
762
+ """
763
+ processed_messages = self.config.preprocess_messages(messages)
764
+
765
+ response = await self.async_client.chat.completions.create(
766
+ model=self.model,
767
+ messages=processed_messages,
768
+ stream=False,
769
+ **self.formatted_params
770
+ )
771
+
772
+ if response.choices[0].finish_reason == "length":
773
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
774
+
775
+ res = self._format_response(response)
776
+
777
+ # Postprocess response
778
+ res_dict = self.config.postprocess_response(res)
779
+ # Write to messages log
780
+ if messages_logger:
781
+ # replace images content with a placeholder "[image]" to save space
782
+ for messages in processed_messages:
783
+ if "content" in messages and isinstance(messages["content"], list):
784
+ for content in messages["content"]:
785
+ if isinstance(content, dict) and content.get("type") == "image_url":
786
+ content["image_url"]["url"] = "[image]"
787
+
788
+ processed_messages.append({"role": "assistant",
789
+ "content": res_dict.get("response", ""),
790
+ "reasoning": res_dict.get("reasoning", "")})
791
+ messages_logger.log_messages(processed_messages)
792
+
793
+ return res_dict
794
+
795
+ def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
796
+ """
797
+ This method inputs an image and returns the correesponding chat messages for the inference engine.
798
+
799
+ Parameters:
800
+ ----------
801
+ system_prompt : str
802
+ the system prompt.
803
+ user_prompt : str
804
+ the user prompt.
805
+ image : Image.Image
806
+ the image for OCR.
807
+ format : str, Optional
808
+ the image format.
809
+ detail : str, Optional
810
+ the detail level of the image. Default is "high".
811
+ """
812
+ base64_str = image_to_base64(image)
813
+ return [
814
+ {"role": "system", "content": system_prompt},
815
+ {
816
+ "role": "user",
817
+ "content": [
818
+ {
819
+ "type": "image_url",
820
+ "image_url": {
821
+ "url": f"data:image/{format};base64,{base64_str}",
822
+ "detail": detail
823
+ },
824
+ },
825
+ {"type": "text", "text": user_prompt},
826
+ ],
827
+ },
828
+ ]
829
+
830
+
831
+ class VLLMVLMEngine(OpenAICompatibleVLMEngine):
832
+ def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:VLMConfig=None, **kwrs):
833
+ """
834
+ vLLM OpenAI compatible server inference engine.
835
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
836
+
837
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
838
+
839
+ Parameters:
840
+ ----------
841
+ model_name : str
842
+ model name as shown in the vLLM server
843
+ api_key : str, Optional
844
+ the API key for the vLLM server.
845
+ base_url : str, Optional
846
+ the base url for the vLLM server.
847
+ config : LLMConfig
848
+ the LLM configuration.
849
+ """
850
+ super().__init__(model, api_key, base_url, config, **kwrs)
851
+
852
+
853
+ def _format_response(self, response: Any) -> Dict[str, str]:
854
+ """
855
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
856
+
857
+ Parameters:
858
+ ----------
859
+ response : Any
860
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
861
+ """
862
+ if isinstance(response, self.ChatCompletionChunk):
863
+ if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
864
+ chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
865
+ if chunk_text is None:
866
+ chunk_text = ""
867
+ return {"type": "reasoning", "data": chunk_text}
868
+ else:
869
+ chunk_text = getattr(response.choices[0].delta, "content", "")
870
+ if chunk_text is None:
871
+ chunk_text = ""
872
+ return {"type": "response", "data": chunk_text}
873
+
874
+ return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
875
+ "response": getattr(response.choices[0].message, "content", "")}
876
+
877
+
878
+ class OpenRouterVLMEngine(OpenAICompatibleVLMEngine):
879
+ def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:VLMConfig=None, **kwrs):
880
+ """
881
+ OpenRouter OpenAI-compatible server inference engine.
882
+
883
+ Parameters:
884
+ ----------
885
+ model_name : str
886
+ model name as shown in the vLLM server
887
+ api_key : str, Optional
888
+ the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
889
+ base_url : str, Optional
890
+ the base url for the vLLM server.
891
+ config : LLMConfig
892
+ the LLM configuration.
893
+ """
894
+ self.api_key = api_key
895
+ if self.api_key is None:
896
+ self.api_key = os.getenv("OPENROUTER_API_KEY")
897
+ super().__init__(model, self.api_key, base_url, config, **kwrs)
898
+
899
+ def _format_response(self, response: Any) -> Dict[str, str]:
900
+ """
901
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
902
+
903
+ Parameters:
904
+ ----------
905
+ response : Any
906
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
907
+ """
908
+ if isinstance(response, self.ChatCompletionChunk):
909
+ if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
910
+ chunk_text = getattr(response.choices[0].delta, "reasoning", "")
911
+ if chunk_text is None:
912
+ chunk_text = ""
913
+ return {"type": "reasoning", "data": chunk_text}
914
+ else:
915
+ chunk_text = getattr(response.choices[0].delta, "content", "")
916
+ if chunk_text is None:
917
+ chunk_text = ""
918
+ return {"type": "response", "data": chunk_text}
919
+
920
+ return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
921
+ "response": getattr(response.choices[0].message, "content", "")}
922
+
923
+
390
924
  class OpenAIVLMEngine(VLMEngine):
391
925
  def __init__(self, model:str, config:VLMConfig=None, **kwrs):
392
926
  """
@@ -423,7 +957,7 @@ class OpenAIVLMEngine(VLMEngine):
423
957
 
424
958
  return formatted_params
425
959
 
426
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[str, None, None]]:
960
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
427
961
  """
428
962
  This method inputs chat messages and outputs LLM generated text.
429
963
 
@@ -435,6 +969,13 @@ class OpenAIVLMEngine(VLMEngine):
435
969
  if True, VLM generated text will be printed in terminal in real-time.
436
970
  stream : bool, Optional
437
971
  if True, returns a generator that yields the output in real-time.
972
+ messages_logger : MessagesLogger, Optional
973
+ the message logger that logs the chat messages.
974
+
975
+ Returns:
976
+ -------
977
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
978
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
438
979
  """
439
980
  processed_messages = self.config.preprocess_messages(messages)
440
981
 
@@ -446,13 +987,32 @@ class OpenAIVLMEngine(VLMEngine):
446
987
  stream=True,
447
988
  **self.formatted_params
448
989
  )
990
+ res_text = ""
449
991
  for chunk in response_stream:
450
992
  if len(chunk.choices) > 0:
451
- if chunk.choices[0].delta.content is not None:
452
- yield chunk.choices[0].delta.content
993
+ chunk_text = chunk.choices[0].delta.content
994
+ if chunk_text is not None:
995
+ res_text += chunk_text
996
+ yield chunk_text
453
997
  if chunk.choices[0].finish_reason == "length":
454
998
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
455
999
 
1000
+ # Postprocess response
1001
+ res_dict = self.config.postprocess_response(res_text)
1002
+ # Write to messages log
1003
+ if messages_logger:
1004
+ # replace images content with a placeholder "[image]" to save space
1005
+ for messages in processed_messages:
1006
+ if "content" in messages and isinstance(messages["content"], list):
1007
+ for content in messages["content"]:
1008
+ if isinstance(content, dict) and content.get("type") == "image_url":
1009
+ content["image_url"]["url"] = "[image]"
1010
+
1011
+ processed_messages.append({"role": "assistant",
1012
+ "content": res_dict.get("response", ""),
1013
+ "reasoning": res_dict.get("reasoning", "")})
1014
+ messages_logger.log_messages(processed_messages)
1015
+
456
1016
  return self.config.postprocess_response(_stream_generator())
457
1017
 
458
1018
  elif verbose:
@@ -472,7 +1032,7 @@ class OpenAIVLMEngine(VLMEngine):
472
1032
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
473
1033
 
474
1034
  print('\n')
475
- return self.config.postprocess_response(res)
1035
+
476
1036
  else:
477
1037
  response = self.client.chat.completions.create(
478
1038
  model=self.model,
@@ -481,10 +1041,27 @@ class OpenAIVLMEngine(VLMEngine):
481
1041
  **self.formatted_params
482
1042
  )
483
1043
  res = response.choices[0].message.content
484
- return self.config.postprocess_response(res)
1044
+
1045
+ # Postprocess response
1046
+ res_dict = self.config.postprocess_response(res)
1047
+ # Write to messages log
1048
+ if messages_logger:
1049
+ # replace images content with a placeholder "[image]" to save space
1050
+ for messages in processed_messages:
1051
+ if "content" in messages and isinstance(messages["content"], list):
1052
+ for content in messages["content"]:
1053
+ if isinstance(content, dict) and content.get("type") == "image_url":
1054
+ content["image_url"]["url"] = "[image]"
1055
+
1056
+ processed_messages.append({"role": "assistant",
1057
+ "content": res_dict.get("response", ""),
1058
+ "reasoning": res_dict.get("reasoning", "")})
1059
+ messages_logger.log_messages(processed_messages)
1060
+
1061
+ return res_dict
485
1062
 
486
1063
 
487
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
1064
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
488
1065
  """
489
1066
  Async version of chat method. Streaming is not supported.
490
1067
  """
@@ -501,7 +1078,23 @@ class OpenAIVLMEngine(VLMEngine):
501
1078
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
502
1079
 
503
1080
  res = response.choices[0].message.content
504
- return self.config.postprocess_response(res)
1081
+ # Postprocess response
1082
+ res_dict = self.config.postprocess_response(res)
1083
+ # Write to messages log
1084
+ if messages_logger:
1085
+ # replace images content with a placeholder "[image]" to save space
1086
+ for messages in processed_messages:
1087
+ if "content" in messages and isinstance(messages["content"], list):
1088
+ for content in messages["content"]:
1089
+ if isinstance(content, dict) and content.get("type") == "image_url":
1090
+ content["image_url"]["url"] = "[image]"
1091
+
1092
+ processed_messages.append({"role": "assistant",
1093
+ "content": res_dict.get("response", ""),
1094
+ "reasoning": res_dict.get("reasoning", "")})
1095
+ messages_logger.log_messages(processed_messages)
1096
+
1097
+ return res_dict
505
1098
 
506
1099
  def get_ocr_messages(self, system_prompt:str, user_prompt:str, image:Image.Image, format:str='png', detail:str="high") -> List[Dict[str,str]]:
507
1100
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vlm4ocr
3
- Version: 0.2.0
3
+ Version: 0.3.1
4
4
  Summary: Python package and Web App for OCR with vision language models.
5
5
  License: MIT
6
6
  Author: Enshuo (David) Hsu
@@ -10,6 +10,8 @@ Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Provides-Extra: tesseract
13
+ Requires-Dist: colorama (>=0.4.4)
14
+ Requires-Dist: json-repair (>=0.30.0)
13
15
  Requires-Dist: pdf2image (>=1.16.0)
14
16
  Requires-Dist: pillow (>=10.0.0)
15
17
  Requires-Dist: pytesseract (>=0.3.13) ; extra == "tesseract"
@@ -0,0 +1,17 @@
1
+ vlm4ocr/__init__.py,sha256=NpJ-jquqaXo-uHPcMOYUtqToLWLxixftPQn7epD2XbY,506
2
+ vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt,sha256=igPOntiLDZXTB71-QrTmMJveb6XC1TgArg1serPc9V8,547
3
+ vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt,sha256=cVn538JojZfCtIhfrcOPWt0dO7dtDqgB9xdS_5VvAqo,41
4
+ vlm4ocr/assets/default_prompt_templates/ocr_JSON_system_prompt.txt,sha256=v-fUw53gkngc_dz9TMH2abALDsAEZfe-zJ2u3-SO4ck,417
5
+ vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt,sha256=pIsYO2G3jkZ5EWg7MJixre3Itz1oPqJSduUZT34_RNY,436
6
+ vlm4ocr/assets/default_prompt_templates/ocr_markdown_user_prompt.txt,sha256=61EJv8POsQGIIUVwCjDU73lMXJE7F3qhPIYl6zSbl1Q,45
7
+ vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt,sha256=WbLSOerqFjlYGaGWJ-w2enhky1WhnPl011s0fgRPgnQ,398
8
+ vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt,sha256=ftgNAIPy_UlrcY6m7-IkH2ApHkCzRnymra1w2wg60Ks,47
9
+ vlm4ocr/cli.py,sha256=mq5fbJQvgUm89Vd9v2SIW9ARsGex-8V46-r3-evjYrs,19966
10
+ vlm4ocr/data_types.py,sha256=BOcq5KsZFJ_-Fxb9A4IJfOd0x5u-1tUQkYbWAJayuPM,4416
11
+ vlm4ocr/ocr_engines.py,sha256=up7p9xGIeBdwQgqChlr7lsTMWTVFtSWzwlFZp2wKAxk,25431
12
+ vlm4ocr/utils.py,sha256=nQhUskOze99wCVMKmvsen0dhq-9NdN4EPC_bdYfkjgA,13611
13
+ vlm4ocr/vlm_engines.py,sha256=rfb4P1fhpY6ClC27FMhYCWOaIjCipZCx3gPrNnDbF0w,50209
14
+ vlm4ocr-0.3.1.dist-info/METADATA,sha256=_l03maaznCHetgYPATohqd_yFJenWE57sdw_JLaVmc0,710
15
+ vlm4ocr-0.3.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
16
+ vlm4ocr-0.3.1.dist-info/entry_points.txt,sha256=qzWUk_QTZ12cH4DLjjfqce89EAlOydD85dreRRZF3K4,44
17
+ vlm4ocr-0.3.1.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- vlm4ocr/__init__.py,sha256=k5TZY0LmRnjGyjHD0H5AxJHJMw_cS2SzGxTJ0NQbQsc,315
2
- vlm4ocr/assets/default_prompt_templates/ocr_HTML_system_prompt.txt,sha256=igPOntiLDZXTB71-QrTmMJveb6XC1TgArg1serPc9V8,547
3
- vlm4ocr/assets/default_prompt_templates/ocr_HTML_user_prompt.txt,sha256=cVn538JojZfCtIhfrcOPWt0dO7dtDqgB9xdS_5VvAqo,41
4
- vlm4ocr/assets/default_prompt_templates/ocr_markdown_system_prompt.txt,sha256=pIsYO2G3jkZ5EWg7MJixre3Itz1oPqJSduUZT34_RNY,436
5
- vlm4ocr/assets/default_prompt_templates/ocr_markdown_user_prompt.txt,sha256=61EJv8POsQGIIUVwCjDU73lMXJE7F3qhPIYl6zSbl1Q,45
6
- vlm4ocr/assets/default_prompt_templates/ocr_text_system_prompt.txt,sha256=WbLSOerqFjlYGaGWJ-w2enhky1WhnPl011s0fgRPgnQ,398
7
- vlm4ocr/assets/default_prompt_templates/ocr_text_user_prompt.txt,sha256=ftgNAIPy_UlrcY6m7-IkH2ApHkCzRnymra1w2wg60Ks,47
8
- vlm4ocr/cli.py,sha256=b13WswreFxTNLA7n2F2jPR7Wrb2Onb06zFnvf7MOLi0,20268
9
- vlm4ocr/data_types.py,sha256=OnbI5IFonp5jPUgq0RIHHSzR9EypJ3jaGoPxtbBpS04,3919
10
- vlm4ocr/ocr_engines.py,sha256=KDU70U-SWkFUZR5fV0vKlnP_9HE0pE5ghzU7ihdTMjc,23605
11
- vlm4ocr/utils.py,sha256=mwp3YDQaoS7YObXtLeuinSEGQ0fJ5KeQMAZmONIKvcg,11907
12
- vlm4ocr/vlm_engines.py,sha256=jQuRZ5HlJtTtJXESiFcoYQXwX-lYu0gc-KKOpRLuW6A,22331
13
- vlm4ocr-0.2.0.dist-info/METADATA,sha256=bpANMPUizxAnWDgrUO3bqNFHnlK72FnCX-7Q7x2bSpA,638
14
- vlm4ocr-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
15
- vlm4ocr-0.2.0.dist-info/entry_points.txt,sha256=qzWUk_QTZ12cH4DLjjfqce89EAlOydD85dreRRZF3K4,44
16
- vlm4ocr-0.2.0.dist-info/RECORD,,