llm-ie 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +5 -4
- llm_ie/chunkers.py +44 -5
- llm_ie/data_types.py +23 -37
- llm_ie/engines.py +577 -61
- llm_ie/extractors.py +335 -219
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.3.dist-info}/METADATA +1 -1
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.3.dist-info}/RECORD +8 -8
- {llm_ie-1.2.2.dist-info → llm_ie-1.2.3.dist-info}/WHEEL +0 -0
llm_ie/engines.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import os
|
|
2
3
|
import re
|
|
3
4
|
import warnings
|
|
4
5
|
import importlib.util
|
|
@@ -33,13 +34,13 @@ class LLMConfig(abc.ABC):
|
|
|
33
34
|
return NotImplemented
|
|
34
35
|
|
|
35
36
|
@abc.abstractmethod
|
|
36
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
37
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
37
38
|
"""
|
|
38
39
|
This method postprocesses the LLM response after it is generated.
|
|
39
40
|
|
|
40
41
|
Parameters:
|
|
41
42
|
----------
|
|
42
|
-
response : Union[str, Generator[Dict[str, str], None, None]]
|
|
43
|
+
response : Union[str, Dict[str, str], Generator[Dict[str, str], None, None]]
|
|
43
44
|
the LLM response. Can be a dict or a generator.
|
|
44
45
|
|
|
45
46
|
Returns:
|
|
@@ -75,15 +76,15 @@ class BasicLLMConfig(LLMConfig):
|
|
|
75
76
|
messages : List[Dict[str,str]]
|
|
76
77
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
77
78
|
"""
|
|
78
|
-
return messages
|
|
79
|
+
return messages.copy()
|
|
79
80
|
|
|
80
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
81
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
81
82
|
"""
|
|
82
83
|
This method postprocesses the LLM response after it is generated.
|
|
83
84
|
|
|
84
85
|
Parameters:
|
|
85
86
|
----------
|
|
86
|
-
response : Union[str, Generator[str, None, None]]
|
|
87
|
+
response : Union[str, Dict[str, str], Generator[str, None, None]]
|
|
87
88
|
the LLM response. Can be a string or a generator.
|
|
88
89
|
|
|
89
90
|
Returns: Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
@@ -93,13 +94,27 @@ class BasicLLMConfig(LLMConfig):
|
|
|
93
94
|
"""
|
|
94
95
|
if isinstance(response, str):
|
|
95
96
|
return {"response": response}
|
|
97
|
+
|
|
98
|
+
elif isinstance(response, dict):
|
|
99
|
+
if "response" in response:
|
|
100
|
+
return response
|
|
101
|
+
else:
|
|
102
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
103
|
+
return {"response": ""}
|
|
96
104
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
105
|
+
elif isinstance(response, Generator):
|
|
106
|
+
def _process_stream():
|
|
107
|
+
for chunk in response:
|
|
108
|
+
if isinstance(chunk, dict):
|
|
109
|
+
yield chunk
|
|
110
|
+
elif isinstance(chunk, str):
|
|
111
|
+
yield {"type": "response", "data": chunk}
|
|
100
112
|
|
|
101
|
-
|
|
113
|
+
return _process_stream()
|
|
102
114
|
|
|
115
|
+
else:
|
|
116
|
+
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
117
|
+
return {"response": ""}
|
|
103
118
|
|
|
104
119
|
class ReasoningLLMConfig(LLMConfig):
|
|
105
120
|
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
@@ -124,11 +139,16 @@ class ReasoningLLMConfig(LLMConfig):
|
|
|
124
139
|
messages : List[Dict[str,str]]
|
|
125
140
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
126
141
|
"""
|
|
127
|
-
return messages
|
|
142
|
+
return messages.copy()
|
|
128
143
|
|
|
129
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
144
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
130
145
|
"""
|
|
131
146
|
This method postprocesses the LLM response after it is generated.
|
|
147
|
+
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
148
|
+
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
149
|
+
3. If input is a generator,
|
|
150
|
+
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
151
|
+
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
132
152
|
|
|
133
153
|
Parameters:
|
|
134
154
|
----------
|
|
@@ -143,18 +163,29 @@ class ReasoningLLMConfig(LLMConfig):
|
|
|
143
163
|
"""
|
|
144
164
|
if isinstance(response, str):
|
|
145
165
|
# get contents between thinking_token_start and thinking_token_end
|
|
146
|
-
|
|
147
|
-
|
|
166
|
+
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
167
|
+
match = re.search(pattern, response, re.DOTALL)
|
|
168
|
+
reasoning = match.group(1) if match else ""
|
|
148
169
|
# get response AFTER thinking_token_end
|
|
149
170
|
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
150
171
|
return {"reasoning": reasoning, "response": response}
|
|
151
172
|
|
|
152
|
-
|
|
173
|
+
elif isinstance(response, dict):
|
|
174
|
+
if "reasoning" in response and "response" in response:
|
|
175
|
+
return response
|
|
176
|
+
else:
|
|
177
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
178
|
+
return {"reasoning": "", "response": ""}
|
|
179
|
+
|
|
180
|
+
elif isinstance(response, Generator):
|
|
153
181
|
def _process_stream():
|
|
154
182
|
think_flag = False
|
|
155
183
|
buffer = ""
|
|
156
184
|
for chunk in response:
|
|
157
|
-
if isinstance(chunk,
|
|
185
|
+
if isinstance(chunk, dict):
|
|
186
|
+
yield chunk
|
|
187
|
+
|
|
188
|
+
elif isinstance(chunk, str):
|
|
158
189
|
buffer += chunk
|
|
159
190
|
# switch between reasoning and response
|
|
160
191
|
if self.thinking_token_start in buffer:
|
|
@@ -173,6 +204,9 @@ class ReasoningLLMConfig(LLMConfig):
|
|
|
173
204
|
|
|
174
205
|
return _process_stream()
|
|
175
206
|
|
|
207
|
+
else:
|
|
208
|
+
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
209
|
+
return {"reasoning": "", "response": ""}
|
|
176
210
|
|
|
177
211
|
class Qwen3LLMConfig(ReasoningLLMConfig):
|
|
178
212
|
def __init__(self, thinking_mode:bool=True, **kwargs):
|
|
@@ -279,6 +313,32 @@ class OpenAIReasoningLLMConfig(ReasoningLLMConfig):
|
|
|
279
313
|
return new_messages
|
|
280
314
|
|
|
281
315
|
|
|
316
|
+
class MessagesLogger:
|
|
317
|
+
def __init__(self):
|
|
318
|
+
"""
|
|
319
|
+
This class is used to log the messages for InferenceEngine.chat().
|
|
320
|
+
"""
|
|
321
|
+
self.messages_log = []
|
|
322
|
+
|
|
323
|
+
def log_messages(self, messages : List[Dict[str,str]]):
|
|
324
|
+
"""
|
|
325
|
+
This method logs the messages to a list.
|
|
326
|
+
"""
|
|
327
|
+
self.messages_log.append(messages)
|
|
328
|
+
|
|
329
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
330
|
+
"""
|
|
331
|
+
This method returns a copy of the current messages log
|
|
332
|
+
"""
|
|
333
|
+
return self.messages_log.copy()
|
|
334
|
+
|
|
335
|
+
def clear_messages_log(self):
|
|
336
|
+
"""
|
|
337
|
+
This method clears the current messages log
|
|
338
|
+
"""
|
|
339
|
+
self.messages_log.clear()
|
|
340
|
+
|
|
341
|
+
|
|
282
342
|
class InferenceEngine:
|
|
283
343
|
@abc.abstractmethod
|
|
284
344
|
def __init__(self, config:LLMConfig, **kwrs):
|
|
@@ -293,10 +353,16 @@ class InferenceEngine:
|
|
|
293
353
|
"""
|
|
294
354
|
return NotImplemented
|
|
295
355
|
|
|
356
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
357
|
+
return self.messages_log.copy()
|
|
358
|
+
|
|
359
|
+
def clear_messages_log(self):
|
|
360
|
+
self.messages_log = []
|
|
361
|
+
|
|
296
362
|
|
|
297
363
|
@abc.abstractmethod
|
|
298
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
299
|
-
|
|
364
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
365
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
300
366
|
"""
|
|
301
367
|
This method inputs chat messages and outputs LLM generated text.
|
|
302
368
|
|
|
@@ -308,6 +374,8 @@ class InferenceEngine:
|
|
|
308
374
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
309
375
|
stream : bool, Optional
|
|
310
376
|
if True, returns a generator that yields the output in real-time.
|
|
377
|
+
Messages_logger : MessagesLogger, Optional
|
|
378
|
+
the message logger that logs the chat messages.
|
|
311
379
|
|
|
312
380
|
Returns:
|
|
313
381
|
-------
|
|
@@ -346,6 +414,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
346
414
|
the LLM configuration.
|
|
347
415
|
"""
|
|
348
416
|
from llama_cpp import Llama
|
|
417
|
+
super().__init__(config)
|
|
349
418
|
self.repo_id = repo_id
|
|
350
419
|
self.gguf_filename = gguf_filename
|
|
351
420
|
self.n_ctx = n_ctx
|
|
@@ -378,7 +447,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
378
447
|
|
|
379
448
|
return formatted_params
|
|
380
449
|
|
|
381
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> Dict[str,str]:
|
|
450
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
382
451
|
"""
|
|
383
452
|
This method inputs chat messages and outputs LLM generated text.
|
|
384
453
|
|
|
@@ -388,15 +457,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
388
457
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
389
458
|
verbose : bool, Optional
|
|
390
459
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
460
|
+
messages_logger : MessagesLogger, Optional
|
|
461
|
+
the message logger that logs the chat messages.
|
|
391
462
|
"""
|
|
463
|
+
# Preprocess messages
|
|
392
464
|
processed_messages = self.config.preprocess_messages(messages)
|
|
393
|
-
|
|
465
|
+
# Generate response
|
|
394
466
|
response = self.model.create_chat_completion(
|
|
395
467
|
messages=processed_messages,
|
|
396
468
|
stream=verbose,
|
|
397
469
|
**self.formatted_params
|
|
398
470
|
)
|
|
399
|
-
|
|
471
|
+
|
|
400
472
|
if verbose:
|
|
401
473
|
res = ''
|
|
402
474
|
for chunk in response:
|
|
@@ -408,7 +480,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
408
480
|
return self.config.postprocess_response(res)
|
|
409
481
|
|
|
410
482
|
res = response['choices'][0]['message']['content']
|
|
411
|
-
|
|
483
|
+
# Postprocess response
|
|
484
|
+
res_dict = self.config.postprocess_response(res)
|
|
485
|
+
# Write to messages log
|
|
486
|
+
if messages_logger:
|
|
487
|
+
processed_messages.append({"role": "assistant",
|
|
488
|
+
"content": res_dict.get("response", ""),
|
|
489
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
490
|
+
messages_logger.log_messages(processed_messages)
|
|
491
|
+
|
|
492
|
+
return res_dict
|
|
412
493
|
|
|
413
494
|
|
|
414
495
|
class OllamaInferenceEngine(InferenceEngine):
|
|
@@ -431,6 +512,7 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
431
512
|
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
432
513
|
|
|
433
514
|
from ollama import Client, AsyncClient
|
|
515
|
+
super().__init__(config)
|
|
434
516
|
self.client = Client(**kwrs)
|
|
435
517
|
self.async_client = AsyncClient(**kwrs)
|
|
436
518
|
self.model_name = model_name
|
|
@@ -450,8 +532,8 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
450
532
|
|
|
451
533
|
return formatted_params
|
|
452
534
|
|
|
453
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
454
|
-
|
|
535
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
536
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
455
537
|
"""
|
|
456
538
|
This method inputs chat messages and outputs VLM generated text.
|
|
457
539
|
|
|
@@ -463,6 +545,8 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
463
545
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
464
546
|
stream : bool, Optional
|
|
465
547
|
if True, returns a generator that yields the output in real-time.
|
|
548
|
+
Messages_logger : MessagesLogger, Optional
|
|
549
|
+
the message logger that logs the chat messages.
|
|
466
550
|
|
|
467
551
|
Returns:
|
|
468
552
|
-------
|
|
@@ -481,10 +565,28 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
481
565
|
stream=True,
|
|
482
566
|
keep_alive=self.keep_alive
|
|
483
567
|
)
|
|
568
|
+
res = {"reasoning": "", "response": ""}
|
|
484
569
|
for chunk in response_stream:
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
570
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
571
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
572
|
+
res["reasoning"] += content_chunk
|
|
573
|
+
yield {"type": "reasoning", "data": content_chunk}
|
|
574
|
+
else:
|
|
575
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
576
|
+
res["response"] += content_chunk
|
|
577
|
+
yield {"type": "response", "data": content_chunk}
|
|
578
|
+
|
|
579
|
+
if chunk.done_reason == "length":
|
|
580
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
581
|
+
|
|
582
|
+
# Postprocess response
|
|
583
|
+
res_dict = self.config.postprocess_response(res)
|
|
584
|
+
# Write to messages log
|
|
585
|
+
if messages_logger:
|
|
586
|
+
processed_messages.append({"role": "assistant",
|
|
587
|
+
"content": res_dict.get("response", ""),
|
|
588
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
589
|
+
messages_logger.log_messages(processed_messages)
|
|
488
590
|
|
|
489
591
|
return self.config.postprocess_response(_stream_generator())
|
|
490
592
|
|
|
@@ -497,14 +599,29 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
497
599
|
keep_alive=self.keep_alive
|
|
498
600
|
)
|
|
499
601
|
|
|
500
|
-
res =
|
|
602
|
+
res = {"reasoning": "", "response": ""}
|
|
603
|
+
phase = ""
|
|
501
604
|
for chunk in response:
|
|
502
|
-
|
|
605
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
606
|
+
if phase != "reasoning":
|
|
607
|
+
print("\n--- Reasoning ---")
|
|
608
|
+
phase = "reasoning"
|
|
609
|
+
|
|
610
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
611
|
+
res["reasoning"] += content_chunk
|
|
612
|
+
else:
|
|
613
|
+
if phase != "response":
|
|
614
|
+
print("\n--- Response ---")
|
|
615
|
+
phase = "response"
|
|
616
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
617
|
+
res["response"] += content_chunk
|
|
618
|
+
|
|
503
619
|
print(content_chunk, end='', flush=True)
|
|
504
|
-
|
|
620
|
+
|
|
621
|
+
if chunk.done_reason == "length":
|
|
622
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
505
623
|
print('\n')
|
|
506
|
-
|
|
507
|
-
|
|
624
|
+
|
|
508
625
|
else:
|
|
509
626
|
response = self.client.chat(
|
|
510
627
|
model=self.model_name,
|
|
@@ -513,11 +630,25 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
513
630
|
stream=False,
|
|
514
631
|
keep_alive=self.keep_alive
|
|
515
632
|
)
|
|
516
|
-
res = response
|
|
517
|
-
|
|
633
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
634
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
635
|
+
|
|
636
|
+
if response.done_reason == "length":
|
|
637
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
638
|
+
|
|
639
|
+
# Postprocess response
|
|
640
|
+
res_dict = self.config.postprocess_response(res)
|
|
641
|
+
# Write to messages log
|
|
642
|
+
if messages_logger:
|
|
643
|
+
processed_messages.append({"role": "assistant",
|
|
644
|
+
"content": res_dict.get("response", ""),
|
|
645
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
646
|
+
messages_logger.log_messages(processed_messages)
|
|
647
|
+
|
|
648
|
+
return res_dict
|
|
518
649
|
|
|
519
650
|
|
|
520
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
|
|
651
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
521
652
|
"""
|
|
522
653
|
Async version of chat method. Streaming is not supported.
|
|
523
654
|
"""
|
|
@@ -531,8 +662,21 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
531
662
|
keep_alive=self.keep_alive
|
|
532
663
|
)
|
|
533
664
|
|
|
534
|
-
res = response
|
|
535
|
-
|
|
665
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
666
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
667
|
+
|
|
668
|
+
if response.done_reason == "length":
|
|
669
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
670
|
+
# Postprocess response
|
|
671
|
+
res_dict = self.config.postprocess_response(res)
|
|
672
|
+
# Write to messages log
|
|
673
|
+
if messages_logger:
|
|
674
|
+
processed_messages.append({"role": "assistant",
|
|
675
|
+
"content": res_dict.get("response", ""),
|
|
676
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
677
|
+
messages_logger.log_messages(processed_messages)
|
|
678
|
+
|
|
679
|
+
return res_dict
|
|
536
680
|
|
|
537
681
|
|
|
538
682
|
class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
@@ -558,6 +702,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
558
702
|
raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
|
|
559
703
|
|
|
560
704
|
from huggingface_hub import InferenceClient, AsyncInferenceClient
|
|
705
|
+
super().__init__(config)
|
|
561
706
|
self.model = model
|
|
562
707
|
self.base_url = base_url
|
|
563
708
|
self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
@@ -577,8 +722,8 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
577
722
|
return formatted_params
|
|
578
723
|
|
|
579
724
|
|
|
580
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
581
|
-
|
|
725
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
726
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
582
727
|
"""
|
|
583
728
|
This method inputs chat messages and outputs LLM generated text.
|
|
584
729
|
|
|
@@ -590,7 +735,9 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
590
735
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
591
736
|
stream : bool, Optional
|
|
592
737
|
if True, returns a generator that yields the output in real-time.
|
|
593
|
-
|
|
738
|
+
messages_logger : MessagesLogger, Optional
|
|
739
|
+
the message logger that logs the chat messages.
|
|
740
|
+
|
|
594
741
|
Returns:
|
|
595
742
|
-------
|
|
596
743
|
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
@@ -605,11 +752,22 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
605
752
|
stream=True,
|
|
606
753
|
**self.formatted_params
|
|
607
754
|
)
|
|
755
|
+
res_text = ""
|
|
608
756
|
for chunk in response_stream:
|
|
609
757
|
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
610
758
|
if content_chunk:
|
|
759
|
+
res_text += content_chunk
|
|
611
760
|
yield content_chunk
|
|
612
761
|
|
|
762
|
+
# Postprocess response
|
|
763
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
764
|
+
# Write to messages log
|
|
765
|
+
if messages_logger:
|
|
766
|
+
processed_messages.append({"role": "assistant",
|
|
767
|
+
"content": res_dict.get("response", ""),
|
|
768
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
769
|
+
messages_logger.log_messages(processed_messages)
|
|
770
|
+
|
|
613
771
|
return self.config.postprocess_response(_stream_generator())
|
|
614
772
|
|
|
615
773
|
elif verbose:
|
|
@@ -625,7 +783,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
625
783
|
if content_chunk:
|
|
626
784
|
res += content_chunk
|
|
627
785
|
print(content_chunk, end='', flush=True)
|
|
628
|
-
|
|
786
|
+
|
|
629
787
|
|
|
630
788
|
else:
|
|
631
789
|
response = self.client.chat.completions.create(
|
|
@@ -634,9 +792,20 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
634
792
|
**self.formatted_params
|
|
635
793
|
)
|
|
636
794
|
res = response.choices[0].message.content
|
|
637
|
-
|
|
795
|
+
|
|
796
|
+
# Postprocess response
|
|
797
|
+
res_dict = self.config.postprocess_response(res)
|
|
798
|
+
# Write to messages log
|
|
799
|
+
if messages_logger:
|
|
800
|
+
processed_messages.append({"role": "assistant",
|
|
801
|
+
"content": res_dict.get("response", ""),
|
|
802
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
803
|
+
messages_logger.log_messages(processed_messages)
|
|
804
|
+
|
|
805
|
+
return res_dict
|
|
638
806
|
|
|
639
|
-
|
|
807
|
+
|
|
808
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
640
809
|
"""
|
|
641
810
|
Async version of chat method. Streaming is not supported.
|
|
642
811
|
"""
|
|
@@ -649,16 +818,299 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
649
818
|
)
|
|
650
819
|
|
|
651
820
|
res = response.choices[0].message.content
|
|
652
|
-
|
|
821
|
+
# Postprocess response
|
|
822
|
+
res_dict = self.config.postprocess_response(res)
|
|
823
|
+
# Write to messages log
|
|
824
|
+
if messages_logger:
|
|
825
|
+
processed_messages.append({"role": "assistant",
|
|
826
|
+
"content": res_dict.get("response", ""),
|
|
827
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
828
|
+
messages_logger.log_messages(processed_messages)
|
|
829
|
+
|
|
830
|
+
return res_dict
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
class OpenAICompatibleInferenceEngine(InferenceEngine):
|
|
834
|
+
def __init__(self, model:str, api_key:str, base_url:str, config:LLMConfig=None, **kwrs):
|
|
835
|
+
"""
|
|
836
|
+
General OpenAI-compatible server inference engine.
|
|
837
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
838
|
+
|
|
839
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
840
|
+
|
|
841
|
+
Parameters:
|
|
842
|
+
----------
|
|
843
|
+
model_name : str
|
|
844
|
+
model name as shown in the vLLM server
|
|
845
|
+
api_key : str
|
|
846
|
+
the API key for the vLLM server.
|
|
847
|
+
base_url : str
|
|
848
|
+
the base url for the vLLM server.
|
|
849
|
+
config : LLMConfig
|
|
850
|
+
the LLM configuration.
|
|
851
|
+
"""
|
|
852
|
+
if importlib.util.find_spec("openai") is None:
|
|
853
|
+
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
854
|
+
|
|
855
|
+
from openai import OpenAI, AsyncOpenAI
|
|
856
|
+
from openai.types.chat import ChatCompletionChunk
|
|
857
|
+
self.ChatCompletionChunk = ChatCompletionChunk
|
|
858
|
+
super().__init__(config)
|
|
859
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
860
|
+
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
861
|
+
self.model = model
|
|
862
|
+
self.config = config if config else BasicLLMConfig()
|
|
863
|
+
self.formatted_params = self._format_config()
|
|
864
|
+
|
|
865
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
866
|
+
"""
|
|
867
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
868
|
+
"""
|
|
869
|
+
formatted_params = self.config.params.copy()
|
|
870
|
+
if "max_new_tokens" in formatted_params:
|
|
871
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
872
|
+
formatted_params.pop("max_new_tokens")
|
|
873
|
+
|
|
874
|
+
return formatted_params
|
|
875
|
+
|
|
876
|
+
@abc.abstractmethod
|
|
877
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
878
|
+
"""
|
|
879
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
880
|
+
|
|
881
|
+
Parameters:
|
|
882
|
+
----------
|
|
883
|
+
response : Any
|
|
884
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
885
|
+
"""
|
|
886
|
+
return NotImplemented
|
|
887
|
+
|
|
888
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
889
|
+
"""
|
|
890
|
+
This method inputs chat messages and outputs LLM generated text.
|
|
891
|
+
|
|
892
|
+
Parameters:
|
|
893
|
+
----------
|
|
894
|
+
messages : List[Dict[str,str]]
|
|
895
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
896
|
+
verbose : bool, Optional
|
|
897
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
898
|
+
stream : bool, Optional
|
|
899
|
+
if True, returns a generator that yields the output in real-time.
|
|
900
|
+
messages_logger : MessagesLogger, Optional
|
|
901
|
+
the message logger that logs the chat messages.
|
|
902
|
+
|
|
903
|
+
Returns:
|
|
904
|
+
-------
|
|
905
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
906
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
907
|
+
"""
|
|
908
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
909
|
+
|
|
910
|
+
if stream:
|
|
911
|
+
def _stream_generator():
|
|
912
|
+
response_stream = self.client.chat.completions.create(
|
|
913
|
+
model=self.model,
|
|
914
|
+
messages=processed_messages,
|
|
915
|
+
stream=True,
|
|
916
|
+
**self.formatted_params
|
|
917
|
+
)
|
|
918
|
+
res_text = ""
|
|
919
|
+
for chunk in response_stream:
|
|
920
|
+
if len(chunk.choices) > 0:
|
|
921
|
+
chunk_dict = self._format_response(chunk)
|
|
922
|
+
yield chunk_dict
|
|
923
|
+
|
|
924
|
+
res_text += chunk_dict["data"]
|
|
925
|
+
if chunk.choices[0].finish_reason == "length":
|
|
926
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
927
|
+
|
|
928
|
+
# Postprocess response
|
|
929
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
930
|
+
# Write to messages log
|
|
931
|
+
if messages_logger:
|
|
932
|
+
processed_messages.append({"role": "assistant",
|
|
933
|
+
"content": res_dict.get("response", ""),
|
|
934
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
935
|
+
messages_logger.log_messages(processed_messages)
|
|
936
|
+
|
|
937
|
+
return self.config.postprocess_response(_stream_generator())
|
|
938
|
+
|
|
939
|
+
elif verbose:
|
|
940
|
+
response = self.client.chat.completions.create(
|
|
941
|
+
model=self.model,
|
|
942
|
+
messages=processed_messages,
|
|
943
|
+
stream=True,
|
|
944
|
+
**self.formatted_params
|
|
945
|
+
)
|
|
946
|
+
res = {"reasoning": "", "response": ""}
|
|
947
|
+
phase = ""
|
|
948
|
+
for chunk in response:
|
|
949
|
+
if len(chunk.choices) > 0:
|
|
950
|
+
chunk_dict = self._format_response(chunk)
|
|
951
|
+
chunk_text = chunk_dict["data"]
|
|
952
|
+
res[chunk_dict["type"]] += chunk_text
|
|
953
|
+
if phase != chunk_dict["type"] and chunk_text != "":
|
|
954
|
+
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
955
|
+
phase = chunk_dict["type"]
|
|
956
|
+
|
|
957
|
+
print(chunk_text, end="", flush=True)
|
|
958
|
+
if chunk.choices[0].finish_reason == "length":
|
|
959
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
960
|
+
|
|
961
|
+
print('\n')
|
|
962
|
+
|
|
963
|
+
else:
|
|
964
|
+
response = self.client.chat.completions.create(
|
|
965
|
+
model=self.model,
|
|
966
|
+
messages=processed_messages,
|
|
967
|
+
stream=False,
|
|
968
|
+
**self.formatted_params
|
|
969
|
+
)
|
|
970
|
+
res = self._format_response(response)
|
|
971
|
+
|
|
972
|
+
if response.choices[0].finish_reason == "length":
|
|
973
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
974
|
+
|
|
975
|
+
# Postprocess response
|
|
976
|
+
res_dict = self.config.postprocess_response(res)
|
|
977
|
+
# Write to messages log
|
|
978
|
+
if messages_logger:
|
|
979
|
+
processed_messages.append({"role": "assistant",
|
|
980
|
+
"content": res_dict.get("response", ""),
|
|
981
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
982
|
+
messages_logger.log_messages(processed_messages)
|
|
983
|
+
|
|
984
|
+
return res_dict
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
988
|
+
"""
|
|
989
|
+
Async version of chat method. Streaming is not supported.
|
|
990
|
+
"""
|
|
991
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
992
|
+
|
|
993
|
+
response = await self.async_client.chat.completions.create(
|
|
994
|
+
model=self.model,
|
|
995
|
+
messages=processed_messages,
|
|
996
|
+
stream=False,
|
|
997
|
+
**self.formatted_params
|
|
998
|
+
)
|
|
653
999
|
|
|
1000
|
+
if response.choices[0].finish_reason == "length":
|
|
1001
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1002
|
+
|
|
1003
|
+
res = self._format_response(response)
|
|
1004
|
+
|
|
1005
|
+
# Postprocess response
|
|
1006
|
+
res_dict = self.config.postprocess_response(res)
|
|
1007
|
+
# Write to messages log
|
|
1008
|
+
if messages_logger:
|
|
1009
|
+
processed_messages.append({"role": "assistant",
|
|
1010
|
+
"content": res_dict.get("response", ""),
|
|
1011
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1012
|
+
messages_logger.log_messages(processed_messages)
|
|
1013
|
+
|
|
1014
|
+
return res_dict
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
class VLLMInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1018
|
+
def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:LLMConfig=None, **kwrs):
|
|
1019
|
+
"""
|
|
1020
|
+
vLLM OpenAI compatible server inference engine.
|
|
1021
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
1022
|
+
|
|
1023
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
1024
|
+
|
|
1025
|
+
Parameters:
|
|
1026
|
+
----------
|
|
1027
|
+
model_name : str
|
|
1028
|
+
model name as shown in the vLLM server
|
|
1029
|
+
api_key : str, Optional
|
|
1030
|
+
the API key for the vLLM server.
|
|
1031
|
+
base_url : str, Optional
|
|
1032
|
+
the base url for the vLLM server.
|
|
1033
|
+
config : LLMConfig
|
|
1034
|
+
the LLM configuration.
|
|
1035
|
+
"""
|
|
1036
|
+
super().__init__(model, api_key, base_url, config, **kwrs)
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1040
|
+
"""
|
|
1041
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1042
|
+
|
|
1043
|
+
Parameters:
|
|
1044
|
+
----------
|
|
1045
|
+
response : Any
|
|
1046
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1047
|
+
"""
|
|
1048
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
1049
|
+
if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
|
|
1050
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
|
|
1051
|
+
if chunk_text is None:
|
|
1052
|
+
chunk_text = ""
|
|
1053
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
1054
|
+
else:
|
|
1055
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1056
|
+
if chunk_text is None:
|
|
1057
|
+
chunk_text = ""
|
|
1058
|
+
return {"type": "response", "data": chunk_text}
|
|
1059
|
+
|
|
1060
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
1061
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
class OpenRouterInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1065
|
+
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:LLMConfig=None, **kwrs):
|
|
1066
|
+
"""
|
|
1067
|
+
OpenRouter OpenAI-compatible server inference engine.
|
|
1068
|
+
|
|
1069
|
+
Parameters:
|
|
1070
|
+
----------
|
|
1071
|
+
model_name : str
|
|
1072
|
+
model name as shown in the vLLM server
|
|
1073
|
+
api_key : str, Optional
|
|
1074
|
+
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
1075
|
+
base_url : str, Optional
|
|
1076
|
+
the base url for the vLLM server.
|
|
1077
|
+
config : LLMConfig
|
|
1078
|
+
the LLM configuration.
|
|
1079
|
+
"""
|
|
1080
|
+
self.api_key = api_key
|
|
1081
|
+
if self.api_key is None:
|
|
1082
|
+
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
1083
|
+
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
1084
|
+
|
|
1085
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1086
|
+
"""
|
|
1087
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1088
|
+
|
|
1089
|
+
Parameters:
|
|
1090
|
+
----------
|
|
1091
|
+
response : Any
|
|
1092
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1093
|
+
"""
|
|
1094
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
1095
|
+
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
1096
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
1097
|
+
if chunk_text is None:
|
|
1098
|
+
chunk_text = ""
|
|
1099
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
1100
|
+
else:
|
|
1101
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1102
|
+
if chunk_text is None:
|
|
1103
|
+
chunk_text = ""
|
|
1104
|
+
return {"type": "response", "data": chunk_text}
|
|
1105
|
+
|
|
1106
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
1107
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
1108
|
+
|
|
654
1109
|
|
|
655
1110
|
class OpenAIInferenceEngine(InferenceEngine):
|
|
656
1111
|
def __init__(self, model:str, config:LLMConfig=None, **kwrs):
|
|
657
1112
|
"""
|
|
658
|
-
The OpenAI API inference engine.
|
|
659
|
-
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
660
|
-
- Llama.cpp OpenAI compatible server (https://llama-cpp-python.readthedocs.io/en/latest/server/)
|
|
661
|
-
|
|
1113
|
+
The OpenAI API inference engine.
|
|
662
1114
|
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
663
1115
|
|
|
664
1116
|
Parameters:
|
|
@@ -670,6 +1122,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
670
1122
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
671
1123
|
|
|
672
1124
|
from openai import OpenAI, AsyncOpenAI
|
|
1125
|
+
super().__init__(config)
|
|
673
1126
|
self.client = OpenAI(**kwrs)
|
|
674
1127
|
self.async_client = AsyncOpenAI(**kwrs)
|
|
675
1128
|
self.model = model
|
|
@@ -687,7 +1140,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
687
1140
|
|
|
688
1141
|
return formatted_params
|
|
689
1142
|
|
|
690
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
1143
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
691
1144
|
"""
|
|
692
1145
|
This method inputs chat messages and outputs LLM generated text.
|
|
693
1146
|
|
|
@@ -699,6 +1152,8 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
699
1152
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
700
1153
|
stream : bool, Optional
|
|
701
1154
|
if True, returns a generator that yields the output in real-time.
|
|
1155
|
+
messages_logger : MessagesLogger, Optional
|
|
1156
|
+
the message logger that logs the chat messages.
|
|
702
1157
|
|
|
703
1158
|
Returns:
|
|
704
1159
|
-------
|
|
@@ -715,13 +1170,25 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
715
1170
|
stream=True,
|
|
716
1171
|
**self.formatted_params
|
|
717
1172
|
)
|
|
1173
|
+
res_text = ""
|
|
718
1174
|
for chunk in response_stream:
|
|
719
1175
|
if len(chunk.choices) > 0:
|
|
720
|
-
|
|
721
|
-
|
|
1176
|
+
chunk_text = chunk.choices[0].delta.content
|
|
1177
|
+
if chunk_text is not None:
|
|
1178
|
+
res_text += chunk_text
|
|
1179
|
+
yield chunk_text
|
|
722
1180
|
if chunk.choices[0].finish_reason == "length":
|
|
723
1181
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
724
1182
|
|
|
1183
|
+
# Postprocess response
|
|
1184
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
1185
|
+
# Write to messages log
|
|
1186
|
+
if messages_logger:
|
|
1187
|
+
processed_messages.append({"role": "assistant",
|
|
1188
|
+
"content": res_dict.get("response", ""),
|
|
1189
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1190
|
+
messages_logger.log_messages(processed_messages)
|
|
1191
|
+
|
|
725
1192
|
return self.config.postprocess_response(_stream_generator())
|
|
726
1193
|
|
|
727
1194
|
elif verbose:
|
|
@@ -741,7 +1208,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
741
1208
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
742
1209
|
|
|
743
1210
|
print('\n')
|
|
744
|
-
|
|
1211
|
+
|
|
745
1212
|
else:
|
|
746
1213
|
response = self.client.chat.completions.create(
|
|
747
1214
|
model=self.model,
|
|
@@ -750,10 +1217,20 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
750
1217
|
**self.formatted_params
|
|
751
1218
|
)
|
|
752
1219
|
res = response.choices[0].message.content
|
|
753
|
-
|
|
1220
|
+
|
|
1221
|
+
# Postprocess response
|
|
1222
|
+
res_dict = self.config.postprocess_response(res)
|
|
1223
|
+
# Write to messages log
|
|
1224
|
+
if messages_logger:
|
|
1225
|
+
processed_messages.append({"role": "assistant",
|
|
1226
|
+
"content": res_dict.get("response", ""),
|
|
1227
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1228
|
+
messages_logger.log_messages(processed_messages)
|
|
1229
|
+
|
|
1230
|
+
return res_dict
|
|
754
1231
|
|
|
755
1232
|
|
|
756
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
|
|
1233
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
757
1234
|
"""
|
|
758
1235
|
Async version of chat method. Streaming is not supported.
|
|
759
1236
|
"""
|
|
@@ -770,7 +1247,16 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
770
1247
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
771
1248
|
|
|
772
1249
|
res = response.choices[0].message.content
|
|
773
|
-
|
|
1250
|
+
# Postprocess response
|
|
1251
|
+
res_dict = self.config.postprocess_response(res)
|
|
1252
|
+
# Write to messages log
|
|
1253
|
+
if messages_logger:
|
|
1254
|
+
processed_messages.append({"role": "assistant",
|
|
1255
|
+
"content": res_dict.get("response", ""),
|
|
1256
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1257
|
+
messages_logger.log_messages(processed_messages)
|
|
1258
|
+
|
|
1259
|
+
return res_dict
|
|
774
1260
|
|
|
775
1261
|
|
|
776
1262
|
class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
@@ -825,6 +1311,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
825
1311
|
raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
|
|
826
1312
|
|
|
827
1313
|
import litellm
|
|
1314
|
+
super().__init__(config)
|
|
828
1315
|
self.litellm = litellm
|
|
829
1316
|
self.model = model
|
|
830
1317
|
self.base_url = base_url
|
|
@@ -843,7 +1330,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
843
1330
|
|
|
844
1331
|
return formatted_params
|
|
845
1332
|
|
|
846
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
1333
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
847
1334
|
"""
|
|
848
1335
|
This method inputs chat messages and outputs LLM generated text.
|
|
849
1336
|
|
|
@@ -855,6 +1342,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
855
1342
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
856
1343
|
stream : bool, Optional
|
|
857
1344
|
if True, returns a generator that yields the output in real-time.
|
|
1345
|
+
messages_logger: MessagesLogger, Optional
|
|
1346
|
+
a messages logger that logs the messages.
|
|
858
1347
|
|
|
859
1348
|
Returns:
|
|
860
1349
|
-------
|
|
@@ -873,12 +1362,22 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
873
1362
|
api_key=self.api_key,
|
|
874
1363
|
**self.formatted_params
|
|
875
1364
|
)
|
|
876
|
-
|
|
1365
|
+
res_text = ""
|
|
877
1366
|
for chunk in response_stream:
|
|
878
1367
|
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
879
1368
|
if chunk_content:
|
|
1369
|
+
res_text += chunk_content
|
|
880
1370
|
yield chunk_content
|
|
881
1371
|
|
|
1372
|
+
# Postprocess response
|
|
1373
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
1374
|
+
# Write to messages log
|
|
1375
|
+
if messages_logger:
|
|
1376
|
+
processed_messages.append({"role": "assistant",
|
|
1377
|
+
"content": res_dict.get("response", ""),
|
|
1378
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1379
|
+
messages_logger.log_messages(processed_messages)
|
|
1380
|
+
|
|
882
1381
|
return self.config.postprocess_response(_stream_generator())
|
|
883
1382
|
|
|
884
1383
|
elif verbose:
|
|
@@ -897,8 +1396,6 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
897
1396
|
if chunk_content:
|
|
898
1397
|
res += chunk_content
|
|
899
1398
|
print(chunk_content, end='', flush=True)
|
|
900
|
-
|
|
901
|
-
return self.config.postprocess_response(res)
|
|
902
1399
|
|
|
903
1400
|
else:
|
|
904
1401
|
response = self.litellm.completion(
|
|
@@ -910,9 +1407,19 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
910
1407
|
**self.formatted_params
|
|
911
1408
|
)
|
|
912
1409
|
res = response.choices[0].message.content
|
|
913
|
-
|
|
1410
|
+
|
|
1411
|
+
# Postprocess response
|
|
1412
|
+
res_dict = self.config.postprocess_response(res)
|
|
1413
|
+
# Write to messages log
|
|
1414
|
+
if messages_logger:
|
|
1415
|
+
processed_messages.append({"role": "assistant",
|
|
1416
|
+
"content": res_dict.get("response", ""),
|
|
1417
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1418
|
+
messages_logger.log_messages(processed_messages)
|
|
1419
|
+
|
|
1420
|
+
return res_dict
|
|
914
1421
|
|
|
915
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> Dict[str,str]:
|
|
1422
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
916
1423
|
"""
|
|
917
1424
|
Async version of chat method. Streaming is not supported.
|
|
918
1425
|
"""
|
|
@@ -928,4 +1435,13 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
928
1435
|
)
|
|
929
1436
|
|
|
930
1437
|
res = response.get('choices')[0].get('message').get('content')
|
|
931
|
-
|
|
1438
|
+
|
|
1439
|
+
# Postprocess response
|
|
1440
|
+
res_dict = self.config.postprocess_response(res)
|
|
1441
|
+
# Write to messages log
|
|
1442
|
+
if messages_logger:
|
|
1443
|
+
processed_messages.append({"role": "assistant",
|
|
1444
|
+
"content": res_dict.get("response", ""),
|
|
1445
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1446
|
+
messages_logger.log_messages(processed_messages)
|
|
1447
|
+
return res_dict
|