llm-ie 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import abc
2
+ import os
2
3
  import re
3
4
  import warnings
4
5
  import importlib.util
@@ -33,18 +34,18 @@ class LLMConfig(abc.ABC):
33
34
  return NotImplemented
34
35
 
35
36
  @abc.abstractmethod
36
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
37
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
37
38
  """
38
39
  This method postprocesses the LLM response after it is generated.
39
40
 
40
41
  Parameters:
41
42
  ----------
42
- response : Union[str, Generator[str, None, None]]
43
- the LLM response. Can be a string or a generator.
43
+ response : Union[str, Dict[str, str], Generator[Dict[str, str], None, None]]
44
+ the LLM response. Can be a dict or a generator.
44
45
 
45
46
  Returns:
46
47
  -------
47
- response : str
48
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
48
49
  the postprocessed LLM response
49
50
  """
50
51
  return NotImplemented
@@ -75,47 +76,58 @@ class BasicLLMConfig(LLMConfig):
75
76
  messages : List[Dict[str,str]]
76
77
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
77
78
  """
78
- return messages
79
+ return messages.copy()
79
80
 
80
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
81
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
81
82
  """
82
83
  This method postprocesses the LLM response after it is generated.
83
84
 
84
85
  Parameters:
85
86
  ----------
86
- response : Union[str, Generator[str, None, None]]
87
+ response : Union[str, Dict[str, str], Generator[str, None, None]]
87
88
  the LLM response. Can be a string or a generator.
88
89
 
89
- Returns: Union[str, Generator[Dict[str, str], None, None]]
90
+ Returns: Union[Dict[str,str], Generator[Dict[str, str], None, None]]
90
91
  the postprocessed LLM response.
91
- if input is a generator, the output will be a generator {"data": <content>}.
92
+ If input is a string, the output will be a dict {"response": <response>}.
93
+ if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
92
94
  """
93
95
  if isinstance(response, str):
94
- return response
96
+ return {"response": response}
97
+
98
+ elif isinstance(response, dict):
99
+ if "response" in response:
100
+ return response
101
+ else:
102
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
103
+ return {"response": ""}
95
104
 
96
- def _process_stream():
97
- for chunk in response:
98
- yield {"type": "response", "data": chunk}
105
+ elif isinstance(response, Generator):
106
+ def _process_stream():
107
+ for chunk in response:
108
+ if isinstance(chunk, dict):
109
+ yield chunk
110
+ elif isinstance(chunk, str):
111
+ yield {"type": "response", "data": chunk}
112
+
113
+ return _process_stream()
99
114
 
100
- return _process_stream()
115
+ else:
116
+ warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
117
+ return {"response": ""}
101
118
 
102
- class Qwen3LLMConfig(LLMConfig):
103
- def __init__(self, thinking_mode:bool=True, **kwargs):
119
+ class ReasoningLLMConfig(LLMConfig):
120
+ def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
104
121
  """
105
- The Qwen3 LLM configuration for reasoning models.
106
-
107
- Parameters:
108
- ----------
109
- thinking_mode : bool, Optional
110
- if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
122
+ The general LLM configuration for reasoning models.
111
123
  """
112
124
  super().__init__(**kwargs)
113
- self.thinking_mode = thinking_mode
125
+ self.thinking_token_start = thinking_token_start
126
+ self.thinking_token_end = thinking_token_end
114
127
 
115
128
  def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
116
129
  """
117
- Append a special token to the system and user prompts.
118
- The token is "/think" if thinking_mode is True, otherwise "/no_think".
130
+ This method preprocesses the input messages before passing them to the LLM.
119
131
 
120
132
  Parameters:
121
133
  ----------
@@ -127,23 +139,16 @@ class Qwen3LLMConfig(LLMConfig):
127
139
  messages : List[Dict[str,str]]
128
140
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
129
141
  """
130
- thinking_token = "/think" if self.thinking_mode else "/no_think"
131
- new_messages = []
132
- for message in messages:
133
- if message['role'] in ['system', 'user']:
134
- new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
135
- else:
136
- new_message = {'role': message['role'], 'content': message['content']}
142
+ return messages.copy()
137
143
 
138
- new_messages.append(new_message)
139
-
140
- return new_messages
141
-
142
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str,str], None, None]]:
144
+ def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
143
145
  """
144
- If input is a generator, tag contents in <think> and </think> as {"type": "reasoning", "data": <content>},
145
- and the rest as {"type": "response", "data": <content>}.
146
- If input is a string, drop contents in <think> and </think>.
146
+ This method postprocesses the LLM response after it is generated.
147
+ 1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
148
+ 2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
149
+ 3. If input is a generator,
150
+ a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
151
+ b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
147
152
 
148
153
  Parameters:
149
154
  ----------
@@ -153,38 +158,99 @@ class Qwen3LLMConfig(LLMConfig):
153
158
  Returns:
154
159
  -------
155
160
  response : Union[str, Generator[str, None, None]]
156
- the postprocessed LLM response.
161
+ the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
157
162
  if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
158
163
  """
159
164
  if isinstance(response, str):
160
- return re.sub(r"<think>.*?</think>\s*", "", response, flags=re.DOTALL).strip()
165
+ # get contents between thinking_token_start and thinking_token_end
166
+ pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
167
+ match = re.search(pattern, response, re.DOTALL)
168
+ reasoning = match.group(1) if match else ""
169
+ # get response AFTER thinking_token_end
170
+ response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
171
+ return {"reasoning": reasoning, "response": response}
172
+
173
+ elif isinstance(response, dict):
174
+ if "reasoning" in response and "response" in response:
175
+ return response
176
+ else:
177
+ warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
178
+ return {"reasoning": "", "response": ""}
161
179
 
162
- if isinstance(response, Generator):
180
+ elif isinstance(response, Generator):
163
181
  def _process_stream():
164
182
  think_flag = False
165
183
  buffer = ""
166
184
  for chunk in response:
167
- if isinstance(chunk, str):
185
+ if isinstance(chunk, dict):
186
+ yield chunk
187
+
188
+ elif isinstance(chunk, str):
168
189
  buffer += chunk
169
190
  # switch between reasoning and response
170
- if "<think>" in buffer:
191
+ if self.thinking_token_start in buffer:
171
192
  think_flag = True
172
- buffer = buffer.replace("<think>", "")
173
- elif "</think>" in buffer:
193
+ buffer = buffer.replace(self.thinking_token_start, "")
194
+ elif self.thinking_token_end in buffer:
174
195
  think_flag = False
175
- buffer = buffer.replace("</think>", "")
196
+ buffer = buffer.replace(self.thinking_token_end, "")
176
197
 
177
198
  # if chunk is in thinking block, tag it as reasoning; else tag it as response
178
- if chunk not in ["<think>", "</think>"]:
199
+ if chunk not in [self.thinking_token_start, self.thinking_token_end]:
179
200
  if think_flag:
180
201
  yield {"type": "reasoning", "data": chunk}
181
202
  else:
182
203
  yield {"type": "response", "data": chunk}
183
204
 
184
205
  return _process_stream()
206
+
207
+ else:
208
+ warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
209
+ return {"reasoning": "", "response": ""}
210
+
211
+ class Qwen3LLMConfig(ReasoningLLMConfig):
212
+ def __init__(self, thinking_mode:bool=True, **kwargs):
213
+ """
214
+ The Qwen3 **hybrid thinking** LLM configuration.
215
+ For Qwen3 thinking 2507, use ReasoningLLMConfig instead; for Qwen3 Instruct, use BasicLLMConfig instead.
216
+
217
+ Parameters:
218
+ ----------
219
+ thinking_mode : bool, Optional
220
+ if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
221
+ """
222
+ super().__init__(**kwargs)
223
+ self.thinking_mode = thinking_mode
185
224
 
225
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
226
+ """
227
+ Append a special token to the system and user prompts.
228
+ The token is "/think" if thinking_mode is True, otherwise "/no_think".
186
229
 
187
- class OpenAIReasoningLLMConfig(LLMConfig):
230
+ Parameters:
231
+ ----------
232
+ messages : List[Dict[str,str]]
233
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
234
+
235
+ Returns:
236
+ -------
237
+ messages : List[Dict[str,str]]
238
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
239
+ """
240
+ thinking_token = "/think" if self.thinking_mode else "/no_think"
241
+ new_messages = []
242
+ for message in messages:
243
+ if message['role'] in ['system', 'user']:
244
+ new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
245
+ else:
246
+ new_message = {'role': message['role'], 'content': message['content']}
247
+
248
+ new_messages.append(new_message)
249
+
250
+ return new_messages
251
+
252
+
253
+ class OpenAIReasoningLLMConfig(ReasoningLLMConfig):
188
254
  def __init__(self, reasoning_effort:str=None, **kwargs):
189
255
  """
190
256
  The OpenAI "o" series configuration.
@@ -246,27 +312,31 @@ class OpenAIReasoningLLMConfig(LLMConfig):
246
312
 
247
313
  return new_messages
248
314
 
249
- def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
250
- """
251
- This method postprocesses the LLM response after it is generated.
252
315
 
253
- Parameters:
254
- ----------
255
- response : Union[str, Generator[str, None, None]]
256
- the LLM response. Can be a string or a generator.
257
-
258
- Returns: Union[str, Generator[Dict[str, str], None, None]]
259
- the postprocessed LLM response.
260
- if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
316
+ class MessagesLogger:
317
+ def __init__(self):
261
318
  """
262
- if isinstance(response, str):
263
- return response
319
+ This class is used to log the messages for InferenceEngine.chat().
320
+ """
321
+ self.messages_log = []
264
322
 
265
- def _process_stream():
266
- for chunk in response:
267
- yield {"type": "response", "data": chunk}
323
+ def log_messages(self, messages : List[Dict[str,str]]):
324
+ """
325
+ This method logs the messages to a list.
326
+ """
327
+ self.messages_log.append(messages)
268
328
 
269
- return _process_stream()
329
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
330
+ """
331
+ This method returns a copy of the current messages log
332
+ """
333
+ return self.messages_log.copy()
334
+
335
+ def clear_messages_log(self):
336
+ """
337
+ This method clears the current messages log
338
+ """
339
+ self.messages_log.clear()
270
340
 
271
341
 
272
342
  class InferenceEngine:
@@ -283,10 +353,16 @@ class InferenceEngine:
283
353
  """
284
354
  return NotImplemented
285
355
 
356
+ def get_messages_log(self) -> List[List[Dict[str,str]]]:
357
+ return self.messages_log.copy()
358
+
359
+ def clear_messages_log(self):
360
+ self.messages_log = []
361
+
286
362
 
287
363
  @abc.abstractmethod
288
- def chat(self, messages:List[Dict[str,str]],
289
- verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
364
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
365
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
290
366
  """
291
367
  This method inputs chat messages and outputs LLM generated text.
292
368
 
@@ -298,6 +374,13 @@ class InferenceEngine:
298
374
  if True, LLM generated text will be printed in terminal in real-time.
299
375
  stream : bool, Optional
300
376
  if True, returns a generator that yields the output in real-time.
377
+ Messages_logger : MessagesLogger, Optional
378
+ the message logger that logs the chat messages.
379
+
380
+ Returns:
381
+ -------
382
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
383
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
301
384
  """
302
385
  return NotImplemented
303
386
 
@@ -331,6 +414,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
331
414
  the LLM configuration.
332
415
  """
333
416
  from llama_cpp import Llama
417
+ super().__init__(config)
334
418
  self.repo_id = repo_id
335
419
  self.gguf_filename = gguf_filename
336
420
  self.n_ctx = n_ctx
@@ -363,7 +447,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
363
447
 
364
448
  return formatted_params
365
449
 
366
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> str:
450
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, messages_logger:MessagesLogger=None) -> Dict[str,str]:
367
451
  """
368
452
  This method inputs chat messages and outputs LLM generated text.
369
453
 
@@ -373,15 +457,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
373
457
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
374
458
  verbose : bool, Optional
375
459
  if True, LLM generated text will be printed in terminal in real-time.
460
+ messages_logger : MessagesLogger, Optional
461
+ the message logger that logs the chat messages.
376
462
  """
463
+ # Preprocess messages
377
464
  processed_messages = self.config.preprocess_messages(messages)
378
-
465
+ # Generate response
379
466
  response = self.model.create_chat_completion(
380
467
  messages=processed_messages,
381
468
  stream=verbose,
382
469
  **self.formatted_params
383
470
  )
384
-
471
+
385
472
  if verbose:
386
473
  res = ''
387
474
  for chunk in response:
@@ -393,7 +480,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
393
480
  return self.config.postprocess_response(res)
394
481
 
395
482
  res = response['choices'][0]['message']['content']
396
- return self.config.postprocess_response(res)
483
+ # Postprocess response
484
+ res_dict = self.config.postprocess_response(res)
485
+ # Write to messages log
486
+ if messages_logger:
487
+ processed_messages.append({"role": "assistant",
488
+ "content": res_dict.get("response", ""),
489
+ "reasoning": res_dict.get("reasoning", "")})
490
+ messages_logger.log_messages(processed_messages)
491
+
492
+ return res_dict
397
493
 
398
494
 
399
495
  class OllamaInferenceEngine(InferenceEngine):
@@ -416,6 +512,7 @@ class OllamaInferenceEngine(InferenceEngine):
416
512
  raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
417
513
 
418
514
  from ollama import Client, AsyncClient
515
+ super().__init__(config)
419
516
  self.client = Client(**kwrs)
420
517
  self.async_client = AsyncClient(**kwrs)
421
518
  self.model_name = model_name
@@ -435,8 +532,8 @@ class OllamaInferenceEngine(InferenceEngine):
435
532
 
436
533
  return formatted_params
437
534
 
438
- def chat(self, messages:List[Dict[str,str]],
439
- verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
535
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
536
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
440
537
  """
441
538
  This method inputs chat messages and outputs VLM generated text.
442
539
 
@@ -448,6 +545,13 @@ class OllamaInferenceEngine(InferenceEngine):
448
545
  if True, VLM generated text will be printed in terminal in real-time.
449
546
  stream : bool, Optional
450
547
  if True, returns a generator that yields the output in real-time.
548
+ Messages_logger : MessagesLogger, Optional
549
+ the message logger that logs the chat messages.
550
+
551
+ Returns:
552
+ -------
553
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
554
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
451
555
  """
452
556
  processed_messages = self.config.preprocess_messages(messages)
453
557
 
@@ -461,10 +565,28 @@ class OllamaInferenceEngine(InferenceEngine):
461
565
  stream=True,
462
566
  keep_alive=self.keep_alive
463
567
  )
568
+ res = {"reasoning": "", "response": ""}
464
569
  for chunk in response_stream:
465
- content_chunk = chunk.get('message', {}).get('content')
466
- if content_chunk:
467
- yield content_chunk
570
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
571
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
572
+ res["reasoning"] += content_chunk
573
+ yield {"type": "reasoning", "data": content_chunk}
574
+ else:
575
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
576
+ res["response"] += content_chunk
577
+ yield {"type": "response", "data": content_chunk}
578
+
579
+ if chunk.done_reason == "length":
580
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
581
+
582
+ # Postprocess response
583
+ res_dict = self.config.postprocess_response(res)
584
+ # Write to messages log
585
+ if messages_logger:
586
+ processed_messages.append({"role": "assistant",
587
+ "content": res_dict.get("response", ""),
588
+ "reasoning": res_dict.get("reasoning", "")})
589
+ messages_logger.log_messages(processed_messages)
468
590
 
469
591
  return self.config.postprocess_response(_stream_generator())
470
592
 
@@ -477,14 +599,29 @@ class OllamaInferenceEngine(InferenceEngine):
477
599
  keep_alive=self.keep_alive
478
600
  )
479
601
 
480
- res = ''
602
+ res = {"reasoning": "", "response": ""}
603
+ phase = ""
481
604
  for chunk in response:
482
- content_chunk = chunk.get('message', {}).get('content')
605
+ if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
606
+ if phase != "reasoning":
607
+ print("\n--- Reasoning ---")
608
+ phase = "reasoning"
609
+
610
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
611
+ res["reasoning"] += content_chunk
612
+ else:
613
+ if phase != "response":
614
+ print("\n--- Response ---")
615
+ phase = "response"
616
+ content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
617
+ res["response"] += content_chunk
618
+
483
619
  print(content_chunk, end='', flush=True)
484
- res += content_chunk
620
+
621
+ if chunk.done_reason == "length":
622
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
485
623
  print('\n')
486
- return self.config.postprocess_response(res)
487
-
624
+
488
625
  else:
489
626
  response = self.client.chat(
490
627
  model=self.model_name,
@@ -493,11 +630,25 @@ class OllamaInferenceEngine(InferenceEngine):
493
630
  stream=False,
494
631
  keep_alive=self.keep_alive
495
632
  )
496
- res = response.get('message', {}).get('content')
497
- return self.config.postprocess_response(res)
633
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
634
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
635
+
636
+ if response.done_reason == "length":
637
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
638
+
639
+ # Postprocess response
640
+ res_dict = self.config.postprocess_response(res)
641
+ # Write to messages log
642
+ if messages_logger:
643
+ processed_messages.append({"role": "assistant",
644
+ "content": res_dict.get("response", ""),
645
+ "reasoning": res_dict.get("reasoning", "")})
646
+ messages_logger.log_messages(processed_messages)
647
+
648
+ return res_dict
498
649
 
499
650
 
500
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
651
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
501
652
  """
502
653
  Async version of chat method. Streaming is not supported.
503
654
  """
@@ -511,8 +662,21 @@ class OllamaInferenceEngine(InferenceEngine):
511
662
  keep_alive=self.keep_alive
512
663
  )
513
664
 
514
- res = response['message']['content']
515
- return self.config.postprocess_response(res)
665
+ res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
666
+ "response": getattr(getattr(response, 'message', {}), 'content', '')}
667
+
668
+ if response.done_reason == "length":
669
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
670
+ # Postprocess response
671
+ res_dict = self.config.postprocess_response(res)
672
+ # Write to messages log
673
+ if messages_logger:
674
+ processed_messages.append({"role": "assistant",
675
+ "content": res_dict.get("response", ""),
676
+ "reasoning": res_dict.get("reasoning", "")})
677
+ messages_logger.log_messages(processed_messages)
678
+
679
+ return res_dict
516
680
 
517
681
 
518
682
  class HuggingFaceHubInferenceEngine(InferenceEngine):
@@ -538,6 +702,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
538
702
  raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
539
703
 
540
704
  from huggingface_hub import InferenceClient, AsyncInferenceClient
705
+ super().__init__(config)
541
706
  self.model = model
542
707
  self.base_url = base_url
543
708
  self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
@@ -557,8 +722,8 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
557
722
  return formatted_params
558
723
 
559
724
 
560
- def chat(self, messages:List[Dict[str,str]],
561
- verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
725
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
726
+ messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
562
727
  """
563
728
  This method inputs chat messages and outputs LLM generated text.
564
729
 
@@ -570,6 +735,13 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
570
735
  if True, VLM generated text will be printed in terminal in real-time.
571
736
  stream : bool, Optional
572
737
  if True, returns a generator that yields the output in real-time.
738
+ messages_logger : MessagesLogger, Optional
739
+ the message logger that logs the chat messages.
740
+
741
+ Returns:
742
+ -------
743
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
744
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
573
745
  """
574
746
  processed_messages = self.config.preprocess_messages(messages)
575
747
 
@@ -580,11 +752,22 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
580
752
  stream=True,
581
753
  **self.formatted_params
582
754
  )
755
+ res_text = ""
583
756
  for chunk in response_stream:
584
757
  content_chunk = chunk.get('choices')[0].get('delta').get('content')
585
758
  if content_chunk:
759
+ res_text += content_chunk
586
760
  yield content_chunk
587
761
 
762
+ # Postprocess response
763
+ res_dict = self.config.postprocess_response(res_text)
764
+ # Write to messages log
765
+ if messages_logger:
766
+ processed_messages.append({"role": "assistant",
767
+ "content": res_dict.get("response", ""),
768
+ "reasoning": res_dict.get("reasoning", "")})
769
+ messages_logger.log_messages(processed_messages)
770
+
588
771
  return self.config.postprocess_response(_stream_generator())
589
772
 
590
773
  elif verbose:
@@ -600,7 +783,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
600
783
  if content_chunk:
601
784
  res += content_chunk
602
785
  print(content_chunk, end='', flush=True)
603
- return self.config.postprocess_response(res)
786
+
604
787
 
605
788
  else:
606
789
  response = self.client.chat.completions.create(
@@ -609,9 +792,20 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
609
792
  **self.formatted_params
610
793
  )
611
794
  res = response.choices[0].message.content
612
- return self.config.postprocess_response(res)
795
+
796
+ # Postprocess response
797
+ res_dict = self.config.postprocess_response(res)
798
+ # Write to messages log
799
+ if messages_logger:
800
+ processed_messages.append({"role": "assistant",
801
+ "content": res_dict.get("response", ""),
802
+ "reasoning": res_dict.get("reasoning", "")})
803
+ messages_logger.log_messages(processed_messages)
804
+
805
+ return res_dict
613
806
 
614
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
807
+
808
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
615
809
  """
616
810
  Async version of chat method. Streaming is not supported.
617
811
  """
@@ -624,16 +818,299 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
624
818
  )
625
819
 
626
820
  res = response.choices[0].message.content
627
- return self.config.postprocess_response(res)
821
+ # Postprocess response
822
+ res_dict = self.config.postprocess_response(res)
823
+ # Write to messages log
824
+ if messages_logger:
825
+ processed_messages.append({"role": "assistant",
826
+ "content": res_dict.get("response", ""),
827
+ "reasoning": res_dict.get("reasoning", "")})
828
+ messages_logger.log_messages(processed_messages)
829
+
830
+ return res_dict
831
+
832
+
833
+ class OpenAICompatibleInferenceEngine(InferenceEngine):
834
+ def __init__(self, model:str, api_key:str, base_url:str, config:LLMConfig=None, **kwrs):
835
+ """
836
+ General OpenAI-compatible server inference engine.
837
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
838
+
839
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
840
+
841
+ Parameters:
842
+ ----------
843
+ model_name : str
844
+ model name as shown in the vLLM server
845
+ api_key : str
846
+ the API key for the vLLM server.
847
+ base_url : str
848
+ the base url for the vLLM server.
849
+ config : LLMConfig
850
+ the LLM configuration.
851
+ """
852
+ if importlib.util.find_spec("openai") is None:
853
+ raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
628
854
 
855
+ from openai import OpenAI, AsyncOpenAI
856
+ from openai.types.chat import ChatCompletionChunk
857
+ self.ChatCompletionChunk = ChatCompletionChunk
858
+ super().__init__(config)
859
+ self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
860
+ self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
861
+ self.model = model
862
+ self.config = config if config else BasicLLMConfig()
863
+ self.formatted_params = self._format_config()
864
+
865
+ def _format_config(self) -> Dict[str, Any]:
866
+ """
867
+ This method format the LLM configuration with the correct key for the inference engine.
868
+ """
869
+ formatted_params = self.config.params.copy()
870
+ if "max_new_tokens" in formatted_params:
871
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
872
+ formatted_params.pop("max_new_tokens")
873
+
874
+ return formatted_params
875
+
876
+ @abc.abstractmethod
877
+ def _format_response(self, response: Any) -> Dict[str, str]:
878
+ """
879
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
880
+
881
+ Parameters:
882
+ ----------
883
+ response : Any
884
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
885
+ """
886
+ return NotImplemented
887
+
888
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
889
+ """
890
+ This method inputs chat messages and outputs LLM generated text.
891
+
892
+ Parameters:
893
+ ----------
894
+ messages : List[Dict[str,str]]
895
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
896
+ verbose : bool, Optional
897
+ if True, VLM generated text will be printed in terminal in real-time.
898
+ stream : bool, Optional
899
+ if True, returns a generator that yields the output in real-time.
900
+ messages_logger : MessagesLogger, Optional
901
+ the message logger that logs the chat messages.
902
+
903
+ Returns:
904
+ -------
905
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
906
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
907
+ """
908
+ processed_messages = self.config.preprocess_messages(messages)
909
+
910
+ if stream:
911
+ def _stream_generator():
912
+ response_stream = self.client.chat.completions.create(
913
+ model=self.model,
914
+ messages=processed_messages,
915
+ stream=True,
916
+ **self.formatted_params
917
+ )
918
+ res_text = ""
919
+ for chunk in response_stream:
920
+ if len(chunk.choices) > 0:
921
+ chunk_dict = self._format_response(chunk)
922
+ yield chunk_dict
923
+
924
+ res_text += chunk_dict["data"]
925
+ if chunk.choices[0].finish_reason == "length":
926
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
927
+
928
+ # Postprocess response
929
+ res_dict = self.config.postprocess_response(res_text)
930
+ # Write to messages log
931
+ if messages_logger:
932
+ processed_messages.append({"role": "assistant",
933
+ "content": res_dict.get("response", ""),
934
+ "reasoning": res_dict.get("reasoning", "")})
935
+ messages_logger.log_messages(processed_messages)
936
+
937
+ return self.config.postprocess_response(_stream_generator())
938
+
939
+ elif verbose:
940
+ response = self.client.chat.completions.create(
941
+ model=self.model,
942
+ messages=processed_messages,
943
+ stream=True,
944
+ **self.formatted_params
945
+ )
946
+ res = {"reasoning": "", "response": ""}
947
+ phase = ""
948
+ for chunk in response:
949
+ if len(chunk.choices) > 0:
950
+ chunk_dict = self._format_response(chunk)
951
+ chunk_text = chunk_dict["data"]
952
+ res[chunk_dict["type"]] += chunk_text
953
+ if phase != chunk_dict["type"] and chunk_text != "":
954
+ print(f"\n--- {chunk_dict['type'].capitalize()} ---")
955
+ phase = chunk_dict["type"]
956
+
957
+ print(chunk_text, end="", flush=True)
958
+ if chunk.choices[0].finish_reason == "length":
959
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
960
+
961
+ print('\n')
962
+
963
+ else:
964
+ response = self.client.chat.completions.create(
965
+ model=self.model,
966
+ messages=processed_messages,
967
+ stream=False,
968
+ **self.formatted_params
969
+ )
970
+ res = self._format_response(response)
971
+
972
+ if response.choices[0].finish_reason == "length":
973
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
974
+
975
+ # Postprocess response
976
+ res_dict = self.config.postprocess_response(res)
977
+ # Write to messages log
978
+ if messages_logger:
979
+ processed_messages.append({"role": "assistant",
980
+ "content": res_dict.get("response", ""),
981
+ "reasoning": res_dict.get("reasoning", "")})
982
+ messages_logger.log_messages(processed_messages)
983
+
984
+ return res_dict
985
+
986
+
987
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
988
+ """
989
+ Async version of chat method. Streaming is not supported.
990
+ """
991
+ processed_messages = self.config.preprocess_messages(messages)
992
+
993
+ response = await self.async_client.chat.completions.create(
994
+ model=self.model,
995
+ messages=processed_messages,
996
+ stream=False,
997
+ **self.formatted_params
998
+ )
999
+
1000
+ if response.choices[0].finish_reason == "length":
1001
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1002
+
1003
+ res = self._format_response(response)
1004
+
1005
+ # Postprocess response
1006
+ res_dict = self.config.postprocess_response(res)
1007
+ # Write to messages log
1008
+ if messages_logger:
1009
+ processed_messages.append({"role": "assistant",
1010
+ "content": res_dict.get("response", ""),
1011
+ "reasoning": res_dict.get("reasoning", "")})
1012
+ messages_logger.log_messages(processed_messages)
1013
+
1014
+ return res_dict
1015
+
1016
+
1017
+ class VLLMInferenceEngine(OpenAICompatibleInferenceEngine):
1018
+ def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:LLMConfig=None, **kwrs):
1019
+ """
1020
+ vLLM OpenAI compatible server inference engine.
1021
+ https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
1022
+
1023
+ For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
1024
+
1025
+ Parameters:
1026
+ ----------
1027
+ model_name : str
1028
+ model name as shown in the vLLM server
1029
+ api_key : str, Optional
1030
+ the API key for the vLLM server.
1031
+ base_url : str, Optional
1032
+ the base url for the vLLM server.
1033
+ config : LLMConfig
1034
+ the LLM configuration.
1035
+ """
1036
+ super().__init__(model, api_key, base_url, config, **kwrs)
1037
+
1038
+
1039
+ def _format_response(self, response: Any) -> Dict[str, str]:
1040
+ """
1041
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
1042
+
1043
+ Parameters:
1044
+ ----------
1045
+ response : Any
1046
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
1047
+ """
1048
+ if isinstance(response, self.ChatCompletionChunk):
1049
+ if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
1050
+ chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
1051
+ if chunk_text is None:
1052
+ chunk_text = ""
1053
+ return {"type": "reasoning", "data": chunk_text}
1054
+ else:
1055
+ chunk_text = getattr(response.choices[0].delta, "content", "")
1056
+ if chunk_text is None:
1057
+ chunk_text = ""
1058
+ return {"type": "response", "data": chunk_text}
1059
+
1060
+ return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
1061
+ "response": getattr(response.choices[0].message, "content", "")}
1062
+
1063
+
1064
+ class OpenRouterInferenceEngine(OpenAICompatibleInferenceEngine):
1065
+ def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:LLMConfig=None, **kwrs):
1066
+ """
1067
+ OpenRouter OpenAI-compatible server inference engine.
1068
+
1069
+ Parameters:
1070
+ ----------
1071
+ model_name : str
1072
+ model name as shown in the vLLM server
1073
+ api_key : str, Optional
1074
+ the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
1075
+ base_url : str, Optional
1076
+ the base url for the vLLM server.
1077
+ config : LLMConfig
1078
+ the LLM configuration.
1079
+ """
1080
+ self.api_key = api_key
1081
+ if self.api_key is None:
1082
+ self.api_key = os.getenv("OPENROUTER_API_KEY")
1083
+ super().__init__(model, self.api_key, base_url, config, **kwrs)
1084
+
1085
+ def _format_response(self, response: Any) -> Dict[str, str]:
1086
+ """
1087
+ This method format the response from OpenAI API to a dict with keys "type" and "data".
1088
+
1089
+ Parameters:
1090
+ ----------
1091
+ response : Any
1092
+ the response from OpenAI-compatible API. Could be a dict, generator, or object.
1093
+ """
1094
+ if isinstance(response, self.ChatCompletionChunk):
1095
+ if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
1096
+ chunk_text = getattr(response.choices[0].delta, "reasoning", "")
1097
+ if chunk_text is None:
1098
+ chunk_text = ""
1099
+ return {"type": "reasoning", "data": chunk_text}
1100
+ else:
1101
+ chunk_text = getattr(response.choices[0].delta, "content", "")
1102
+ if chunk_text is None:
1103
+ chunk_text = ""
1104
+ return {"type": "response", "data": chunk_text}
1105
+
1106
+ return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
1107
+ "response": getattr(response.choices[0].message, "content", "")}
1108
+
629
1109
 
630
1110
  class OpenAIInferenceEngine(InferenceEngine):
631
1111
  def __init__(self, model:str, config:LLMConfig=None, **kwrs):
632
1112
  """
633
- The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
634
- - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
635
- - Llama.cpp OpenAI compatible server (https://llama-cpp-python.readthedocs.io/en/latest/server/)
636
-
1113
+ The OpenAI API inference engine.
637
1114
  For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
638
1115
 
639
1116
  Parameters:
@@ -645,6 +1122,7 @@ class OpenAIInferenceEngine(InferenceEngine):
645
1122
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
646
1123
 
647
1124
  from openai import OpenAI, AsyncOpenAI
1125
+ super().__init__(config)
648
1126
  self.client = OpenAI(**kwrs)
649
1127
  self.async_client = AsyncOpenAI(**kwrs)
650
1128
  self.model = model
@@ -662,7 +1140,7 @@ class OpenAIInferenceEngine(InferenceEngine):
662
1140
 
663
1141
  return formatted_params
664
1142
 
665
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
1143
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
666
1144
  """
667
1145
  This method inputs chat messages and outputs LLM generated text.
668
1146
 
@@ -674,6 +1152,13 @@ class OpenAIInferenceEngine(InferenceEngine):
674
1152
  if True, VLM generated text will be printed in terminal in real-time.
675
1153
  stream : bool, Optional
676
1154
  if True, returns a generator that yields the output in real-time.
1155
+ messages_logger : MessagesLogger, Optional
1156
+ the message logger that logs the chat messages.
1157
+
1158
+ Returns:
1159
+ -------
1160
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
1161
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
677
1162
  """
678
1163
  processed_messages = self.config.preprocess_messages(messages)
679
1164
 
@@ -685,13 +1170,25 @@ class OpenAIInferenceEngine(InferenceEngine):
685
1170
  stream=True,
686
1171
  **self.formatted_params
687
1172
  )
1173
+ res_text = ""
688
1174
  for chunk in response_stream:
689
1175
  if len(chunk.choices) > 0:
690
- if chunk.choices[0].delta.content is not None:
691
- yield chunk.choices[0].delta.content
1176
+ chunk_text = chunk.choices[0].delta.content
1177
+ if chunk_text is not None:
1178
+ res_text += chunk_text
1179
+ yield chunk_text
692
1180
  if chunk.choices[0].finish_reason == "length":
693
1181
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
694
1182
 
1183
+ # Postprocess response
1184
+ res_dict = self.config.postprocess_response(res_text)
1185
+ # Write to messages log
1186
+ if messages_logger:
1187
+ processed_messages.append({"role": "assistant",
1188
+ "content": res_dict.get("response", ""),
1189
+ "reasoning": res_dict.get("reasoning", "")})
1190
+ messages_logger.log_messages(processed_messages)
1191
+
695
1192
  return self.config.postprocess_response(_stream_generator())
696
1193
 
697
1194
  elif verbose:
@@ -711,7 +1208,7 @@ class OpenAIInferenceEngine(InferenceEngine):
711
1208
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
712
1209
 
713
1210
  print('\n')
714
- return self.config.postprocess_response(res)
1211
+
715
1212
  else:
716
1213
  response = self.client.chat.completions.create(
717
1214
  model=self.model,
@@ -720,10 +1217,20 @@ class OpenAIInferenceEngine(InferenceEngine):
720
1217
  **self.formatted_params
721
1218
  )
722
1219
  res = response.choices[0].message.content
723
- return self.config.postprocess_response(res)
1220
+
1221
+ # Postprocess response
1222
+ res_dict = self.config.postprocess_response(res)
1223
+ # Write to messages log
1224
+ if messages_logger:
1225
+ processed_messages.append({"role": "assistant",
1226
+ "content": res_dict.get("response", ""),
1227
+ "reasoning": res_dict.get("reasoning", "")})
1228
+ messages_logger.log_messages(processed_messages)
1229
+
1230
+ return res_dict
724
1231
 
725
1232
 
726
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
1233
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
727
1234
  """
728
1235
  Async version of chat method. Streaming is not supported.
729
1236
  """
@@ -740,7 +1247,16 @@ class OpenAIInferenceEngine(InferenceEngine):
740
1247
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
741
1248
 
742
1249
  res = response.choices[0].message.content
743
- return self.config.postprocess_response(res)
1250
+ # Postprocess response
1251
+ res_dict = self.config.postprocess_response(res)
1252
+ # Write to messages log
1253
+ if messages_logger:
1254
+ processed_messages.append({"role": "assistant",
1255
+ "content": res_dict.get("response", ""),
1256
+ "reasoning": res_dict.get("reasoning", "")})
1257
+ messages_logger.log_messages(processed_messages)
1258
+
1259
+ return res_dict
744
1260
 
745
1261
 
746
1262
  class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
@@ -795,6 +1311,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
795
1311
  raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
796
1312
 
797
1313
  import litellm
1314
+ super().__init__(config)
798
1315
  self.litellm = litellm
799
1316
  self.model = model
800
1317
  self.base_url = base_url
@@ -813,7 +1330,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
813
1330
 
814
1331
  return formatted_params
815
1332
 
816
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
1333
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
817
1334
  """
818
1335
  This method inputs chat messages and outputs LLM generated text.
819
1336
 
@@ -825,6 +1342,13 @@ class LiteLLMInferenceEngine(InferenceEngine):
825
1342
  if True, VLM generated text will be printed in terminal in real-time.
826
1343
  stream : bool, Optional
827
1344
  if True, returns a generator that yields the output in real-time.
1345
+ messages_logger: MessagesLogger, Optional
1346
+ a messages logger that logs the messages.
1347
+
1348
+ Returns:
1349
+ -------
1350
+ response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
1351
+ a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
828
1352
  """
829
1353
  processed_messages = self.config.preprocess_messages(messages)
830
1354
 
@@ -838,12 +1362,22 @@ class LiteLLMInferenceEngine(InferenceEngine):
838
1362
  api_key=self.api_key,
839
1363
  **self.formatted_params
840
1364
  )
841
-
1365
+ res_text = ""
842
1366
  for chunk in response_stream:
843
1367
  chunk_content = chunk.get('choices')[0].get('delta').get('content')
844
1368
  if chunk_content:
1369
+ res_text += chunk_content
845
1370
  yield chunk_content
846
1371
 
1372
+ # Postprocess response
1373
+ res_dict = self.config.postprocess_response(res_text)
1374
+ # Write to messages log
1375
+ if messages_logger:
1376
+ processed_messages.append({"role": "assistant",
1377
+ "content": res_dict.get("response", ""),
1378
+ "reasoning": res_dict.get("reasoning", "")})
1379
+ messages_logger.log_messages(processed_messages)
1380
+
847
1381
  return self.config.postprocess_response(_stream_generator())
848
1382
 
849
1383
  elif verbose:
@@ -862,8 +1396,6 @@ class LiteLLMInferenceEngine(InferenceEngine):
862
1396
  if chunk_content:
863
1397
  res += chunk_content
864
1398
  print(chunk_content, end='', flush=True)
865
-
866
- return self.config.postprocess_response(res)
867
1399
 
868
1400
  else:
869
1401
  response = self.litellm.completion(
@@ -875,9 +1407,19 @@ class LiteLLMInferenceEngine(InferenceEngine):
875
1407
  **self.formatted_params
876
1408
  )
877
1409
  res = response.choices[0].message.content
878
- return self.config.postprocess_response(res)
1410
+
1411
+ # Postprocess response
1412
+ res_dict = self.config.postprocess_response(res)
1413
+ # Write to messages log
1414
+ if messages_logger:
1415
+ processed_messages.append({"role": "assistant",
1416
+ "content": res_dict.get("response", ""),
1417
+ "reasoning": res_dict.get("reasoning", "")})
1418
+ messages_logger.log_messages(processed_messages)
1419
+
1420
+ return res_dict
879
1421
 
880
- async def chat_async(self, messages:List[Dict[str,str]]) -> str:
1422
+ async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
881
1423
  """
882
1424
  Async version of chat method. Streaming is not supported.
883
1425
  """
@@ -893,4 +1435,13 @@ class LiteLLMInferenceEngine(InferenceEngine):
893
1435
  )
894
1436
 
895
1437
  res = response.get('choices')[0].get('message').get('content')
896
- return self.config.postprocess_response(res)
1438
+
1439
+ # Postprocess response
1440
+ res_dict = self.config.postprocess_response(res)
1441
+ # Write to messages log
1442
+ if messages_logger:
1443
+ processed_messages.append({"role": "assistant",
1444
+ "content": res_dict.get("response", ""),
1445
+ "reasoning": res_dict.get("reasoning", "")})
1446
+ messages_logger.log_messages(processed_messages)
1447
+ return res_dict