llm-ie 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py
CHANGED
|
@@ -1,1491 +1,37 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
----------
|
|
26
|
-
messages : List[Dict[str,str]]
|
|
27
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
28
|
-
|
|
29
|
-
Returns:
|
|
30
|
-
-------
|
|
31
|
-
messages : List[Dict[str,str]]
|
|
32
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
33
|
-
"""
|
|
34
|
-
return NotImplemented
|
|
35
|
-
|
|
36
|
-
@abc.abstractmethod
|
|
37
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
38
|
-
"""
|
|
39
|
-
This method postprocesses the LLM response after it is generated.
|
|
40
|
-
|
|
41
|
-
Parameters:
|
|
42
|
-
----------
|
|
43
|
-
response : Union[str, Dict[str, str], Generator[Dict[str, str], None, None]]
|
|
44
|
-
the LLM response. Can be a dict or a generator.
|
|
45
|
-
|
|
46
|
-
Returns:
|
|
47
|
-
-------
|
|
48
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
49
|
-
the postprocessed LLM response
|
|
50
|
-
"""
|
|
51
|
-
return NotImplemented
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
class BasicLLMConfig(LLMConfig):
|
|
55
|
-
def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
|
|
56
|
-
"""
|
|
57
|
-
The basic LLM configuration for most non-reasoning models.
|
|
58
|
-
"""
|
|
59
|
-
super().__init__(**kwargs)
|
|
60
|
-
self.max_new_tokens = max_new_tokens
|
|
61
|
-
self.temperature = temperature
|
|
62
|
-
self.params["max_new_tokens"] = self.max_new_tokens
|
|
63
|
-
self.params["temperature"] = self.temperature
|
|
64
|
-
|
|
65
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
66
|
-
"""
|
|
67
|
-
This method preprocesses the input messages before passing them to the LLM.
|
|
68
|
-
|
|
69
|
-
Parameters:
|
|
70
|
-
----------
|
|
71
|
-
messages : List[Dict[str,str]]
|
|
72
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
73
|
-
|
|
74
|
-
Returns:
|
|
75
|
-
-------
|
|
76
|
-
messages : List[Dict[str,str]]
|
|
77
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
78
|
-
"""
|
|
79
|
-
return messages.copy()
|
|
80
|
-
|
|
81
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
82
|
-
"""
|
|
83
|
-
This method postprocesses the LLM response after it is generated.
|
|
84
|
-
|
|
85
|
-
Parameters:
|
|
86
|
-
----------
|
|
87
|
-
response : Union[str, Dict[str, str], Generator[str, None, None]]
|
|
88
|
-
the LLM response. Can be a string or a generator.
|
|
89
|
-
|
|
90
|
-
Returns: Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
91
|
-
the postprocessed LLM response.
|
|
92
|
-
If input is a string, the output will be a dict {"response": <response>}.
|
|
93
|
-
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
94
|
-
"""
|
|
95
|
-
if isinstance(response, str):
|
|
96
|
-
return {"response": response}
|
|
97
|
-
|
|
98
|
-
elif isinstance(response, dict):
|
|
99
|
-
if "response" in response:
|
|
100
|
-
return response
|
|
101
|
-
else:
|
|
102
|
-
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
103
|
-
return {"response": ""}
|
|
104
|
-
|
|
105
|
-
elif isinstance(response, Generator):
|
|
106
|
-
def _process_stream():
|
|
107
|
-
for chunk in response:
|
|
108
|
-
if isinstance(chunk, dict):
|
|
109
|
-
yield chunk
|
|
110
|
-
elif isinstance(chunk, str):
|
|
111
|
-
yield {"type": "response", "data": chunk}
|
|
112
|
-
|
|
113
|
-
return _process_stream()
|
|
114
|
-
|
|
115
|
-
else:
|
|
116
|
-
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
117
|
-
return {"response": ""}
|
|
118
|
-
|
|
119
|
-
class ReasoningLLMConfig(LLMConfig):
|
|
120
|
-
def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
|
|
121
|
-
"""
|
|
122
|
-
The general LLM configuration for reasoning models.
|
|
123
|
-
"""
|
|
124
|
-
super().__init__(**kwargs)
|
|
125
|
-
self.thinking_token_start = thinking_token_start
|
|
126
|
-
self.thinking_token_end = thinking_token_end
|
|
127
|
-
|
|
128
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
129
|
-
"""
|
|
130
|
-
This method preprocesses the input messages before passing them to the LLM.
|
|
131
|
-
|
|
132
|
-
Parameters:
|
|
133
|
-
----------
|
|
134
|
-
messages : List[Dict[str,str]]
|
|
135
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
-------
|
|
139
|
-
messages : List[Dict[str,str]]
|
|
140
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
141
|
-
"""
|
|
142
|
-
return messages.copy()
|
|
143
|
-
|
|
144
|
-
def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
|
|
145
|
-
"""
|
|
146
|
-
This method postprocesses the LLM response after it is generated.
|
|
147
|
-
1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
|
|
148
|
-
2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
|
|
149
|
-
3. If input is a generator,
|
|
150
|
-
a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
|
|
151
|
-
b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
|
|
152
|
-
|
|
153
|
-
Parameters:
|
|
154
|
-
----------
|
|
155
|
-
response : Union[str, Generator[str, None, None]]
|
|
156
|
-
the LLM response. Can be a string or a generator.
|
|
157
|
-
|
|
158
|
-
Returns:
|
|
159
|
-
-------
|
|
160
|
-
response : Union[str, Generator[str, None, None]]
|
|
161
|
-
the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
|
|
162
|
-
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
163
|
-
"""
|
|
164
|
-
if isinstance(response, str):
|
|
165
|
-
# get contents between thinking_token_start and thinking_token_end
|
|
166
|
-
pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
|
|
167
|
-
match = re.search(pattern, response, re.DOTALL)
|
|
168
|
-
reasoning = match.group(1) if match else ""
|
|
169
|
-
# get response AFTER thinking_token_end
|
|
170
|
-
response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
|
|
171
|
-
return {"reasoning": reasoning, "response": response}
|
|
172
|
-
|
|
173
|
-
elif isinstance(response, dict):
|
|
174
|
-
if "reasoning" in response and "response" in response:
|
|
175
|
-
return response
|
|
176
|
-
else:
|
|
177
|
-
warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
|
|
178
|
-
return {"reasoning": "", "response": ""}
|
|
179
|
-
|
|
180
|
-
elif isinstance(response, Generator):
|
|
181
|
-
def _process_stream():
|
|
182
|
-
think_flag = False
|
|
183
|
-
buffer = ""
|
|
184
|
-
for chunk in response:
|
|
185
|
-
if isinstance(chunk, dict):
|
|
186
|
-
yield chunk
|
|
187
|
-
|
|
188
|
-
elif isinstance(chunk, str):
|
|
189
|
-
buffer += chunk
|
|
190
|
-
# switch between reasoning and response
|
|
191
|
-
if self.thinking_token_start in buffer:
|
|
192
|
-
think_flag = True
|
|
193
|
-
buffer = buffer.replace(self.thinking_token_start, "")
|
|
194
|
-
elif self.thinking_token_end in buffer:
|
|
195
|
-
think_flag = False
|
|
196
|
-
buffer = buffer.replace(self.thinking_token_end, "")
|
|
197
|
-
|
|
198
|
-
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
199
|
-
if chunk not in [self.thinking_token_start, self.thinking_token_end]:
|
|
200
|
-
if think_flag:
|
|
201
|
-
yield {"type": "reasoning", "data": chunk}
|
|
202
|
-
else:
|
|
203
|
-
yield {"type": "response", "data": chunk}
|
|
204
|
-
|
|
205
|
-
return _process_stream()
|
|
206
|
-
|
|
207
|
-
else:
|
|
208
|
-
warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
|
|
209
|
-
return {"reasoning": "", "response": ""}
|
|
210
|
-
|
|
211
|
-
class Qwen3LLMConfig(ReasoningLLMConfig):
|
|
212
|
-
def __init__(self, thinking_mode:bool=True, **kwargs):
|
|
213
|
-
"""
|
|
214
|
-
The Qwen3 **hybrid thinking** LLM configuration.
|
|
215
|
-
For Qwen3 thinking 2507, use ReasoningLLMConfig instead; for Qwen3 Instruct, use BasicLLMConfig instead.
|
|
216
|
-
|
|
217
|
-
Parameters:
|
|
218
|
-
----------
|
|
219
|
-
thinking_mode : bool, Optional
|
|
220
|
-
if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
|
|
221
|
-
"""
|
|
222
|
-
super().__init__(**kwargs)
|
|
223
|
-
self.thinking_mode = thinking_mode
|
|
224
|
-
|
|
225
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
226
|
-
"""
|
|
227
|
-
Append a special token to the system and user prompts.
|
|
228
|
-
The token is "/think" if thinking_mode is True, otherwise "/no_think".
|
|
229
|
-
|
|
230
|
-
Parameters:
|
|
231
|
-
----------
|
|
232
|
-
messages : List[Dict[str,str]]
|
|
233
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
234
|
-
|
|
235
|
-
Returns:
|
|
236
|
-
-------
|
|
237
|
-
messages : List[Dict[str,str]]
|
|
238
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
239
|
-
"""
|
|
240
|
-
thinking_token = "/think" if self.thinking_mode else "/no_think"
|
|
241
|
-
new_messages = []
|
|
242
|
-
for message in messages:
|
|
243
|
-
if message['role'] in ['system', 'user']:
|
|
244
|
-
new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
|
|
245
|
-
else:
|
|
246
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
247
|
-
|
|
248
|
-
new_messages.append(new_message)
|
|
249
|
-
|
|
250
|
-
return new_messages
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
class OpenAIReasoningLLMConfig(ReasoningLLMConfig):
|
|
254
|
-
def __init__(self, reasoning_effort:str=None, **kwargs):
|
|
255
|
-
"""
|
|
256
|
-
The OpenAI "o" series configuration.
|
|
257
|
-
1. The reasoning effort as one of {"low", "medium", "high"}.
|
|
258
|
-
For models that do not support setting reasoning effort (e.g., o1-mini, o1-preview), set to None.
|
|
259
|
-
2. The temperature parameter is not supported and will be ignored.
|
|
260
|
-
3. The system prompt is not supported and will be concatenated to the next user prompt.
|
|
261
|
-
|
|
262
|
-
Parameters:
|
|
263
|
-
----------
|
|
264
|
-
reasoning_effort : str, Optional
|
|
265
|
-
the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
|
|
266
|
-
"""
|
|
267
|
-
super().__init__(**kwargs)
|
|
268
|
-
if reasoning_effort is not None:
|
|
269
|
-
if reasoning_effort not in ["low", "medium", "high"]:
|
|
270
|
-
raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
|
|
271
|
-
|
|
272
|
-
self.reasoning_effort = reasoning_effort
|
|
273
|
-
self.params["reasoning_effort"] = self.reasoning_effort
|
|
274
|
-
|
|
275
|
-
if "temperature" in self.params:
|
|
276
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
277
|
-
self.params.pop("temperature")
|
|
278
|
-
|
|
279
|
-
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
280
|
-
"""
|
|
281
|
-
Concatenate system prompts to the next user prompt.
|
|
282
|
-
|
|
283
|
-
Parameters:
|
|
284
|
-
----------
|
|
285
|
-
messages : List[Dict[str,str]]
|
|
286
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
287
|
-
|
|
288
|
-
Returns:
|
|
289
|
-
-------
|
|
290
|
-
messages : List[Dict[str,str]]
|
|
291
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
292
|
-
"""
|
|
293
|
-
system_prompt_holder = ""
|
|
294
|
-
new_messages = []
|
|
295
|
-
for i, message in enumerate(messages):
|
|
296
|
-
# if system prompt, store it in system_prompt_holder
|
|
297
|
-
if message['role'] == 'system':
|
|
298
|
-
system_prompt_holder = message['content']
|
|
299
|
-
# if user prompt, concatenate it with system_prompt_holder
|
|
300
|
-
elif message['role'] == 'user':
|
|
301
|
-
if system_prompt_holder:
|
|
302
|
-
new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
|
|
303
|
-
system_prompt_holder = ""
|
|
304
|
-
else:
|
|
305
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
306
|
-
|
|
307
|
-
new_messages.append(new_message)
|
|
308
|
-
# if assistant/other prompt, do nothing
|
|
309
|
-
else:
|
|
310
|
-
new_message = {'role': message['role'], 'content': message['content']}
|
|
311
|
-
new_messages.append(new_message)
|
|
312
|
-
|
|
313
|
-
return new_messages
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
class MessagesLogger:
|
|
317
|
-
def __init__(self):
|
|
318
|
-
"""
|
|
319
|
-
This class is used to log the messages for InferenceEngine.chat().
|
|
320
|
-
"""
|
|
321
|
-
self.messages_log = []
|
|
322
|
-
|
|
323
|
-
def log_messages(self, messages : List[Dict[str,str]]):
|
|
324
|
-
"""
|
|
325
|
-
This method logs the messages to a list.
|
|
326
|
-
"""
|
|
327
|
-
self.messages_log.append(messages)
|
|
328
|
-
|
|
329
|
-
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
330
|
-
"""
|
|
331
|
-
This method returns a copy of the current messages log
|
|
332
|
-
"""
|
|
333
|
-
return self.messages_log.copy()
|
|
334
|
-
|
|
335
|
-
def clear_messages_log(self):
|
|
336
|
-
"""
|
|
337
|
-
This method clears the current messages log
|
|
338
|
-
"""
|
|
339
|
-
self.messages_log.clear()
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
class InferenceEngine:
|
|
343
|
-
@abc.abstractmethod
|
|
344
|
-
def __init__(self, config:LLMConfig, **kwrs):
|
|
345
|
-
"""
|
|
346
|
-
This is an abstract class to provide interfaces for LLM inference engines.
|
|
347
|
-
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
348
|
-
|
|
349
|
-
Parameters:
|
|
350
|
-
----------
|
|
351
|
-
config : LLMConfig
|
|
352
|
-
the LLM configuration. Must be a child class of LLMConfig.
|
|
353
|
-
"""
|
|
354
|
-
return NotImplemented
|
|
355
|
-
|
|
356
|
-
def get_messages_log(self) -> List[List[Dict[str,str]]]:
|
|
357
|
-
return self.messages_log.copy()
|
|
358
|
-
|
|
359
|
-
def clear_messages_log(self):
|
|
360
|
-
self.messages_log = []
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
@abc.abstractmethod
|
|
364
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
365
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
366
|
-
"""
|
|
367
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
368
|
-
|
|
369
|
-
Parameters:
|
|
370
|
-
----------
|
|
371
|
-
messages : List[Dict[str,str]]
|
|
372
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
373
|
-
verbose : bool, Optional
|
|
374
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
375
|
-
stream : bool, Optional
|
|
376
|
-
if True, returns a generator that yields the output in real-time.
|
|
377
|
-
Messages_logger : MessagesLogger, Optional
|
|
378
|
-
the message logger that logs the chat messages.
|
|
379
|
-
|
|
380
|
-
Returns:
|
|
381
|
-
-------
|
|
382
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
383
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
384
|
-
"""
|
|
385
|
-
return NotImplemented
|
|
386
|
-
|
|
387
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
388
|
-
"""
|
|
389
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
390
|
-
|
|
391
|
-
Return : Dict[str, Any]
|
|
392
|
-
the config parameters.
|
|
393
|
-
"""
|
|
394
|
-
return NotImplemented
|
|
395
|
-
|
|
1
|
+
from llm_inference_engine import (
|
|
2
|
+
# Configs
|
|
3
|
+
LLMConfig,
|
|
4
|
+
BasicLLMConfig,
|
|
5
|
+
ReasoningLLMConfig,
|
|
6
|
+
Qwen3LLMConfig,
|
|
7
|
+
OpenAIReasoningLLMConfig,
|
|
8
|
+
|
|
9
|
+
# Base Engine
|
|
10
|
+
InferenceEngine,
|
|
11
|
+
|
|
12
|
+
# Concrete Engines
|
|
13
|
+
OllamaInferenceEngine,
|
|
14
|
+
OpenAIInferenceEngine,
|
|
15
|
+
HuggingFaceHubInferenceEngine,
|
|
16
|
+
AzureOpenAIInferenceEngine,
|
|
17
|
+
LiteLLMInferenceEngine,
|
|
18
|
+
OpenAICompatibleInferenceEngine,
|
|
19
|
+
VLLMInferenceEngine,
|
|
20
|
+
SGLangInferenceEngine,
|
|
21
|
+
OpenRouterInferenceEngine
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from llm_inference_engine.utils import MessagesLogger
|
|
396
25
|
|
|
397
26
|
class LlamaCppInferenceEngine(InferenceEngine):
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
the exact name as shown on Huggingface repo
|
|
406
|
-
gguf_filename : str
|
|
407
|
-
the exact name as shown in Huggingface repo -> Files and versions.
|
|
408
|
-
If multiple gguf files are needed, use the first.
|
|
409
|
-
n_ctx : int, Optional
|
|
410
|
-
context length that LLM will evaluate.
|
|
411
|
-
n_gpu_layers : int, Optional
|
|
412
|
-
number of layers to offload to GPU. Default is all layers (-1).
|
|
413
|
-
config : LLMConfig
|
|
414
|
-
the LLM configuration.
|
|
415
|
-
"""
|
|
416
|
-
from llama_cpp import Llama
|
|
417
|
-
super().__init__(config)
|
|
418
|
-
self.repo_id = repo_id
|
|
419
|
-
self.gguf_filename = gguf_filename
|
|
420
|
-
self.n_ctx = n_ctx
|
|
421
|
-
self.n_gpu_layers = n_gpu_layers
|
|
422
|
-
self.config = config if config else BasicLLMConfig()
|
|
423
|
-
self.formatted_params = self._format_config()
|
|
424
|
-
|
|
425
|
-
self.model = Llama.from_pretrained(
|
|
426
|
-
repo_id=self.repo_id,
|
|
427
|
-
filename=self.gguf_filename,
|
|
428
|
-
n_gpu_layers=n_gpu_layers,
|
|
429
|
-
n_ctx=n_ctx,
|
|
430
|
-
**kwrs
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
def __del__(self):
|
|
434
|
-
"""
|
|
435
|
-
When the inference engine is deleted, release memory for model.
|
|
436
|
-
"""
|
|
437
|
-
del self.model
|
|
438
|
-
|
|
439
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
440
|
-
"""
|
|
441
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
442
|
-
"""
|
|
443
|
-
formatted_params = self.config.params.copy()
|
|
444
|
-
if "max_new_tokens" in formatted_params:
|
|
445
|
-
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
446
|
-
formatted_params.pop("max_new_tokens")
|
|
447
|
-
|
|
448
|
-
return formatted_params
|
|
449
|
-
|
|
450
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
451
|
-
"""
|
|
452
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
453
|
-
|
|
454
|
-
Parameters:
|
|
455
|
-
----------
|
|
456
|
-
messages : List[Dict[str,str]]
|
|
457
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
458
|
-
verbose : bool, Optional
|
|
459
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
460
|
-
messages_logger : MessagesLogger, Optional
|
|
461
|
-
the message logger that logs the chat messages.
|
|
462
|
-
"""
|
|
463
|
-
# Preprocess messages
|
|
464
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
465
|
-
# Generate response
|
|
466
|
-
response = self.model.create_chat_completion(
|
|
467
|
-
messages=processed_messages,
|
|
468
|
-
stream=verbose,
|
|
469
|
-
**self.formatted_params
|
|
470
|
-
)
|
|
471
|
-
|
|
472
|
-
if verbose:
|
|
473
|
-
res = ''
|
|
474
|
-
for chunk in response:
|
|
475
|
-
out_dict = chunk['choices'][0]['delta']
|
|
476
|
-
if 'content' in out_dict:
|
|
477
|
-
res += out_dict['content']
|
|
478
|
-
print(out_dict['content'], end='', flush=True)
|
|
479
|
-
print('\n')
|
|
480
|
-
return self.config.postprocess_response(res)
|
|
481
|
-
|
|
482
|
-
res = response['choices'][0]['message']['content']
|
|
483
|
-
# Postprocess response
|
|
484
|
-
res_dict = self.config.postprocess_response(res)
|
|
485
|
-
# Write to messages log
|
|
486
|
-
if messages_logger:
|
|
487
|
-
processed_messages.append({"role": "assistant",
|
|
488
|
-
"content": res_dict.get("response", ""),
|
|
489
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
490
|
-
messages_logger.log_messages(processed_messages)
|
|
491
|
-
|
|
492
|
-
return res_dict
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
class OllamaInferenceEngine(InferenceEngine):
|
|
496
|
-
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, config:LLMConfig=None, **kwrs):
|
|
497
|
-
"""
|
|
498
|
-
The Ollama inference engine.
|
|
499
|
-
|
|
500
|
-
Parameters:
|
|
501
|
-
----------
|
|
502
|
-
model_name : str
|
|
503
|
-
the model name exactly as shown in >> ollama ls
|
|
504
|
-
num_ctx : int, Optional
|
|
505
|
-
context length that LLM will evaluate.
|
|
506
|
-
keep_alive : int, Optional
|
|
507
|
-
seconds to hold the LLM after the last API call.
|
|
508
|
-
config : LLMConfig
|
|
509
|
-
the LLM configuration.
|
|
510
|
-
"""
|
|
511
|
-
if importlib.util.find_spec("ollama") is None:
|
|
512
|
-
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
513
|
-
|
|
514
|
-
from ollama import Client, AsyncClient
|
|
515
|
-
super().__init__(config)
|
|
516
|
-
self.client = Client(**kwrs)
|
|
517
|
-
self.async_client = AsyncClient(**kwrs)
|
|
518
|
-
self.model_name = model_name
|
|
519
|
-
self.num_ctx = num_ctx
|
|
520
|
-
self.keep_alive = keep_alive
|
|
521
|
-
self.config = config if config else BasicLLMConfig()
|
|
522
|
-
self.formatted_params = self._format_config()
|
|
523
|
-
|
|
524
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
525
|
-
"""
|
|
526
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
527
|
-
"""
|
|
528
|
-
formatted_params = self.config.params.copy()
|
|
529
|
-
if "max_new_tokens" in formatted_params:
|
|
530
|
-
formatted_params["num_predict"] = formatted_params["max_new_tokens"]
|
|
531
|
-
formatted_params.pop("max_new_tokens")
|
|
532
|
-
|
|
533
|
-
return formatted_params
|
|
534
|
-
|
|
535
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
536
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
537
|
-
"""
|
|
538
|
-
This method inputs chat messages and outputs VLM generated text.
|
|
539
|
-
|
|
540
|
-
Parameters:
|
|
541
|
-
----------
|
|
542
|
-
messages : List[Dict[str,str]]
|
|
543
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
544
|
-
verbose : bool, Optional
|
|
545
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
546
|
-
stream : bool, Optional
|
|
547
|
-
if True, returns a generator that yields the output in real-time.
|
|
548
|
-
Messages_logger : MessagesLogger, Optional
|
|
549
|
-
the message logger that logs the chat messages.
|
|
550
|
-
|
|
551
|
-
Returns:
|
|
552
|
-
-------
|
|
553
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
554
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
555
|
-
"""
|
|
556
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
557
|
-
|
|
558
|
-
options={'num_ctx': self.num_ctx, **self.formatted_params}
|
|
559
|
-
if stream:
|
|
560
|
-
def _stream_generator():
|
|
561
|
-
response_stream = self.client.chat(
|
|
562
|
-
model=self.model_name,
|
|
563
|
-
messages=processed_messages,
|
|
564
|
-
options=options,
|
|
565
|
-
stream=True,
|
|
566
|
-
keep_alive=self.keep_alive
|
|
567
|
-
)
|
|
568
|
-
res = {"reasoning": "", "response": ""}
|
|
569
|
-
for chunk in response_stream:
|
|
570
|
-
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
571
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
572
|
-
res["reasoning"] += content_chunk
|
|
573
|
-
yield {"type": "reasoning", "data": content_chunk}
|
|
574
|
-
else:
|
|
575
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
576
|
-
res["response"] += content_chunk
|
|
577
|
-
yield {"type": "response", "data": content_chunk}
|
|
578
|
-
|
|
579
|
-
if chunk.done_reason == "length":
|
|
580
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
581
|
-
|
|
582
|
-
# Postprocess response
|
|
583
|
-
res_dict = self.config.postprocess_response(res)
|
|
584
|
-
# Write to messages log
|
|
585
|
-
if messages_logger:
|
|
586
|
-
processed_messages.append({"role": "assistant",
|
|
587
|
-
"content": res_dict.get("response", ""),
|
|
588
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
589
|
-
messages_logger.log_messages(processed_messages)
|
|
590
|
-
|
|
591
|
-
return self.config.postprocess_response(_stream_generator())
|
|
592
|
-
|
|
593
|
-
elif verbose:
|
|
594
|
-
response = self.client.chat(
|
|
595
|
-
model=self.model_name,
|
|
596
|
-
messages=processed_messages,
|
|
597
|
-
options=options,
|
|
598
|
-
stream=True,
|
|
599
|
-
keep_alive=self.keep_alive
|
|
600
|
-
)
|
|
601
|
-
|
|
602
|
-
res = {"reasoning": "", "response": ""}
|
|
603
|
-
phase = ""
|
|
604
|
-
for chunk in response:
|
|
605
|
-
if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
|
|
606
|
-
if phase != "reasoning":
|
|
607
|
-
print("\n--- Reasoning ---")
|
|
608
|
-
phase = "reasoning"
|
|
609
|
-
|
|
610
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
|
|
611
|
-
res["reasoning"] += content_chunk
|
|
612
|
-
else:
|
|
613
|
-
if phase != "response":
|
|
614
|
-
print("\n--- Response ---")
|
|
615
|
-
phase = "response"
|
|
616
|
-
content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
|
|
617
|
-
res["response"] += content_chunk
|
|
618
|
-
|
|
619
|
-
print(content_chunk, end='', flush=True)
|
|
620
|
-
|
|
621
|
-
if chunk.done_reason == "length":
|
|
622
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
623
|
-
print('\n')
|
|
624
|
-
|
|
625
|
-
else:
|
|
626
|
-
response = self.client.chat(
|
|
627
|
-
model=self.model_name,
|
|
628
|
-
messages=processed_messages,
|
|
629
|
-
options=options,
|
|
630
|
-
stream=False,
|
|
631
|
-
keep_alive=self.keep_alive
|
|
632
|
-
)
|
|
633
|
-
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
634
|
-
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
635
|
-
|
|
636
|
-
if response.done_reason == "length":
|
|
637
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
638
|
-
|
|
639
|
-
# Postprocess response
|
|
640
|
-
res_dict = self.config.postprocess_response(res)
|
|
641
|
-
# Write to messages log
|
|
642
|
-
if messages_logger:
|
|
643
|
-
processed_messages.append({"role": "assistant",
|
|
644
|
-
"content": res_dict.get("response", ""),
|
|
645
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
646
|
-
messages_logger.log_messages(processed_messages)
|
|
647
|
-
|
|
648
|
-
return res_dict
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
652
|
-
"""
|
|
653
|
-
Async version of chat method. Streaming is not supported.
|
|
654
|
-
"""
|
|
655
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
656
|
-
|
|
657
|
-
response = await self.async_client.chat(
|
|
658
|
-
model=self.model_name,
|
|
659
|
-
messages=processed_messages,
|
|
660
|
-
options={'num_ctx': self.num_ctx, **self.formatted_params},
|
|
661
|
-
stream=False,
|
|
662
|
-
keep_alive=self.keep_alive
|
|
663
|
-
)
|
|
664
|
-
|
|
665
|
-
res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
|
|
666
|
-
"response": getattr(getattr(response, 'message', {}), 'content', '')}
|
|
667
|
-
|
|
668
|
-
if response.done_reason == "length":
|
|
669
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
670
|
-
# Postprocess response
|
|
671
|
-
res_dict = self.config.postprocess_response(res)
|
|
672
|
-
# Write to messages log
|
|
673
|
-
if messages_logger:
|
|
674
|
-
processed_messages.append({"role": "assistant",
|
|
675
|
-
"content": res_dict.get("response", ""),
|
|
676
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
677
|
-
messages_logger.log_messages(processed_messages)
|
|
678
|
-
|
|
679
|
-
return res_dict
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
683
|
-
def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, config:LLMConfig=None, **kwrs):
|
|
684
|
-
"""
|
|
685
|
-
The Huggingface_hub InferenceClient inference engine.
|
|
686
|
-
For parameters and documentation, refer to https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client
|
|
687
|
-
|
|
688
|
-
Parameters:
|
|
689
|
-
----------
|
|
690
|
-
model : str
|
|
691
|
-
the model name exactly as shown in Huggingface repo
|
|
692
|
-
token : str, Optional
|
|
693
|
-
the Huggingface token. If None, will use the token in os.environ['HF_TOKEN'].
|
|
694
|
-
base_url : str, Optional
|
|
695
|
-
the base url for the LLM server. If None, will use the default Huggingface Hub URL.
|
|
696
|
-
api_key : str, Optional
|
|
697
|
-
the API key for the LLM server.
|
|
698
|
-
config : LLMConfig
|
|
699
|
-
the LLM configuration.
|
|
700
|
-
"""
|
|
701
|
-
if importlib.util.find_spec("huggingface_hub") is None:
|
|
702
|
-
raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
|
|
703
|
-
|
|
704
|
-
from huggingface_hub import InferenceClient, AsyncInferenceClient
|
|
705
|
-
super().__init__(config)
|
|
706
|
-
self.model = model
|
|
707
|
-
self.base_url = base_url
|
|
708
|
-
self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
709
|
-
self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
710
|
-
self.config = config if config else BasicLLMConfig()
|
|
711
|
-
self.formatted_params = self._format_config()
|
|
712
|
-
|
|
713
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
714
|
-
"""
|
|
715
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
716
|
-
"""
|
|
717
|
-
formatted_params = self.config.params.copy()
|
|
718
|
-
if "max_new_tokens" in formatted_params:
|
|
719
|
-
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
720
|
-
formatted_params.pop("max_new_tokens")
|
|
721
|
-
|
|
722
|
-
return formatted_params
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
|
|
726
|
-
messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
727
|
-
"""
|
|
728
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
729
|
-
|
|
730
|
-
Parameters:
|
|
731
|
-
----------
|
|
732
|
-
messages : List[Dict[str,str]]
|
|
733
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
734
|
-
verbose : bool, Optional
|
|
735
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
736
|
-
stream : bool, Optional
|
|
737
|
-
if True, returns a generator that yields the output in real-time.
|
|
738
|
-
messages_logger : MessagesLogger, Optional
|
|
739
|
-
the message logger that logs the chat messages.
|
|
740
|
-
|
|
741
|
-
Returns:
|
|
742
|
-
-------
|
|
743
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
744
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
745
|
-
"""
|
|
746
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
747
|
-
|
|
748
|
-
if stream:
|
|
749
|
-
def _stream_generator():
|
|
750
|
-
response_stream = self.client.chat.completions.create(
|
|
751
|
-
messages=processed_messages,
|
|
752
|
-
stream=True,
|
|
753
|
-
**self.formatted_params
|
|
754
|
-
)
|
|
755
|
-
res_text = ""
|
|
756
|
-
for chunk in response_stream:
|
|
757
|
-
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
758
|
-
if content_chunk:
|
|
759
|
-
res_text += content_chunk
|
|
760
|
-
yield content_chunk
|
|
761
|
-
|
|
762
|
-
# Postprocess response
|
|
763
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
764
|
-
# Write to messages log
|
|
765
|
-
if messages_logger:
|
|
766
|
-
processed_messages.append({"role": "assistant",
|
|
767
|
-
"content": res_dict.get("response", ""),
|
|
768
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
769
|
-
messages_logger.log_messages(processed_messages)
|
|
770
|
-
|
|
771
|
-
return self.config.postprocess_response(_stream_generator())
|
|
772
|
-
|
|
773
|
-
elif verbose:
|
|
774
|
-
response = self.client.chat.completions.create(
|
|
775
|
-
messages=processed_messages,
|
|
776
|
-
stream=True,
|
|
777
|
-
**self.formatted_params
|
|
778
|
-
)
|
|
779
|
-
|
|
780
|
-
res = ''
|
|
781
|
-
for chunk in response:
|
|
782
|
-
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
783
|
-
if content_chunk:
|
|
784
|
-
res += content_chunk
|
|
785
|
-
print(content_chunk, end='', flush=True)
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
else:
|
|
789
|
-
response = self.client.chat.completions.create(
|
|
790
|
-
messages=processed_messages,
|
|
791
|
-
stream=False,
|
|
792
|
-
**self.formatted_params
|
|
793
|
-
)
|
|
794
|
-
res = response.choices[0].message.content
|
|
795
|
-
|
|
796
|
-
# Postprocess response
|
|
797
|
-
res_dict = self.config.postprocess_response(res)
|
|
798
|
-
# Write to messages log
|
|
799
|
-
if messages_logger:
|
|
800
|
-
processed_messages.append({"role": "assistant",
|
|
801
|
-
"content": res_dict.get("response", ""),
|
|
802
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
803
|
-
messages_logger.log_messages(processed_messages)
|
|
804
|
-
|
|
805
|
-
return res_dict
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
809
|
-
"""
|
|
810
|
-
Async version of chat method. Streaming is not supported.
|
|
811
|
-
"""
|
|
812
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
813
|
-
|
|
814
|
-
response = await self.client_async.chat.completions.create(
|
|
815
|
-
messages=processed_messages,
|
|
816
|
-
stream=False,
|
|
817
|
-
**self.formatted_params
|
|
818
|
-
)
|
|
819
|
-
|
|
820
|
-
res = response.choices[0].message.content
|
|
821
|
-
# Postprocess response
|
|
822
|
-
res_dict = self.config.postprocess_response(res)
|
|
823
|
-
# Write to messages log
|
|
824
|
-
if messages_logger:
|
|
825
|
-
processed_messages.append({"role": "assistant",
|
|
826
|
-
"content": res_dict.get("response", ""),
|
|
827
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
828
|
-
messages_logger.log_messages(processed_messages)
|
|
829
|
-
|
|
830
|
-
return res_dict
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
class OpenAICompatibleInferenceEngine(InferenceEngine):
|
|
834
|
-
def __init__(self, model:str, api_key:str, base_url:str, config:LLMConfig=None, **kwrs):
|
|
835
|
-
"""
|
|
836
|
-
General OpenAI-compatible server inference engine.
|
|
837
|
-
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
838
|
-
|
|
839
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
840
|
-
|
|
841
|
-
Parameters:
|
|
842
|
-
----------
|
|
843
|
-
model_name : str
|
|
844
|
-
model name as shown in the vLLM server
|
|
845
|
-
api_key : str
|
|
846
|
-
the API key for the vLLM server.
|
|
847
|
-
base_url : str
|
|
848
|
-
the base url for the vLLM server.
|
|
849
|
-
config : LLMConfig
|
|
850
|
-
the LLM configuration.
|
|
851
|
-
"""
|
|
852
|
-
if importlib.util.find_spec("openai") is None:
|
|
853
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
854
|
-
|
|
855
|
-
from openai import OpenAI, AsyncOpenAI
|
|
856
|
-
from openai.types.chat import ChatCompletionChunk
|
|
857
|
-
self.ChatCompletionChunk = ChatCompletionChunk
|
|
858
|
-
super().__init__(config)
|
|
859
|
-
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
860
|
-
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
|
|
861
|
-
self.model = model
|
|
862
|
-
self.config = config if config else BasicLLMConfig()
|
|
863
|
-
self.formatted_params = self._format_config()
|
|
864
|
-
|
|
865
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
866
|
-
"""
|
|
867
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
868
|
-
"""
|
|
869
|
-
formatted_params = self.config.params.copy()
|
|
870
|
-
if "max_new_tokens" in formatted_params:
|
|
871
|
-
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
872
|
-
formatted_params.pop("max_new_tokens")
|
|
873
|
-
|
|
874
|
-
return formatted_params
|
|
875
|
-
|
|
876
|
-
@abc.abstractmethod
|
|
877
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
878
|
-
"""
|
|
879
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
880
|
-
|
|
881
|
-
Parameters:
|
|
882
|
-
----------
|
|
883
|
-
response : Any
|
|
884
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
885
|
-
"""
|
|
886
|
-
return NotImplemented
|
|
887
|
-
|
|
888
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
889
|
-
"""
|
|
890
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
891
|
-
|
|
892
|
-
Parameters:
|
|
893
|
-
----------
|
|
894
|
-
messages : List[Dict[str,str]]
|
|
895
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
896
|
-
verbose : bool, Optional
|
|
897
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
898
|
-
stream : bool, Optional
|
|
899
|
-
if True, returns a generator that yields the output in real-time.
|
|
900
|
-
messages_logger : MessagesLogger, Optional
|
|
901
|
-
the message logger that logs the chat messages.
|
|
902
|
-
|
|
903
|
-
Returns:
|
|
904
|
-
-------
|
|
905
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
906
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
907
|
-
"""
|
|
908
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
909
|
-
|
|
910
|
-
if stream:
|
|
911
|
-
def _stream_generator():
|
|
912
|
-
response_stream = self.client.chat.completions.create(
|
|
913
|
-
model=self.model,
|
|
914
|
-
messages=processed_messages,
|
|
915
|
-
stream=True,
|
|
916
|
-
**self.formatted_params
|
|
917
|
-
)
|
|
918
|
-
res_text = ""
|
|
919
|
-
for chunk in response_stream:
|
|
920
|
-
if len(chunk.choices) > 0:
|
|
921
|
-
chunk_dict = self._format_response(chunk)
|
|
922
|
-
yield chunk_dict
|
|
923
|
-
|
|
924
|
-
res_text += chunk_dict["data"]
|
|
925
|
-
if chunk.choices[0].finish_reason == "length":
|
|
926
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
927
|
-
|
|
928
|
-
# Postprocess response
|
|
929
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
930
|
-
# Write to messages log
|
|
931
|
-
if messages_logger:
|
|
932
|
-
processed_messages.append({"role": "assistant",
|
|
933
|
-
"content": res_dict.get("response", ""),
|
|
934
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
935
|
-
messages_logger.log_messages(processed_messages)
|
|
936
|
-
|
|
937
|
-
return self.config.postprocess_response(_stream_generator())
|
|
938
|
-
|
|
939
|
-
elif verbose:
|
|
940
|
-
response = self.client.chat.completions.create(
|
|
941
|
-
model=self.model,
|
|
942
|
-
messages=processed_messages,
|
|
943
|
-
stream=True,
|
|
944
|
-
**self.formatted_params
|
|
945
|
-
)
|
|
946
|
-
res = {"reasoning": "", "response": ""}
|
|
947
|
-
phase = ""
|
|
948
|
-
for chunk in response:
|
|
949
|
-
if len(chunk.choices) > 0:
|
|
950
|
-
chunk_dict = self._format_response(chunk)
|
|
951
|
-
chunk_text = chunk_dict["data"]
|
|
952
|
-
res[chunk_dict["type"]] += chunk_text
|
|
953
|
-
if phase != chunk_dict["type"] and chunk_text != "":
|
|
954
|
-
print(f"\n--- {chunk_dict['type'].capitalize()} ---")
|
|
955
|
-
phase = chunk_dict["type"]
|
|
956
|
-
|
|
957
|
-
print(chunk_text, end="", flush=True)
|
|
958
|
-
if chunk.choices[0].finish_reason == "length":
|
|
959
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
960
|
-
|
|
961
|
-
print('\n')
|
|
962
|
-
|
|
963
|
-
else:
|
|
964
|
-
response = self.client.chat.completions.create(
|
|
965
|
-
model=self.model,
|
|
966
|
-
messages=processed_messages,
|
|
967
|
-
stream=False,
|
|
968
|
-
**self.formatted_params
|
|
969
|
-
)
|
|
970
|
-
res = self._format_response(response)
|
|
971
|
-
|
|
972
|
-
if response.choices[0].finish_reason == "length":
|
|
973
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
974
|
-
|
|
975
|
-
# Postprocess response
|
|
976
|
-
res_dict = self.config.postprocess_response(res)
|
|
977
|
-
# Write to messages log
|
|
978
|
-
if messages_logger:
|
|
979
|
-
processed_messages.append({"role": "assistant",
|
|
980
|
-
"content": res_dict.get("response", ""),
|
|
981
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
982
|
-
messages_logger.log_messages(processed_messages)
|
|
983
|
-
|
|
984
|
-
return res_dict
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
988
|
-
"""
|
|
989
|
-
Async version of chat method. Streaming is not supported.
|
|
990
|
-
"""
|
|
991
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
992
|
-
|
|
993
|
-
response = await self.async_client.chat.completions.create(
|
|
994
|
-
model=self.model,
|
|
995
|
-
messages=processed_messages,
|
|
996
|
-
stream=False,
|
|
997
|
-
**self.formatted_params
|
|
998
|
-
)
|
|
999
|
-
|
|
1000
|
-
if response.choices[0].finish_reason == "length":
|
|
1001
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1002
|
-
|
|
1003
|
-
res = self._format_response(response)
|
|
1004
|
-
|
|
1005
|
-
# Postprocess response
|
|
1006
|
-
res_dict = self.config.postprocess_response(res)
|
|
1007
|
-
# Write to messages log
|
|
1008
|
-
if messages_logger:
|
|
1009
|
-
processed_messages.append({"role": "assistant",
|
|
1010
|
-
"content": res_dict.get("response", ""),
|
|
1011
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1012
|
-
messages_logger.log_messages(processed_messages)
|
|
1013
|
-
|
|
1014
|
-
return res_dict
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
class VLLMInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1018
|
-
def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:LLMConfig=None, **kwrs):
|
|
1019
|
-
"""
|
|
1020
|
-
vLLM OpenAI compatible server inference engine.
|
|
1021
|
-
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
|
1022
|
-
|
|
1023
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
1024
|
-
|
|
1025
|
-
Parameters:
|
|
1026
|
-
----------
|
|
1027
|
-
model_name : str
|
|
1028
|
-
model name as shown in the vLLM server
|
|
1029
|
-
api_key : str, Optional
|
|
1030
|
-
the API key for the vLLM server.
|
|
1031
|
-
base_url : str, Optional
|
|
1032
|
-
the base url for the vLLM server.
|
|
1033
|
-
config : LLMConfig
|
|
1034
|
-
the LLM configuration.
|
|
1035
|
-
"""
|
|
1036
|
-
super().__init__(model, api_key, base_url, config, **kwrs)
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1040
|
-
"""
|
|
1041
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1042
|
-
|
|
1043
|
-
Parameters:
|
|
1044
|
-
----------
|
|
1045
|
-
response : Any
|
|
1046
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1047
|
-
"""
|
|
1048
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
1049
|
-
if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
|
|
1050
|
-
chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
|
|
1051
|
-
if chunk_text is None:
|
|
1052
|
-
chunk_text = ""
|
|
1053
|
-
return {"type": "reasoning", "data": chunk_text}
|
|
1054
|
-
else:
|
|
1055
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1056
|
-
if chunk_text is None:
|
|
1057
|
-
chunk_text = ""
|
|
1058
|
-
return {"type": "response", "data": chunk_text}
|
|
1059
|
-
|
|
1060
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
1061
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
1062
|
-
|
|
1063
|
-
class SGLangInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1064
|
-
def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:30000/v1", config:LLMConfig=None, **kwrs):
|
|
1065
|
-
"""
|
|
1066
|
-
SGLang OpenAI compatible API inference engine.
|
|
1067
|
-
https://docs.sglang.ai/basic_usage/openai_api.html
|
|
1068
|
-
|
|
1069
|
-
Parameters:
|
|
1070
|
-
----------
|
|
1071
|
-
model_name : str
|
|
1072
|
-
model name as shown in the vLLM server
|
|
1073
|
-
api_key : str, Optional
|
|
1074
|
-
the API key for the vLLM server.
|
|
1075
|
-
base_url : str, Optional
|
|
1076
|
-
the base url for the vLLM server.
|
|
1077
|
-
config : LLMConfig
|
|
1078
|
-
the LLM configuration.
|
|
1079
|
-
"""
|
|
1080
|
-
super().__init__(model, api_key, base_url, config, **kwrs)
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1084
|
-
"""
|
|
1085
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1086
|
-
|
|
1087
|
-
Parameters:
|
|
1088
|
-
----------
|
|
1089
|
-
response : Any
|
|
1090
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1091
|
-
"""
|
|
1092
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
1093
|
-
if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
|
|
1094
|
-
chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
|
|
1095
|
-
if chunk_text is None:
|
|
1096
|
-
chunk_text = ""
|
|
1097
|
-
return {"type": "reasoning", "data": chunk_text}
|
|
1098
|
-
else:
|
|
1099
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1100
|
-
if chunk_text is None:
|
|
1101
|
-
chunk_text = ""
|
|
1102
|
-
return {"type": "response", "data": chunk_text}
|
|
1103
|
-
|
|
1104
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
|
|
1105
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
class OpenRouterInferenceEngine(OpenAICompatibleInferenceEngine):
|
|
1109
|
-
def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:LLMConfig=None, **kwrs):
|
|
1110
|
-
"""
|
|
1111
|
-
OpenRouter OpenAI-compatible server inference engine.
|
|
1112
|
-
|
|
1113
|
-
Parameters:
|
|
1114
|
-
----------
|
|
1115
|
-
model_name : str
|
|
1116
|
-
model name as shown in the vLLM server
|
|
1117
|
-
api_key : str, Optional
|
|
1118
|
-
the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
|
|
1119
|
-
base_url : str, Optional
|
|
1120
|
-
the base url for the vLLM server.
|
|
1121
|
-
config : LLMConfig
|
|
1122
|
-
the LLM configuration.
|
|
1123
|
-
"""
|
|
1124
|
-
self.api_key = api_key
|
|
1125
|
-
if self.api_key is None:
|
|
1126
|
-
self.api_key = os.getenv("OPENROUTER_API_KEY")
|
|
1127
|
-
super().__init__(model, self.api_key, base_url, config, **kwrs)
|
|
1128
|
-
|
|
1129
|
-
def _format_response(self, response: Any) -> Dict[str, str]:
|
|
1130
|
-
"""
|
|
1131
|
-
This method format the response from OpenAI API to a dict with keys "type" and "data".
|
|
1132
|
-
|
|
1133
|
-
Parameters:
|
|
1134
|
-
----------
|
|
1135
|
-
response : Any
|
|
1136
|
-
the response from OpenAI-compatible API. Could be a dict, generator, or object.
|
|
1137
|
-
"""
|
|
1138
|
-
if isinstance(response, self.ChatCompletionChunk):
|
|
1139
|
-
if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
|
|
1140
|
-
chunk_text = getattr(response.choices[0].delta, "reasoning", "")
|
|
1141
|
-
if chunk_text is None:
|
|
1142
|
-
chunk_text = ""
|
|
1143
|
-
return {"type": "reasoning", "data": chunk_text}
|
|
1144
|
-
else:
|
|
1145
|
-
chunk_text = getattr(response.choices[0].delta, "content", "")
|
|
1146
|
-
if chunk_text is None:
|
|
1147
|
-
chunk_text = ""
|
|
1148
|
-
return {"type": "response", "data": chunk_text}
|
|
1149
|
-
|
|
1150
|
-
return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
|
|
1151
|
-
"response": getattr(response.choices[0].message, "content", "")}
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
class OpenAIInferenceEngine(InferenceEngine):
|
|
1155
|
-
def __init__(self, model:str, config:LLMConfig=None, **kwrs):
|
|
1156
|
-
"""
|
|
1157
|
-
The OpenAI API inference engine.
|
|
1158
|
-
For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
|
|
1159
|
-
|
|
1160
|
-
Parameters:
|
|
1161
|
-
----------
|
|
1162
|
-
model_name : str
|
|
1163
|
-
model name as described in https://platform.openai.com/docs/models
|
|
1164
|
-
"""
|
|
1165
|
-
if importlib.util.find_spec("openai") is None:
|
|
1166
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
1167
|
-
|
|
1168
|
-
from openai import OpenAI, AsyncOpenAI
|
|
1169
|
-
super().__init__(config)
|
|
1170
|
-
self.client = OpenAI(**kwrs)
|
|
1171
|
-
self.async_client = AsyncOpenAI(**kwrs)
|
|
1172
|
-
self.model = model
|
|
1173
|
-
self.config = config if config else BasicLLMConfig()
|
|
1174
|
-
self.formatted_params = self._format_config()
|
|
1175
|
-
|
|
1176
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
1177
|
-
"""
|
|
1178
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
1179
|
-
"""
|
|
1180
|
-
formatted_params = self.config.params.copy()
|
|
1181
|
-
if "max_new_tokens" in formatted_params:
|
|
1182
|
-
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
1183
|
-
formatted_params.pop("max_new_tokens")
|
|
1184
|
-
|
|
1185
|
-
return formatted_params
|
|
1186
|
-
|
|
1187
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
|
|
1188
|
-
"""
|
|
1189
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
1190
|
-
|
|
1191
|
-
Parameters:
|
|
1192
|
-
----------
|
|
1193
|
-
messages : List[Dict[str,str]]
|
|
1194
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
1195
|
-
verbose : bool, Optional
|
|
1196
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
1197
|
-
stream : bool, Optional
|
|
1198
|
-
if True, returns a generator that yields the output in real-time.
|
|
1199
|
-
messages_logger : MessagesLogger, Optional
|
|
1200
|
-
the message logger that logs the chat messages.
|
|
1201
|
-
|
|
1202
|
-
Returns:
|
|
1203
|
-
-------
|
|
1204
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
1205
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
1206
|
-
"""
|
|
1207
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1208
|
-
|
|
1209
|
-
if stream:
|
|
1210
|
-
def _stream_generator():
|
|
1211
|
-
response_stream = self.client.chat.completions.create(
|
|
1212
|
-
model=self.model,
|
|
1213
|
-
messages=processed_messages,
|
|
1214
|
-
stream=True,
|
|
1215
|
-
**self.formatted_params
|
|
1216
|
-
)
|
|
1217
|
-
res_text = ""
|
|
1218
|
-
for chunk in response_stream:
|
|
1219
|
-
if len(chunk.choices) > 0:
|
|
1220
|
-
chunk_text = chunk.choices[0].delta.content
|
|
1221
|
-
if chunk_text is not None:
|
|
1222
|
-
res_text += chunk_text
|
|
1223
|
-
yield chunk_text
|
|
1224
|
-
if chunk.choices[0].finish_reason == "length":
|
|
1225
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1226
|
-
|
|
1227
|
-
# Postprocess response
|
|
1228
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
1229
|
-
# Write to messages log
|
|
1230
|
-
if messages_logger:
|
|
1231
|
-
processed_messages.append({"role": "assistant",
|
|
1232
|
-
"content": res_dict.get("response", ""),
|
|
1233
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1234
|
-
messages_logger.log_messages(processed_messages)
|
|
1235
|
-
|
|
1236
|
-
return self.config.postprocess_response(_stream_generator())
|
|
1237
|
-
|
|
1238
|
-
elif verbose:
|
|
1239
|
-
response = self.client.chat.completions.create(
|
|
1240
|
-
model=self.model,
|
|
1241
|
-
messages=processed_messages,
|
|
1242
|
-
stream=True,
|
|
1243
|
-
**self.formatted_params
|
|
1244
|
-
)
|
|
1245
|
-
res = ''
|
|
1246
|
-
for chunk in response:
|
|
1247
|
-
if len(chunk.choices) > 0:
|
|
1248
|
-
if chunk.choices[0].delta.content is not None:
|
|
1249
|
-
res += chunk.choices[0].delta.content
|
|
1250
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
1251
|
-
if chunk.choices[0].finish_reason == "length":
|
|
1252
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1253
|
-
|
|
1254
|
-
print('\n')
|
|
1255
|
-
|
|
1256
|
-
else:
|
|
1257
|
-
response = self.client.chat.completions.create(
|
|
1258
|
-
model=self.model,
|
|
1259
|
-
messages=processed_messages,
|
|
1260
|
-
stream=False,
|
|
1261
|
-
**self.formatted_params
|
|
1262
|
-
)
|
|
1263
|
-
res = response.choices[0].message.content
|
|
1264
|
-
|
|
1265
|
-
# Postprocess response
|
|
1266
|
-
res_dict = self.config.postprocess_response(res)
|
|
1267
|
-
# Write to messages log
|
|
1268
|
-
if messages_logger:
|
|
1269
|
-
processed_messages.append({"role": "assistant",
|
|
1270
|
-
"content": res_dict.get("response", ""),
|
|
1271
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1272
|
-
messages_logger.log_messages(processed_messages)
|
|
1273
|
-
|
|
1274
|
-
return res_dict
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
1278
|
-
"""
|
|
1279
|
-
Async version of chat method. Streaming is not supported.
|
|
1280
|
-
"""
|
|
1281
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1282
|
-
|
|
1283
|
-
response = await self.async_client.chat.completions.create(
|
|
1284
|
-
model=self.model,
|
|
1285
|
-
messages=processed_messages,
|
|
1286
|
-
stream=False,
|
|
1287
|
-
**self.formatted_params
|
|
1288
|
-
)
|
|
1289
|
-
|
|
1290
|
-
if response.choices[0].finish_reason == "length":
|
|
1291
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
1292
|
-
|
|
1293
|
-
res = response.choices[0].message.content
|
|
1294
|
-
# Postprocess response
|
|
1295
|
-
res_dict = self.config.postprocess_response(res)
|
|
1296
|
-
# Write to messages log
|
|
1297
|
-
if messages_logger:
|
|
1298
|
-
processed_messages.append({"role": "assistant",
|
|
1299
|
-
"content": res_dict.get("response", ""),
|
|
1300
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1301
|
-
messages_logger.log_messages(processed_messages)
|
|
1302
|
-
|
|
1303
|
-
return res_dict
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
1307
|
-
def __init__(self, model:str, api_version:str, config:LLMConfig=None, **kwrs):
|
|
1308
|
-
"""
|
|
1309
|
-
The Azure OpenAI API inference engine.
|
|
1310
|
-
For parameters and documentation, refer to
|
|
1311
|
-
- https://azure.microsoft.com/en-us/products/ai-services/openai-service
|
|
1312
|
-
- https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
|
|
1313
|
-
|
|
1314
|
-
Parameters:
|
|
1315
|
-
----------
|
|
1316
|
-
model : str
|
|
1317
|
-
model name as described in https://platform.openai.com/docs/models
|
|
1318
|
-
api_version : str
|
|
1319
|
-
the Azure OpenAI API version
|
|
1320
|
-
config : LLMConfig
|
|
1321
|
-
the LLM configuration.
|
|
1322
|
-
"""
|
|
1323
|
-
if importlib.util.find_spec("openai") is None:
|
|
1324
|
-
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
1325
|
-
|
|
1326
|
-
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
1327
|
-
self.model = model
|
|
1328
|
-
self.api_version = api_version
|
|
1329
|
-
self.client = AzureOpenAI(api_version=self.api_version,
|
|
1330
|
-
**kwrs)
|
|
1331
|
-
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
1332
|
-
**kwrs)
|
|
1333
|
-
self.config = config if config else BasicLLMConfig()
|
|
1334
|
-
self.formatted_params = self._format_config()
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
class LiteLLMInferenceEngine(InferenceEngine):
|
|
1338
|
-
def __init__(self, model:str=None, base_url:str=None, api_key:str=None, config:LLMConfig=None):
|
|
1339
|
-
"""
|
|
1340
|
-
The LiteLLM inference engine.
|
|
1341
|
-
For parameters and documentation, refer to https://github.com/BerriAI/litellm?tab=readme-ov-file
|
|
1342
|
-
|
|
1343
|
-
Parameters:
|
|
1344
|
-
----------
|
|
1345
|
-
model : str
|
|
1346
|
-
the model name
|
|
1347
|
-
base_url : str, Optional
|
|
1348
|
-
the base url for the LLM server
|
|
1349
|
-
api_key : str, Optional
|
|
1350
|
-
the API key for the LLM server
|
|
1351
|
-
config : LLMConfig
|
|
1352
|
-
the LLM configuration.
|
|
1353
|
-
"""
|
|
1354
|
-
if importlib.util.find_spec("litellm") is None:
|
|
1355
|
-
raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
|
|
1356
|
-
|
|
1357
|
-
import litellm
|
|
1358
|
-
super().__init__(config)
|
|
1359
|
-
self.litellm = litellm
|
|
1360
|
-
self.model = model
|
|
1361
|
-
self.base_url = base_url
|
|
1362
|
-
self.api_key = api_key
|
|
1363
|
-
self.config = config if config else BasicLLMConfig()
|
|
1364
|
-
self.formatted_params = self._format_config()
|
|
1365
|
-
|
|
1366
|
-
def _format_config(self) -> Dict[str, Any]:
|
|
1367
|
-
"""
|
|
1368
|
-
This method format the LLM configuration with the correct key for the inference engine.
|
|
1369
|
-
"""
|
|
1370
|
-
formatted_params = self.config.params.copy()
|
|
1371
|
-
if "max_new_tokens" in formatted_params:
|
|
1372
|
-
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
1373
|
-
formatted_params.pop("max_new_tokens")
|
|
1374
|
-
|
|
1375
|
-
return formatted_params
|
|
1376
|
-
|
|
1377
|
-
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
|
|
1378
|
-
"""
|
|
1379
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
1380
|
-
|
|
1381
|
-
Parameters:
|
|
1382
|
-
----------
|
|
1383
|
-
messages : List[Dict[str,str]]
|
|
1384
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
1385
|
-
verbose : bool, Optional
|
|
1386
|
-
if True, VLM generated text will be printed in terminal in real-time.
|
|
1387
|
-
stream : bool, Optional
|
|
1388
|
-
if True, returns a generator that yields the output in real-time.
|
|
1389
|
-
messages_logger: MessagesLogger, Optional
|
|
1390
|
-
a messages logger that logs the messages.
|
|
1391
|
-
|
|
1392
|
-
Returns:
|
|
1393
|
-
-------
|
|
1394
|
-
response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
|
|
1395
|
-
a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
|
|
1396
|
-
"""
|
|
1397
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1398
|
-
|
|
1399
|
-
if stream:
|
|
1400
|
-
def _stream_generator():
|
|
1401
|
-
response_stream = self.litellm.completion(
|
|
1402
|
-
model=self.model,
|
|
1403
|
-
messages=processed_messages,
|
|
1404
|
-
stream=True,
|
|
1405
|
-
base_url=self.base_url,
|
|
1406
|
-
api_key=self.api_key,
|
|
1407
|
-
**self.formatted_params
|
|
1408
|
-
)
|
|
1409
|
-
res_text = ""
|
|
1410
|
-
for chunk in response_stream:
|
|
1411
|
-
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
1412
|
-
if chunk_content:
|
|
1413
|
-
res_text += chunk_content
|
|
1414
|
-
yield chunk_content
|
|
1415
|
-
|
|
1416
|
-
# Postprocess response
|
|
1417
|
-
res_dict = self.config.postprocess_response(res_text)
|
|
1418
|
-
# Write to messages log
|
|
1419
|
-
if messages_logger:
|
|
1420
|
-
processed_messages.append({"role": "assistant",
|
|
1421
|
-
"content": res_dict.get("response", ""),
|
|
1422
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1423
|
-
messages_logger.log_messages(processed_messages)
|
|
1424
|
-
|
|
1425
|
-
return self.config.postprocess_response(_stream_generator())
|
|
1426
|
-
|
|
1427
|
-
elif verbose:
|
|
1428
|
-
response = self.litellm.completion(
|
|
1429
|
-
model=self.model,
|
|
1430
|
-
messages=processed_messages,
|
|
1431
|
-
stream=True,
|
|
1432
|
-
base_url=self.base_url,
|
|
1433
|
-
api_key=self.api_key,
|
|
1434
|
-
**self.formatted_params
|
|
1435
|
-
)
|
|
1436
|
-
|
|
1437
|
-
res = ''
|
|
1438
|
-
for chunk in response:
|
|
1439
|
-
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
1440
|
-
if chunk_content:
|
|
1441
|
-
res += chunk_content
|
|
1442
|
-
print(chunk_content, end='', flush=True)
|
|
1443
|
-
|
|
1444
|
-
else:
|
|
1445
|
-
response = self.litellm.completion(
|
|
1446
|
-
model=self.model,
|
|
1447
|
-
messages=processed_messages,
|
|
1448
|
-
stream=False,
|
|
1449
|
-
base_url=self.base_url,
|
|
1450
|
-
api_key=self.api_key,
|
|
1451
|
-
**self.formatted_params
|
|
1452
|
-
)
|
|
1453
|
-
res = response.choices[0].message.content
|
|
1454
|
-
|
|
1455
|
-
# Postprocess response
|
|
1456
|
-
res_dict = self.config.postprocess_response(res)
|
|
1457
|
-
# Write to messages log
|
|
1458
|
-
if messages_logger:
|
|
1459
|
-
processed_messages.append({"role": "assistant",
|
|
1460
|
-
"content": res_dict.get("response", ""),
|
|
1461
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1462
|
-
messages_logger.log_messages(processed_messages)
|
|
1463
|
-
|
|
1464
|
-
return res_dict
|
|
1465
|
-
|
|
1466
|
-
async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
|
|
1467
|
-
"""
|
|
1468
|
-
Async version of chat method. Streaming is not supported.
|
|
1469
|
-
"""
|
|
1470
|
-
processed_messages = self.config.preprocess_messages(messages)
|
|
1471
|
-
|
|
1472
|
-
response = await self.litellm.acompletion(
|
|
1473
|
-
model=self.model,
|
|
1474
|
-
messages=processed_messages,
|
|
1475
|
-
stream=False,
|
|
1476
|
-
base_url=self.base_url,
|
|
1477
|
-
api_key=self.api_key,
|
|
1478
|
-
**self.formatted_params
|
|
27
|
+
"""
|
|
28
|
+
Deprecated: This engine is no longer supported. Please run llama.cpp as a server and use OpenAICompatibleInferenceEngine instead.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, *args, **kwargs):
|
|
31
|
+
raise NotImplementedError(
|
|
32
|
+
"LlamaCppInferenceEngine has been deprecated. "
|
|
33
|
+
"Please run llama.cpp as a server and use OpenAICompatibleInferenceEngine."
|
|
1479
34
|
)
|
|
1480
|
-
|
|
1481
|
-
res = response.get('choices')[0].get('message').get('content')
|
|
1482
35
|
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
# Write to messages log
|
|
1486
|
-
if messages_logger:
|
|
1487
|
-
processed_messages.append({"role": "assistant",
|
|
1488
|
-
"content": res_dict.get("response", ""),
|
|
1489
|
-
"reasoning": res_dict.get("reasoning", "")})
|
|
1490
|
-
messages_logger.log_messages(processed_messages)
|
|
1491
|
-
return res_dict
|
|
36
|
+
def chat(self, *args, **kwargs):
|
|
37
|
+
raise NotImplementedError("This engine is deprecated.")
|