llm-ie 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +5 -4
- llm_ie/chunkers.py +78 -4
- llm_ie/data_types.py +23 -37
- llm_ie/engines.py +663 -112
- llm_ie/extractors.py +357 -206
- llm_ie/prompt_editor.py +4 -4
- {llm_ie-1.2.1.dist-info → llm_ie-1.2.3.dist-info}/METADATA +1 -1
- {llm_ie-1.2.1.dist-info → llm_ie-1.2.3.dist-info}/RECORD +9 -9
- {llm_ie-1.2.1.dist-info → llm_ie-1.2.3.dist-info}/WHEEL +0 -0
llm_ie/engines.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import os
|
|
2
3
|
import re
|
|
3
4
|
import warnings
|
|
4
5
|
import importlib.util
|
|
@@ -33,18 +34,18 @@ class LLMConfig(abc.ABC):
|
|
|
33
34
|
return NotImplemented
|
|
34
35
|
|
|
35
36
|
@abc.abstractmethod
|
|
36
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
37
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
37
38
|
"""
|
|
38
39
|
This method postprocesses the LLM response after it is generated.
|
|
39
40
|
|
|
40
41
|
Parameters:
|
|
41
42
|
----------
|
|
42
|
-
response : Union[str, Generator[str, None, None]]
|
|
43
|
-
the LLM response. Can be a
|
|
43
|
+
response : Union[str, Dict[str, str], Generator[Dict[str, str], None, None]]
|
|
44
|
+
the LLM response. Can be a dict or a generator.
|
|
44
45
|
|
|
45
46
|
Returns:
|
|
46
47
|
-------
|
|
47
|
-
response : str
|
|
48
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
48
49
|
the postprocessed LLM response
|
|
49
50
|
"""
|
|
50
51
|
return NotImplemented
|
|
@@ -75,47 +76,58 @@ class BasicLLMConfig(LLMConfig):
|
|
|
75
76
|
messages : List[Dict[str,str]]
|
|
76
77
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
77
78
|
"""
|
|
78
|
-
return messages
|
|
79
|
+
return messages.copy()
|
|
79
80
|
|
|
80
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
81
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
81
82
|
"""
|
|
82
83
|
This method postprocesses the LLM response after it is generated.
|
|
83
84
|
|
|
84
85
|
Parameters:
|
|
85
86
|
----------
|
|
86
|
-
response : Union[str, Generator[str, None, None]]
|
|
87
|
+
response : Union[str, Dict[str, str], Generator[str, None, None]]
|
|
87
88
|
the LLM response. Can be a string or a generator.
|
|
88
89
|
|
|
89
|
-
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
90
|
+
Returns: Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
90
91
|
the postprocessed LLM response.
|
|
91
|
-
|
|
92
|
+
If input is a string, the output will be a dict {"response": <response>}.
|
|
93
|
+
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
92
94
|
"""
|
|
93
95
|
if isinstance(response, str):
|
|
94
|
-
return response
|
|
96
|
+
return {"response": response}
|
|
97
|
+
|
|
98
|
+
elif isinstance(response, dict):
|
|
99
|
+
if "response" in response:
|
|
100
|
+
return response
|
|
101
|
+
else:
|
|
102
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
103
|
+
return {"response": ""}
|
|
95
104
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
105
|
+
elif isinstance(response, Generator):
|
|
106
|
+
def _process_stream():
|
|
107
|
+
for chunk in response:
|
|
108
|
+
if isinstance(chunk, dict):
|
|
109
|
+
yield chunk
|
|
110
|
+
elif isinstance(chunk, str):
|
|
111
|
+
yield {"type": "response", "data": chunk}
|
|
112
|
+
|
|
113
|
+
return _process_stream()
|
|
99
114
|
|
|
100
|
-
|
|
115
|
+
else:
|
|
116
|
+
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
117
|
+
return {"response": ""}
|
|
101
118
|
|
|
102
|
-
class
|
|
103
|
-
def __init__(self,
|
|
119
|
+
class ReasoningLLMConfig(LLMConfig):
|
|
120
|
+
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
104
121
|
"""
|
|
105
|
-
The
|
|
106
|
-
|
|
107
|
-
Parameters:
|
|
108
|
-
----------
|
|
109
|
-
thinking_mode : bool, Optional
|
|
110
|
-
if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
|
|
122
|
+
The general LLM configuration for reasoning models.
|
|
111
123
|
"""
|
|
112
124
|
super().__init__(**kwargs)
|
|
113
|
-
self.
|
|
125
|
+
self.thinking_token_start = thinking_token_start
|
|
126
|
+
self.thinking_token_end = thinking_token_end
|
|
114
127
|
|
|
115
128
|
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
116
129
|
"""
|
|
117
|
-
|
|
118
|
-
The token is "/think" if thinking_mode is True, otherwise "/no_think".
|
|
130
|
+
This method preprocesses the input messages before passing them to the LLM.
|
|
119
131
|
|
|
120
132
|
Parameters:
|
|
121
133
|
----------
|
|
@@ -127,23 +139,16 @@ class Qwen3LLMConfig(LLMConfig):
|
|
|
127
139
|
messages : List[Dict[str,str]]
|
|
128
140
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
129
141
|
"""
|
|
130
|
-
|
|
131
|
-
new_messages = []
|
|
132
|
-
for message in messages:
|
|
133
|
-
if message['role'] in ['system', 'user']:
|
|
134
|
-
new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
|
|
135
|
-
else:
|
|
136
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
142
|
+
return messages.copy()
|
|
137
143
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
return new_messages
|
|
141
|
-
|
|
142
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str,str], None, None]]:
|
|
144
|
+
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
143
145
|
"""
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
If input is a
|
|
146
|
+
This method postprocesses the LLM response after it is generated.
|
|
147
|
+
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
148
|
+
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
149
|
+
3. If input is a generator,
|
|
150
|
+
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
151
|
+
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
147
152
|
|
|
148
153
|
Parameters:
|
|
149
154
|
----------
|
|
@@ -153,38 +158,99 @@ class Qwen3LLMConfig(LLMConfig):
|
|
|
153
158
|
Returns:
|
|
154
159
|
-------
|
|
155
160
|
response : Union[str, Generator[str, None, None]]
|
|
156
|
-
the postprocessed LLM response
|
|
161
|
+
the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
|
|
157
162
|
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
158
163
|
"""
|
|
159
164
|
if isinstance(response, str):
|
|
160
|
-
|
|
165
|
+
# get contents between thinking_token_start and thinking_token_end
|
|
166
|
+
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
167
|
+
match = re.search(pattern, response, re.DOTALL)
|
|
168
|
+
reasoning = match.group(1) if match else ""
|
|
169
|
+
# get response AFTER thinking_token_end
|
|
170
|
+
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
171
|
+
return {"reasoning": reasoning, "response": response}
|
|
172
|
+
|
|
173
|
+
elif isinstance(response, dict):
|
|
174
|
+
if "reasoning" in response and "response" in response:
|
|
175
|
+
return response
|
|
176
|
+
else:
|
|
177
|
+
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
178
|
+
return {"reasoning": "", "response": ""}
|
|
161
179
|
|
|
162
|
-
|
|
180
|
+
elif isinstance(response, Generator):
|
|
163
181
|
def _process_stream():
|
|
164
182
|
think_flag = False
|
|
165
183
|
buffer = ""
|
|
166
184
|
for chunk in response:
|
|
167
|
-
if isinstance(chunk,
|
|
185
|
+
if isinstance(chunk, dict):
|
|
186
|
+
yield chunk
|
|
187
|
+
|
|
188
|
+
elif isinstance(chunk, str):
|
|
168
189
|
buffer += chunk
|
|
169
190
|
# switch between reasoning and response
|
|
170
|
-
if
|
|
191
|
+
if self.thinking_token_start in buffer:
|
|
171
192
|
think_flag = True
|
|
172
|
-
buffer = buffer.replace(
|
|
173
|
-
elif
|
|
193
|
+
buffer = buffer.replace(self.thinking_token_start, "")
|
|
194
|
+
elif self.thinking_token_end in buffer:
|
|
174
195
|
think_flag = False
|
|
175
|
-
buffer = buffer.replace(
|
|
196
|
+
buffer = buffer.replace(self.thinking_token_end, "")
|
|
176
197
|
|
|
177
198
|
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
178
|
-
if chunk not in [
|
|
199
|
+
if chunk not in [self.thinking_token_start, self.thinking_token_end]:
|
|
179
200
|
if think_flag:
|
|
180
201
|
yield {"type": "reasoning", "data": chunk}
|
|
181
202
|
else:
|
|
182
203
|
yield {"type": "response", "data": chunk}
|
|
183
204
|
|
|
184
205
|
return _process_stream()
|
|
206
|
+
|
|
207
|
+
else:
|
|
208
|
+
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
209
|
+
return {"reasoning": "", "response": ""}
|
|
210
|
+
|
|
211
|
+
class Qwen3LLMConfig(ReasoningLLMConfig):
|
|
212
|
+
def __init__(self, thinking_mode:bool=True, **kwargs):
|
|
213
|
+
"""
|
|
214
|
+
The Qwen3 **hybrid thinking** LLM configuration.
|
|
215
|
+
For Qwen3 thinking 2507, use ReasoningLLMConfig instead; for Qwen3 Instruct, use BasicLLMConfig instead.
|
|
216
|
+
|
|
217
|
+
Parameters:
|
|
218
|
+
----------
|
|
219
|
+
thinking_mode : bool, Optional
|
|
220
|
+
if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
|
|
221
|
+
"""
|
|
222
|
+
super().__init__(**kwargs)
|
|
223
|
+
self.thinking_mode = thinking_mode
|
|
185
224
|
|
|
225
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
226
|
+
"""
|
|
227
|
+
Append a special token to the system and user prompts.
|
|
228
|
+
The token is "/think" if thinking_mode is True, otherwise "/no_think".
|
|
186
229
|
|
|
187
|
-
|
|
230
|
+
Parameters:
|
|
231
|
+
----------
|
|
232
|
+
messages : List[Dict[str,str]]
|
|
233
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
-------
|
|
237
|
+
messages : List[Dict[str,str]]
|
|
238
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
239
|
+
"""
|
|
240
|
+
thinking_token = "/think" if self.thinking_mode else "/no_think"
|
|
241
|
+
new_messages = []
|
|
242
|
+
for message in messages:
|
|
243
|
+
if message['role'] in ['system', 'user']:
|
|
244
|
+
new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
|
|
245
|
+
else:
|
|
246
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
247
|
+
|
|
248
|
+
new_messages.append(new_message)
|
|
249
|
+
|
|
250
|
+
return new_messages
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class OpenAIReasoningLLMConfig(ReasoningLLMConfig):
|
|
188
254
|
def __init__(self, reasoning_effort:str=None, **kwargs):
|
|
189
255
|
"""
|
|
190
256
|
The OpenAI "o" series configuration.
|
|
@@ -246,27 +312,31 @@ class OpenAIReasoningLLMConfig(LLMConfig):
|
|
|
246
312
|
|
|
247
313
|
return new_messages
|
|
248
314
|
|
|
249
|
-
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
250
|
-
"""
|
|
251
|
-
This method postprocesses the LLM response after it is generated.
|
|
252
315
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
response : Union[str, Generator[str, None, None]]
|
|
256
|
-
the LLM response. Can be a string or a generator.
|
|
257
|
-
|
|
258
|
-
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
259
|
-
the postprocessed LLM response.
|
|
260
|
-
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
316
|
+
class MessagesLogger:
|
|
317
|
+
def __init__(self):
|
|
261
318
|
"""
|
|
262
|
-
|
|
263
|
-
|
|
319
|
+
This class is used to log the messages for InferenceEngine.chat().
|
|
320
|
+
"""
|
|
321
|
+
self.messages_log = []
|
|
264
322
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
323
|
+
def log_messages(self, messages : List[Dict[str,str]]):
|
|
324
|
+
"""
|
|
325
|
+
This method logs the messages to a list.
|
|
326
|
+
"""
|
|
327
|
+
self.messages_log.append(messages)
|
|
268
328
|
|
|
269
|
-
|
|
329
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
330
|
+
"""
|
|
331
|
+
This method returns a copy of the current messages log
|
|
332
|
+
"""
|
|
333
|
+
return self.messages_log.copy()
|
|
334
|
+
|
|
335
|
+
def clear_messages_log(self):
|
|
336
|
+
"""
|
|
337
|
+
This method clears the current messages log
|
|
338
|
+
"""
|
|
339
|
+
self.messages_log.clear()
|
|
270
340
|
|
|
271
341
|
|
|
272
342
|
class InferenceEngine:
|
|
@@ -283,10 +353,16 @@ class InferenceEngine:
|
|
|
283
353
|
"""
|
|
284
354
|
return NotImplemented
|
|
285
355
|
|
|
356
|
+
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
357
|
+
return self.messages_log.copy()
|
|
358
|
+
|
|
359
|
+
def clear_messages_log(self):
|
|
360
|
+
self.messages_log = []
|
|
361
|
+
|
|
286
362
|
|
|
287
363
|
@abc.abstractmethod
|
|
288
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
289
|
-
|
|
364
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
365
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
290
366
|
"""
|
|
291
367
|
This method inputs chat messages and outputs LLM generated text.
|
|
292
368
|
|
|
@@ -298,6 +374,13 @@ class InferenceEngine:
|
|
|
298
374
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
299
375
|
stream : bool, Optional
|
|
300
376
|
if True, returns a generator that yields the output in real-time.
|
|
377
|
+
Messages_logger : MessagesLogger, Optional
|
|
378
|
+
the message logger that logs the chat messages.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
-------
|
|
382
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
383
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
301
384
|
"""
|
|
302
385
|
return NotImplemented
|
|
303
386
|
|
|
@@ -331,6 +414,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
331
414
|
the LLM configuration.
|
|
332
415
|
"""
|
|
333
416
|
from llama_cpp import Llama
|
|
417
|
+
super().__init__(config)
|
|
334
418
|
self.repo_id = repo_id
|
|
335
419
|
self.gguf_filename = gguf_filename
|
|
336
420
|
self.n_ctx = n_ctx
|
|
@@ -363,7 +447,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
363
447
|
|
|
364
448
|
return formatted_params
|
|
365
449
|
|
|
366
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> str:
|
|
450
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
367
451
|
"""
|
|
368
452
|
This method inputs chat messages and outputs LLM generated text.
|
|
369
453
|
|
|
@@ -373,15 +457,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
373
457
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
374
458
|
verbose : bool, Optional
|
|
375
459
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
460
|
+
messages_logger : MessagesLogger, Optional
|
|
461
|
+
the message logger that logs the chat messages.
|
|
376
462
|
"""
|
|
463
|
+
# Preprocess messages
|
|
377
464
|
processed_messages = self.config.preprocess_messages(messages)
|
|
378
|
-
|
|
465
|
+
# Generate response
|
|
379
466
|
response = self.model.create_chat_completion(
|
|
380
467
|
messages=processed_messages,
|
|
381
468
|
stream=verbose,
|
|
382
469
|
**self.formatted_params
|
|
383
470
|
)
|
|
384
|
-
|
|
471
|
+
|
|
385
472
|
if verbose:
|
|
386
473
|
res = ''
|
|
387
474
|
for chunk in response:
|
|
@@ -393,7 +480,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
393
480
|
return self.config.postprocess_response(res)
|
|
394
481
|
|
|
395
482
|
res = response['choices'][0]['message']['content']
|
|
396
|
-
|
|
483
|
+
# Postprocess response
|
|
484
|
+
res_dict = self.config.postprocess_response(res)
|
|
485
|
+
# Write to messages log
|
|
486
|
+
if messages_logger:
|
|
487
|
+
processed_messages.append({"role": "assistant",
|
|
488
|
+
"content": res_dict.get("response", ""),
|
|
489
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
490
|
+
messages_logger.log_messages(processed_messages)
|
|
491
|
+
|
|
492
|
+
return res_dict
|
|
397
493
|
|
|
398
494
|
|
|
399
495
|
class OllamaInferenceEngine(InferenceEngine):
|
|
@@ -416,6 +512,7 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
416
512
|
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
417
513
|
|
|
418
514
|
from ollama import Client, AsyncClient
|
|
515
|
+
super().__init__(config)
|
|
419
516
|
self.client = Client(**kwrs)
|
|
420
517
|
self.async_client = AsyncClient(**kwrs)
|
|
421
518
|
self.model_name = model_name
|
|
@@ -435,8 +532,8 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
435
532
|
|
|
436
533
|
return formatted_params
|
|
437
534
|
|
|
438
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
439
|
-
|
|
535
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
536
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
440
537
|
"""
|
|
441
538
|
This method inputs chat messages and outputs VLM generated text.
|
|
442
539
|
|
|
@@ -448,6 +545,13 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
448
545
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
449
546
|
stream : bool, Optional
|
|
450
547
|
if True, returns a generator that yields the output in real-time.
|
|
548
|
+
Messages_logger : MessagesLogger, Optional
|
|
549
|
+
the message logger that logs the chat messages.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
-------
|
|
553
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
554
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
451
555
|
"""
|
|
452
556
|
processed_messages = self.config.preprocess_messages(messages)
|
|
453
557
|
|
|
@@ -461,10 +565,28 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
461
565
|
stream=True,
|
|
462
566
|
keep_alive=self.keep_alive
|
|
463
567
|
)
|
|
568
|
+
res = {"reasoning": "", "response": ""}
|
|
464
569
|
for chunk in response_stream:
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
570
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
571
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
572
|
+
res["reasoning"] += content_chunk
|
|
573
|
+
yield {"type": "reasoning", "data": content_chunk}
|
|
574
|
+
else:
|
|
575
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
576
|
+
res["response"] += content_chunk
|
|
577
|
+
yield {"type": "response", "data": content_chunk}
|
|
578
|
+
|
|
579
|
+
if chunk.done_reason == "length":
|
|
580
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
581
|
+
|
|
582
|
+
# Postprocess response
|
|
583
|
+
res_dict = self.config.postprocess_response(res)
|
|
584
|
+
# Write to messages log
|
|
585
|
+
if messages_logger:
|
|
586
|
+
processed_messages.append({"role": "assistant",
|
|
587
|
+
"content": res_dict.get("response", ""),
|
|
588
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
589
|
+
messages_logger.log_messages(processed_messages)
|
|
468
590
|
|
|
469
591
|
return self.config.postprocess_response(_stream_generator())
|
|
470
592
|
|
|
@@ -477,14 +599,29 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
477
599
|
keep_alive=self.keep_alive
|
|
478
600
|
)
|
|
479
601
|
|
|
480
|
-
res =
|
|
602
|
+
res = {"reasoning": "", "response": ""}
|
|
603
|
+
phase = ""
|
|
481
604
|
for chunk in response:
|
|
482
|
-
|
|
605
|
+
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
606
|
+
if phase != "reasoning":
|
|
607
|
+
print("\n--- Reasoning ---")
|
|
608
|
+
phase = "reasoning"
|
|
609
|
+
|
|
610
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
611
|
+
res["reasoning"] += content_chunk
|
|
612
|
+
else:
|
|
613
|
+
if phase != "response":
|
|
614
|
+
print("\n--- Response ---")
|
|
615
|
+
phase = "response"
|
|
616
|
+
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
617
|
+
res["response"] += content_chunk
|
|
618
|
+
|
|
483
619
|
print(content_chunk, end='', flush=True)
|
|
484
|
-
|
|
620
|
+
|
|
621
|
+
if chunk.done_reason == "length":
|
|
622
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
485
623
|
print('\n')
|
|
486
|
-
|
|
487
|
-
|
|
624
|
+
|
|
488
625
|
else:
|
|
489
626
|
response = self.client.chat(
|
|
490
627
|
model=self.model_name,
|
|
@@ -493,11 +630,25 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
493
630
|
stream=False,
|
|
494
631
|
keep_alive=self.keep_alive
|
|
495
632
|
)
|
|
496
|
-
res = response
|
|
497
|
-
|
|
633
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
634
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
635
|
+
|
|
636
|
+
if response.done_reason == "length":
|
|
637
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
638
|
+
|
|
639
|
+
# Postprocess response
|
|
640
|
+
res_dict = self.config.postprocess_response(res)
|
|
641
|
+
# Write to messages log
|
|
642
|
+
if messages_logger:
|
|
643
|
+
processed_messages.append({"role": "assistant",
|
|
644
|
+
"content": res_dict.get("response", ""),
|
|
645
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
646
|
+
messages_logger.log_messages(processed_messages)
|
|
647
|
+
|
|
648
|
+
return res_dict
|
|
498
649
|
|
|
499
650
|
|
|
500
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
651
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
501
652
|
"""
|
|
502
653
|
Async version of chat method. Streaming is not supported.
|
|
503
654
|
"""
|
|
@@ -511,8 +662,21 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
511
662
|
keep_alive=self.keep_alive
|
|
512
663
|
)
|
|
513
664
|
|
|
514
|
-
res = response
|
|
515
|
-
|
|
665
|
+
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
666
|
+
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
667
|
+
|
|
668
|
+
if response.done_reason == "length":
|
|
669
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
670
|
+
# Postprocess response
|
|
671
|
+
res_dict = self.config.postprocess_response(res)
|
|
672
|
+
# Write to messages log
|
|
673
|
+
if messages_logger:
|
|
674
|
+
processed_messages.append({"role": "assistant",
|
|
675
|
+
"content": res_dict.get("response", ""),
|
|
676
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
677
|
+
messages_logger.log_messages(processed_messages)
|
|
678
|
+
|
|
679
|
+
return res_dict
|
|
516
680
|
|
|
517
681
|
|
|
518
682
|
class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
@@ -538,6 +702,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
538
702
|
raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
|
|
539
703
|
|
|
540
704
|
from huggingface_hub import InferenceClient, AsyncInferenceClient
|
|
705
|
+
super().__init__(config)
|
|
541
706
|
self.model = model
|
|
542
707
|
self.base_url = base_url
|
|
543
708
|
self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
@@ -557,8 +722,8 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
557
722
|
return formatted_params
|
|
558
723
|
|
|
559
724
|
|
|
560
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
561
|
-
|
|
725
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
726
|
+
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
562
727
|
"""
|
|
563
728
|
This method inputs chat messages and outputs LLM generated text.
|
|
564
729
|
|
|
@@ -570,6 +735,13 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
570
735
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
571
736
|
stream : bool, Optional
|
|
572
737
|
if True, returns a generator that yields the output in real-time.
|
|
738
|
+
messages_logger : MessagesLogger, Optional
|
|
739
|
+
the message logger that logs the chat messages.
|
|
740
|
+
|
|
741
|
+
Returns:
|
|
742
|
+
-------
|
|
743
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
744
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
573
745
|
"""
|
|
574
746
|
processed_messages = self.config.preprocess_messages(messages)
|
|
575
747
|
|
|
@@ -580,11 +752,22 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
580
752
|
stream=True,
|
|
581
753
|
**self.formatted_params
|
|
582
754
|
)
|
|
755
|
+
res_text = ""
|
|
583
756
|
for chunk in response_stream:
|
|
584
757
|
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
585
758
|
if content_chunk:
|
|
759
|
+
res_text += content_chunk
|
|
586
760
|
yield content_chunk
|
|
587
761
|
|
|
762
|
+
# Postprocess response
|
|
763
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
764
|
+
# Write to messages log
|
|
765
|
+
if messages_logger:
|
|
766
|
+
processed_messages.append({"role": "assistant",
|
|
767
|
+
"content": res_dict.get("response", ""),
|
|
768
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
769
|
+
messages_logger.log_messages(processed_messages)
|
|
770
|
+
|
|
588
771
|
return self.config.postprocess_response(_stream_generator())
|
|
589
772
|
|
|
590
773
|
elif verbose:
|
|
@@ -600,7 +783,7 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
600
783
|
if content_chunk:
|
|
601
784
|
res += content_chunk
|
|
602
785
|
print(content_chunk, end='', flush=True)
|
|
603
|
-
|
|
786
|
+
|
|
604
787
|
|
|
605
788
|
else:
|
|
606
789
|
response = self.client.chat.completions.create(
|
|
@@ -609,9 +792,20 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
609
792
|
**self.formatted_params
|
|
610
793
|
)
|
|
611
794
|
res = response.choices[0].message.content
|
|
612
|
-
|
|
795
|
+
|
|
796
|
+
# Postprocess response
|
|
797
|
+
res_dict = self.config.postprocess_response(res)
|
|
798
|
+
# Write to messages log
|
|
799
|
+
if messages_logger:
|
|
800
|
+
processed_messages.append({"role": "assistant",
|
|
801
|
+
"content": res_dict.get("response", ""),
|
|
802
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
803
|
+
messages_logger.log_messages(processed_messages)
|
|
804
|
+
|
|
805
|
+
return res_dict
|
|
613
806
|
|
|
614
|
-
|
|
807
|
+
|
|
808
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
615
809
|
"""
|
|
616
810
|
Async version of chat method. Streaming is not supported.
|
|
617
811
|
"""
|
|
@@ -624,16 +818,299 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
624
818
|
)
|
|
625
819
|
|
|
626
820
|
res = response.choices[0].message.content
|
|
627
|
-
|
|
821
|
+
# Postprocess response
|
|
822
|
+
res_dict = self.config.postprocess_response(res)
|
|
823
|
+
# Write to messages log
|
|
824
|
+
if messages_logger:
|
|
825
|
+
processed_messages.append({"role": "assistant",
|
|
826
|
+
"content": res_dict.get("response", ""),
|
|
827
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
828
|
+
messages_logger.log_messages(processed_messages)
|
|
829
|
+
|
|
830
|
+
return res_dict
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
class OpenAICompatibleInferenceEngine(InferenceEngine):
|
|
834
|
+
def __init__(self, model:str, api_key:str, base_url:str, config:LLMConfig=None, **kwrs):
|
|
835
|
+
"""
|
|
836
|
+
General OpenAI-compatible server inference engine.
|
|
837
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
838
|
+
|
|
839
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
840
|
+
|
|
841
|
+
Parameters:
|
|
842
|
+
----------
|
|
843
|
+
model_name : str
|
|
844
|
+
model name as shown in the vLLM server
|
|
845
|
+
api_key : str
|
|
846
|
+
the API key for the vLLM server.
|
|
847
|
+
base_url : str
|
|
848
|
+
the base url for the vLLM server.
|
|
849
|
+
config : LLMConfig
|
|
850
|
+
the LLM configuration.
|
|
851
|
+
"""
|
|
852
|
+
if importlib.util.find_spec("openai") is None:
|
|
853
|
+
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
628
854
|
|
|
855
|
+
from openai import OpenAI, AsyncOpenAI
|
|
856
|
+
from openai.types.chat import ChatCompletionChunk
|
|
857
|
+
self.ChatCompletionChunk = ChatCompletionChunk
|
|
858
|
+
super().__init__(config)
|
|
859
|
+
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
860
|
+
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
861
|
+
self.model = model
|
|
862
|
+
self.config = config if config else BasicLLMConfig()
|
|
863
|
+
self.formatted_params = self._format_config()
|
|
864
|
+
|
|
865
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
866
|
+
"""
|
|
867
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
868
|
+
"""
|
|
869
|
+
formatted_params = self.config.params.copy()
|
|
870
|
+
if "max_new_tokens" in formatted_params:
|
|
871
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
872
|
+
formatted_params.pop("max_new_tokens")
|
|
873
|
+
|
|
874
|
+
return formatted_params
|
|
875
|
+
|
|
876
|
+
@abc.abstractmethod
|
|
877
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
878
|
+
"""
|
|
879
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
880
|
+
|
|
881
|
+
Parameters:
|
|
882
|
+
----------
|
|
883
|
+
response : Any
|
|
884
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
885
|
+
"""
|
|
886
|
+
return NotImplemented
|
|
887
|
+
|
|
888
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
889
|
+
"""
|
|
890
|
+
This method inputs chat messages and outputs LLM generated text.
|
|
891
|
+
|
|
892
|
+
Parameters:
|
|
893
|
+
----------
|
|
894
|
+
messages : List[Dict[str,str]]
|
|
895
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
896
|
+
verbose : bool, Optional
|
|
897
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
898
|
+
stream : bool, Optional
|
|
899
|
+
if True, returns a generator that yields the output in real-time.
|
|
900
|
+
messages_logger : MessagesLogger, Optional
|
|
901
|
+
the message logger that logs the chat messages.
|
|
902
|
+
|
|
903
|
+
Returns:
|
|
904
|
+
-------
|
|
905
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
906
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
907
|
+
"""
|
|
908
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
909
|
+
|
|
910
|
+
if stream:
|
|
911
|
+
def _stream_generator():
|
|
912
|
+
response_stream = self.client.chat.completions.create(
|
|
913
|
+
model=self.model,
|
|
914
|
+
messages=processed_messages,
|
|
915
|
+
stream=True,
|
|
916
|
+
**self.formatted_params
|
|
917
|
+
)
|
|
918
|
+
res_text = ""
|
|
919
|
+
for chunk in response_stream:
|
|
920
|
+
if len(chunk.choices) > 0:
|
|
921
|
+
chunk_dict = self._format_response(chunk)
|
|
922
|
+
yield chunk_dict
|
|
923
|
+
|
|
924
|
+
res_text += chunk_dict["data"]
|
|
925
|
+
if chunk.choices[0].finish_reason == "length":
|
|
926
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
927
|
+
|
|
928
|
+
# Postprocess response
|
|
929
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
930
|
+
# Write to messages log
|
|
931
|
+
if messages_logger:
|
|
932
|
+
processed_messages.append({"role": "assistant",
|
|
933
|
+
"content": res_dict.get("response", ""),
|
|
934
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
935
|
+
messages_logger.log_messages(processed_messages)
|
|
936
|
+
|
|
937
|
+
return self.config.postprocess_response(_stream_generator())
|
|
938
|
+
|
|
939
|
+
elif verbose:
|
|
940
|
+
response = self.client.chat.completions.create(
|
|
941
|
+
model=self.model,
|
|
942
|
+
messages=processed_messages,
|
|
943
|
+
stream=True,
|
|
944
|
+
**self.formatted_params
|
|
945
|
+
)
|
|
946
|
+
res = {"reasoning": "", "response": ""}
|
|
947
|
+
phase = ""
|
|
948
|
+
for chunk in response:
|
|
949
|
+
if len(chunk.choices) > 0:
|
|
950
|
+
chunk_dict = self._format_response(chunk)
|
|
951
|
+
chunk_text = chunk_dict["data"]
|
|
952
|
+
res[chunk_dict["type"]] += chunk_text
|
|
953
|
+
if phase != chunk_dict["type"] and chunk_text != "":
|
|
954
|
+
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
955
|
+
phase = chunk_dict["type"]
|
|
956
|
+
|
|
957
|
+
print(chunk_text, end="", flush=True)
|
|
958
|
+
if chunk.choices[0].finish_reason == "length":
|
|
959
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
960
|
+
|
|
961
|
+
print('\n')
|
|
962
|
+
|
|
963
|
+
else:
|
|
964
|
+
response = self.client.chat.completions.create(
|
|
965
|
+
model=self.model,
|
|
966
|
+
messages=processed_messages,
|
|
967
|
+
stream=False,
|
|
968
|
+
**self.formatted_params
|
|
969
|
+
)
|
|
970
|
+
res = self._format_response(response)
|
|
971
|
+
|
|
972
|
+
if response.choices[0].finish_reason == "length":
|
|
973
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
974
|
+
|
|
975
|
+
# Postprocess response
|
|
976
|
+
res_dict = self.config.postprocess_response(res)
|
|
977
|
+
# Write to messages log
|
|
978
|
+
if messages_logger:
|
|
979
|
+
processed_messages.append({"role": "assistant",
|
|
980
|
+
"content": res_dict.get("response", ""),
|
|
981
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
982
|
+
messages_logger.log_messages(processed_messages)
|
|
983
|
+
|
|
984
|
+
return res_dict
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
988
|
+
"""
|
|
989
|
+
Async version of chat method. Streaming is not supported.
|
|
990
|
+
"""
|
|
991
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
992
|
+
|
|
993
|
+
response = await self.async_client.chat.completions.create(
|
|
994
|
+
model=self.model,
|
|
995
|
+
messages=processed_messages,
|
|
996
|
+
stream=False,
|
|
997
|
+
**self.formatted_params
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
if response.choices[0].finish_reason == "length":
|
|
1001
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1002
|
+
|
|
1003
|
+
res = self._format_response(response)
|
|
1004
|
+
|
|
1005
|
+
# Postprocess response
|
|
1006
|
+
res_dict = self.config.postprocess_response(res)
|
|
1007
|
+
# Write to messages log
|
|
1008
|
+
if messages_logger:
|
|
1009
|
+
processed_messages.append({"role": "assistant",
|
|
1010
|
+
"content": res_dict.get("response", ""),
|
|
1011
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1012
|
+
messages_logger.log_messages(processed_messages)
|
|
1013
|
+
|
|
1014
|
+
return res_dict
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
class VLLMInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1018
|
+
def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:LLMConfig=None, **kwrs):
|
|
1019
|
+
"""
|
|
1020
|
+
vLLM OpenAI compatible server inference engine.
|
|
1021
|
+
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
1022
|
+
|
|
1023
|
+
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
1024
|
+
|
|
1025
|
+
Parameters:
|
|
1026
|
+
----------
|
|
1027
|
+
model_name : str
|
|
1028
|
+
model name as shown in the vLLM server
|
|
1029
|
+
api_key : str, Optional
|
|
1030
|
+
the API key for the vLLM server.
|
|
1031
|
+
base_url : str, Optional
|
|
1032
|
+
the base url for the vLLM server.
|
|
1033
|
+
config : LLMConfig
|
|
1034
|
+
the LLM configuration.
|
|
1035
|
+
"""
|
|
1036
|
+
super().__init__(model, api_key, base_url, config, **kwrs)
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1040
|
+
"""
|
|
1041
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1042
|
+
|
|
1043
|
+
Parameters:
|
|
1044
|
+
----------
|
|
1045
|
+
response : Any
|
|
1046
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1047
|
+
"""
|
|
1048
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
1049
|
+
if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
|
|
1050
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
|
|
1051
|
+
if chunk_text is None:
|
|
1052
|
+
chunk_text = ""
|
|
1053
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
1054
|
+
else:
|
|
1055
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1056
|
+
if chunk_text is None:
|
|
1057
|
+
chunk_text = ""
|
|
1058
|
+
return {"type": "response", "data": chunk_text}
|
|
1059
|
+
|
|
1060
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
1061
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
class OpenRouterInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1065
|
+
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:LLMConfig=None, **kwrs):
|
|
1066
|
+
"""
|
|
1067
|
+
OpenRouter OpenAI-compatible server inference engine.
|
|
1068
|
+
|
|
1069
|
+
Parameters:
|
|
1070
|
+
----------
|
|
1071
|
+
model_name : str
|
|
1072
|
+
model name as shown in the vLLM server
|
|
1073
|
+
api_key : str, Optional
|
|
1074
|
+
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
1075
|
+
base_url : str, Optional
|
|
1076
|
+
the base url for the vLLM server.
|
|
1077
|
+
config : LLMConfig
|
|
1078
|
+
the LLM configuration.
|
|
1079
|
+
"""
|
|
1080
|
+
self.api_key = api_key
|
|
1081
|
+
if self.api_key is None:
|
|
1082
|
+
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
1083
|
+
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
1084
|
+
|
|
1085
|
+
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1086
|
+
"""
|
|
1087
|
+
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1088
|
+
|
|
1089
|
+
Parameters:
|
|
1090
|
+
----------
|
|
1091
|
+
response : Any
|
|
1092
|
+
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1093
|
+
"""
|
|
1094
|
+
if isinstance(response, self.ChatCompletionChunk):
|
|
1095
|
+
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
1096
|
+
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
1097
|
+
if chunk_text is None:
|
|
1098
|
+
chunk_text = ""
|
|
1099
|
+
return {"type": "reasoning", "data": chunk_text}
|
|
1100
|
+
else:
|
|
1101
|
+
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1102
|
+
if chunk_text is None:
|
|
1103
|
+
chunk_text = ""
|
|
1104
|
+
return {"type": "response", "data": chunk_text}
|
|
1105
|
+
|
|
1106
|
+
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
1107
|
+
"response": getattr(response.choices[0].message, "content", "")}
|
|
1108
|
+
|
|
629
1109
|
|
|
630
1110
|
class OpenAIInferenceEngine(InferenceEngine):
|
|
631
1111
|
def __init__(self, model:str, config:LLMConfig=None, **kwrs):
|
|
632
1112
|
"""
|
|
633
|
-
The OpenAI API inference engine.
|
|
634
|
-
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
635
|
-
- Llama.cpp OpenAI compatible server (https://llama-cpp-python.readthedocs.io/en/latest/server/)
|
|
636
|
-
|
|
1113
|
+
The OpenAI API inference engine.
|
|
637
1114
|
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
638
1115
|
|
|
639
1116
|
Parameters:
|
|
@@ -645,6 +1122,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
645
1122
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
646
1123
|
|
|
647
1124
|
from openai import OpenAI, AsyncOpenAI
|
|
1125
|
+
super().__init__(config)
|
|
648
1126
|
self.client = OpenAI(**kwrs)
|
|
649
1127
|
self.async_client = AsyncOpenAI(**kwrs)
|
|
650
1128
|
self.model = model
|
|
@@ -662,7 +1140,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
662
1140
|
|
|
663
1141
|
return formatted_params
|
|
664
1142
|
|
|
665
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
1143
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
666
1144
|
"""
|
|
667
1145
|
This method inputs chat messages and outputs LLM generated text.
|
|
668
1146
|
|
|
@@ -674,6 +1152,13 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
674
1152
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
675
1153
|
stream : bool, Optional
|
|
676
1154
|
if True, returns a generator that yields the output in real-time.
|
|
1155
|
+
messages_logger : MessagesLogger, Optional
|
|
1156
|
+
the message logger that logs the chat messages.
|
|
1157
|
+
|
|
1158
|
+
Returns:
|
|
1159
|
+
-------
|
|
1160
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
1161
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
677
1162
|
"""
|
|
678
1163
|
processed_messages = self.config.preprocess_messages(messages)
|
|
679
1164
|
|
|
@@ -685,13 +1170,25 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
685
1170
|
stream=True,
|
|
686
1171
|
**self.formatted_params
|
|
687
1172
|
)
|
|
1173
|
+
res_text = ""
|
|
688
1174
|
for chunk in response_stream:
|
|
689
1175
|
if len(chunk.choices) > 0:
|
|
690
|
-
|
|
691
|
-
|
|
1176
|
+
chunk_text = chunk.choices[0].delta.content
|
|
1177
|
+
if chunk_text is not None:
|
|
1178
|
+
res_text += chunk_text
|
|
1179
|
+
yield chunk_text
|
|
692
1180
|
if chunk.choices[0].finish_reason == "length":
|
|
693
1181
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
694
1182
|
|
|
1183
|
+
# Postprocess response
|
|
1184
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
1185
|
+
# Write to messages log
|
|
1186
|
+
if messages_logger:
|
|
1187
|
+
processed_messages.append({"role": "assistant",
|
|
1188
|
+
"content": res_dict.get("response", ""),
|
|
1189
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1190
|
+
messages_logger.log_messages(processed_messages)
|
|
1191
|
+
|
|
695
1192
|
return self.config.postprocess_response(_stream_generator())
|
|
696
1193
|
|
|
697
1194
|
elif verbose:
|
|
@@ -711,7 +1208,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
711
1208
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
712
1209
|
|
|
713
1210
|
print('\n')
|
|
714
|
-
|
|
1211
|
+
|
|
715
1212
|
else:
|
|
716
1213
|
response = self.client.chat.completions.create(
|
|
717
1214
|
model=self.model,
|
|
@@ -720,10 +1217,20 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
720
1217
|
**self.formatted_params
|
|
721
1218
|
)
|
|
722
1219
|
res = response.choices[0].message.content
|
|
723
|
-
|
|
1220
|
+
|
|
1221
|
+
# Postprocess response
|
|
1222
|
+
res_dict = self.config.postprocess_response(res)
|
|
1223
|
+
# Write to messages log
|
|
1224
|
+
if messages_logger:
|
|
1225
|
+
processed_messages.append({"role": "assistant",
|
|
1226
|
+
"content": res_dict.get("response", ""),
|
|
1227
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1228
|
+
messages_logger.log_messages(processed_messages)
|
|
1229
|
+
|
|
1230
|
+
return res_dict
|
|
724
1231
|
|
|
725
1232
|
|
|
726
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
1233
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
727
1234
|
"""
|
|
728
1235
|
Async version of chat method. Streaming is not supported.
|
|
729
1236
|
"""
|
|
@@ -740,7 +1247,16 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
740
1247
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
741
1248
|
|
|
742
1249
|
res = response.choices[0].message.content
|
|
743
|
-
|
|
1250
|
+
# Postprocess response
|
|
1251
|
+
res_dict = self.config.postprocess_response(res)
|
|
1252
|
+
# Write to messages log
|
|
1253
|
+
if messages_logger:
|
|
1254
|
+
processed_messages.append({"role": "assistant",
|
|
1255
|
+
"content": res_dict.get("response", ""),
|
|
1256
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1257
|
+
messages_logger.log_messages(processed_messages)
|
|
1258
|
+
|
|
1259
|
+
return res_dict
|
|
744
1260
|
|
|
745
1261
|
|
|
746
1262
|
class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
@@ -795,6 +1311,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
795
1311
|
raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
|
|
796
1312
|
|
|
797
1313
|
import litellm
|
|
1314
|
+
super().__init__(config)
|
|
798
1315
|
self.litellm = litellm
|
|
799
1316
|
self.model = model
|
|
800
1317
|
self.base_url = base_url
|
|
@@ -813,7 +1330,7 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
813
1330
|
|
|
814
1331
|
return formatted_params
|
|
815
1332
|
|
|
816
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
1333
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
817
1334
|
"""
|
|
818
1335
|
This method inputs chat messages and outputs LLM generated text.
|
|
819
1336
|
|
|
@@ -825,6 +1342,13 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
825
1342
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
826
1343
|
stream : bool, Optional
|
|
827
1344
|
if True, returns a generator that yields the output in real-time.
|
|
1345
|
+
messages_logger: MessagesLogger, Optional
|
|
1346
|
+
a messages logger that logs the messages.
|
|
1347
|
+
|
|
1348
|
+
Returns:
|
|
1349
|
+
-------
|
|
1350
|
+
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
1351
|
+
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
828
1352
|
"""
|
|
829
1353
|
processed_messages = self.config.preprocess_messages(messages)
|
|
830
1354
|
|
|
@@ -838,12 +1362,22 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
838
1362
|
api_key=self.api_key,
|
|
839
1363
|
**self.formatted_params
|
|
840
1364
|
)
|
|
841
|
-
|
|
1365
|
+
res_text = ""
|
|
842
1366
|
for chunk in response_stream:
|
|
843
1367
|
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
844
1368
|
if chunk_content:
|
|
1369
|
+
res_text += chunk_content
|
|
845
1370
|
yield chunk_content
|
|
846
1371
|
|
|
1372
|
+
# Postprocess response
|
|
1373
|
+
res_dict = self.config.postprocess_response(res_text)
|
|
1374
|
+
# Write to messages log
|
|
1375
|
+
if messages_logger:
|
|
1376
|
+
processed_messages.append({"role": "assistant",
|
|
1377
|
+
"content": res_dict.get("response", ""),
|
|
1378
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1379
|
+
messages_logger.log_messages(processed_messages)
|
|
1380
|
+
|
|
847
1381
|
return self.config.postprocess_response(_stream_generator())
|
|
848
1382
|
|
|
849
1383
|
elif verbose:
|
|
@@ -862,8 +1396,6 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
862
1396
|
if chunk_content:
|
|
863
1397
|
res += chunk_content
|
|
864
1398
|
print(chunk_content, end='', flush=True)
|
|
865
|
-
|
|
866
|
-
return self.config.postprocess_response(res)
|
|
867
1399
|
|
|
868
1400
|
else:
|
|
869
1401
|
response = self.litellm.completion(
|
|
@@ -875,9 +1407,19 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
875
1407
|
**self.formatted_params
|
|
876
1408
|
)
|
|
877
1409
|
res = response.choices[0].message.content
|
|
878
|
-
|
|
1410
|
+
|
|
1411
|
+
# Postprocess response
|
|
1412
|
+
res_dict = self.config.postprocess_response(res)
|
|
1413
|
+
# Write to messages log
|
|
1414
|
+
if messages_logger:
|
|
1415
|
+
processed_messages.append({"role": "assistant",
|
|
1416
|
+
"content": res_dict.get("response", ""),
|
|
1417
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1418
|
+
messages_logger.log_messages(processed_messages)
|
|
1419
|
+
|
|
1420
|
+
return res_dict
|
|
879
1421
|
|
|
880
|
-
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
1422
|
+
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
881
1423
|
"""
|
|
882
1424
|
Async version of chat method. Streaming is not supported.
|
|
883
1425
|
"""
|
|
@@ -893,4 +1435,13 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
893
1435
|
)
|
|
894
1436
|
|
|
895
1437
|
res = response.get('choices')[0].get('message').get('content')
|
|
896
|
-
|
|
1438
|
+
|
|
1439
|
+
# Postprocess response
|
|
1440
|
+
res_dict = self.config.postprocess_response(res)
|
|
1441
|
+
# Write to messages log
|
|
1442
|
+
if messages_logger:
|
|
1443
|
+
processed_messages.append({"role": "assistant",
|
|
1444
|
+
"content": res_dict.get("response", ""),
|
|
1445
|
+
"reasoning": res_dict.get("reasoning", "")})
|
|
1446
|
+
messages_logger.log_messages(processed_messages)
|
|
1447
|
+
return res_dict
|