llm-ie 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +4 -4
- llm_ie/asset/prompt_guide/AttributeExtractor_prompt_guide.txt +52 -0
- llm_ie/engines.py +497 -250
- llm_ie/extractors.py +479 -681
- llm_ie/prompt_editor.py +13 -13
- {llm_ie-1.0.0.dist-info → llm_ie-1.2.0.dist-info}/METADATA +2 -2
- {llm_ie-1.0.0.dist-info → llm_ie-1.2.0.dist-info}/RECORD +8 -7
- {llm_ie-1.0.0.dist-info → llm_ie-1.2.0.dist-info}/WHEEL +0 -0
llm_ie/engines.py
CHANGED
|
@@ -1,22 +1,290 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import re
|
|
2
3
|
import warnings
|
|
3
|
-
import importlib
|
|
4
|
-
from typing import List, Dict, Union, Generator
|
|
4
|
+
import importlib.util
|
|
5
|
+
from typing import Any, Tuple, List, Dict, Union, Generator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LLMConfig(abc.ABC):
|
|
9
|
+
def __init__(self, **kwargs):
|
|
10
|
+
"""
|
|
11
|
+
This is an abstract class to provide interfaces for LLM configuration.
|
|
12
|
+
Children classes that inherts this class can be used in extrators and prompt editor.
|
|
13
|
+
Common LLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
|
|
14
|
+
"""
|
|
15
|
+
self.params = kwargs.copy()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
20
|
+
"""
|
|
21
|
+
This method preprocesses the input messages before passing them to the LLM.
|
|
22
|
+
|
|
23
|
+
Parameters:
|
|
24
|
+
----------
|
|
25
|
+
messages : List[Dict[str,str]]
|
|
26
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
-------
|
|
30
|
+
messages : List[Dict[str,str]]
|
|
31
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
32
|
+
"""
|
|
33
|
+
return NotImplemented
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
37
|
+
"""
|
|
38
|
+
This method postprocesses the LLM response after it is generated.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
----------
|
|
42
|
+
response : Union[str, Generator[str, None, None]]
|
|
43
|
+
the LLM response. Can be a string or a generator.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
-------
|
|
47
|
+
response : str
|
|
48
|
+
the postprocessed LLM response
|
|
49
|
+
"""
|
|
50
|
+
return NotImplemented
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class BasicLLMConfig(LLMConfig):
|
|
54
|
+
def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
|
|
55
|
+
"""
|
|
56
|
+
The basic LLM configuration for most non-reasoning models.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
self.max_new_tokens = max_new_tokens
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
self.params["max_new_tokens"] = self.max_new_tokens
|
|
62
|
+
self.params["temperature"] = self.temperature
|
|
63
|
+
|
|
64
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
65
|
+
"""
|
|
66
|
+
This method preprocesses the input messages before passing them to the LLM.
|
|
67
|
+
|
|
68
|
+
Parameters:
|
|
69
|
+
----------
|
|
70
|
+
messages : List[Dict[str,str]]
|
|
71
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
-------
|
|
75
|
+
messages : List[Dict[str,str]]
|
|
76
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
77
|
+
"""
|
|
78
|
+
return messages
|
|
79
|
+
|
|
80
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
81
|
+
"""
|
|
82
|
+
This method postprocesses the LLM response after it is generated.
|
|
83
|
+
|
|
84
|
+
Parameters:
|
|
85
|
+
----------
|
|
86
|
+
response : Union[str, Generator[str, None, None]]
|
|
87
|
+
the LLM response. Can be a string or a generator.
|
|
88
|
+
|
|
89
|
+
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
90
|
+
the postprocessed LLM response.
|
|
91
|
+
if input is a generator, the output will be a generator {"data": <content>}.
|
|
92
|
+
"""
|
|
93
|
+
if isinstance(response, str):
|
|
94
|
+
return response
|
|
95
|
+
|
|
96
|
+
def _process_stream():
|
|
97
|
+
for chunk in response:
|
|
98
|
+
yield {"type": "response", "data": chunk}
|
|
99
|
+
|
|
100
|
+
return _process_stream()
|
|
101
|
+
|
|
102
|
+
class Qwen3LLMConfig(LLMConfig):
|
|
103
|
+
def __init__(self, thinking_mode:bool=True, **kwargs):
|
|
104
|
+
"""
|
|
105
|
+
The Qwen3 LLM configuration for reasoning models.
|
|
106
|
+
|
|
107
|
+
Parameters:
|
|
108
|
+
----------
|
|
109
|
+
thinking_mode : bool, Optional
|
|
110
|
+
if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
|
|
111
|
+
"""
|
|
112
|
+
super().__init__(**kwargs)
|
|
113
|
+
self.thinking_mode = thinking_mode
|
|
114
|
+
|
|
115
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
116
|
+
"""
|
|
117
|
+
Append a special token to the system and user prompts.
|
|
118
|
+
The token is "/think" if thinking_mode is True, otherwise "/no_think".
|
|
119
|
+
|
|
120
|
+
Parameters:
|
|
121
|
+
----------
|
|
122
|
+
messages : List[Dict[str,str]]
|
|
123
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
-------
|
|
127
|
+
messages : List[Dict[str,str]]
|
|
128
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
129
|
+
"""
|
|
130
|
+
thinking_token = "/think" if self.thinking_mode else "/no_think"
|
|
131
|
+
new_messages = []
|
|
132
|
+
for message in messages:
|
|
133
|
+
if message['role'] in ['system', 'user']:
|
|
134
|
+
new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
|
|
135
|
+
else:
|
|
136
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
137
|
+
|
|
138
|
+
new_messages.append(new_message)
|
|
139
|
+
|
|
140
|
+
return new_messages
|
|
141
|
+
|
|
142
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str,str], None, None]]:
|
|
143
|
+
"""
|
|
144
|
+
If input is a generator, tag contents in <think> and </think> as {"type": "reasoning", "data": <content>},
|
|
145
|
+
and the rest as {"type": "response", "data": <content>}.
|
|
146
|
+
If input is a string, drop contents in <think> and </think>.
|
|
147
|
+
|
|
148
|
+
Parameters:
|
|
149
|
+
----------
|
|
150
|
+
response : Union[str, Generator[str, None, None]]
|
|
151
|
+
the LLM response. Can be a string or a generator.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
-------
|
|
155
|
+
response : Union[str, Generator[str, None, None]]
|
|
156
|
+
the postprocessed LLM response.
|
|
157
|
+
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
158
|
+
"""
|
|
159
|
+
if isinstance(response, str):
|
|
160
|
+
return re.sub(r"<think>.*?</think>\s*", "", response, flags=re.DOTALL).strip()
|
|
161
|
+
|
|
162
|
+
if isinstance(response, Generator):
|
|
163
|
+
def _process_stream():
|
|
164
|
+
think_flag = False
|
|
165
|
+
buffer = ""
|
|
166
|
+
for chunk in response:
|
|
167
|
+
if isinstance(chunk, str):
|
|
168
|
+
buffer += chunk
|
|
169
|
+
# switch between reasoning and response
|
|
170
|
+
if "<think>" in buffer:
|
|
171
|
+
think_flag = True
|
|
172
|
+
buffer = buffer.replace("<think>", "")
|
|
173
|
+
elif "</think>" in buffer:
|
|
174
|
+
think_flag = False
|
|
175
|
+
buffer = buffer.replace("</think>", "")
|
|
176
|
+
|
|
177
|
+
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
178
|
+
if chunk not in ["<think>", "</think>"]:
|
|
179
|
+
if think_flag:
|
|
180
|
+
yield {"type": "reasoning", "data": chunk}
|
|
181
|
+
else:
|
|
182
|
+
yield {"type": "response", "data": chunk}
|
|
183
|
+
|
|
184
|
+
return _process_stream()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class OpenAIReasoningLLMConfig(LLMConfig):
|
|
188
|
+
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
189
|
+
"""
|
|
190
|
+
The OpenAI "o" series configuration.
|
|
191
|
+
1. The reasoning effort is set to "low" by default.
|
|
192
|
+
2. The temperature parameter is not supported and will be ignored.
|
|
193
|
+
3. The system prompt is not supported and will be concatenated to the next user prompt.
|
|
194
|
+
|
|
195
|
+
Parameters:
|
|
196
|
+
----------
|
|
197
|
+
reasoning_effort : str, Optional
|
|
198
|
+
the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
|
|
199
|
+
"""
|
|
200
|
+
super().__init__(**kwargs)
|
|
201
|
+
if reasoning_effort not in ["low", "medium", "high"]:
|
|
202
|
+
raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
|
|
203
|
+
|
|
204
|
+
self.reasoning_effort = reasoning_effort
|
|
205
|
+
self.params["reasoning_effort"] = self.reasoning_effort
|
|
206
|
+
|
|
207
|
+
if "temperature" in self.params:
|
|
208
|
+
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
209
|
+
self.params.pop("temperature")
|
|
210
|
+
|
|
211
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
212
|
+
"""
|
|
213
|
+
Concatenate system prompts to the next user prompt.
|
|
214
|
+
|
|
215
|
+
Parameters:
|
|
216
|
+
----------
|
|
217
|
+
messages : List[Dict[str,str]]
|
|
218
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
-------
|
|
222
|
+
messages : List[Dict[str,str]]
|
|
223
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
224
|
+
"""
|
|
225
|
+
system_prompt_holder = ""
|
|
226
|
+
new_messages = []
|
|
227
|
+
for i, message in enumerate(messages):
|
|
228
|
+
# if system prompt, store it in system_prompt_holder
|
|
229
|
+
if message['role'] == 'system':
|
|
230
|
+
system_prompt_holder = message['content']
|
|
231
|
+
# if user prompt, concatenate it with system_prompt_holder
|
|
232
|
+
elif message['role'] == 'user':
|
|
233
|
+
if system_prompt_holder:
|
|
234
|
+
new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
|
|
235
|
+
system_prompt_holder = ""
|
|
236
|
+
else:
|
|
237
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
238
|
+
|
|
239
|
+
new_messages.append(new_message)
|
|
240
|
+
# if assistant/other prompt, do nothing
|
|
241
|
+
else:
|
|
242
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
243
|
+
new_messages.append(new_message)
|
|
244
|
+
|
|
245
|
+
return new_messages
|
|
246
|
+
|
|
247
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
248
|
+
"""
|
|
249
|
+
This method postprocesses the LLM response after it is generated.
|
|
250
|
+
|
|
251
|
+
Parameters:
|
|
252
|
+
----------
|
|
253
|
+
response : Union[str, Generator[str, None, None]]
|
|
254
|
+
the LLM response. Can be a string or a generator.
|
|
255
|
+
|
|
256
|
+
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
257
|
+
the postprocessed LLM response.
|
|
258
|
+
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
259
|
+
"""
|
|
260
|
+
if isinstance(response, str):
|
|
261
|
+
return response
|
|
262
|
+
|
|
263
|
+
def _process_stream():
|
|
264
|
+
for chunk in response:
|
|
265
|
+
yield {"type": "response", "data": chunk}
|
|
266
|
+
|
|
267
|
+
return _process_stream()
|
|
5
268
|
|
|
6
269
|
|
|
7
270
|
class InferenceEngine:
|
|
8
271
|
@abc.abstractmethod
|
|
9
|
-
def __init__(self):
|
|
272
|
+
def __init__(self, config:LLMConfig, **kwrs):
|
|
10
273
|
"""
|
|
11
274
|
This is an abstract class to provide interfaces for LLM inference engines.
|
|
12
275
|
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
276
|
+
|
|
277
|
+
Parameters:
|
|
278
|
+
----------
|
|
279
|
+
config : LLMConfig
|
|
280
|
+
the LLM configuration. Must be a child class of LLMConfig.
|
|
13
281
|
"""
|
|
14
282
|
return NotImplemented
|
|
15
283
|
|
|
16
284
|
|
|
17
285
|
@abc.abstractmethod
|
|
18
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
19
|
-
verbose:bool=False, stream:bool=False
|
|
286
|
+
def chat(self, messages:List[Dict[str,str]],
|
|
287
|
+
verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
20
288
|
"""
|
|
21
289
|
This method inputs chat messages and outputs LLM generated text.
|
|
22
290
|
|
|
@@ -24,10 +292,6 @@ class InferenceEngine:
|
|
|
24
292
|
----------
|
|
25
293
|
messages : List[Dict[str,str]]
|
|
26
294
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
27
|
-
max_new_tokens : str, Optional
|
|
28
|
-
the max number of new tokens LLM can generate.
|
|
29
|
-
temperature : float, Optional
|
|
30
|
-
the temperature for token sampling.
|
|
31
295
|
verbose : bool, Optional
|
|
32
296
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
33
297
|
stream : bool, Optional
|
|
@@ -35,9 +299,18 @@ class InferenceEngine:
|
|
|
35
299
|
"""
|
|
36
300
|
return NotImplemented
|
|
37
301
|
|
|
302
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
303
|
+
"""
|
|
304
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
305
|
+
|
|
306
|
+
Return : Dict[str, Any]
|
|
307
|
+
the config parameters.
|
|
308
|
+
"""
|
|
309
|
+
return NotImplemented
|
|
310
|
+
|
|
38
311
|
|
|
39
312
|
class LlamaCppInferenceEngine(InferenceEngine):
|
|
40
|
-
def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, **kwrs):
|
|
313
|
+
def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, config:LLMConfig=None, **kwrs):
|
|
41
314
|
"""
|
|
42
315
|
The Llama.cpp inference engine.
|
|
43
316
|
|
|
@@ -52,12 +325,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
52
325
|
context length that LLM will evaluate.
|
|
53
326
|
n_gpu_layers : int, Optional
|
|
54
327
|
number of layers to offload to GPU. Default is all layers (-1).
|
|
328
|
+
config : LLMConfig
|
|
329
|
+
the LLM configuration.
|
|
55
330
|
"""
|
|
56
331
|
from llama_cpp import Llama
|
|
57
332
|
self.repo_id = repo_id
|
|
58
333
|
self.gguf_filename = gguf_filename
|
|
59
334
|
self.n_ctx = n_ctx
|
|
60
335
|
self.n_gpu_layers = n_gpu_layers
|
|
336
|
+
self.config = config if config else BasicLLMConfig()
|
|
337
|
+
self.formatted_params = self._format_config()
|
|
61
338
|
|
|
62
339
|
self.model = Llama.from_pretrained(
|
|
63
340
|
repo_id=self.repo_id,
|
|
@@ -73,8 +350,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
73
350
|
"""
|
|
74
351
|
del self.model
|
|
75
352
|
|
|
353
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
354
|
+
"""
|
|
355
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
356
|
+
"""
|
|
357
|
+
formatted_params = self.config.params.copy()
|
|
358
|
+
if "max_new_tokens" in formatted_params:
|
|
359
|
+
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
360
|
+
formatted_params.pop("max_new_tokens")
|
|
361
|
+
|
|
362
|
+
return formatted_params
|
|
76
363
|
|
|
77
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
364
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> str:
|
|
78
365
|
"""
|
|
79
366
|
This method inputs chat messages and outputs LLM generated text.
|
|
80
367
|
|
|
@@ -82,19 +369,15 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
82
369
|
----------
|
|
83
370
|
messages : List[Dict[str,str]]
|
|
84
371
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
85
|
-
max_new_tokens : str, Optional
|
|
86
|
-
the max number of new tokens LLM can generate.
|
|
87
|
-
temperature : float, Optional
|
|
88
|
-
the temperature for token sampling.
|
|
89
372
|
verbose : bool, Optional
|
|
90
373
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
91
374
|
"""
|
|
375
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
376
|
+
|
|
92
377
|
response = self.model.create_chat_completion(
|
|
93
|
-
messages=
|
|
94
|
-
max_tokens=max_new_tokens,
|
|
95
|
-
temperature=temperature,
|
|
378
|
+
messages=processed_messages,
|
|
96
379
|
stream=verbose,
|
|
97
|
-
**
|
|
380
|
+
**self.formatted_params
|
|
98
381
|
)
|
|
99
382
|
|
|
100
383
|
if verbose:
|
|
@@ -105,13 +388,14 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
105
388
|
res += out_dict['content']
|
|
106
389
|
print(out_dict['content'], end='', flush=True)
|
|
107
390
|
print('\n')
|
|
108
|
-
return res
|
|
391
|
+
return self.config.postprocess_response(res)
|
|
109
392
|
|
|
110
|
-
|
|
393
|
+
res = response['choices'][0]['message']['content']
|
|
394
|
+
return self.config.postprocess_response(res)
|
|
111
395
|
|
|
112
396
|
|
|
113
397
|
class OllamaInferenceEngine(InferenceEngine):
|
|
114
|
-
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
|
|
398
|
+
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, config:LLMConfig=None, **kwrs):
|
|
115
399
|
"""
|
|
116
400
|
The Ollama inference engine.
|
|
117
401
|
|
|
@@ -123,6 +407,8 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
123
407
|
context length that LLM will evaluate.
|
|
124
408
|
keep_alive : int, Optional
|
|
125
409
|
seconds to hold the LLM after the last API call.
|
|
410
|
+
config : LLMConfig
|
|
411
|
+
the LLM configuration.
|
|
126
412
|
"""
|
|
127
413
|
if importlib.util.find_spec("ollama") is None:
|
|
128
414
|
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
@@ -133,9 +419,22 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
133
419
|
self.model_name = model_name
|
|
134
420
|
self.num_ctx = num_ctx
|
|
135
421
|
self.keep_alive = keep_alive
|
|
422
|
+
self.config = config if config else BasicLLMConfig()
|
|
423
|
+
self.formatted_params = self._format_config()
|
|
136
424
|
|
|
137
|
-
def
|
|
138
|
-
|
|
425
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
426
|
+
"""
|
|
427
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
428
|
+
"""
|
|
429
|
+
formatted_params = self.config.params.copy()
|
|
430
|
+
if "max_new_tokens" in formatted_params:
|
|
431
|
+
formatted_params["num_predict"] = formatted_params["max_new_tokens"]
|
|
432
|
+
formatted_params.pop("max_new_tokens")
|
|
433
|
+
|
|
434
|
+
return formatted_params
|
|
435
|
+
|
|
436
|
+
def chat(self, messages:List[Dict[str,str]],
|
|
437
|
+
verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
139
438
|
"""
|
|
140
439
|
This method inputs chat messages and outputs VLM generated text.
|
|
141
440
|
|
|
@@ -143,21 +442,19 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
143
442
|
----------
|
|
144
443
|
messages : List[Dict[str,str]]
|
|
145
444
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
146
|
-
max_new_tokens : str, Optional
|
|
147
|
-
the max number of new tokens VLM can generate.
|
|
148
|
-
temperature : float, Optional
|
|
149
|
-
the temperature for token sampling.
|
|
150
445
|
verbose : bool, Optional
|
|
151
446
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
152
447
|
stream : bool, Optional
|
|
153
448
|
if True, returns a generator that yields the output in real-time.
|
|
154
449
|
"""
|
|
155
|
-
|
|
450
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
451
|
+
|
|
452
|
+
options={'num_ctx': self.num_ctx, **self.formatted_params}
|
|
156
453
|
if stream:
|
|
157
454
|
def _stream_generator():
|
|
158
455
|
response_stream = self.client.chat(
|
|
159
456
|
model=self.model_name,
|
|
160
|
-
messages=
|
|
457
|
+
messages=processed_messages,
|
|
161
458
|
options=options,
|
|
162
459
|
stream=True,
|
|
163
460
|
keep_alive=self.keep_alive
|
|
@@ -167,12 +464,12 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
167
464
|
if content_chunk:
|
|
168
465
|
yield content_chunk
|
|
169
466
|
|
|
170
|
-
return _stream_generator()
|
|
467
|
+
return self.config.postprocess_response(_stream_generator())
|
|
171
468
|
|
|
172
469
|
elif verbose:
|
|
173
470
|
response = self.client.chat(
|
|
174
471
|
model=self.model_name,
|
|
175
|
-
messages=
|
|
472
|
+
messages=processed_messages,
|
|
176
473
|
options=options,
|
|
177
474
|
stream=True,
|
|
178
475
|
keep_alive=self.keep_alive
|
|
@@ -184,48 +481,82 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
184
481
|
print(content_chunk, end='', flush=True)
|
|
185
482
|
res += content_chunk
|
|
186
483
|
print('\n')
|
|
187
|
-
return res
|
|
484
|
+
return self.config.postprocess_response(res)
|
|
188
485
|
|
|
189
486
|
else:
|
|
190
487
|
response = self.client.chat(
|
|
191
488
|
model=self.model_name,
|
|
192
|
-
messages=
|
|
489
|
+
messages=processed_messages,
|
|
193
490
|
options=options,
|
|
194
491
|
stream=False,
|
|
195
492
|
keep_alive=self.keep_alive
|
|
196
493
|
)
|
|
197
|
-
|
|
494
|
+
res = response.get('message', {}).get('content')
|
|
495
|
+
return self.config.postprocess_response(res)
|
|
496
|
+
|
|
198
497
|
|
|
199
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
498
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
200
499
|
"""
|
|
201
500
|
Async version of chat method. Streaming is not supported.
|
|
202
501
|
"""
|
|
502
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
503
|
+
|
|
203
504
|
response = await self.async_client.chat(
|
|
204
505
|
model=self.model_name,
|
|
205
|
-
messages=
|
|
206
|
-
options={'
|
|
506
|
+
messages=processed_messages,
|
|
507
|
+
options={'num_ctx': self.num_ctx, **self.formatted_params},
|
|
207
508
|
stream=False,
|
|
208
509
|
keep_alive=self.keep_alive
|
|
209
510
|
)
|
|
210
511
|
|
|
211
|
-
|
|
512
|
+
res = response['message']['content']
|
|
513
|
+
return self.config.postprocess_response(res)
|
|
212
514
|
|
|
213
515
|
|
|
214
516
|
class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
215
|
-
def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, **kwrs):
|
|
517
|
+
def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, config:LLMConfig=None, **kwrs):
|
|
216
518
|
"""
|
|
217
519
|
The Huggingface_hub InferenceClient inference engine.
|
|
218
520
|
For parameters and documentation, refer to https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client
|
|
521
|
+
|
|
522
|
+
Parameters:
|
|
523
|
+
----------
|
|
524
|
+
model : str
|
|
525
|
+
the model name exactly as shown in Huggingface repo
|
|
526
|
+
token : str, Optional
|
|
527
|
+
the Huggingface token. If None, will use the token in os.environ['HF_TOKEN'].
|
|
528
|
+
base_url : str, Optional
|
|
529
|
+
the base url for the LLM server. If None, will use the default Huggingface Hub URL.
|
|
530
|
+
api_key : str, Optional
|
|
531
|
+
the API key for the LLM server.
|
|
532
|
+
config : LLMConfig
|
|
533
|
+
the LLM configuration.
|
|
219
534
|
"""
|
|
220
535
|
if importlib.util.find_spec("huggingface_hub") is None:
|
|
221
536
|
raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
|
|
222
537
|
|
|
223
538
|
from huggingface_hub import InferenceClient, AsyncInferenceClient
|
|
539
|
+
self.model = model
|
|
540
|
+
self.base_url = base_url
|
|
224
541
|
self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
225
542
|
self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
543
|
+
self.config = config if config else BasicLLMConfig()
|
|
544
|
+
self.formatted_params = self._format_config()
|
|
226
545
|
|
|
227
|
-
def
|
|
228
|
-
|
|
546
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
547
|
+
"""
|
|
548
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
549
|
+
"""
|
|
550
|
+
formatted_params = self.config.params.copy()
|
|
551
|
+
if "max_new_tokens" in formatted_params:
|
|
552
|
+
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
553
|
+
formatted_params.pop("max_new_tokens")
|
|
554
|
+
|
|
555
|
+
return formatted_params
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def chat(self, messages:List[Dict[str,str]],
|
|
559
|
+
verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
229
560
|
"""
|
|
230
561
|
This method inputs chat messages and outputs LLM generated text.
|
|
231
562
|
|
|
@@ -233,38 +564,32 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
233
564
|
----------
|
|
234
565
|
messages : List[Dict[str,str]]
|
|
235
566
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
236
|
-
max_new_tokens : str, Optional
|
|
237
|
-
the max number of new tokens LLM can generate.
|
|
238
|
-
temperature : float, Optional
|
|
239
|
-
the temperature for token sampling.
|
|
240
567
|
verbose : bool, Optional
|
|
241
568
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
242
569
|
stream : bool, Optional
|
|
243
570
|
if True, returns a generator that yields the output in real-time.
|
|
244
571
|
"""
|
|
572
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
573
|
+
|
|
245
574
|
if stream:
|
|
246
575
|
def _stream_generator():
|
|
247
576
|
response_stream = self.client.chat.completions.create(
|
|
248
|
-
messages=
|
|
249
|
-
max_tokens=max_new_tokens,
|
|
250
|
-
temperature=temperature,
|
|
577
|
+
messages=processed_messages,
|
|
251
578
|
stream=True,
|
|
252
|
-
**
|
|
579
|
+
**self.formatted_params
|
|
253
580
|
)
|
|
254
581
|
for chunk in response_stream:
|
|
255
582
|
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
256
583
|
if content_chunk:
|
|
257
584
|
yield content_chunk
|
|
258
585
|
|
|
259
|
-
return _stream_generator()
|
|
586
|
+
return self.config.postprocess_response(_stream_generator())
|
|
260
587
|
|
|
261
588
|
elif verbose:
|
|
262
589
|
response = self.client.chat.completions.create(
|
|
263
|
-
messages=
|
|
264
|
-
max_tokens=max_new_tokens,
|
|
265
|
-
temperature=temperature,
|
|
590
|
+
messages=processed_messages,
|
|
266
591
|
stream=True,
|
|
267
|
-
**
|
|
592
|
+
**self.formatted_params
|
|
268
593
|
)
|
|
269
594
|
|
|
270
595
|
res = ''
|
|
@@ -273,35 +598,35 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
273
598
|
if content_chunk:
|
|
274
599
|
res += content_chunk
|
|
275
600
|
print(content_chunk, end='', flush=True)
|
|
276
|
-
return res
|
|
601
|
+
return self.config.postprocess_response(res)
|
|
277
602
|
|
|
278
603
|
else:
|
|
279
604
|
response = self.client.chat.completions.create(
|
|
280
|
-
messages=
|
|
281
|
-
max_tokens=max_new_tokens,
|
|
282
|
-
temperature=temperature,
|
|
605
|
+
messages=processed_messages,
|
|
283
606
|
stream=False,
|
|
284
|
-
**
|
|
607
|
+
**self.formatted_params
|
|
285
608
|
)
|
|
286
|
-
|
|
609
|
+
res = response.choices[0].message.content
|
|
610
|
+
return self.config.postprocess_response(res)
|
|
287
611
|
|
|
288
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
612
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
289
613
|
"""
|
|
290
614
|
Async version of chat method. Streaming is not supported.
|
|
291
615
|
"""
|
|
616
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
617
|
+
|
|
292
618
|
response = await self.client_async.chat.completions.create(
|
|
293
|
-
messages=
|
|
294
|
-
max_tokens=max_new_tokens,
|
|
295
|
-
temperature=temperature,
|
|
619
|
+
messages=processed_messages,
|
|
296
620
|
stream=False,
|
|
297
|
-
**
|
|
621
|
+
**self.formatted_params
|
|
298
622
|
)
|
|
299
623
|
|
|
300
|
-
|
|
624
|
+
res = response.choices[0].message.content
|
|
625
|
+
return self.config.postprocess_response(res)
|
|
301
626
|
|
|
302
627
|
|
|
303
628
|
class OpenAIInferenceEngine(InferenceEngine):
|
|
304
|
-
def __init__(self, model:str,
|
|
629
|
+
def __init__(self, model:str, config:LLMConfig=None, **kwrs):
|
|
305
630
|
"""
|
|
306
631
|
The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
|
|
307
632
|
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
@@ -313,8 +638,6 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
313
638
|
----------
|
|
314
639
|
model_name : str
|
|
315
640
|
model name as described in https://platform.openai.com/docs/models
|
|
316
|
-
reasoning_model : bool, Optional
|
|
317
|
-
indicator for OpenAI reasoning models ("o" series).
|
|
318
641
|
"""
|
|
319
642
|
if importlib.util.find_spec("openai") is None:
|
|
320
643
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
@@ -323,10 +646,21 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
323
646
|
self.client = OpenAI(**kwrs)
|
|
324
647
|
self.async_client = AsyncOpenAI(**kwrs)
|
|
325
648
|
self.model = model
|
|
326
|
-
self.
|
|
649
|
+
self.config = config if config else BasicLLMConfig()
|
|
650
|
+
self.formatted_params = self._format_config()
|
|
651
|
+
|
|
652
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
653
|
+
"""
|
|
654
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
655
|
+
"""
|
|
656
|
+
formatted_params = self.config.params.copy()
|
|
657
|
+
if "max_new_tokens" in formatted_params:
|
|
658
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
659
|
+
formatted_params.pop("max_new_tokens")
|
|
660
|
+
|
|
661
|
+
return formatted_params
|
|
327
662
|
|
|
328
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
329
|
-
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
663
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
330
664
|
"""
|
|
331
665
|
This method inputs chat messages and outputs LLM generated text.
|
|
332
666
|
|
|
@@ -334,177 +668,81 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
334
668
|
----------
|
|
335
669
|
messages : List[Dict[str,str]]
|
|
336
670
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
337
|
-
max_new_tokens : str, Optional
|
|
338
|
-
the max number of new tokens LLM can generate.
|
|
339
|
-
temperature : float, Optional
|
|
340
|
-
the temperature for token sampling.
|
|
341
671
|
verbose : bool, Optional
|
|
342
672
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
343
673
|
stream : bool, Optional
|
|
344
674
|
if True, returns a generator that yields the output in real-time.
|
|
345
675
|
"""
|
|
346
|
-
|
|
347
|
-
if self.reasoning_model:
|
|
348
|
-
# Reasoning models do not support temperature parameter
|
|
349
|
-
if temperature != 0.0:
|
|
350
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
351
|
-
|
|
352
|
-
# Reasoning models do not support system prompts
|
|
353
|
-
if any(msg['role'] == 'system' for msg in messages):
|
|
354
|
-
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
355
|
-
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
if stream:
|
|
359
|
-
def _stream_generator():
|
|
360
|
-
response_stream = self.client.chat.completions.create(
|
|
361
|
-
model=self.model,
|
|
362
|
-
messages=messages,
|
|
363
|
-
max_completion_tokens=max_new_tokens,
|
|
364
|
-
stream=True,
|
|
365
|
-
**kwrs
|
|
366
|
-
)
|
|
367
|
-
for chunk in response_stream:
|
|
368
|
-
if len(chunk.choices) > 0:
|
|
369
|
-
if chunk.choices[0].delta.content is not None:
|
|
370
|
-
yield chunk.choices[0].delta.content
|
|
371
|
-
if chunk.choices[0].finish_reason == "length":
|
|
372
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
373
|
-
if self.reasoning_model:
|
|
374
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
375
|
-
return _stream_generator()
|
|
376
|
-
|
|
377
|
-
elif verbose:
|
|
378
|
-
response = self.client.chat.completions.create(
|
|
379
|
-
model=self.model,
|
|
380
|
-
messages=messages,
|
|
381
|
-
max_completion_tokens=max_new_tokens,
|
|
382
|
-
stream=True,
|
|
383
|
-
**kwrs
|
|
384
|
-
)
|
|
385
|
-
res = ''
|
|
386
|
-
for chunk in response:
|
|
387
|
-
if len(chunk.choices) > 0:
|
|
388
|
-
if chunk.choices[0].delta.content is not None:
|
|
389
|
-
res += chunk.choices[0].delta.content
|
|
390
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
391
|
-
if chunk.choices[0].finish_reason == "length":
|
|
392
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
393
|
-
if self.reasoning_model:
|
|
394
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
676
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
395
677
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
)
|
|
406
|
-
return response.choices[0].message.content
|
|
407
|
-
|
|
408
|
-
# For non-reasoning models
|
|
409
|
-
else:
|
|
410
|
-
if stream:
|
|
411
|
-
def _stream_generator():
|
|
412
|
-
response_stream = self.client.chat.completions.create(
|
|
413
|
-
model=self.model,
|
|
414
|
-
messages=messages,
|
|
415
|
-
max_tokens=max_new_tokens,
|
|
416
|
-
temperature=temperature,
|
|
417
|
-
stream=True,
|
|
418
|
-
**kwrs
|
|
419
|
-
)
|
|
420
|
-
for chunk in response_stream:
|
|
421
|
-
if len(chunk.choices) > 0:
|
|
422
|
-
if chunk.choices[0].delta.content is not None:
|
|
423
|
-
yield chunk.choices[0].delta.content
|
|
424
|
-
if chunk.choices[0].finish_reason == "length":
|
|
425
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
426
|
-
if self.reasoning_model:
|
|
427
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
428
|
-
return _stream_generator()
|
|
429
|
-
|
|
430
|
-
elif verbose:
|
|
431
|
-
response = self.client.chat.completions.create(
|
|
432
|
-
model=self.model,
|
|
433
|
-
messages=messages,
|
|
434
|
-
max_tokens=max_new_tokens,
|
|
435
|
-
temperature=temperature,
|
|
436
|
-
stream=True,
|
|
437
|
-
**kwrs
|
|
438
|
-
)
|
|
439
|
-
res = ''
|
|
440
|
-
for chunk in response:
|
|
678
|
+
if stream:
|
|
679
|
+
def _stream_generator():
|
|
680
|
+
response_stream = self.client.chat.completions.create(
|
|
681
|
+
model=self.model,
|
|
682
|
+
messages=processed_messages,
|
|
683
|
+
stream=True,
|
|
684
|
+
**self.formatted_params
|
|
685
|
+
)
|
|
686
|
+
for chunk in response_stream:
|
|
441
687
|
if len(chunk.choices) > 0:
|
|
442
688
|
if chunk.choices[0].delta.content is not None:
|
|
443
|
-
|
|
444
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
689
|
+
yield chunk.choices[0].delta.content
|
|
445
690
|
if chunk.choices[0].finish_reason == "length":
|
|
446
691
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
447
|
-
if self.reasoning_model:
|
|
448
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
449
|
-
|
|
450
|
-
print('\n')
|
|
451
|
-
return res
|
|
452
|
-
|
|
453
|
-
else:
|
|
454
|
-
response = self.client.chat.completions.create(
|
|
455
|
-
model=self.model,
|
|
456
|
-
messages=messages,
|
|
457
|
-
max_tokens=max_new_tokens,
|
|
458
|
-
temperature=temperature,
|
|
459
|
-
stream=False,
|
|
460
|
-
**kwrs
|
|
461
|
-
)
|
|
462
|
-
|
|
463
|
-
return response.choices[0].message.content
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
|
|
467
|
-
"""
|
|
468
|
-
Async version of chat method. Streaming is not supported.
|
|
469
|
-
"""
|
|
470
|
-
if self.reasoning_model:
|
|
471
|
-
# Reasoning models do not support temperature parameter
|
|
472
|
-
if temperature != 0.0:
|
|
473
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
474
692
|
|
|
475
|
-
|
|
476
|
-
if any(msg['role'] == 'system' for msg in messages):
|
|
477
|
-
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
478
|
-
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
693
|
+
return self.config.postprocess_response(_stream_generator())
|
|
479
694
|
|
|
480
|
-
|
|
695
|
+
elif verbose:
|
|
696
|
+
response = self.client.chat.completions.create(
|
|
481
697
|
model=self.model,
|
|
482
|
-
messages=
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
**kwrs
|
|
698
|
+
messages=processed_messages,
|
|
699
|
+
stream=True,
|
|
700
|
+
**self.formatted_params
|
|
486
701
|
)
|
|
702
|
+
res = ''
|
|
703
|
+
for chunk in response:
|
|
704
|
+
if len(chunk.choices) > 0:
|
|
705
|
+
if chunk.choices[0].delta.content is not None:
|
|
706
|
+
res += chunk.choices[0].delta.content
|
|
707
|
+
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
708
|
+
if chunk.choices[0].finish_reason == "length":
|
|
709
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
487
710
|
|
|
711
|
+
print('\n')
|
|
712
|
+
return self.config.postprocess_response(res)
|
|
488
713
|
else:
|
|
489
|
-
response =
|
|
714
|
+
response = self.client.chat.completions.create(
|
|
490
715
|
model=self.model,
|
|
491
|
-
messages=
|
|
492
|
-
max_tokens=max_new_tokens,
|
|
493
|
-
temperature=temperature,
|
|
716
|
+
messages=processed_messages,
|
|
494
717
|
stream=False,
|
|
495
|
-
**
|
|
718
|
+
**self.formatted_params
|
|
496
719
|
)
|
|
720
|
+
res = response.choices[0].message.content
|
|
721
|
+
return self.config.postprocess_response(res)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
725
|
+
"""
|
|
726
|
+
Async version of chat method. Streaming is not supported.
|
|
727
|
+
"""
|
|
728
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
729
|
+
|
|
730
|
+
response = await self.async_client.chat.completions.create(
|
|
731
|
+
model=self.model,
|
|
732
|
+
messages=processed_messages,
|
|
733
|
+
stream=False,
|
|
734
|
+
**self.formatted_params
|
|
735
|
+
)
|
|
497
736
|
|
|
498
737
|
if response.choices[0].finish_reason == "length":
|
|
499
738
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
500
|
-
if self.reasoning_model:
|
|
501
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
502
739
|
|
|
503
|
-
|
|
740
|
+
res = response.choices[0].message.content
|
|
741
|
+
return self.config.postprocess_response(res)
|
|
504
742
|
|
|
505
743
|
|
|
506
744
|
class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
507
|
-
def __init__(self, model:str, api_version:str,
|
|
745
|
+
def __init__(self, model:str, api_version:str, config:LLMConfig=None, **kwrs):
|
|
508
746
|
"""
|
|
509
747
|
The Azure OpenAI API inference engine.
|
|
510
748
|
For parameters and documentation, refer to
|
|
@@ -517,8 +755,8 @@ class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
|
517
755
|
model name as described in https://platform.openai.com/docs/models
|
|
518
756
|
api_version : str
|
|
519
757
|
the Azure OpenAI API version
|
|
520
|
-
|
|
521
|
-
|
|
758
|
+
config : LLMConfig
|
|
759
|
+
the LLM configuration.
|
|
522
760
|
"""
|
|
523
761
|
if importlib.util.find_spec("openai") is None:
|
|
524
762
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
@@ -530,11 +768,12 @@ class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
|
530
768
|
**kwrs)
|
|
531
769
|
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
532
770
|
**kwrs)
|
|
533
|
-
self.
|
|
771
|
+
self.config = config if config else BasicLLMConfig()
|
|
772
|
+
self.formatted_params = self._format_config()
|
|
534
773
|
|
|
535
774
|
|
|
536
775
|
class LiteLLMInferenceEngine(InferenceEngine):
|
|
537
|
-
def __init__(self, model:str=None, base_url:str=None, api_key:str=None):
|
|
776
|
+
def __init__(self, model:str=None, base_url:str=None, api_key:str=None, config:LLMConfig=None):
|
|
538
777
|
"""
|
|
539
778
|
The LiteLLM inference engine.
|
|
540
779
|
For parameters and documentation, refer to https://github.com/BerriAI/litellm?tab=readme-ov-file
|
|
@@ -547,6 +786,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
547
786
|
the base url for the LLM server
|
|
548
787
|
api_key : str, Optional
|
|
549
788
|
the API key for the LLM server
|
|
789
|
+
config : LLMConfig
|
|
790
|
+
the LLM configuration.
|
|
550
791
|
"""
|
|
551
792
|
if importlib.util.find_spec("litellm") is None:
|
|
552
793
|
raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
|
|
@@ -556,36 +797,44 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
556
797
|
self.model = model
|
|
557
798
|
self.base_url = base_url
|
|
558
799
|
self.api_key = api_key
|
|
800
|
+
self.config = config if config else BasicLLMConfig()
|
|
801
|
+
self.formatted_params = self._format_config()
|
|
559
802
|
|
|
560
|
-
def
|
|
561
|
-
|
|
803
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
804
|
+
"""
|
|
805
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
806
|
+
"""
|
|
807
|
+
formatted_params = self.config.params.copy()
|
|
808
|
+
if "max_new_tokens" in formatted_params:
|
|
809
|
+
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
810
|
+
formatted_params.pop("max_new_tokens")
|
|
811
|
+
|
|
812
|
+
return formatted_params
|
|
813
|
+
|
|
814
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
562
815
|
"""
|
|
563
816
|
This method inputs chat messages and outputs LLM generated text.
|
|
564
817
|
|
|
565
818
|
Parameters:
|
|
566
819
|
----------
|
|
567
820
|
messages : List[Dict[str,str]]
|
|
568
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
569
|
-
max_new_tokens : str, Optional
|
|
570
|
-
the max number of new tokens LLM can generate.
|
|
571
|
-
temperature : float, Optional
|
|
572
|
-
the temperature for token sampling.
|
|
821
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
573
822
|
verbose : bool, Optional
|
|
574
823
|
if True, VLM generated text will be printed in terminal in real-time.
|
|
575
824
|
stream : bool, Optional
|
|
576
825
|
if True, returns a generator that yields the output in real-time.
|
|
577
826
|
"""
|
|
827
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
828
|
+
|
|
578
829
|
if stream:
|
|
579
830
|
def _stream_generator():
|
|
580
831
|
response_stream = self.litellm.completion(
|
|
581
832
|
model=self.model,
|
|
582
|
-
messages=
|
|
583
|
-
max_tokens=max_new_tokens,
|
|
584
|
-
temperature=temperature,
|
|
833
|
+
messages=processed_messages,
|
|
585
834
|
stream=True,
|
|
586
835
|
base_url=self.base_url,
|
|
587
836
|
api_key=self.api_key,
|
|
588
|
-
**
|
|
837
|
+
**self.formatted_params
|
|
589
838
|
)
|
|
590
839
|
|
|
591
840
|
for chunk in response_stream:
|
|
@@ -593,18 +842,16 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
593
842
|
if chunk_content:
|
|
594
843
|
yield chunk_content
|
|
595
844
|
|
|
596
|
-
return _stream_generator()
|
|
845
|
+
return self.config.postprocess_response(_stream_generator())
|
|
597
846
|
|
|
598
847
|
elif verbose:
|
|
599
848
|
response = self.litellm.completion(
|
|
600
849
|
model=self.model,
|
|
601
|
-
messages=
|
|
602
|
-
max_tokens=max_new_tokens,
|
|
603
|
-
temperature=temperature,
|
|
850
|
+
messages=processed_messages,
|
|
604
851
|
stream=True,
|
|
605
852
|
base_url=self.base_url,
|
|
606
853
|
api_key=self.api_key,
|
|
607
|
-
**
|
|
854
|
+
**self.formatted_params
|
|
608
855
|
)
|
|
609
856
|
|
|
610
857
|
res = ''
|
|
@@ -614,34 +861,34 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
614
861
|
res += chunk_content
|
|
615
862
|
print(chunk_content, end='', flush=True)
|
|
616
863
|
|
|
617
|
-
return res
|
|
864
|
+
return self.config.postprocess_response(res)
|
|
618
865
|
|
|
619
866
|
else:
|
|
620
867
|
response = self.litellm.completion(
|
|
621
868
|
model=self.model,
|
|
622
|
-
messages=
|
|
623
|
-
max_tokens=max_new_tokens,
|
|
624
|
-
temperature=temperature,
|
|
869
|
+
messages=processed_messages,
|
|
625
870
|
stream=False,
|
|
626
871
|
base_url=self.base_url,
|
|
627
872
|
api_key=self.api_key,
|
|
628
|
-
**
|
|
873
|
+
**self.formatted_params
|
|
629
874
|
)
|
|
630
|
-
|
|
875
|
+
res = response.choices[0].message.content
|
|
876
|
+
return self.config.postprocess_response(res)
|
|
631
877
|
|
|
632
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
878
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
633
879
|
"""
|
|
634
880
|
Async version of chat method. Streaming is not supported.
|
|
635
881
|
"""
|
|
882
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
883
|
+
|
|
636
884
|
response = await self.litellm.acompletion(
|
|
637
885
|
model=self.model,
|
|
638
|
-
messages=
|
|
639
|
-
max_tokens=max_new_tokens,
|
|
640
|
-
temperature=temperature,
|
|
886
|
+
messages=processed_messages,
|
|
641
887
|
stream=False,
|
|
642
888
|
base_url=self.base_url,
|
|
643
889
|
api_key=self.api_key,
|
|
644
|
-
**
|
|
890
|
+
**self.formatted_params
|
|
645
891
|
)
|
|
646
892
|
|
|
647
|
-
|
|
893
|
+
res = response.get('choices')[0].get('message').get('content')
|
|
894
|
+
return self.config.postprocess_response(res)
|