llm-ie 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_ie/__init__.py +4 -2
- llm_ie/asset/default_prompts/BasicReviewFrameExtractor_addition_review_prompt.txt +3 -0
- llm_ie/asset/default_prompts/BasicReviewFrameExtractor_revision_review_prompt.txt +2 -0
- llm_ie/asset/default_prompts/ReviewFrameExtractor_addition_review_prompt.txt +2 -1
- llm_ie/asset/default_prompts/ReviewFrameExtractor_revision_review_prompt.txt +2 -1
- llm_ie/asset/prompt_guide/BasicFrameExtractor_prompt_guide.txt +104 -86
- llm_ie/asset/prompt_guide/BasicReviewFrameExtractor_prompt_guide.txt +163 -0
- llm_ie/asset/prompt_guide/DirectFrameExtractor_prompt_guide.txt +163 -0
- llm_ie/asset/prompt_guide/ReviewFrameExtractor_prompt_guide.txt +103 -85
- llm_ie/asset/prompt_guide/SentenceFrameExtractor_prompt_guide.txt +103 -86
- llm_ie/asset/prompt_guide/SentenceReviewFrameExtractor_prompt_guide.txt +103 -86
- llm_ie/chunkers.py +191 -0
- llm_ie/data_types.py +75 -1
- llm_ie/engines.py +274 -183
- llm_ie/extractors.py +1062 -727
- llm_ie/prompt_editor.py +39 -6
- llm_ie-1.0.0.dist-info/METADATA +18 -0
- llm_ie-1.0.0.dist-info/RECORD +27 -0
- llm_ie/asset/prompt_guide/SentenceCoTFrameExtractor_prompt_guide.txt +0 -217
- llm_ie-0.4.6.dist-info/METADATA +0 -1215
- llm_ie-0.4.6.dist-info/RECORD +0 -23
- {llm_ie-0.4.6.dist-info → llm_ie-1.0.0.dist-info}/WHEEL +0 -0
llm_ie/engines.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import warnings
|
|
3
3
|
import importlib
|
|
4
|
-
from typing import List, Dict, Union
|
|
4
|
+
from typing import List, Dict, Union, Generator
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class InferenceEngine:
|
|
@@ -15,7 +15,8 @@ class InferenceEngine:
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@abc.abstractmethod
|
|
18
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
18
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
19
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
19
20
|
"""
|
|
20
21
|
This method inputs chat messages and outputs LLM generated text.
|
|
21
22
|
|
|
@@ -27,8 +28,10 @@ class InferenceEngine:
|
|
|
27
28
|
the max number of new tokens LLM can generate.
|
|
28
29
|
temperature : float, Optional
|
|
29
30
|
the temperature for token sampling.
|
|
31
|
+
verbose : bool, Optional
|
|
32
|
+
if True, LLM generated text will be printed in terminal in real-time.
|
|
30
33
|
stream : bool, Optional
|
|
31
|
-
if True,
|
|
34
|
+
if True, returns a generator that yields the output in real-time.
|
|
32
35
|
"""
|
|
33
36
|
return NotImplemented
|
|
34
37
|
|
|
@@ -71,7 +74,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
71
74
|
del self.model
|
|
72
75
|
|
|
73
76
|
|
|
74
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
77
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, verbose:bool=False, **kwrs) -> str:
|
|
75
78
|
"""
|
|
76
79
|
This method inputs chat messages and outputs LLM generated text.
|
|
77
80
|
|
|
@@ -83,18 +86,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
83
86
|
the max number of new tokens LLM can generate.
|
|
84
87
|
temperature : float, Optional
|
|
85
88
|
the temperature for token sampling.
|
|
86
|
-
|
|
89
|
+
verbose : bool, Optional
|
|
87
90
|
if True, LLM generated text will be printed in terminal in real-time.
|
|
88
91
|
"""
|
|
89
92
|
response = self.model.create_chat_completion(
|
|
90
93
|
messages=messages,
|
|
91
94
|
max_tokens=max_new_tokens,
|
|
92
95
|
temperature=temperature,
|
|
93
|
-
stream=
|
|
96
|
+
stream=verbose,
|
|
94
97
|
**kwrs
|
|
95
98
|
)
|
|
96
99
|
|
|
97
|
-
if
|
|
100
|
+
if verbose:
|
|
98
101
|
res = ''
|
|
99
102
|
for chunk in response:
|
|
100
103
|
out_dict = chunk['choices'][0]['delta']
|
|
@@ -107,9 +110,6 @@ class LlamaCppInferenceEngine(InferenceEngine):
|
|
|
107
110
|
return response['choices'][0]['message']['content']
|
|
108
111
|
|
|
109
112
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
113
|
class OllamaInferenceEngine(InferenceEngine):
|
|
114
114
|
def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
|
|
115
115
|
"""
|
|
@@ -133,39 +133,68 @@ class OllamaInferenceEngine(InferenceEngine):
|
|
|
133
133
|
self.model_name = model_name
|
|
134
134
|
self.num_ctx = num_ctx
|
|
135
135
|
self.keep_alive = keep_alive
|
|
136
|
-
|
|
137
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
136
|
+
|
|
137
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
138
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
138
139
|
"""
|
|
139
|
-
This method inputs chat messages and outputs
|
|
140
|
+
This method inputs chat messages and outputs VLM generated text.
|
|
140
141
|
|
|
141
142
|
Parameters:
|
|
142
143
|
----------
|
|
143
144
|
messages : List[Dict[str,str]]
|
|
144
145
|
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
145
146
|
max_new_tokens : str, Optional
|
|
146
|
-
the max number of new tokens
|
|
147
|
+
the max number of new tokens VLM can generate.
|
|
147
148
|
temperature : float, Optional
|
|
148
149
|
the temperature for token sampling.
|
|
150
|
+
verbose : bool, Optional
|
|
151
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
149
152
|
stream : bool, Optional
|
|
150
|
-
if True,
|
|
153
|
+
if True, returns a generator that yields the output in real-time.
|
|
151
154
|
"""
|
|
152
|
-
|
|
155
|
+
options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs}
|
|
156
|
+
if stream:
|
|
157
|
+
def _stream_generator():
|
|
158
|
+
response_stream = self.client.chat(
|
|
159
|
+
model=self.model_name,
|
|
160
|
+
messages=messages,
|
|
161
|
+
options=options,
|
|
162
|
+
stream=True,
|
|
163
|
+
keep_alive=self.keep_alive
|
|
164
|
+
)
|
|
165
|
+
for chunk in response_stream:
|
|
166
|
+
content_chunk = chunk.get('message', {}).get('content')
|
|
167
|
+
if content_chunk:
|
|
168
|
+
yield content_chunk
|
|
169
|
+
|
|
170
|
+
return _stream_generator()
|
|
171
|
+
|
|
172
|
+
elif verbose:
|
|
173
|
+
response = self.client.chat(
|
|
153
174
|
model=self.model_name,
|
|
154
175
|
messages=messages,
|
|
155
|
-
options=
|
|
156
|
-
stream=
|
|
176
|
+
options=options,
|
|
177
|
+
stream=True,
|
|
157
178
|
keep_alive=self.keep_alive
|
|
158
179
|
)
|
|
159
|
-
|
|
180
|
+
|
|
160
181
|
res = ''
|
|
161
182
|
for chunk in response:
|
|
162
|
-
|
|
163
|
-
print(
|
|
183
|
+
content_chunk = chunk.get('message', {}).get('content')
|
|
184
|
+
print(content_chunk, end='', flush=True)
|
|
185
|
+
res += content_chunk
|
|
164
186
|
print('\n')
|
|
165
187
|
return res
|
|
166
188
|
|
|
167
|
-
|
|
168
|
-
|
|
189
|
+
else:
|
|
190
|
+
response = self.client.chat(
|
|
191
|
+
model=self.model_name,
|
|
192
|
+
messages=messages,
|
|
193
|
+
options=options,
|
|
194
|
+
stream=False,
|
|
195
|
+
keep_alive=self.keep_alive
|
|
196
|
+
)
|
|
197
|
+
return response.get('message', {}).get('content')
|
|
169
198
|
|
|
170
199
|
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
|
|
171
200
|
"""
|
|
@@ -195,7 +224,8 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
195
224
|
self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
196
225
|
self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
|
|
197
226
|
|
|
198
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
227
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
228
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
199
229
|
"""
|
|
200
230
|
This method inputs chat messages and outputs LLM generated text.
|
|
201
231
|
|
|
@@ -207,25 +237,53 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
|
|
|
207
237
|
the max number of new tokens LLM can generate.
|
|
208
238
|
temperature : float, Optional
|
|
209
239
|
the temperature for token sampling.
|
|
240
|
+
verbose : bool, Optional
|
|
241
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
210
242
|
stream : bool, Optional
|
|
211
|
-
if True,
|
|
243
|
+
if True, returns a generator that yields the output in real-time.
|
|
212
244
|
"""
|
|
213
|
-
response = self.client.chat.completions.create(
|
|
214
|
-
messages=messages,
|
|
215
|
-
max_tokens=max_new_tokens,
|
|
216
|
-
temperature=temperature,
|
|
217
|
-
stream=stream,
|
|
218
|
-
**kwrs
|
|
219
|
-
)
|
|
220
|
-
|
|
221
245
|
if stream:
|
|
246
|
+
def _stream_generator():
|
|
247
|
+
response_stream = self.client.chat.completions.create(
|
|
248
|
+
messages=messages,
|
|
249
|
+
max_tokens=max_new_tokens,
|
|
250
|
+
temperature=temperature,
|
|
251
|
+
stream=True,
|
|
252
|
+
**kwrs
|
|
253
|
+
)
|
|
254
|
+
for chunk in response_stream:
|
|
255
|
+
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
256
|
+
if content_chunk:
|
|
257
|
+
yield content_chunk
|
|
258
|
+
|
|
259
|
+
return _stream_generator()
|
|
260
|
+
|
|
261
|
+
elif verbose:
|
|
262
|
+
response = self.client.chat.completions.create(
|
|
263
|
+
messages=messages,
|
|
264
|
+
max_tokens=max_new_tokens,
|
|
265
|
+
temperature=temperature,
|
|
266
|
+
stream=True,
|
|
267
|
+
**kwrs
|
|
268
|
+
)
|
|
269
|
+
|
|
222
270
|
res = ''
|
|
223
271
|
for chunk in response:
|
|
224
|
-
|
|
225
|
-
|
|
272
|
+
content_chunk = chunk.get('choices')[0].get('delta').get('content')
|
|
273
|
+
if content_chunk:
|
|
274
|
+
res += content_chunk
|
|
275
|
+
print(content_chunk, end='', flush=True)
|
|
226
276
|
return res
|
|
227
277
|
|
|
228
|
-
|
|
278
|
+
else:
|
|
279
|
+
response = self.client.chat.completions.create(
|
|
280
|
+
messages=messages,
|
|
281
|
+
max_tokens=max_new_tokens,
|
|
282
|
+
temperature=temperature,
|
|
283
|
+
stream=False,
|
|
284
|
+
**kwrs
|
|
285
|
+
)
|
|
286
|
+
return response.choices[0].message.content
|
|
229
287
|
|
|
230
288
|
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
|
|
231
289
|
"""
|
|
@@ -267,7 +325,8 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
267
325
|
self.model = model
|
|
268
326
|
self.reasoning_model = reasoning_model
|
|
269
327
|
|
|
270
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
328
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
329
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
271
330
|
"""
|
|
272
331
|
This method inputs chat messages and outputs LLM generated text.
|
|
273
332
|
|
|
@@ -279,60 +338,145 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
279
338
|
the max number of new tokens LLM can generate.
|
|
280
339
|
temperature : float, Optional
|
|
281
340
|
the temperature for token sampling.
|
|
341
|
+
verbose : bool, Optional
|
|
342
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
282
343
|
stream : bool, Optional
|
|
283
|
-
if True,
|
|
344
|
+
if True, returns a generator that yields the output in real-time.
|
|
284
345
|
"""
|
|
346
|
+
# For reasoning models
|
|
285
347
|
if self.reasoning_model:
|
|
348
|
+
# Reasoning models do not support temperature parameter
|
|
286
349
|
if temperature != 0.0:
|
|
287
350
|
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
351
|
+
|
|
352
|
+
# Reasoning models do not support system prompts
|
|
353
|
+
if any(msg['role'] == 'system' for msg in messages):
|
|
354
|
+
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
355
|
+
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
288
356
|
|
|
289
|
-
response = self.client.chat.completions.create(
|
|
290
|
-
model=self.model,
|
|
291
|
-
messages=messages,
|
|
292
|
-
max_completion_tokens=max_new_tokens,
|
|
293
|
-
stream=stream,
|
|
294
|
-
**kwrs
|
|
295
|
-
)
|
|
296
357
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
358
|
+
if stream:
|
|
359
|
+
def _stream_generator():
|
|
360
|
+
response_stream = self.client.chat.completions.create(
|
|
361
|
+
model=self.model,
|
|
362
|
+
messages=messages,
|
|
363
|
+
max_completion_tokens=max_new_tokens,
|
|
364
|
+
stream=True,
|
|
365
|
+
**kwrs
|
|
366
|
+
)
|
|
367
|
+
for chunk in response_stream:
|
|
368
|
+
if len(chunk.choices) > 0:
|
|
369
|
+
if chunk.choices[0].delta.content is not None:
|
|
370
|
+
yield chunk.choices[0].delta.content
|
|
371
|
+
if chunk.choices[0].finish_reason == "length":
|
|
372
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
373
|
+
if self.reasoning_model:
|
|
374
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
375
|
+
return _stream_generator()
|
|
376
|
+
|
|
377
|
+
elif verbose:
|
|
378
|
+
response = self.client.chat.completions.create(
|
|
379
|
+
model=self.model,
|
|
380
|
+
messages=messages,
|
|
381
|
+
max_completion_tokens=max_new_tokens,
|
|
382
|
+
stream=True,
|
|
383
|
+
**kwrs
|
|
384
|
+
)
|
|
385
|
+
res = ''
|
|
386
|
+
for chunk in response:
|
|
387
|
+
if len(chunk.choices) > 0:
|
|
388
|
+
if chunk.choices[0].delta.content is not None:
|
|
389
|
+
res += chunk.choices[0].delta.content
|
|
390
|
+
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
391
|
+
if chunk.choices[0].finish_reason == "length":
|
|
392
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
393
|
+
if self.reasoning_model:
|
|
394
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
395
|
+
|
|
396
|
+
print('\n')
|
|
397
|
+
return res
|
|
398
|
+
else:
|
|
399
|
+
response = self.client.chat.completions.create(
|
|
400
|
+
model=self.model,
|
|
401
|
+
messages=messages,
|
|
402
|
+
max_completion_tokens=max_new_tokens,
|
|
403
|
+
stream=False,
|
|
404
|
+
**kwrs
|
|
405
|
+
)
|
|
406
|
+
return response.choices[0].message.content
|
|
306
407
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
408
|
+
# For non-reasoning models
|
|
409
|
+
else:
|
|
410
|
+
if stream:
|
|
411
|
+
def _stream_generator():
|
|
412
|
+
response_stream = self.client.chat.completions.create(
|
|
413
|
+
model=self.model,
|
|
414
|
+
messages=messages,
|
|
415
|
+
max_tokens=max_new_tokens,
|
|
416
|
+
temperature=temperature,
|
|
417
|
+
stream=True,
|
|
418
|
+
**kwrs
|
|
419
|
+
)
|
|
420
|
+
for chunk in response_stream:
|
|
421
|
+
if len(chunk.choices) > 0:
|
|
422
|
+
if chunk.choices[0].delta.content is not None:
|
|
423
|
+
yield chunk.choices[0].delta.content
|
|
424
|
+
if chunk.choices[0].finish_reason == "length":
|
|
425
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
426
|
+
if self.reasoning_model:
|
|
427
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
428
|
+
return _stream_generator()
|
|
324
429
|
|
|
325
|
-
|
|
430
|
+
elif verbose:
|
|
431
|
+
response = self.client.chat.completions.create(
|
|
432
|
+
model=self.model,
|
|
433
|
+
messages=messages,
|
|
434
|
+
max_tokens=max_new_tokens,
|
|
435
|
+
temperature=temperature,
|
|
436
|
+
stream=True,
|
|
437
|
+
**kwrs
|
|
438
|
+
)
|
|
439
|
+
res = ''
|
|
440
|
+
for chunk in response:
|
|
441
|
+
if len(chunk.choices) > 0:
|
|
442
|
+
if chunk.choices[0].delta.content is not None:
|
|
443
|
+
res += chunk.choices[0].delta.content
|
|
444
|
+
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
445
|
+
if chunk.choices[0].finish_reason == "length":
|
|
446
|
+
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
447
|
+
if self.reasoning_model:
|
|
448
|
+
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
449
|
+
|
|
450
|
+
print('\n')
|
|
451
|
+
return res
|
|
452
|
+
|
|
453
|
+
else:
|
|
454
|
+
response = self.client.chat.completions.create(
|
|
455
|
+
model=self.model,
|
|
456
|
+
messages=messages,
|
|
457
|
+
max_tokens=max_new_tokens,
|
|
458
|
+
temperature=temperature,
|
|
459
|
+
stream=False,
|
|
460
|
+
**kwrs
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
return response.choices[0].message.content
|
|
326
464
|
|
|
327
465
|
|
|
328
|
-
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=
|
|
466
|
+
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
|
|
329
467
|
"""
|
|
330
468
|
Async version of chat method. Streaming is not supported.
|
|
331
469
|
"""
|
|
332
470
|
if self.reasoning_model:
|
|
471
|
+
# Reasoning models do not support temperature parameter
|
|
333
472
|
if temperature != 0.0:
|
|
334
473
|
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
335
|
-
|
|
474
|
+
|
|
475
|
+
# Reasoning models do not support system prompts
|
|
476
|
+
if any(msg['role'] == 'system' for msg in messages):
|
|
477
|
+
warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
|
|
478
|
+
messages = [msg for msg in messages if msg['role'] != 'system']
|
|
479
|
+
|
|
336
480
|
response = await self.async_client.chat.completions.create(
|
|
337
481
|
model=self.model,
|
|
338
482
|
messages=messages,
|
|
@@ -340,6 +484,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
340
484
|
stream=False,
|
|
341
485
|
**kwrs
|
|
342
486
|
)
|
|
487
|
+
|
|
343
488
|
else:
|
|
344
489
|
response = await self.async_client.chat.completions.create(
|
|
345
490
|
model=self.model,
|
|
@@ -358,7 +503,7 @@ class OpenAIInferenceEngine(InferenceEngine):
|
|
|
358
503
|
return response.choices[0].message.content
|
|
359
504
|
|
|
360
505
|
|
|
361
|
-
class AzureOpenAIInferenceEngine(
|
|
506
|
+
class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
|
|
362
507
|
def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
|
|
363
508
|
"""
|
|
364
509
|
The Azure OpenAI API inference engine.
|
|
@@ -387,96 +532,6 @@ class AzureOpenAIInferenceEngine(InferenceEngine):
|
|
|
387
532
|
**kwrs)
|
|
388
533
|
self.reasoning_model = reasoning_model
|
|
389
534
|
|
|
390
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
|
|
391
|
-
"""
|
|
392
|
-
This method inputs chat messages and outputs LLM generated text.
|
|
393
|
-
|
|
394
|
-
Parameters:
|
|
395
|
-
----------
|
|
396
|
-
messages : List[Dict[str,str]]
|
|
397
|
-
a list of dict with role and content. role must be one of {"system", "user", "assistant"}
|
|
398
|
-
max_new_tokens : str, Optional
|
|
399
|
-
the max number of new tokens LLM can generate.
|
|
400
|
-
temperature : float, Optional
|
|
401
|
-
the temperature for token sampling.
|
|
402
|
-
stream : bool, Optional
|
|
403
|
-
if True, LLM generated text will be printed in terminal in real-time.
|
|
404
|
-
"""
|
|
405
|
-
if self.reasoning_model:
|
|
406
|
-
if temperature != 0.0:
|
|
407
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
408
|
-
|
|
409
|
-
response = self.client.chat.completions.create(
|
|
410
|
-
model=self.model,
|
|
411
|
-
messages=messages,
|
|
412
|
-
max_completion_tokens=max_new_tokens,
|
|
413
|
-
stream=stream,
|
|
414
|
-
**kwrs
|
|
415
|
-
)
|
|
416
|
-
|
|
417
|
-
else:
|
|
418
|
-
response = self.client.chat.completions.create(
|
|
419
|
-
model=self.model,
|
|
420
|
-
messages=messages,
|
|
421
|
-
max_tokens=max_new_tokens,
|
|
422
|
-
temperature=temperature,
|
|
423
|
-
stream=stream,
|
|
424
|
-
**kwrs
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
if stream:
|
|
428
|
-
res = ''
|
|
429
|
-
for chunk in response:
|
|
430
|
-
if len(chunk.choices) > 0:
|
|
431
|
-
if chunk.choices[0].delta.content is not None:
|
|
432
|
-
res += chunk.choices[0].delta.content
|
|
433
|
-
print(chunk.choices[0].delta.content, end="", flush=True)
|
|
434
|
-
if chunk.choices[0].finish_reason == "length":
|
|
435
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
436
|
-
if self.reasoning_model:
|
|
437
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
438
|
-
return res
|
|
439
|
-
|
|
440
|
-
if response.choices[0].finish_reason == "length":
|
|
441
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
442
|
-
if self.reasoning_model:
|
|
443
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
444
|
-
|
|
445
|
-
return response.choices[0].message.content
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
|
|
449
|
-
"""
|
|
450
|
-
Async version of chat method. Streaming is not supported.
|
|
451
|
-
"""
|
|
452
|
-
if self.reasoning_model:
|
|
453
|
-
if temperature != 0.0:
|
|
454
|
-
warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
|
|
455
|
-
|
|
456
|
-
response = await self.async_client.chat.completions.create(
|
|
457
|
-
model=self.model,
|
|
458
|
-
messages=messages,
|
|
459
|
-
max_completion_tokens=max_new_tokens,
|
|
460
|
-
stream=False,
|
|
461
|
-
**kwrs
|
|
462
|
-
)
|
|
463
|
-
else:
|
|
464
|
-
response = await self.async_client.chat.completions.create(
|
|
465
|
-
model=self.model,
|
|
466
|
-
messages=messages,
|
|
467
|
-
max_tokens=max_new_tokens,
|
|
468
|
-
temperature=temperature,
|
|
469
|
-
stream=False,
|
|
470
|
-
**kwrs
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
if response.choices[0].finish_reason == "length":
|
|
474
|
-
warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
|
|
475
|
-
if self.reasoning_model:
|
|
476
|
-
warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
|
|
477
|
-
|
|
478
|
-
return response.choices[0].message.content
|
|
479
|
-
|
|
480
535
|
|
|
481
536
|
class LiteLLMInferenceEngine(InferenceEngine):
|
|
482
537
|
def __init__(self, model:str=None, base_url:str=None, api_key:str=None):
|
|
@@ -502,7 +557,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
502
557
|
self.base_url = base_url
|
|
503
558
|
self.api_key = api_key
|
|
504
559
|
|
|
505
|
-
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
560
|
+
def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
|
|
561
|
+
verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
|
|
506
562
|
"""
|
|
507
563
|
This method inputs chat messages and outputs LLM generated text.
|
|
508
564
|
|
|
@@ -514,29 +570,64 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
514
570
|
the max number of new tokens LLM can generate.
|
|
515
571
|
temperature : float, Optional
|
|
516
572
|
the temperature for token sampling.
|
|
573
|
+
verbose : bool, Optional
|
|
574
|
+
if True, VLM generated text will be printed in terminal in real-time.
|
|
517
575
|
stream : bool, Optional
|
|
518
|
-
if True,
|
|
576
|
+
if True, returns a generator that yields the output in real-time.
|
|
519
577
|
"""
|
|
520
|
-
response = self.litellm.completion(
|
|
521
|
-
model=self.model,
|
|
522
|
-
messages=messages,
|
|
523
|
-
max_tokens=max_new_tokens,
|
|
524
|
-
temperature=temperature,
|
|
525
|
-
stream=stream,
|
|
526
|
-
base_url=self.base_url,
|
|
527
|
-
api_key=self.api_key,
|
|
528
|
-
**kwrs
|
|
529
|
-
)
|
|
530
|
-
|
|
531
578
|
if stream:
|
|
579
|
+
def _stream_generator():
|
|
580
|
+
response_stream = self.litellm.completion(
|
|
581
|
+
model=self.model,
|
|
582
|
+
messages=messages,
|
|
583
|
+
max_tokens=max_new_tokens,
|
|
584
|
+
temperature=temperature,
|
|
585
|
+
stream=True,
|
|
586
|
+
base_url=self.base_url,
|
|
587
|
+
api_key=self.api_key,
|
|
588
|
+
**kwrs
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
for chunk in response_stream:
|
|
592
|
+
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
593
|
+
if chunk_content:
|
|
594
|
+
yield chunk_content
|
|
595
|
+
|
|
596
|
+
return _stream_generator()
|
|
597
|
+
|
|
598
|
+
elif verbose:
|
|
599
|
+
response = self.litellm.completion(
|
|
600
|
+
model=self.model,
|
|
601
|
+
messages=messages,
|
|
602
|
+
max_tokens=max_new_tokens,
|
|
603
|
+
temperature=temperature,
|
|
604
|
+
stream=True,
|
|
605
|
+
base_url=self.base_url,
|
|
606
|
+
api_key=self.api_key,
|
|
607
|
+
**kwrs
|
|
608
|
+
)
|
|
609
|
+
|
|
532
610
|
res = ''
|
|
533
611
|
for chunk in response:
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
612
|
+
chunk_content = chunk.get('choices')[0].get('delta').get('content')
|
|
613
|
+
if chunk_content:
|
|
614
|
+
res += chunk_content
|
|
615
|
+
print(chunk_content, end='', flush=True)
|
|
616
|
+
|
|
537
617
|
return res
|
|
538
618
|
|
|
539
|
-
|
|
619
|
+
else:
|
|
620
|
+
response = self.litellm.completion(
|
|
621
|
+
model=self.model,
|
|
622
|
+
messages=messages,
|
|
623
|
+
max_tokens=max_new_tokens,
|
|
624
|
+
temperature=temperature,
|
|
625
|
+
stream=False,
|
|
626
|
+
base_url=self.base_url,
|
|
627
|
+
api_key=self.api_key,
|
|
628
|
+
**kwrs
|
|
629
|
+
)
|
|
630
|
+
return response.choices[0].message.content
|
|
540
631
|
|
|
541
632
|
async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
|
|
542
633
|
"""
|
|
@@ -553,4 +644,4 @@ class LiteLLMInferenceEngine(InferenceEngine):
|
|
|
553
644
|
**kwrs
|
|
554
645
|
)
|
|
555
646
|
|
|
556
|
-
return response.choices[0].message.content
|
|
647
|
+
return response.get('choices')[0].get('message').get('content')
|