llm-ie 0.4.7__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +6 -4
- llm_ie/asset/default_prompts/BasicReviewFrameExtractor_addition_review_prompt.txt +3 -0
- llm_ie/asset/default_prompts/BasicReviewFrameExtractor_revision_review_prompt.txt +2 -0
- llm_ie/asset/default_prompts/ReviewFrameExtractor_addition_review_prompt.txt +2 -1
- llm_ie/asset/default_prompts/ReviewFrameExtractor_revision_review_prompt.txt +2 -1
- llm_ie/asset/prompt_guide/BasicFrameExtractor_prompt_guide.txt +104 -86
- llm_ie/asset/prompt_guide/BasicReviewFrameExtractor_prompt_guide.txt +163 -0
- llm_ie/asset/prompt_guide/DirectFrameExtractor_prompt_guide.txt +163 -0
- llm_ie/asset/prompt_guide/ReviewFrameExtractor_prompt_guide.txt +103 -85
- llm_ie/asset/prompt_guide/SentenceFrameExtractor_prompt_guide.txt +103 -86
- llm_ie/asset/prompt_guide/SentenceReviewFrameExtractor_prompt_guide.txt +103 -86
- llm_ie/chunkers.py +191 -0
- llm_ie/data_types.py +75 -1
- llm_ie/engines.py +600 -262
- llm_ie/extractors.py +859 -899
- llm_ie/prompt_editor.py +45 -12
- llm_ie-1.1.0.dist-info/METADATA +18 -0
- llm_ie-1.1.0.dist-info/RECORD +27 -0
- llm_ie/asset/prompt_guide/SentenceCoTFrameExtractor_prompt_guide.txt +0 -217
- llm_ie-0.4.7.dist-info/METADATA +0 -1219
- llm_ie-0.4.7.dist-info/RECORD +0 -23
- {llm_ie-0.4.7.dist-info → llm_ie-1.1.0.dist-info}/WHEEL +0 -0
llm_ie/engines.py
CHANGED
|
@@ -1,21 +1,290 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import re
|
|
2
3
|
import warnings
|
|
3
|
-
import importlib
|
|
4
|
-
from typing import List, Dict, Union
|
|
4
|
+
import importlib.util
|
|
5
|
+
from typing import Any, Tuple, List, Dict, Union, Generator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LLMConfig(abc.ABC):
|
|
9
|
+
def __init__(self, **kwargs):
|
|
10
|
+
"""
|
|
11
|
+
This is an abstract class to provide interfaces for LLM configuration.
|
|
12
|
+
Children classes that inherts this class can be used in extrators and prompt editor.
|
|
13
|
+
Common LLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
|
|
14
|
+
"""
|
|
15
|
+
self.params = kwargs.copy()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@abc.abstractmethod
|
|
19
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
20
|
+
"""
|
|
21
|
+
This method preprocesses the input messages before passing them to the LLM.
|
|
22
|
+
|
|
23
|
+
Parameters:
|
|
24
|
+
----------
|
|
25
|
+
messages : List[Dict[str,str]]
|
|
26
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
-------
|
|
30
|
+
messages : List[Dict[str,str]]
|
|
31
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
32
|
+
"""
|
|
33
|
+
return NotImplemented
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
|
|
37
|
+
"""
|
|
38
|
+
This method postprocesses the LLM response after it is generated.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
----------
|
|
42
|
+
response : Union[str, Generator[str, None, None]]
|
|
43
|
+
the LLM response. Can be a string or a generator.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
-------
|
|
47
|
+
response : str
|
|
48
|
+
the postprocessed LLM response
|
|
49
|
+
"""
|
|
50
|
+
return NotImplemented
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class BasicLLMConfig(LLMConfig):
|
|
54
|
+
def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
|
|
55
|
+
"""
|
|
56
|
+
The basic LLM configuration for most non-reasoning models.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
self.max_new_tokens = max_new_tokens
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
self.params["max_new_tokens"] = self.max_new_tokens
|
|
62
|
+
self.params["temperature"] = self.temperature
|
|
63
|
+
|
|
64
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
65
|
+
"""
|
|
66
|
+
This method preprocesses the input messages before passing them to the LLM.
|
|
67
|
+
|
|
68
|
+
Parameters:
|
|
69
|
+
----------
|
|
70
|
+
messages : List[Dict[str,str]]
|
|
71
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
-------
|
|
75
|
+
messages : List[Dict[str,str]]
|
|
76
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
77
|
+
"""
|
|
78
|
+
return messages
|
|
79
|
+
|
|
80
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
81
|
+
"""
|
|
82
|
+
This method postprocesses the LLM response after it is generated.
|
|
83
|
+
|
|
84
|
+
Parameters:
|
|
85
|
+
----------
|
|
86
|
+
response : Union[str, Generator[str, None, None]]
|
|
87
|
+
the LLM response. Can be a string or a generator.
|
|
88
|
+
|
|
89
|
+
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
90
|
+
the postprocessed LLM response.
|
|
91
|
+
if input is a generator, the output will be a generator {"data": <content>}.
|
|
92
|
+
"""
|
|
93
|
+
if isinstance(response, str):
|
|
94
|
+
return response
|
|
95
|
+
|
|
96
|
+
def _process_stream():
|
|
97
|
+
for chunk in response:
|
|
98
|
+
yield {"type": "response", "data": chunk}
|
|
99
|
+
|
|
100
|
+
return _process_stream()
|
|
101
|
+
|
|
102
|
+
class Qwen3LLMConfig(LLMConfig):
|
|
103
|
+
def __init__(self, thinking_mode:bool=True, **kwargs):
|
|
104
|
+
"""
|
|
105
|
+
The Qwen3 LLM configuration for reasoning models.
|
|
106
|
+
|
|
107
|
+
Parameters:
|
|
108
|
+
----------
|
|
109
|
+
thinking_mode : bool, Optional
|
|
110
|
+
if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
|
|
111
|
+
"""
|
|
112
|
+
super().__init__(**kwargs)
|
|
113
|
+
self.thinking_mode = thinking_mode
|
|
114
|
+
|
|
115
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
116
|
+
"""
|
|
117
|
+
Append a special token to the system and user prompts.
|
|
118
|
+
The token is "/think" if thinking_mode is True, otherwise "/no_think".
|
|
119
|
+
|
|
120
|
+
Parameters:
|
|
121
|
+
----------
|
|
122
|
+
messages : List[Dict[str,str]]
|
|
123
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
-------
|
|
127
|
+
messages : List[Dict[str,str]]
|
|
128
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
129
|
+
"""
|
|
130
|
+
thinking_token = "/think" if self.thinking_mode else "/no_think"
|
|
131
|
+
new_messages = []
|
|
132
|
+
for message in messages:
|
|
133
|
+
if message['role'] in ['system', 'user']:
|
|
134
|
+
new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
|
|
135
|
+
else:
|
|
136
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
137
|
+
|
|
138
|
+
new_messages.append(new_message)
|
|
139
|
+
|
|
140
|
+
return new_messages
|
|
141
|
+
|
|
142
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str,str], None, None]]:
|
|
143
|
+
"""
|
|
144
|
+
If input is a generator, tag contents in <think> and </think> as {"type": "reasoning", "data": <content>},
|
|
145
|
+
and the rest as {"type": "response", "data": <content>}.
|
|
146
|
+
If input is a string, drop contents in <think> and </think>.
|
|
147
|
+
|
|
148
|
+
Parameters:
|
|
149
|
+
----------
|
|
150
|
+
response : Union[str, Generator[str, None, None]]
|
|
151
|
+
the LLM response. Can be a string or a generator.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
-------
|
|
155
|
+
response : Union[str, Generator[str, None, None]]
|
|
156
|
+
the postprocessed LLM response.
|
|
157
|
+
if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
|
|
158
|
+
"""
|
|
159
|
+
if isinstance(response, str):
|
|
160
|
+
return re.sub(r"<think>.*?</think>\s*", "", response, flags=re.DOTALL).strip()
|
|
161
|
+
|
|
162
|
+
if isinstance(response, Generator):
|
|
163
|
+
def _process_stream():
|
|
164
|
+
think_flag = False
|
|
165
|
+
buffer = ""
|
|
166
|
+
for chunk in response:
|
|
167
|
+
if isinstance(chunk, str):
|
|
168
|
+
buffer += chunk
|
|
169
|
+
# switch between reasoning and response
|
|
170
|
+
if "<think>" in buffer:
|
|
171
|
+
think_flag = True
|
|
172
|
+
buffer = buffer.replace("<think>", "")
|
|
173
|
+
elif "</think>" in buffer:
|
|
174
|
+
think_flag = False
|
|
175
|
+
buffer = buffer.replace("</think>", "")
|
|
176
|
+
|
|
177
|
+
# if chunk is in thinking block, tag it as reasoning; else tag it as response
|
|
178
|
+
if chunk not in ["<think>", "</think>"]:
|
|
179
|
+
if think_flag:
|
|
180
|
+
yield {"type": "reasoning", "data": chunk}
|
|
181
|
+
else:
|
|
182
|
+
yield {"type": "response", "data": chunk}
|
|
183
|
+
|
|
184
|
+
return _process_stream()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class OpenAIReasoningLLMConfig(LLMConfig):
|
|
188
|
+
def __init__(self, reasoning_effort:str="low", **kwargs):
|
|
189
|
+
"""
|
|
190
|
+
The OpenAI "o" series configuration.
|
|
191
|
+
1. The reasoning effort is set to "low" by default.
|
|
192
|
+
2. The temperature parameter is not supported and will be ignored.
|
|
193
|
+
3. The system prompt is not supported and will be concatenated to the next user prompt.
|
|
194
|
+
|
|
195
|
+
Parameters:
|
|
196
|
+
----------
|
|
197
|
+
reasoning_effort : str, Optional
|
|
198
|
+
the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
|
|
199
|
+
"""
|
|
200
|
+
super().__init__(**kwargs)
|
|
201
|
+
if reasoning_effort not in ["low", "medium", "high"]:
|
|
202
|
+
raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
|
|
203
|
+
|
|
204
|
+
self.reasoning_effort = reasoning_effort
|
|
205
|
+
self.params["reasoning_effort"] = self.reasoning_effort
|
|
206
|
+
|
|
207
|
+
if "temperature" in self.params:
|
|
208
|
+
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
209
|
+
self.params.pop("temperature")
|
|
210
|
+
|
|
211
|
+
def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
|
|
212
|
+
"""
|
|
213
|
+
Concatenate system prompts to the next user prompt.
|
|
214
|
+
|
|
215
|
+
Parameters:
|
|
216
|
+
----------
|
|
217
|
+
messages : List[Dict[str,str]]
|
|
218
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
-------
|
|
222
|
+
messages : List[Dict[str,str]]
|
|
223
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
224
|
+
"""
|
|
225
|
+
system_prompt_holder = ""
|
|
226
|
+
new_messages = []
|
|
227
|
+
for i, message in enumerate(messages):
|
|
228
|
+
# if system prompt, store it in system_prompt_holder
|
|
229
|
+
if message['role'] == 'system':
|
|
230
|
+
system_prompt_holder = message['content']
|
|
231
|
+
# if user prompt, concatenate it with system_prompt_holder
|
|
232
|
+
elif message['role'] == 'user':
|
|
233
|
+
if system_prompt_holder:
|
|
234
|
+
new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
|
|
235
|
+
system_prompt_holder = ""
|
|
236
|
+
else:
|
|
237
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
238
|
+
|
|
239
|
+
new_messages.append(new_message)
|
|
240
|
+
# if assistant/other prompt, do nothing
|
|
241
|
+
else:
|
|
242
|
+
new_message = {'role': message['role'], 'content': message['content']}
|
|
243
|
+
new_messages.append(new_message)
|
|
244
|
+
|
|
245
|
+
return new_messages
|
|
246
|
+
|
|
247
|
+
def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
248
|
+
"""
|
|
249
|
+
This method postprocesses the LLM response after it is generated.
|
|
250
|
+
|
|
251
|
+
Parameters:
|
|
252
|
+
----------
|
|
253
|
+
response : Union[str, Generator[str, None, None]]
|
|
254
|
+
the LLM response. Can be a string or a generator.
|
|
255
|
+
|
|
256
|
+
Returns: Union[str, Generator[Dict[str, str], None, None]]
|
|
257
|
+
the postprocessed LLM response.
|
|
258
|
+
if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
|
|
259
|
+
"""
|
|
260
|
+
if isinstance(response, str):
|
|
261
|
+
return response
|
|
262
|
+
|
|
263
|
+
def _process_stream():
|
|
264
|
+
for chunk in response:
|
|
265
|
+
yield {"type": "response", "data": chunk}
|
|
266
|
+
|
|
267
|
+
return _process_stream()
|
|
5
268
|
|
|
6
269
|
|
|
7
270
|
class InferenceEngine:
|
|
8
271
|
@abc.abstractmethod
|
|
9
|
-
def __init__(self):
|
|
272
|
+
def __init__(self, config:LLMConfig, **kwrs):
|
|
10
273
|
"""
|
|
11
274
|
This is an abstract class to provide interfaces for LLM inference engines.
|
|
12
275
|
Children classes that inherts this class can be used in extrators. Must implement chat() method.
|
|
276
|
+
|
|
277
|
+
Parameters:
|
|
278
|
+
----------
|
|
279
|
+
config : LLMConfig
|
|
280
|
+
the LLM configuration. Must be a child class of LLMConfig.
|
|
13
281
|
"""
|
|
14
282
|
return NotImplemented
|
|
15
283
|
|
|
16
284
|
|
|
17
285
|
@abc.abstractmethod
|
|
18
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
286
|
+
def chat(self, messages:List[Dict[str,str]],
|
|
287
|
+
verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
19
288
|
"""
|
|
20
289
|
This method inputs chat messages and outputs LLM generated text.
|
|
21
290
|
|
|
@@ -23,18 +292,25 @@ class InferenceEngine:
|
|
|
23
292
|
----------
|
|
24
293
|
messages : List[Dict[str,str]]
|
|
25
294
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
temperature : float, Optional
|
|
29
|
-
the temperature for token sampling.
|
|
295
|
+
verbose : bool, Optional
|
|
296
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
30
297
|
stream : bool, Optional
|
|
31
|
-
if True,
|
|
298
|
+
if True, returns a generator that yields the output in real-time.
|
|
299
|
+
"""
|
|
300
|
+
return NotImplemented
|
|
301
|
+
|
|
302
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
303
|
+
"""
|
|
304
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
305
|
+
|
|
306
|
+
Return : Dict[str, Any]
|
|
307
|
+
the config parameters.
|
|
32
308
|
"""
|
|
33
309
|
return NotImplemented
|
|
34
310
|
|
|
35
311
|
|
|
36
312
|
class LlamaCppInferenceEngine(InferenceEngine):
|
|
37
|
-
def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, **kwrs):
|
|
313
|
+
def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, config:LLMConfig=None, **kwrs):
|
|
38
314
|
"""
|
|
39
315
|
The Llama.cpp inference engine.
|
|
40
316
|
|
|
@@ -49,12 +325,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
49
325
|
context length that LLM will evaluate.
|
|
50
326
|
n_gpu_layers : int, Optional
|
|
51
327
|
number of layers to offload to GPU. Default is all layers (-1).
|
|
328
|
+
config : LLMConfig
|
|
329
|
+
the LLM configuration.
|
|
52
330
|
"""
|
|
53
331
|
from llama_cpp import Llama
|
|
54
332
|
self.repo_id = repo_id
|
|
55
333
|
self.gguf_filename = gguf_filename
|
|
56
334
|
self.n_ctx = n_ctx
|
|
57
335
|
self.n_gpu_layers = n_gpu_layers
|
|
336
|
+
self.config = config if config else BasicLLMConfig()
|
|
337
|
+
self.formatted_params = self._format_config()
|
|
58
338
|
|
|
59
339
|
self.model = Llama.from_pretrained(
|
|
60
340
|
repo_id=self.repo_id,
|
|
@@ -70,8 +350,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
70
350
|
"""
|
|
71
351
|
del self.model
|
|
72
352
|
|
|
353
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
354
|
+
"""
|
|
355
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
356
|
+
"""
|
|
357
|
+
formatted_params = self.config.params.copy()
|
|
358
|
+
if "max_new_tokens" in formatted_params:
|
|
359
|
+
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
360
|
+
formatted_params.pop("max_new_tokens")
|
|
361
|
+
|
|
362
|
+
return formatted_params
|
|
73
363
|
|
|
74
|
-
def chat(self, messages:List[Dict[str,str]],
|
|
364
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> str:
|
|
75
365
|
"""
|
|
76
366
|
This method inputs chat messages and outputs LLM generated text.
|
|
77
367
|
|
|
@@ -79,22 +369,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
79
369
|
----------
|
|
80
370
|
messages : List[Dict[str,str]]
|
|
81
371
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
82
|
-
|
|
83
|
-
the max number of new tokens LLM can generate.
|
|
84
|
-
temperature : float, Optional
|
|
85
|
-
the temperature for token sampling.
|
|
86
|
-
stream : bool, Optional
|
|
372
|
+
verbose : bool, Optional
|
|
87
373
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
88
374
|
"""
|
|
375
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
376
|
+
|
|
89
377
|
response = self.model.create_chat_completion(
|
|
90
|
-
messages=
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
stream=stream,
|
|
94
|
-
**kwrs
|
|
378
|
+
messages=processed_messages,
|
|
379
|
+
stream=verbose,
|
|
380
|
+
**self.formatted_params
|
|
95
381
|
)
|
|
96
382
|
|
|
97
|
-
if
|
|
383
|
+
if verbose:
|
|
98
384
|
res = ''
|
|
99
385
|
for chunk in response:
|
|
100
386
|
out_dict = chunk['choices'][0]['delta']
|
|
@@ -102,16 +388,14 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
102
388
|
res += out_dict['content']
|
|
103
389
|
print(out_dict['content'], end='', flush=True)
|
|
104
390
|
print('\n')
|
|
105
|
-
return res
|
|
391
|
+
return self.config.postprocess_response(res)
|
|
106
392
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
393
|
+
res = response['choices'][0]['message']['content']
|
|
394
|
+
return self.config.postprocess_response(res)
|
|
111
395
|
|
|
112
396
|
|
|
113
397
|
class OllamaInferenceEngine(InferenceEngine):
|
|
114
|
-
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
|
|
398
|
+
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, config:LLMConfig=None, **kwrs):
|
|
115
399
|
"""
|
|
116
400
|
The Ollama inference engine.
|
|
117
401
|
|
|
@@ -123,6 +407,8 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
123
407
|
context length that LLM will evaluate.
|
|
124
408
|
keep_alive : int, Optional
|
|
125
409
|
seconds to hold the LLM after the last API call.
|
|
410
|
+
config : LLMConfig
|
|
411
|
+
the LLM configuration.
|
|
126
412
|
"""
|
|
127
413
|
if importlib.util.find_spec("ollama") is None:
|
|
128
414
|
raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
|
|
@@ -133,69 +419,144 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
133
419
|
self.model_name = model_name
|
|
134
420
|
self.num_ctx = num_ctx
|
|
135
421
|
self.keep_alive = keep_alive
|
|
422
|
+
self.config = config if config else BasicLLMConfig()
|
|
423
|
+
self.formatted_params = self._format_config()
|
|
424
|
+
|
|
425
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
426
|
+
"""
|
|
427
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
428
|
+
"""
|
|
429
|
+
formatted_params = self.config.params.copy()
|
|
430
|
+
if "max_new_tokens" in formatted_params:
|
|
431
|
+
formatted_params["num_predict"] = formatted_params["max_new_tokens"]
|
|
432
|
+
formatted_params.pop("max_new_tokens")
|
|
136
433
|
|
|
137
|
-
|
|
434
|
+
return formatted_params
|
|
435
|
+
|
|
436
|
+
def chat(self, messages:List[Dict[str,str]],
|
|
437
|
+
verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
138
438
|
"""
|
|
139
|
-
This method inputs chat messages and outputs
|
|
439
|
+
This method inputs chat messages and outputs VLM generated text.
|
|
140
440
|
|
|
141
441
|
Parameters:
|
|
142
442
|
----------
|
|
143
443
|
messages : List[Dict[str,str]]
|
|
144
444
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
temperature : float, Optional
|
|
148
|
-
the temperature for token sampling.
|
|
445
|
+
verbose : bool, Optional
|
|
446
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
149
447
|
stream : bool, Optional
|
|
150
|
-
if True,
|
|
448
|
+
if True, returns a generator that yields the output in real-time.
|
|
151
449
|
"""
|
|
152
|
-
|
|
450
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
451
|
+
|
|
452
|
+
options={'num_ctx': self.num_ctx, **self.formatted_params}
|
|
453
|
+
if stream:
|
|
454
|
+
def _stream_generator():
|
|
455
|
+
response_stream = self.client.chat(
|
|
456
|
+
model=self.model_name,
|
|
457
|
+
messages=processed_messages,
|
|
458
|
+
options=options,
|
|
459
|
+
stream=True,
|
|
460
|
+
keep_alive=self.keep_alive
|
|
461
|
+
)
|
|
462
|
+
for chunk in response_stream:
|
|
463
|
+
content_chunk = chunk.get('message', {}).get('content')
|
|
464
|
+
if content_chunk:
|
|
465
|
+
yield content_chunk
|
|
466
|
+
|
|
467
|
+
return self.config.postprocess_response(_stream_generator())
|
|
468
|
+
|
|
469
|
+
elif verbose:
|
|
470
|
+
response = self.client.chat(
|
|
153
471
|
model=self.model_name,
|
|
154
|
-
messages=
|
|
155
|
-
options=
|
|
156
|
-
stream=
|
|
472
|
+
messages=processed_messages,
|
|
473
|
+
options=options,
|
|
474
|
+
stream=True,
|
|
157
475
|
keep_alive=self.keep_alive
|
|
158
476
|
)
|
|
159
|
-
|
|
477
|
+
|
|
160
478
|
res = ''
|
|
161
479
|
for chunk in response:
|
|
162
|
-
|
|
163
|
-
print(
|
|
480
|
+
content_chunk = chunk.get('message', {}).get('content')
|
|
481
|
+
print(content_chunk, end='', flush=True)
|
|
482
|
+
res += content_chunk
|
|
164
483
|
print('\n')
|
|
165
|
-
return res
|
|
484
|
+
return self.config.postprocess_response(res)
|
|
485
|
+
|
|
486
|
+
else:
|
|
487
|
+
response = self.client.chat(
|
|
488
|
+
model=self.model_name,
|
|
489
|
+
messages=processed_messages,
|
|
490
|
+
options=options,
|
|
491
|
+
stream=False,
|
|
492
|
+
keep_alive=self.keep_alive
|
|
493
|
+
)
|
|
494
|
+
res = response.get('message', {}).get('content')
|
|
495
|
+
return self.config.postprocess_response(res)
|
|
166
496
|
|
|
167
|
-
return response['message']['content']
|
|
168
|
-
|
|
169
497
|
|
|
170
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
498
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
171
499
|
"""
|
|
172
500
|
Async version of chat method. Streaming is not supported.
|
|
173
501
|
"""
|
|
502
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
503
|
+
|
|
174
504
|
response = await self.async_client.chat(
|
|
175
505
|
model=self.model_name,
|
|
176
|
-
messages=
|
|
177
|
-
options={'
|
|
506
|
+
messages=processed_messages,
|
|
507
|
+
options={'num_ctx': self.num_ctx, **self.formatted_params},
|
|
178
508
|
stream=False,
|
|
179
509
|
keep_alive=self.keep_alive
|
|
180
510
|
)
|
|
181
511
|
|
|
182
|
-
|
|
512
|
+
res = response['message']['content']
|
|
513
|
+
return self.config.postprocess_response(res)
|
|
183
514
|
|
|
184
515
|
|
|
185
516
|
class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
186
|
-
def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, **kwrs):
|
|
517
|
+
def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, config:LLMConfig=None, **kwrs):
|
|
187
518
|
"""
|
|
188
519
|
The Huggingface_hub InferenceClient inference engine.
|
|
189
520
|
For parameters and documentation, refer to https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client
|
|
521
|
+
|
|
522
|
+
Parameters:
|
|
523
|
+
----------
|
|
524
|
+
model : str
|
|
525
|
+
the model name exactly as shown in Huggingface repo
|
|
526
|
+
token : str, Optional
|
|
527
|
+
the Huggingface token. If None, will use the token in os.environ['HF_TOKEN'].
|
|
528
|
+
base_url : str, Optional
|
|
529
|
+
the base url for the LLM server. If None, will use the default Huggingface Hub URL.
|
|
530
|
+
api_key : str, Optional
|
|
531
|
+
the API key for the LLM server.
|
|
532
|
+
config : LLMConfig
|
|
533
|
+
the LLM configuration.
|
|
190
534
|
"""
|
|
191
535
|
if importlib.util.find_spec("huggingface_hub") is None:
|
|
192
536
|
raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
|
|
193
537
|
|
|
194
538
|
from huggingface_hub import InferenceClient, AsyncInferenceClient
|
|
539
|
+
self.model = model
|
|
540
|
+
self.base_url = base_url
|
|
195
541
|
self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
196
542
|
self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
543
|
+
self.config = config if config else BasicLLMConfig()
|
|
544
|
+
self.formatted_params = self._format_config()
|
|
197
545
|
|
|
198
|
-
def
|
|
546
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
547
|
+
"""
|
|
548
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
549
|
+
"""
|
|
550
|
+
formatted_params = self.config.params.copy()
|
|
551
|
+
if "max_new_tokens" in formatted_params:
|
|
552
|
+
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
553
|
+
formatted_params.pop("max_new_tokens")
|
|
554
|
+
|
|
555
|
+
return formatted_params
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def chat(self, messages:List[Dict[str,str]],
|
|
559
|
+
verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
199
560
|
"""
|
|
200
561
|
This method inputs chat messages and outputs LLM generated text.
|
|
201
562
|
|
|
@@ -203,47 +564,69 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
203
564
|
----------
|
|
204
565
|
messages : List[Dict[str,str]]
|
|
205
566
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
temperature : float, Optional
|
|
209
|
-
the temperature for token sampling.
|
|
567
|
+
verbose : bool, Optional
|
|
568
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
210
569
|
stream : bool, Optional
|
|
211
|
-
if True,
|
|
570
|
+
if True, returns a generator that yields the output in real-time.
|
|
212
571
|
"""
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
max_tokens=max_new_tokens,
|
|
216
|
-
temperature=temperature,
|
|
217
|
-
stream=stream,
|
|
218
|
-
**kwrs
|
|
219
|
-
)
|
|
220
|
-
|
|
572
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
573
|
+
|
|
221
574
|
if stream:
|
|
575
|
+
def _stream_generator():
|
|
576
|
+
response_stream = self.client.chat.completions.create(
|
|
577
|
+
messages=processed_messages,
|
|
578
|
+
stream=True,
|
|
579
|
+
**self.formatted_params
|
|
580
|
+
)
|
|
581
|
+
for chunk in response_stream:
|
|
582
|
+
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
583
|
+
if content_chunk:
|
|
584
|
+
yield content_chunk
|
|
585
|
+
|
|
586
|
+
return self.config.postprocess_response(_stream_generator())
|
|
587
|
+
|
|
588
|
+
elif verbose:
|
|
589
|
+
response = self.client.chat.completions.create(
|
|
590
|
+
messages=processed_messages,
|
|
591
|
+
stream=True,
|
|
592
|
+
**self.formatted_params
|
|
593
|
+
)
|
|
594
|
+
|
|
222
595
|
res = ''
|
|
223
596
|
for chunk in response:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
597
|
+
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
598
|
+
if content_chunk:
|
|
599
|
+
res += content_chunk
|
|
600
|
+
print(content_chunk, end='', flush=True)
|
|
601
|
+
return self.config.postprocess_response(res)
|
|
227
602
|
|
|
228
|
-
|
|
603
|
+
else:
|
|
604
|
+
response = self.client.chat.completions.create(
|
|
605
|
+
messages=processed_messages,
|
|
606
|
+
stream=False,
|
|
607
|
+
**self.formatted_params
|
|
608
|
+
)
|
|
609
|
+
res = response.choices[0].message.content
|
|
610
|
+
return self.config.postprocess_response(res)
|
|
229
611
|
|
|
230
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
612
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
231
613
|
"""
|
|
232
614
|
Async version of chat method. Streaming is not supported.
|
|
233
615
|
"""
|
|
616
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
617
|
+
|
|
234
618
|
response = await self.client_async.chat.completions.create(
|
|
235
|
-
messages=
|
|
236
|
-
max_tokens=max_new_tokens,
|
|
237
|
-
temperature=temperature,
|
|
619
|
+
messages=processed_messages,
|
|
238
620
|
stream=False,
|
|
239
|
-
**
|
|
621
|
+
**self.formatted_params
|
|
240
622
|
)
|
|
241
623
|
|
|
242
|
-
|
|
624
|
+
res = response.choices[0].message.content
|
|
625
|
+
return self.config.postprocess_response(res)
|
|
243
626
|
|
|
244
627
|
|
|
245
628
|
class OpenAIInferenceEngine(InferenceEngine):
|
|
246
|
-
def __init__(self, model:str,
|
|
629
|
+
def __init__(self, model:str, config:LLMConfig=None, **kwrs):
|
|
247
630
|
"""
|
|
248
631
|
The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
|
|
249
632
|
- vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
|
|
@@ -255,8 +638,6 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
255
638
|
----------
|
|
256
639
|
model_name : str
|
|
257
640
|
model name as described in https://platform.openai.com/docs/models
|
|
258
|
-
reasoning_model : bool, Optional
|
|
259
|
-
indicator for OpenAI reasoning models ("o" series).
|
|
260
641
|
"""
|
|
261
642
|
if importlib.util.find_spec("openai") is None:
|
|
262
643
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
@@ -265,9 +646,21 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
265
646
|
self.client = OpenAI(**kwrs)
|
|
266
647
|
self.async_client = AsyncOpenAI(**kwrs)
|
|
267
648
|
self.model = model
|
|
268
|
-
self.
|
|
649
|
+
self.config = config if config else BasicLLMConfig()
|
|
650
|
+
self.formatted_params = self._format_config()
|
|
269
651
|
|
|
270
|
-
def
|
|
652
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
653
|
+
"""
|
|
654
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
655
|
+
"""
|
|
656
|
+
formatted_params = self.config.params.copy()
|
|
657
|
+
if "max_new_tokens" in formatted_params:
|
|
658
|
+
formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
|
|
659
|
+
formatted_params.pop("max_new_tokens")
|
|
660
|
+
|
|
661
|
+
return formatted_params
|
|
662
|
+
|
|
663
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
271
664
|
"""
|
|
272
665
|
This method inputs chat messages and outputs LLM generated text.
|
|
273
666
|
|
|
@@ -275,36 +668,37 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
275
668
|
----------
|
|
276
669
|
messages : List[Dict[str,str]]
|
|
277
670
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
temperature : float, Optional
|
|
281
|
-
the temperature for token sampling.
|
|
671
|
+
verbose : bool, Optional
|
|
672
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
282
673
|
stream : bool, Optional
|
|
283
|
-
if True,
|
|
674
|
+
if True, returns a generator that yields the output in real-time.
|
|
284
675
|
"""
|
|
285
|
-
|
|
286
|
-
if temperature != 0.0:
|
|
287
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
288
|
-
|
|
289
|
-
response = self.client.chat.completions.create(
|
|
290
|
-
model=self.model,
|
|
291
|
-
messages=messages,
|
|
292
|
-
max_completion_tokens=max_new_tokens,
|
|
293
|
-
stream=stream,
|
|
294
|
-
**kwrs
|
|
295
|
-
)
|
|
676
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
296
677
|
|
|
297
|
-
|
|
678
|
+
if stream:
|
|
679
|
+
def _stream_generator():
|
|
680
|
+
response_stream = self.client.chat.completions.create(
|
|
681
|
+
model=self.model,
|
|
682
|
+
messages=processed_messages,
|
|
683
|
+
stream=True,
|
|
684
|
+
**self.formatted_params
|
|
685
|
+
)
|
|
686
|
+
for chunk in response_stream:
|
|
687
|
+
if len(chunk.choices) > 0:
|
|
688
|
+
if chunk.choices[0].delta.content is not None:
|
|
689
|
+
yield chunk.choices[0].delta.content
|
|
690
|
+
if chunk.choices[0].finish_reason == "length":
|
|
691
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
692
|
+
|
|
693
|
+
return self.config.postprocess_response(_stream_generator())
|
|
694
|
+
|
|
695
|
+
elif verbose:
|
|
298
696
|
response = self.client.chat.completions.create(
|
|
299
697
|
model=self.model,
|
|
300
|
-
messages=
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
stream=stream,
|
|
304
|
-
**kwrs
|
|
698
|
+
messages=processed_messages,
|
|
699
|
+
stream=True,
|
|
700
|
+
**self.formatted_params
|
|
305
701
|
)
|
|
306
|
-
|
|
307
|
-
if stream:
|
|
308
702
|
res = ''
|
|
309
703
|
for chunk in response:
|
|
310
704
|
if len(chunk.choices) > 0:
|
|
@@ -313,53 +707,42 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
313
707
|
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
314
708
|
if chunk.choices[0].finish_reason == "length":
|
|
315
709
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
316
|
-
if self.reasoning_model:
|
|
317
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
318
|
-
return res
|
|
319
|
-
|
|
320
|
-
if response.choices[0].finish_reason == "length":
|
|
321
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
322
|
-
if self.reasoning_model:
|
|
323
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
324
|
-
|
|
325
|
-
return response.choices[0].message.content
|
|
326
|
-
|
|
327
710
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
Async version of chat method. Streaming is not supported.
|
|
331
|
-
"""
|
|
332
|
-
if self.reasoning_model:
|
|
333
|
-
if temperature != 0.0:
|
|
334
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
335
|
-
|
|
336
|
-
response = await self.async_client.chat.completions.create(
|
|
337
|
-
model=self.model,
|
|
338
|
-
messages=messages,
|
|
339
|
-
max_completion_tokens=max_new_tokens,
|
|
340
|
-
stream=False,
|
|
341
|
-
**kwrs
|
|
342
|
-
)
|
|
711
|
+
print('\n')
|
|
712
|
+
return self.config.postprocess_response(res)
|
|
343
713
|
else:
|
|
344
|
-
response =
|
|
714
|
+
response = self.client.chat.completions.create(
|
|
345
715
|
model=self.model,
|
|
346
|
-
messages=
|
|
347
|
-
max_tokens=max_new_tokens,
|
|
348
|
-
temperature=temperature,
|
|
716
|
+
messages=processed_messages,
|
|
349
717
|
stream=False,
|
|
350
|
-
**
|
|
718
|
+
**self.formatted_params
|
|
351
719
|
)
|
|
720
|
+
res = response.choices[0].message.content
|
|
721
|
+
return self.config.postprocess_response(res)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
725
|
+
"""
|
|
726
|
+
Async version of chat method. Streaming is not supported.
|
|
727
|
+
"""
|
|
728
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
729
|
+
|
|
730
|
+
response = await self.async_client.chat.completions.create(
|
|
731
|
+
model=self.model,
|
|
732
|
+
messages=processed_messages,
|
|
733
|
+
stream=False,
|
|
734
|
+
**self.formatted_params
|
|
735
|
+
)
|
|
352
736
|
|
|
353
737
|
if response.choices[0].finish_reason == "length":
|
|
354
738
|
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
355
|
-
if self.reasoning_model:
|
|
356
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
357
739
|
|
|
358
|
-
|
|
740
|
+
res = response.choices[0].message.content
|
|
741
|
+
return self.config.postprocess_response(res)
|
|
359
742
|
|
|
360
743
|
|
|
361
|
-
class AzureOpenAIInferenceEngine(
|
|
362
|
-
def __init__(self, model:str, api_version:str,
|
|
744
|
+
class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
745
|
+
def __init__(self, model:str, api_version:str, config:LLMConfig=None, **kwrs):
|
|
363
746
|
"""
|
|
364
747
|
The Azure OpenAI API inference engine.
|
|
365
748
|
For parameters and documentation, refer to
|
|
@@ -372,8 +755,8 @@ class AzureOpenAIInferenceEngine(InferenceEngine):
|
|
|
372
755
|
model name as described in https://platform.openai.com/docs/models
|
|
373
756
|
api_version : str
|
|
374
757
|
the Azure OpenAI API version
|
|
375
|
-
|
|
376
|
-
|
|
758
|
+
config : LLMConfig
|
|
759
|
+
the LLM configuration.
|
|
377
760
|
"""
|
|
378
761
|
if importlib.util.find_spec("openai") is None:
|
|
379
762
|
raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
|
|
@@ -385,101 +768,12 @@ class AzureOpenAIInferenceEngine(InferenceEngine):
|
|
|
385
768
|
**kwrs)
|
|
386
769
|
self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
|
|
387
770
|
**kwrs)
|
|
388
|
-
self.
|
|
389
|
-
|
|
390
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
|
|
391
|
-
"""
|
|
392
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
393
|
-
|
|
394
|
-
Parameters:
|
|
395
|
-
----------
|
|
396
|
-
messages : List[Dict[str,str]]
|
|
397
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
398
|
-
max_new_tokens : str, Optional
|
|
399
|
-
the max number of new tokens LLM can generate.
|
|
400
|
-
temperature : float, Optional
|
|
401
|
-
the temperature for token sampling.
|
|
402
|
-
stream : bool, Optional
|
|
403
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
404
|
-
"""
|
|
405
|
-
if self.reasoning_model:
|
|
406
|
-
if temperature != 0.0:
|
|
407
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
408
|
-
|
|
409
|
-
response = self.client.chat.completions.create(
|
|
410
|
-
model=self.model,
|
|
411
|
-
messages=messages,
|
|
412
|
-
max_completion_tokens=max_new_tokens,
|
|
413
|
-
stream=stream,
|
|
414
|
-
**kwrs
|
|
415
|
-
)
|
|
416
|
-
|
|
417
|
-
else:
|
|
418
|
-
response = self.client.chat.completions.create(
|
|
419
|
-
model=self.model,
|
|
420
|
-
messages=messages,
|
|
421
|
-
max_tokens=max_new_tokens,
|
|
422
|
-
temperature=temperature,
|
|
423
|
-
stream=stream,
|
|
424
|
-
**kwrs
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
if stream:
|
|
428
|
-
res = ''
|
|
429
|
-
for chunk in response:
|
|
430
|
-
if len(chunk.choices) > 0:
|
|
431
|
-
if chunk.choices[0].delta.content is not None:
|
|
432
|
-
res += chunk.choices[0].delta.content
|
|
433
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
434
|
-
if chunk.choices[0].finish_reason == "length":
|
|
435
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
436
|
-
if self.reasoning_model:
|
|
437
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
438
|
-
return res
|
|
439
|
-
|
|
440
|
-
if response.choices[0].finish_reason == "length":
|
|
441
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
442
|
-
if self.reasoning_model:
|
|
443
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
444
|
-
|
|
445
|
-
return response.choices[0].message.content
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
|
|
449
|
-
"""
|
|
450
|
-
Async version of chat method. Streaming is not supported.
|
|
451
|
-
"""
|
|
452
|
-
if self.reasoning_model:
|
|
453
|
-
if temperature != 0.0:
|
|
454
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
455
|
-
|
|
456
|
-
response = await self.async_client.chat.completions.create(
|
|
457
|
-
model=self.model,
|
|
458
|
-
messages=messages,
|
|
459
|
-
max_completion_tokens=max_new_tokens,
|
|
460
|
-
stream=False,
|
|
461
|
-
**kwrs
|
|
462
|
-
)
|
|
463
|
-
else:
|
|
464
|
-
response = await self.async_client.chat.completions.create(
|
|
465
|
-
model=self.model,
|
|
466
|
-
messages=messages,
|
|
467
|
-
max_tokens=max_new_tokens,
|
|
468
|
-
temperature=temperature,
|
|
469
|
-
stream=False,
|
|
470
|
-
**kwrs
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
if response.choices[0].finish_reason == "length":
|
|
474
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
475
|
-
if self.reasoning_model:
|
|
476
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
477
|
-
|
|
478
|
-
return response.choices[0].message.content
|
|
771
|
+
self.config = config if config else BasicLLMConfig()
|
|
772
|
+
self.formatted_params = self._format_config()
|
|
479
773
|
|
|
480
774
|
|
|
481
775
|
class LiteLLMInferenceEngine(InferenceEngine):
|
|
482
|
-
def __init__(self, model:str=None, base_url:str=None, api_key:str=None):
|
|
776
|
+
def __init__(self, model:str=None, base_url:str=None, api_key:str=None, config:LLMConfig=None):
|
|
483
777
|
"""
|
|
484
778
|
The LiteLLM inference engine.
|
|
485
779
|
For parameters and documentation, refer to https://github.com/BerriAI/litellm?tab=readme-ov-file
|
|
@@ -492,6 +786,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
492
786
|
the base url for the LLM server
|
|
493
787
|
api_key : str, Optional
|
|
494
788
|
the API key for the LLM server
|
|
789
|
+
config : LLMConfig
|
|
790
|
+
the LLM configuration.
|
|
495
791
|
"""
|
|
496
792
|
if importlib.util.find_spec("litellm") is None:
|
|
497
793
|
raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
|
|
@@ -501,56 +797,98 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
501
797
|
self.model = model
|
|
502
798
|
self.base_url = base_url
|
|
503
799
|
self.api_key = api_key
|
|
800
|
+
self.config = config if config else BasicLLMConfig()
|
|
801
|
+
self.formatted_params = self._format_config()
|
|
504
802
|
|
|
505
|
-
def
|
|
803
|
+
def _format_config(self) -> Dict[str, Any]:
|
|
804
|
+
"""
|
|
805
|
+
This method format the LLM configuration with the correct key for the inference engine.
|
|
806
|
+
"""
|
|
807
|
+
formatted_params = self.config.params.copy()
|
|
808
|
+
if "max_new_tokens" in formatted_params:
|
|
809
|
+
formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
|
|
810
|
+
formatted_params.pop("max_new_tokens")
|
|
811
|
+
|
|
812
|
+
return formatted_params
|
|
813
|
+
|
|
814
|
+
def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
|
|
506
815
|
"""
|
|
507
816
|
This method inputs chat messages and outputs LLM generated text.
|
|
508
817
|
|
|
509
818
|
Parameters:
|
|
510
819
|
----------
|
|
511
820
|
messages : List[Dict[str,str]]
|
|
512
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
temperature : float, Optional
|
|
516
|
-
the temperature for token sampling.
|
|
821
|
+
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
822
|
+
verbose : bool, Optional
|
|
823
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
517
824
|
stream : bool, Optional
|
|
518
|
-
if True,
|
|
825
|
+
if True, returns a generator that yields the output in real-time.
|
|
519
826
|
"""
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
messages=messages,
|
|
523
|
-
max_tokens=max_new_tokens,
|
|
524
|
-
temperature=temperature,
|
|
525
|
-
stream=stream,
|
|
526
|
-
base_url=self.base_url,
|
|
527
|
-
api_key=self.api_key,
|
|
528
|
-
**kwrs
|
|
529
|
-
)
|
|
530
|
-
|
|
827
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
828
|
+
|
|
531
829
|
if stream:
|
|
830
|
+
def _stream_generator():
|
|
831
|
+
response_stream = self.litellm.completion(
|
|
832
|
+
model=self.model,
|
|
833
|
+
messages=processed_messages,
|
|
834
|
+
stream=True,
|
|
835
|
+
base_url=self.base_url,
|
|
836
|
+
api_key=self.api_key,
|
|
837
|
+
**self.formatted_params
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
for chunk in response_stream:
|
|
841
|
+
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
842
|
+
if chunk_content:
|
|
843
|
+
yield chunk_content
|
|
844
|
+
|
|
845
|
+
return self.config.postprocess_response(_stream_generator())
|
|
846
|
+
|
|
847
|
+
elif verbose:
|
|
848
|
+
response = self.litellm.completion(
|
|
849
|
+
model=self.model,
|
|
850
|
+
messages=processed_messages,
|
|
851
|
+
stream=True,
|
|
852
|
+
base_url=self.base_url,
|
|
853
|
+
api_key=self.api_key,
|
|
854
|
+
**self.formatted_params
|
|
855
|
+
)
|
|
856
|
+
|
|
532
857
|
res = ''
|
|
533
858
|
for chunk in response:
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
859
|
+
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
860
|
+
if chunk_content:
|
|
861
|
+
res += chunk_content
|
|
862
|
+
print(chunk_content, end='', flush=True)
|
|
863
|
+
|
|
864
|
+
return self.config.postprocess_response(res)
|
|
538
865
|
|
|
539
|
-
|
|
866
|
+
else:
|
|
867
|
+
response = self.litellm.completion(
|
|
868
|
+
model=self.model,
|
|
869
|
+
messages=processed_messages,
|
|
870
|
+
stream=False,
|
|
871
|
+
base_url=self.base_url,
|
|
872
|
+
api_key=self.api_key,
|
|
873
|
+
**self.formatted_params
|
|
874
|
+
)
|
|
875
|
+
res = response.choices[0].message.content
|
|
876
|
+
return self.config.postprocess_response(res)
|
|
540
877
|
|
|
541
|
-
async def chat_async(self, messages:List[Dict[str,str]]
|
|
878
|
+
async def chat_async(self, messages:List[Dict[str,str]]) -> str:
|
|
542
879
|
"""
|
|
543
880
|
Async version of chat method. Streaming is not supported.
|
|
544
881
|
"""
|
|
882
|
+
processed_messages = self.config.preprocess_messages(messages)
|
|
883
|
+
|
|
545
884
|
response = await self.litellm.acompletion(
|
|
546
885
|
model=self.model,
|
|
547
|
-
messages=
|
|
548
|
-
max_tokens=max_new_tokens,
|
|
549
|
-
temperature=temperature,
|
|
886
|
+
messages=processed_messages,
|
|
550
887
|
stream=False,
|
|
551
888
|
base_url=self.base_url,
|
|
552
889
|
api_key=self.api_key,
|
|
553
|
-
**
|
|
890
|
+
**self.formatted_params
|
|
554
891
|
)
|
|
555
892
|
|
|
556
|
-
|
|
893
|
+
res = response.get('choices')[0].get('message').get('content')
|
|
894
|
+
return self.config.postprocess_response(res)
|