llm-ie 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import abc
2
+ import os
2
3
  import re
3
4
  import warnings
4
5
  import importlib.util
@@ -33,13 +34,13 @@ class LLMConfig(abc.ABC):
33
34
  return NotImplemented
34
35
 
35
36
  @abc.abstractmethod
36
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
37
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
37
38
  """
38
39
  This method postprocesses the LLM response after it is generated.
39
40
 
40
41
  Parameters:
41
42
  ----------
42
- response : Union[str, Generator[Dict[str, str], None, None]]
43
+ response : Union[str, Dict[str, str], Generator[Dict[str, str], None, None]]
43
44
  the LLM response. Can be a dict or a generator.
44
45
 
45
46
  Returns:
@@ -75,15 +76,15 @@ class BasicLLMConfig(LLMConfig):
75
76
  messages : List[Dict[str,str]]
76
77
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
77
78
  """
78
- return messages
79
+ return messages.copy()
79
80
 
80
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
81
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
81
82
  """
82
83
  This method postprocesses the LLM response after it is generated.
83
84
 
84
85
  Parameters:
85
86
  ----------
86
- response : Union[str, Generator[str, None, None]]
87
+ response : Union[str, Dict[str, str], Generator[str, None, None]]
87
88
  the LLM response. Can be a string or a generator.
88
89
 
89
90
  Returns: Union[Dict[str,str], Generator[Dict[str, str], None, None]]
@@ -93,13 +94,27 @@ class BasicLLMConfig(LLMConfig):
93
94
  """
94
95
  if isinstance(response, str):
95
96
  return {"response": response}
97
+
98
+ elif isinstance(response, dict):
99
+ if "response" in response:
100
+ return response
101
+ else:
102
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
103
+ return {"response": ""}
96
104
 
97
- def _process_stream():
98
- for chunk in response:
99
- yield {"type": "response", "data": chunk}
105
+ elif isinstance(response, Generator):
106
+ def _process_stream():
107
+ for chunk in response:
108
+ if isinstance(chunk, dict):
109
+ yield chunk
110
+ elif isinstance(chunk, str):
111
+ yield {"type": "response", "data": chunk}
100
112
 
101
- return _process_stream()
113
+ return _process_stream()
102
114
 
115
+ else:
116
+ warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
117
+ return {"response": ""}
103
118
 
104
119
  class ReasoningLLMConfig(LLMConfig):
105
120
  def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
@@ -124,11 +139,16 @@ class ReasoningLLMConfig(LLMConfig):
124
139
  messages : List[Dict[str,str]]
125
140
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
126
141
  """
127
- return messages
142
+ return messages.copy()
128
143
 
129
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
144
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
130
145
  """
131
146
  This method postprocesses the LLM response after it is generated.
147
+ 1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
148
+ 2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
149
+ 3. If input is a generator,
150
+ a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
151
+ b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
132
152
 
133
153
  Parameters:
134
154
  ----------
@@ -143,18 +163,29 @@ class ReasoningLLMConfig(LLMConfig):
143
163
  """
144
164
  if isinstance(response, str):
145
165
  # get contents between thinking_token_start and thinking_token_end
146
- match = re.search(f"{self.thinking_token_start}.*?{self.thinking_token_end}", response, re.DOTALL)
147
- reasoning = match.group(0) if match else ""
166
+ pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
167
+ match = re.search(pattern, response, re.DOTALL)
168
+ reasoning = match.group(1) if match else ""
148
169
  # get response AFTER thinking_token_end
149
170
  response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
150
171
  return {"reasoning": reasoning, "response": response}
151
172
 
152
- if isinstance(response, Generator):
173
+ elif isinstance(response, dict):
174
+ if "reasoning" in response and "response" in response:
175
+ return response
176
+ else:
177
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
178
+ return {"reasoning": "", "response": ""}
179
+
180
+ elif isinstance(response, Generator):
153
181
  def _process_stream():
154
182
  think_flag = False
155
183
  buffer = ""
156
184
  for chunk in response:
157
- if isinstance(chunk, str):
185
+ if isinstance(chunk, dict):
186
+ yield chunk
187
+
188
+ elif isinstance(chunk, str):
158
189
  buffer += chunk
159
190
  # switch between reasoning and response
160
191
  if self.thinking_token_start in buffer:
@@ -173,6 +204,9 @@ class ReasoningLLMConfig(LLMConfig):
173
204
 
174
205
  return _process_stream()
175
206
 
207
+ else:
208
+ warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
209
+ return {"reasoning": "", "response": ""}
176
210
 
177
211
  class Qwen3LLMConfig(ReasoningLLMConfig):
178
212
  def __init__(self, thinking_mode:bool=True, **kwargs):
@@ -279,6 +313,32 @@ class OpenAIReasoningLLMConfig(ReasoningLLMConfig):
279
313
  return new_messages
280
314
 
281
315
 
316
+ class MessagesLogger:
317
+ def __init__(self):
318
+ """
319
+ This class is used to log the messages for InferenceEngine.chat().
320
+ """
321
+ self.messages_log = []
322
+
323
+ def log_messages(self, messages : List[Dict[str,str]]):
324
+ """
325
+ This method logs the messages to a list.
326
+ """
327
+ self.messages_log.append(messages)
328
+
329
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
330
+ """
331
+ This method returns a copy of the current messages log
332
+ """
333
+ return self.messages_log.copy()
334
+
335
+ def clear_messages_log(self):
336
+ """
337
+ This method clears the current messages log
338
+ """
339
+ self.messages_log.clear()
340
+
341
+
282
342
  class InferenceEngine:
283
343
  @abc.abstractmethod
284
344
  def __init__(self, config:LLMConfig, **kwrs):
@@ -293,10 +353,16 @@ class InferenceEngine:
293
353
  """
294
354
  return NotImplemented
295
355
 
356
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
357
+ return self.messages_log.copy()
358
+
359
+ def clear_messages_log(self):
360
+ self.messages_log = []
361
+
296
362
 
297
363
  @abc.abstractmethod
298
- def chat(self, messages:List[Dict[str,str]],
299
- verbose:bool=False, stream:bool=False) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
364
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
365
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
300
366
  """
301
367
  This method inputs chat messages and outputs LLM generated text.
302
368
 
@@ -308,6 +374,8 @@ class InferenceEngine:
308
374
  if True, LLM generated text will be printed in terminal in real-time.
309
375
  stream : bool, Optional
310
376
  if True, returns a generator that yields the output in real-time.
377
+ Messages_logger : MessagesLogger, Optional
378
+ the message logger that logs the chat messages.
311
379
 
312
380
  Returns:
313
381
  -------
@@ -346,6 +414,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
346
414
  the LLM configuration.
347
415
  """
348
416
  from llama_cpp import Llama
417
+ super().__init__(config)
349
418
  self.repo_id = repo_id
350
419
  self.gguf_filename = gguf_filename
351
420
  self.n_ctx = n_ctx
@@ -378,7 +447,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
378
447
 
379
448
  return formatted_params
380
449
 
381
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> Dict[str,str]:
450
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, messages_logger:MessagesLogger=None) -> Dict[str,str]:
382
451
  """
383
452
  This method inputs chat messages and outputs LLM generated text.
384
453
 
@@ -388,15 +457,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
388
457
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
389
458
  verbose : bool, Optional
390
459
  if True, LLM generated text will be printed in terminal in real-time.
460
+ messages_logger : MessagesLogger, Optional
461
+ the message logger that logs the chat messages.
391
462
  """
463
+ # Preprocess messages
392
464
  processed_messages = self.config.preprocess_messages(messages)
393
-
465
+ # Generate response
394
466
  response = self.model.create_chat_completion(
395
467
  messages=processed_messages,
396
468
  stream=verbose,
397
469
  **self.formatted_params
398
470
  )
399
-
471
+
400
472
  if verbose:
401
473
  res = ''
402
474
  for chunk in response:
@@ -408,7 +480,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
408
480
  return self.config.postprocess_response(res)
409
481
 
410
482
  res = response['choices'][0]['message']['content']
411
- return self.config.postprocess_response(res)
483
+ # Postprocess response
484
+ res_dict = self.config.postprocess_response(res)
485
+ # Write to messages log
486
+ if messages_logger:
487
+ processed_messages.append({"role": "assistant",
488
+ "content": res_dict.get("response", ""),
489
+ "reasoning": res_dict.get("reasoning", "")})
490
+ messages_logger.log_messages(processed_messages)
491
+
492
+ return res_dict
412
493
 
413
494
 
414
495
  class OllamaInferenceEngine(InferenceEngine):
@@ -431,6 +512,7 @@ class OllamaInferenceEngine(InferenceEngine):
431
512
  raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
432
513
 
433
514
  from ollama import Client, AsyncClient
515
+ super().__init__(config)
434
516
  self.client = Client(**kwrs)
435
517
  self.async_client = AsyncClient(**kwrs)
436
518
  self.model_name = model_name
@@ -450,8 +532,8 @@ class OllamaInferenceEngine(InferenceEngine):
450
532
 
451
533
  return formatted_params
452
534
 
453
- def chat(self, messages:List[Dict[str,str]],
454
- verbose:bool=False, stream:bool=False) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
535
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
536
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
455
537
  """
456
538
  This method inputs chat messages and outputs VLM generated text.
457
539
 
@@ -463,6 +545,8 @@ class OllamaInferenceEngine(InferenceEngine):
463
545
  if True, VLM generated text will be printed in terminal in real-time.
464
546
  stream : bool, Optional
465
547
  if True, returns a generator that yields the output in real-time.
548
+ Messages_logger : MessagesLogger, Optional
549
+ the message logger that logs the chat messages.
466
550
 
467
551
  Returns:
468
552
  -------
@@ -481,10 +565,28 @@ class OllamaInferenceEngine(InferenceEngine):
481
565
  stream=True,
482
566
  keep_alive=self.keep_alive
483
567
  )
568
+ res = {"reasoning": "", "response": ""}
484
569
  for chunk in response_stream:
485
- content_chunk = chunk.get('message', {}).get('content')
486
- if content_chunk:
487
- yield content_chunk
570
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
571
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
572
+ res["reasoning"] += content_chunk
573
+ yield {"type": "reasoning", "data": content_chunk}
574
+ else:
575
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
576
+ res["response"] += content_chunk
577
+ yield {"type": "response", "data": content_chunk}
578
+
579
+ if chunk.done_reason == "length":
580
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
581
+
582
+ # Postprocess response
583
+ res_dict = self.config.postprocess_response(res)
584
+ # Write to messages log
585
+ if messages_logger:
586
+ processed_messages.append({"role": "assistant",
587
+ "content": res_dict.get("response", ""),
588
+ "reasoning": res_dict.get("reasoning", "")})
589
+ messages_logger.log_messages(processed_messages)
488
590
 
489
591
  return self.config.postprocess_response(_stream_generator())
490
592
 
@@ -497,14 +599,29 @@ class OllamaInferenceEngine(InferenceEngine):
497
599
  keep_alive=self.keep_alive
498
600
  )
499
601
 
500
- res = ''
602
+ res = {"reasoning": "", "response": ""}
603
+ phase = ""
501
604
  for chunk in response:
502
- content_chunk = chunk.get('message', {}).get('content')
605
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
606
+ if phase != "reasoning":
607
+ print("\n--- Reasoning ---")
608
+ phase = "reasoning"
609
+
610
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
611
+ res["reasoning"] += content_chunk
612
+ else:
613
+ if phase != "response":
614
+ print("\n--- Response ---")
615
+ phase = "response"
616
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
617
+ res["response"] += content_chunk
618
+
503
619
  print(content_chunk, end='', flush=True)
504
- res += content_chunk
620
+
621
+ if chunk.done_reason == "length":
622
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
505
623
  print('\n')
506
- return self.config.postprocess_response(res)
507
-
624
+
508
625
  else:
509
626
  response = self.client.chat(
510
627
  model=self.model_name,
@@ -513,11 +630,25 @@ class OllamaInferenceEngine(InferenceEngine):
513
630
  stream=False,
514
631
  keep_alive=self.keep_alive
515
632
  )
516
- res = response.get('message', {}).get('content')
517
- return self.config.postprocess_response(res)
633
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
634
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
635
+
636
+ if response.done_reason == "length":
637
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
638
+
639
+ # Postprocess response
640
+ res_dict = self.config.postprocess_response(res)
641
+ # Write to messages log
642
+ if messages_logger:
643
+ processed_messages.append({"role": "assistant",
644
+ "content": res_dict.get("response", ""),
645
+ "reasoning": res_dict.get("reasoning", "")})
646
+ messages_logger.log_messages(processed_messages)
647
+
648
+ return res_dict
518
649
 
519
650
 
520
- async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
651
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
521
652
  """
522
653
  Async version of chat method. Streaming is not supported.
523
654
  """
@@ -531,8 +662,21 @@ class OllamaInferenceEngine(InferenceEngine):
531
662
  keep_alive=self.keep_alive
532
663
  )
533
664
 
534
- res = response['message']['content']
535
- return self.config.postprocess_response(res)
665
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
666
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
667
+
668
+ if response.done_reason == "length":
669
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
670
+ # Postprocess response
671
+ res_dict = self.config.postprocess_response(res)
672
+ # Write to messages log
673
+ if messages_logger:
674
+ processed_messages.append({"role": "assistant",
675
+ "content": res_dict.get("response", ""),
676
+ "reasoning": res_dict.get("reasoning", "")})
677
+ messages_logger.log_messages(processed_messages)
678
+
679
+ return res_dict
536
680
 
537
681
 
538
682
  class HuggingFaceHubInferenceEngine(InferenceEngine):
@@ -558,6 +702,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
558
702
  raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
559
703
 
560
704
  from huggingface_hub import InferenceClient, AsyncInferenceClient
705
+ super().__init__(config)
561
706
  self.model = model
562
707
  self.base_url = base_url
563
708
  self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
@@ -577,8 +722,8 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
577
722
  return formatted_params
578
723
 
579
724
 
580
- def chat(self, messages:List[Dict[str,str]],
581
- verbose:bool=False, stream:bool=False) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
725
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
726
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
582
727
  """
583
728
  This method inputs chat messages and outputs LLM generated text.
584
729
 
@@ -590,7 +735,9 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
590
735
  if True, VLM generated text will be printed in terminal in real-time.
591
736
  stream : bool, Optional
592
737
  if True, returns a generator that yields the output in real-time.
593
-
738
+ messages_logger : MessagesLogger, Optional
739
+ the message logger that logs the chat messages.
740
+
594
741
  Returns:
595
742
  -------
596
743
  response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
@@ -605,11 +752,22 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
605
752
  stream=True,
606
753
  **self.formatted_params
607
754
  )
755
+ res_text = ""
608
756
  for chunk in response_stream:
609
757
  content_chunk = chunk.get('choices')[0].get('delta').get('content')
610
758
  if content_chunk:
759
+ res_text += content_chunk
611
760
  yield content_chunk
612
761
 
762
+ # Postprocess response
763
+ res_dict = self.config.postprocess_response(res_text)
764
+ # Write to messages log
765
+ if messages_logger:
766
+ processed_messages.append({"role": "assistant",
767
+ "content": res_dict.get("response", ""),
768
+ "reasoning": res_dict.get("reasoning", "")})
769
+ messages_logger.log_messages(processed_messages)
770
+
613
771
  return self.config.postprocess_response(_stream_generator())
614
772
 
615
773
  elif verbose:
@@ -625,7 +783,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
625
783
  if content_chunk:
626
784
  res += content_chunk
627
785
  print(content_chunk, end='', flush=True)
628
- return self.config.postprocess_response(res)
786
+
629
787
 
630
788
  else:
631
789
  response = self.client.chat.completions.create(
@@ -634,9 +792,20 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
634
792
  **self.formatted_params
635
793
  )
636
794
  res = response.choices[0].message.content
637
- return self.config.postprocess_response(res)
795
+
796
+ # Postprocess response
797
+ res_dict = self.config.postprocess_response(res)
798
+ # Write to messages log
799
+ if messages_logger:
800
+ processed_messages.append({"role": "assistant",
801
+ "content": res_dict.get("response", ""),
802
+ "reasoning": res_dict.get("reasoning", "")})
803
+ messages_logger.log_messages(processed_messages)
804
+
805
+ return res_dict
806
+
638
807
 
639
- async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
808
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
640
809
  """
641
810
  Async version of chat method. Streaming is not supported.
642
811
  """
@@ -649,16 +818,343 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
649
818
  )
650
819
 
651
820
  res = response.choices[0].message.content
652
- return self.config.postprocess_response(res)
821
+ # Postprocess response
822
+ res_dict = self.config.postprocess_response(res)
823
+ # Write to messages log
824
+ if messages_logger:
825
+ processed_messages.append({"role": "assistant",
826
+ "content": res_dict.get("response", ""),
827
+ "reasoning": res_dict.get("reasoning", "")})
828
+ messages_logger.log_messages(processed_messages)
829
+
830
+ return res_dict
831
+
832
+
833
+ class OpenAICompatibleInferenceEngine(InferenceEngine):
834
+ def __init__(self, model:str, api_key:str, base_url:str, config:LLMConfig=None, **kwrs):
835
+ """
836
+ General OpenAI-compatible server inference engine.
837
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
838
+
839
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
840
+
841
+ Parameters:
842
+ ----------
843
+ model_name : str
844
+ model name as shown in the vLLM server
845
+ api_key : str
846
+ the API key for the vLLM server.
847
+ base_url : str
848
+ the base url for the vLLM server.
849
+ config : LLMConfig
850
+ the LLM configuration.
851
+ """
852
+ if importlib.util.find_spec("openai") is None:
853
+ raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
653
854
 
855
+ from openai import OpenAI, AsyncOpenAI
856
+ from openai.types.chat import ChatCompletionChunk
857
+ self.ChatCompletionChunk = ChatCompletionChunk
858
+ super().__init__(config)
859
+ self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
860
+ self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
861
+ self.model = model
862
+ self.config = config if config else BasicLLMConfig()
863
+ self.formatted_params = self._format_config()
864
+
865
+ def _format_config(self) -> Dict[str, Any]:
866
+ """
867
+ This method format the LLM configuration with the correct key for the inference engine.
868
+ """
869
+ formatted_params = self.config.params.copy()
870
+ if "max_new_tokens" in formatted_params:
871
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
872
+ formatted_params.pop("max_new_tokens")
873
+
874
+ return formatted_params
875
+
876
+ @abc.abstractmethod
877
+ def _format_response(self, response: Any) -> Dict[str, str]:
878
+ """
879
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
880
+
881
+ Parameters:
882
+ ----------
883
+ response : Any
884
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
885
+ """
886
+ return NotImplemented
887
+
888
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
889
+ """
890
+ This method inputs chat messages and outputs LLM generated text.
891
+
892
+ Parameters:
893
+ ----------
894
+ messages : List[Dict[str,str]]
895
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
896
+ verbose : bool, Optional
897
+ if True, VLM generated text will be printed in terminal in real-time.
898
+ stream : bool, Optional
899
+ if True, returns a generator that yields the output in real-time.
900
+ messages_logger : MessagesLogger, Optional
901
+ the message logger that logs the chat messages.
902
+
903
+ Returns:
904
+ -------
905
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
906
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
907
+ """
908
+ processed_messages = self.config.preprocess_messages(messages)
909
+
910
+ if stream:
911
+ def _stream_generator():
912
+ response_stream = self.client.chat.completions.create(
913
+ model=self.model,
914
+ messages=processed_messages,
915
+ stream=True,
916
+ **self.formatted_params
917
+ )
918
+ res_text = ""
919
+ for chunk in response_stream:
920
+ if len(chunk.choices) > 0:
921
+ chunk_dict = self._format_response(chunk)
922
+ yield chunk_dict
923
+
924
+ res_text += chunk_dict["data"]
925
+ if chunk.choices[0].finish_reason == "length":
926
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
927
+
928
+ # Postprocess response
929
+ res_dict = self.config.postprocess_response(res_text)
930
+ # Write to messages log
931
+ if messages_logger:
932
+ processed_messages.append({"role": "assistant",
933
+ "content": res_dict.get("response", ""),
934
+ "reasoning": res_dict.get("reasoning", "")})
935
+ messages_logger.log_messages(processed_messages)
936
+
937
+ return self.config.postprocess_response(_stream_generator())
938
+
939
+ elif verbose:
940
+ response = self.client.chat.completions.create(
941
+ model=self.model,
942
+ messages=processed_messages,
943
+ stream=True,
944
+ **self.formatted_params
945
+ )
946
+ res = {"reasoning": "", "response": ""}
947
+ phase = ""
948
+ for chunk in response:
949
+ if len(chunk.choices) > 0:
950
+ chunk_dict = self._format_response(chunk)
951
+ chunk_text = chunk_dict["data"]
952
+ res[chunk_dict["type"]] += chunk_text
953
+ if phase != chunk_dict["type"] and chunk_text != "":
954
+ print(f"\n--- {chunk_dict['type'].capitalize()} ---")
955
+ phase = chunk_dict["type"]
956
+
957
+ print(chunk_text, end="", flush=True)
958
+ if chunk.choices[0].finish_reason == "length":
959
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
960
+
961
+ print('\n')
962
+
963
+ else:
964
+ response = self.client.chat.completions.create(
965
+ model=self.model,
966
+ messages=processed_messages,
967
+ stream=False,
968
+ **self.formatted_params
969
+ )
970
+ res = self._format_response(response)
971
+
972
+ if response.choices[0].finish_reason == "length":
973
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
974
+
975
+ # Postprocess response
976
+ res_dict = self.config.postprocess_response(res)
977
+ # Write to messages log
978
+ if messages_logger:
979
+ processed_messages.append({"role": "assistant",
980
+ "content": res_dict.get("response", ""),
981
+ "reasoning": res_dict.get("reasoning", "")})
982
+ messages_logger.log_messages(processed_messages)
983
+
984
+ return res_dict
985
+
986
+
987
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
988
+ """
989
+ Async version of chat method. Streaming is not supported.
990
+ """
991
+ processed_messages = self.config.preprocess_messages(messages)
992
+
993
+ response = await self.async_client.chat.completions.create(
994
+ model=self.model,
995
+ messages=processed_messages,
996
+ stream=False,
997
+ **self.formatted_params
998
+ )
999
+
1000
+ if response.choices[0].finish_reason == "length":
1001
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1002
+
1003
+ res = self._format_response(response)
1004
+
1005
+ # Postprocess response
1006
+ res_dict = self.config.postprocess_response(res)
1007
+ # Write to messages log
1008
+ if messages_logger:
1009
+ processed_messages.append({"role": "assistant",
1010
+ "content": res_dict.get("response", ""),
1011
+ "reasoning": res_dict.get("reasoning", "")})
1012
+ messages_logger.log_messages(processed_messages)
1013
+
1014
+ return res_dict
1015
+
1016
+
1017
+ class VLLMInferenceEngine(OpenAICompatibleInferenceEngine):
1018
+ def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:LLMConfig=None, **kwrs):
1019
+ """
1020
+ vLLM OpenAI compatible server inference engine.
1021
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
1022
+
1023
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
1024
+
1025
+ Parameters:
1026
+ ----------
1027
+ model_name : str
1028
+ model name as shown in the vLLM server
1029
+ api_key : str, Optional
1030
+ the API key for the vLLM server.
1031
+ base_url : str, Optional
1032
+ the base url for the vLLM server.
1033
+ config : LLMConfig
1034
+ the LLM configuration.
1035
+ """
1036
+ super().__init__(model, api_key, base_url, config, **kwrs)
1037
+
1038
+
1039
+ def _format_response(self, response: Any) -> Dict[str, str]:
1040
+ """
1041
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
1042
+
1043
+ Parameters:
1044
+ ----------
1045
+ response : Any
1046
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
1047
+ """
1048
+ if isinstance(response, self.ChatCompletionChunk):
1049
+ if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
1050
+ chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
1051
+ if chunk_text is None:
1052
+ chunk_text = ""
1053
+ return {"type": "reasoning", "data": chunk_text}
1054
+ else:
1055
+ chunk_text = getattr(response.choices[0].delta, "content", "")
1056
+ if chunk_text is None:
1057
+ chunk_text = ""
1058
+ return {"type": "response", "data": chunk_text}
1059
+
1060
+ return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
1061
+ "response": getattr(response.choices[0].message, "content", "")}
1062
+
1063
+ class SGLangInferenceEngine(OpenAICompatibleInferenceEngine):
1064
+ def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:30000/v1", config:LLMConfig=None, **kwrs):
1065
+ """
1066
+ SGLang OpenAI compatible API inference engine.
1067
+ https://docs.sglang.ai/basic_usage/openai_api.html
1068
+
1069
+ Parameters:
1070
+ ----------
1071
+ model_name : str
1072
+ model name as shown in the vLLM server
1073
+ api_key : str, Optional
1074
+ the API key for the vLLM server.
1075
+ base_url : str, Optional
1076
+ the base url for the vLLM server.
1077
+ config : LLMConfig
1078
+ the LLM configuration.
1079
+ """
1080
+ super().__init__(model, api_key, base_url, config, **kwrs)
1081
+
1082
+
1083
+ def _format_response(self, response: Any) -> Dict[str, str]:
1084
+ """
1085
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
1086
+
1087
+ Parameters:
1088
+ ----------
1089
+ response : Any
1090
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
1091
+ """
1092
+ if isinstance(response, self.ChatCompletionChunk):
1093
+ if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
1094
+ chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
1095
+ if chunk_text is None:
1096
+ chunk_text = ""
1097
+ return {"type": "reasoning", "data": chunk_text}
1098
+ else:
1099
+ chunk_text = getattr(response.choices[0].delta, "content", "")
1100
+ if chunk_text is None:
1101
+ chunk_text = ""
1102
+ return {"type": "response", "data": chunk_text}
1103
+
1104
+ return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
1105
+ "response": getattr(response.choices[0].message, "content", "")}
1106
+
1107
+
1108
+ class OpenRouterInferenceEngine(OpenAICompatibleInferenceEngine):
1109
+ def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:LLMConfig=None, **kwrs):
1110
+ """
1111
+ OpenRouter OpenAI-compatible server inference engine.
1112
+
1113
+ Parameters:
1114
+ ----------
1115
+ model_name : str
1116
+ model name as shown in the vLLM server
1117
+ api_key : str, Optional
1118
+ the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
1119
+ base_url : str, Optional
1120
+ the base url for the vLLM server.
1121
+ config : LLMConfig
1122
+ the LLM configuration.
1123
+ """
1124
+ self.api_key = api_key
1125
+ if self.api_key is None:
1126
+ self.api_key = os.getenv("OPENROUTER_API_KEY")
1127
+ super().__init__(model, self.api_key, base_url, config, **kwrs)
1128
+
1129
+ def _format_response(self, response: Any) -> Dict[str, str]:
1130
+ """
1131
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
1132
+
1133
+ Parameters:
1134
+ ----------
1135
+ response : Any
1136
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
1137
+ """
1138
+ if isinstance(response, self.ChatCompletionChunk):
1139
+ if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
1140
+ chunk_text = getattr(response.choices[0].delta, "reasoning", "")
1141
+ if chunk_text is None:
1142
+ chunk_text = ""
1143
+ return {"type": "reasoning", "data": chunk_text}
1144
+ else:
1145
+ chunk_text = getattr(response.choices[0].delta, "content", "")
1146
+ if chunk_text is None:
1147
+ chunk_text = ""
1148
+ return {"type": "response", "data": chunk_text}
1149
+
1150
+ return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
1151
+ "response": getattr(response.choices[0].message, "content", "")}
1152
+
654
1153
 
655
1154
  class OpenAIInferenceEngine(InferenceEngine):
656
1155
  def __init__(self, model:str, config:LLMConfig=None, **kwrs):
657
1156
  """
658
- The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
659
- - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
660
- - Llama.cpp OpenAI compatible server (https://llama-cpp-python.readthedocs.io/en/latest/server/)
661
-
1157
+ The OpenAI API inference engine.
662
1158
  For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
663
1159
 
664
1160
  Parameters:
@@ -670,6 +1166,7 @@ class OpenAIInferenceEngine(InferenceEngine):
670
1166
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
671
1167
 
672
1168
  from openai import OpenAI, AsyncOpenAI
1169
+ super().__init__(config)
673
1170
  self.client = OpenAI(**kwrs)
674
1171
  self.async_client = AsyncOpenAI(**kwrs)
675
1172
  self.model = model
@@ -687,7 +1184,7 @@ class OpenAIInferenceEngine(InferenceEngine):
687
1184
 
688
1185
  return formatted_params
689
1186
 
690
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
1187
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
691
1188
  """
692
1189
  This method inputs chat messages and outputs LLM generated text.
693
1190
 
@@ -699,6 +1196,8 @@ class OpenAIInferenceEngine(InferenceEngine):
699
1196
  if True, VLM generated text will be printed in terminal in real-time.
700
1197
  stream : bool, Optional
701
1198
  if True, returns a generator that yields the output in real-time.
1199
+ messages_logger : MessagesLogger, Optional
1200
+ the message logger that logs the chat messages.
702
1201
 
703
1202
  Returns:
704
1203
  -------
@@ -715,13 +1214,25 @@ class OpenAIInferenceEngine(InferenceEngine):
715
1214
  stream=True,
716
1215
  **self.formatted_params
717
1216
  )
1217
+ res_text = ""
718
1218
  for chunk in response_stream:
719
1219
  if len(chunk.choices) > 0:
720
- if chunk.choices[0].delta.content is not None:
721
- yield chunk.choices[0].delta.content
1220
+ chunk_text = chunk.choices[0].delta.content
1221
+ if chunk_text is not None:
1222
+ res_text += chunk_text
1223
+ yield chunk_text
722
1224
  if chunk.choices[0].finish_reason == "length":
723
1225
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
724
1226
 
1227
+ # Postprocess response
1228
+ res_dict = self.config.postprocess_response(res_text)
1229
+ # Write to messages log
1230
+ if messages_logger:
1231
+ processed_messages.append({"role": "assistant",
1232
+ "content": res_dict.get("response", ""),
1233
+ "reasoning": res_dict.get("reasoning", "")})
1234
+ messages_logger.log_messages(processed_messages)
1235
+
725
1236
  return self.config.postprocess_response(_stream_generator())
726
1237
 
727
1238
  elif verbose:
@@ -741,7 +1252,7 @@ class OpenAIInferenceEngine(InferenceEngine):
741
1252
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
742
1253
 
743
1254
  print('\n')
744
- return self.config.postprocess_response(res)
1255
+
745
1256
  else:
746
1257
  response = self.client.chat.completions.create(
747
1258
  model=self.model,
@@ -750,10 +1261,20 @@ class OpenAIInferenceEngine(InferenceEngine):
750
1261
  **self.formatted_params
751
1262
  )
752
1263
  res = response.choices[0].message.content
753
- return self.config.postprocess_response(res)
1264
+
1265
+ # Postprocess response
1266
+ res_dict = self.config.postprocess_response(res)
1267
+ # Write to messages log
1268
+ if messages_logger:
1269
+ processed_messages.append({"role": "assistant",
1270
+ "content": res_dict.get("response", ""),
1271
+ "reasoning": res_dict.get("reasoning", "")})
1272
+ messages_logger.log_messages(processed_messages)
1273
+
1274
+ return res_dict
754
1275
 
755
1276
 
756
- async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
1277
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
757
1278
  """
758
1279
  Async version of chat method. Streaming is not supported.
759
1280
  """
@@ -770,7 +1291,16 @@ class OpenAIInferenceEngine(InferenceEngine):
770
1291
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
771
1292
 
772
1293
  res = response.choices[0].message.content
773
- return self.config.postprocess_response(res)
1294
+ # Postprocess response
1295
+ res_dict = self.config.postprocess_response(res)
1296
+ # Write to messages log
1297
+ if messages_logger:
1298
+ processed_messages.append({"role": "assistant",
1299
+ "content": res_dict.get("response", ""),
1300
+ "reasoning": res_dict.get("reasoning", "")})
1301
+ messages_logger.log_messages(processed_messages)
1302
+
1303
+ return res_dict
774
1304
 
775
1305
 
776
1306
  class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
@@ -825,6 +1355,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
825
1355
  raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
826
1356
 
827
1357
  import litellm
1358
+ super().__init__(config)
828
1359
  self.litellm = litellm
829
1360
  self.model = model
830
1361
  self.base_url = base_url
@@ -843,7 +1374,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
843
1374
 
844
1375
  return formatted_params
845
1376
 
846
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
1377
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
847
1378
  """
848
1379
  This method inputs chat messages and outputs LLM generated text.
849
1380
 
@@ -855,6 +1386,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
855
1386
  if True, VLM generated text will be printed in terminal in real-time.
856
1387
  stream : bool, Optional
857
1388
  if True, returns a generator that yields the output in real-time.
1389
+ messages_logger: MessagesLogger, Optional
1390
+ a messages logger that logs the messages.
858
1391
 
859
1392
  Returns:
860
1393
  -------
@@ -873,12 +1406,22 @@ class LiteLLMInferenceEngine(InferenceEngine):
873
1406
  api_key=self.api_key,
874
1407
  **self.formatted_params
875
1408
  )
876
-
1409
+ res_text = ""
877
1410
  for chunk in response_stream:
878
1411
  chunk_content = chunk.get('choices')[0].get('delta').get('content')
879
1412
  if chunk_content:
1413
+ res_text += chunk_content
880
1414
  yield chunk_content
881
1415
 
1416
+ # Postprocess response
1417
+ res_dict = self.config.postprocess_response(res_text)
1418
+ # Write to messages log
1419
+ if messages_logger:
1420
+ processed_messages.append({"role": "assistant",
1421
+ "content": res_dict.get("response", ""),
1422
+ "reasoning": res_dict.get("reasoning", "")})
1423
+ messages_logger.log_messages(processed_messages)
1424
+
882
1425
  return self.config.postprocess_response(_stream_generator())
883
1426
 
884
1427
  elif verbose:
@@ -897,8 +1440,6 @@ class LiteLLMInferenceEngine(InferenceEngine):
897
1440
  if chunk_content:
898
1441
  res += chunk_content
899
1442
  print(chunk_content, end='', flush=True)
900
-
901
- return self.config.postprocess_response(res)
902
1443
 
903
1444
  else:
904
1445
  response = self.litellm.completion(
@@ -910,9 +1451,19 @@ class LiteLLMInferenceEngine(InferenceEngine):
910
1451
  **self.formatted_params
911
1452
  )
912
1453
  res = response.choices[0].message.content
913
- return self.config.postprocess_response(res)
1454
+
1455
+ # Postprocess response
1456
+ res_dict = self.config.postprocess_response(res)
1457
+ # Write to messages log
1458
+ if messages_logger:
1459
+ processed_messages.append({"role": "assistant",
1460
+ "content": res_dict.get("response", ""),
1461
+ "reasoning": res_dict.get("reasoning", "")})
1462
+ messages_logger.log_messages(processed_messages)
1463
+
1464
+ return res_dict
914
1465
 
915
- async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
1466
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
916
1467
  """
917
1468
  Async version of chat method. Streaming is not supported.
918
1469
  """
@@ -928,4 +1479,13 @@ class LiteLLMInferenceEngine(InferenceEngine):
928
1479
  )
929
1480
 
930
1481
  res = response.get('choices')[0].get('message').get('content')
931
- return self.config.postprocess_response(res)
1482
+
1483
+ # Postprocess response
1484
+ res_dict = self.config.postprocess_response(res)
1485
+ # Write to messages log
1486
+ if messages_logger:
1487
+ processed_messages.append({"role": "assistant",
1488
+ "content": res_dict.get("response", ""),
1489
+ "reasoning": res_dict.get("reasoning", "")})
1490
+ messages_logger.log_messages(processed_messages)
1491
+ return res_dict