llm-ie 0.4.6__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import abc
2
2
  import warnings
3
3
  import importlib
4
- from typing import List, Dict, Union
4
+ from typing import List, Dict, Union, Generator
5
5
 
6
6
 
7
7
  class InferenceEngine:
@@ -15,7 +15,8 @@ class InferenceEngine:
15
15
 
16
16
 
17
17
  @abc.abstractmethod
18
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
18
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
19
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
19
20
  """
20
21
  This method inputs chat messages and outputs LLM generated text.
21
22
 
@@ -27,8 +28,10 @@ class InferenceEngine:
27
28
  the max number of new tokens LLM can generate.
28
29
  temperature : float, Optional
29
30
  the temperature for token sampling.
31
+ verbose : bool, Optional
32
+ if True, LLM generated text will be printed in terminal in real-time.
30
33
  stream : bool, Optional
31
- if True, LLM generated text will be printed in terminal in real-time.
34
+ if True, returns a generator that yields the output in real-time.
32
35
  """
33
36
  return NotImplemented
34
37
 
@@ -71,7 +74,7 @@ class LlamaCppInferenceEngine(InferenceEngine):
71
74
  del self.model
72
75
 
73
76
 
74
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
77
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, verbose:bool=False, **kwrs) -> str:
75
78
  """
76
79
  This method inputs chat messages and outputs LLM generated text.
77
80
 
@@ -83,18 +86,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
83
86
  the max number of new tokens LLM can generate.
84
87
  temperature : float, Optional
85
88
  the temperature for token sampling.
86
- stream : bool, Optional
89
+ verbose : bool, Optional
87
90
  if True, LLM generated text will be printed in terminal in real-time.
88
91
  """
89
92
  response = self.model.create_chat_completion(
90
93
  messages=messages,
91
94
  max_tokens=max_new_tokens,
92
95
  temperature=temperature,
93
- stream=stream,
96
+ stream=verbose,
94
97
  **kwrs
95
98
  )
96
99
 
97
- if stream:
100
+ if verbose:
98
101
  res = ''
99
102
  for chunk in response:
100
103
  out_dict = chunk['choices'][0]['delta']
@@ -107,9 +110,6 @@ class LlamaCppInferenceEngine(InferenceEngine):
107
110
  return response['choices'][0]['message']['content']
108
111
 
109
112
 
110
-
111
-
112
-
113
113
  class OllamaInferenceEngine(InferenceEngine):
114
114
  def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
115
115
  """
@@ -133,39 +133,68 @@ class OllamaInferenceEngine(InferenceEngine):
133
133
  self.model_name = model_name
134
134
  self.num_ctx = num_ctx
135
135
  self.keep_alive = keep_alive
136
-
137
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
136
+
137
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
138
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
138
139
  """
139
- This method inputs chat messages and outputs LLM generated text.
140
+ This method inputs chat messages and outputs VLM generated text.
140
141
 
141
142
  Parameters:
142
143
  ----------
143
144
  messages : List[Dict[str,str]]
144
145
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
145
146
  max_new_tokens : str, Optional
146
- the max number of new tokens LLM can generate.
147
+ the max number of new tokens VLM can generate.
147
148
  temperature : float, Optional
148
149
  the temperature for token sampling.
150
+ verbose : bool, Optional
151
+ if True, VLM generated text will be printed in terminal in real-time.
149
152
  stream : bool, Optional
150
- if True, LLM generated text will be printed in terminal in real-time.
153
+ if True, returns a generator that yields the output in real-time.
151
154
  """
152
- response = self.client.chat(
155
+ options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs}
156
+ if stream:
157
+ def _stream_generator():
158
+ response_stream = self.client.chat(
159
+ model=self.model_name,
160
+ messages=messages,
161
+ options=options,
162
+ stream=True,
163
+ keep_alive=self.keep_alive
164
+ )
165
+ for chunk in response_stream:
166
+ content_chunk = chunk.get('message', {}).get('content')
167
+ if content_chunk:
168
+ yield content_chunk
169
+
170
+ return _stream_generator()
171
+
172
+ elif verbose:
173
+ response = self.client.chat(
153
174
  model=self.model_name,
154
175
  messages=messages,
155
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
156
- stream=stream,
176
+ options=options,
177
+ stream=True,
157
178
  keep_alive=self.keep_alive
158
179
  )
159
- if stream:
180
+
160
181
  res = ''
161
182
  for chunk in response:
162
- res += chunk['message']['content']
163
- print(chunk['message']['content'], end='', flush=True)
183
+ content_chunk = chunk.get('message', {}).get('content')
184
+ print(content_chunk, end='', flush=True)
185
+ res += content_chunk
164
186
  print('\n')
165
187
  return res
166
188
 
167
- return response['message']['content']
168
-
189
+ else:
190
+ response = self.client.chat(
191
+ model=self.model_name,
192
+ messages=messages,
193
+ options=options,
194
+ stream=False,
195
+ keep_alive=self.keep_alive
196
+ )
197
+ return response.get('message', {}).get('content')
169
198
 
170
199
  async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
171
200
  """
@@ -195,7 +224,8 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
195
224
  self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
196
225
  self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
197
226
 
198
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
227
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
228
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
199
229
  """
200
230
  This method inputs chat messages and outputs LLM generated text.
201
231
 
@@ -207,25 +237,53 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
207
237
  the max number of new tokens LLM can generate.
208
238
  temperature : float, Optional
209
239
  the temperature for token sampling.
240
+ verbose : bool, Optional
241
+ if True, VLM generated text will be printed in terminal in real-time.
210
242
  stream : bool, Optional
211
- if True, LLM generated text will be printed in terminal in real-time.
243
+ if True, returns a generator that yields the output in real-time.
212
244
  """
213
- response = self.client.chat.completions.create(
214
- messages=messages,
215
- max_tokens=max_new_tokens,
216
- temperature=temperature,
217
- stream=stream,
218
- **kwrs
219
- )
220
-
221
245
  if stream:
246
+ def _stream_generator():
247
+ response_stream = self.client.chat.completions.create(
248
+ messages=messages,
249
+ max_tokens=max_new_tokens,
250
+ temperature=temperature,
251
+ stream=True,
252
+ **kwrs
253
+ )
254
+ for chunk in response_stream:
255
+ content_chunk = chunk.get('choices')[0].get('delta').get('content')
256
+ if content_chunk:
257
+ yield content_chunk
258
+
259
+ return _stream_generator()
260
+
261
+ elif verbose:
262
+ response = self.client.chat.completions.create(
263
+ messages=messages,
264
+ max_tokens=max_new_tokens,
265
+ temperature=temperature,
266
+ stream=True,
267
+ **kwrs
268
+ )
269
+
222
270
  res = ''
223
271
  for chunk in response:
224
- res += chunk.choices[0].delta.content
225
- print(chunk.choices[0].delta.content, end='', flush=True)
272
+ content_chunk = chunk.get('choices')[0].get('delta').get('content')
273
+ if content_chunk:
274
+ res += content_chunk
275
+ print(content_chunk, end='', flush=True)
226
276
  return res
227
277
 
228
- return response.choices[0].message.content
278
+ else:
279
+ response = self.client.chat.completions.create(
280
+ messages=messages,
281
+ max_tokens=max_new_tokens,
282
+ temperature=temperature,
283
+ stream=False,
284
+ **kwrs
285
+ )
286
+ return response.choices[0].message.content
229
287
 
230
288
  async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
231
289
  """
@@ -267,7 +325,8 @@ class OpenAIInferenceEngine(InferenceEngine):
267
325
  self.model = model
268
326
  self.reasoning_model = reasoning_model
269
327
 
270
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
328
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
329
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
271
330
  """
272
331
  This method inputs chat messages and outputs LLM generated text.
273
332
 
@@ -279,60 +338,145 @@ class OpenAIInferenceEngine(InferenceEngine):
279
338
  the max number of new tokens LLM can generate.
280
339
  temperature : float, Optional
281
340
  the temperature for token sampling.
341
+ verbose : bool, Optional
342
+ if True, VLM generated text will be printed in terminal in real-time.
282
343
  stream : bool, Optional
283
- if True, LLM generated text will be printed in terminal in real-time.
344
+ if True, returns a generator that yields the output in real-time.
284
345
  """
346
+ # For reasoning models
285
347
  if self.reasoning_model:
348
+ # Reasoning models do not support temperature parameter
286
349
  if temperature != 0.0:
287
350
  warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
351
+
352
+ # Reasoning models do not support system prompts
353
+ if any(msg['role'] == 'system' for msg in messages):
354
+ warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
355
+ messages = [msg for msg in messages if msg['role'] != 'system']
288
356
 
289
- response = self.client.chat.completions.create(
290
- model=self.model,
291
- messages=messages,
292
- max_completion_tokens=max_new_tokens,
293
- stream=stream,
294
- **kwrs
295
- )
296
357
 
297
- else:
298
- response = self.client.chat.completions.create(
299
- model=self.model,
300
- messages=messages,
301
- max_tokens=max_new_tokens,
302
- temperature=temperature,
303
- stream=stream,
304
- **kwrs
305
- )
358
+ if stream:
359
+ def _stream_generator():
360
+ response_stream = self.client.chat.completions.create(
361
+ model=self.model,
362
+ messages=messages,
363
+ max_completion_tokens=max_new_tokens,
364
+ stream=True,
365
+ **kwrs
366
+ )
367
+ for chunk in response_stream:
368
+ if len(chunk.choices) > 0:
369
+ if chunk.choices[0].delta.content is not None:
370
+ yield chunk.choices[0].delta.content
371
+ if chunk.choices[0].finish_reason == "length":
372
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
373
+ if self.reasoning_model:
374
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
375
+ return _stream_generator()
376
+
377
+ elif verbose:
378
+ response = self.client.chat.completions.create(
379
+ model=self.model,
380
+ messages=messages,
381
+ max_completion_tokens=max_new_tokens,
382
+ stream=True,
383
+ **kwrs
384
+ )
385
+ res = ''
386
+ for chunk in response:
387
+ if len(chunk.choices) > 0:
388
+ if chunk.choices[0].delta.content is not None:
389
+ res += chunk.choices[0].delta.content
390
+ print(chunk.choices[0].delta.content, end="", flush=True)
391
+ if chunk.choices[0].finish_reason == "length":
392
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
393
+ if self.reasoning_model:
394
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
395
+
396
+ print('\n')
397
+ return res
398
+ else:
399
+ response = self.client.chat.completions.create(
400
+ model=self.model,
401
+ messages=messages,
402
+ max_completion_tokens=max_new_tokens,
403
+ stream=False,
404
+ **kwrs
405
+ )
406
+ return response.choices[0].message.content
306
407
 
307
- if stream:
308
- res = ''
309
- for chunk in response:
310
- if len(chunk.choices) > 0:
311
- if chunk.choices[0].delta.content is not None:
312
- res += chunk.choices[0].delta.content
313
- print(chunk.choices[0].delta.content, end="", flush=True)
314
- if chunk.choices[0].finish_reason == "length":
315
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
316
- if self.reasoning_model:
317
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
318
- return res
319
-
320
- if response.choices[0].finish_reason == "length":
321
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
322
- if self.reasoning_model:
323
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
408
+ # For non-reasoning models
409
+ else:
410
+ if stream:
411
+ def _stream_generator():
412
+ response_stream = self.client.chat.completions.create(
413
+ model=self.model,
414
+ messages=messages,
415
+ max_tokens=max_new_tokens,
416
+ temperature=temperature,
417
+ stream=True,
418
+ **kwrs
419
+ )
420
+ for chunk in response_stream:
421
+ if len(chunk.choices) > 0:
422
+ if chunk.choices[0].delta.content is not None:
423
+ yield chunk.choices[0].delta.content
424
+ if chunk.choices[0].finish_reason == "length":
425
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
426
+ if self.reasoning_model:
427
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
428
+ return _stream_generator()
324
429
 
325
- return response.choices[0].message.content
430
+ elif verbose:
431
+ response = self.client.chat.completions.create(
432
+ model=self.model,
433
+ messages=messages,
434
+ max_tokens=max_new_tokens,
435
+ temperature=temperature,
436
+ stream=True,
437
+ **kwrs
438
+ )
439
+ res = ''
440
+ for chunk in response:
441
+ if len(chunk.choices) > 0:
442
+ if chunk.choices[0].delta.content is not None:
443
+ res += chunk.choices[0].delta.content
444
+ print(chunk.choices[0].delta.content, end="", flush=True)
445
+ if chunk.choices[0].finish_reason == "length":
446
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
447
+ if self.reasoning_model:
448
+ warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
449
+
450
+ print('\n')
451
+ return res
452
+
453
+ else:
454
+ response = self.client.chat.completions.create(
455
+ model=self.model,
456
+ messages=messages,
457
+ max_tokens=max_new_tokens,
458
+ temperature=temperature,
459
+ stream=False,
460
+ **kwrs
461
+ )
462
+
463
+ return response.choices[0].message.content
326
464
 
327
465
 
328
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
466
+ async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
329
467
  """
330
468
  Async version of chat method. Streaming is not supported.
331
469
  """
332
470
  if self.reasoning_model:
471
+ # Reasoning models do not support temperature parameter
333
472
  if temperature != 0.0:
334
473
  warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
335
-
474
+
475
+ # Reasoning models do not support system prompts
476
+ if any(msg['role'] == 'system' for msg in messages):
477
+ warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
478
+ messages = [msg for msg in messages if msg['role'] != 'system']
479
+
336
480
  response = await self.async_client.chat.completions.create(
337
481
  model=self.model,
338
482
  messages=messages,
@@ -340,6 +484,7 @@ class OpenAIInferenceEngine(InferenceEngine):
340
484
  stream=False,
341
485
  **kwrs
342
486
  )
487
+
343
488
  else:
344
489
  response = await self.async_client.chat.completions.create(
345
490
  model=self.model,
@@ -358,7 +503,7 @@ class OpenAIInferenceEngine(InferenceEngine):
358
503
  return response.choices[0].message.content
359
504
 
360
505
 
361
- class AzureOpenAIInferenceEngine(InferenceEngine):
506
+ class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
362
507
  def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
363
508
  """
364
509
  The Azure OpenAI API inference engine.
@@ -387,96 +532,6 @@ class AzureOpenAIInferenceEngine(InferenceEngine):
387
532
  **kwrs)
388
533
  self.reasoning_model = reasoning_model
389
534
 
390
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
391
- """
392
- This method inputs chat messages and outputs LLM generated text.
393
-
394
- Parameters:
395
- ----------
396
- messages : List[Dict[str,str]]
397
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
398
- max_new_tokens : str, Optional
399
- the max number of new tokens LLM can generate.
400
- temperature : float, Optional
401
- the temperature for token sampling.
402
- stream : bool, Optional
403
- if True, LLM generated text will be printed in terminal in real-time.
404
- """
405
- if self.reasoning_model:
406
- if temperature != 0.0:
407
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
408
-
409
- response = self.client.chat.completions.create(
410
- model=self.model,
411
- messages=messages,
412
- max_completion_tokens=max_new_tokens,
413
- stream=stream,
414
- **kwrs
415
- )
416
-
417
- else:
418
- response = self.client.chat.completions.create(
419
- model=self.model,
420
- messages=messages,
421
- max_tokens=max_new_tokens,
422
- temperature=temperature,
423
- stream=stream,
424
- **kwrs
425
- )
426
-
427
- if stream:
428
- res = ''
429
- for chunk in response:
430
- if len(chunk.choices) > 0:
431
- if chunk.choices[0].delta.content is not None:
432
- res += chunk.choices[0].delta.content
433
- print(chunk.choices[0].delta.content, end="", flush=True)
434
- if chunk.choices[0].finish_reason == "length":
435
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
436
- if self.reasoning_model:
437
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
438
- return res
439
-
440
- if response.choices[0].finish_reason == "length":
441
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
442
- if self.reasoning_model:
443
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
444
-
445
- return response.choices[0].message.content
446
-
447
-
448
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
449
- """
450
- Async version of chat method. Streaming is not supported.
451
- """
452
- if self.reasoning_model:
453
- if temperature != 0.0:
454
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
455
-
456
- response = await self.async_client.chat.completions.create(
457
- model=self.model,
458
- messages=messages,
459
- max_completion_tokens=max_new_tokens,
460
- stream=False,
461
- **kwrs
462
- )
463
- else:
464
- response = await self.async_client.chat.completions.create(
465
- model=self.model,
466
- messages=messages,
467
- max_tokens=max_new_tokens,
468
- temperature=temperature,
469
- stream=False,
470
- **kwrs
471
- )
472
-
473
- if response.choices[0].finish_reason == "length":
474
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
475
- if self.reasoning_model:
476
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
477
-
478
- return response.choices[0].message.content
479
-
480
535
 
481
536
  class LiteLLMInferenceEngine(InferenceEngine):
482
537
  def __init__(self, model:str=None, base_url:str=None, api_key:str=None):
@@ -502,7 +557,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
502
557
  self.base_url = base_url
503
558
  self.api_key = api_key
504
559
 
505
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
560
+ def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
561
+ verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
506
562
  """
507
563
  This method inputs chat messages and outputs LLM generated text.
508
564
 
@@ -514,29 +570,64 @@ class LiteLLMInferenceEngine(InferenceEngine):
514
570
  the max number of new tokens LLM can generate.
515
571
  temperature : float, Optional
516
572
  the temperature for token sampling.
573
+ verbose : bool, Optional
574
+ if True, VLM generated text will be printed in terminal in real-time.
517
575
  stream : bool, Optional
518
- if True, LLM generated text will be printed in terminal in real-time.
576
+ if True, returns a generator that yields the output in real-time.
519
577
  """
520
- response = self.litellm.completion(
521
- model=self.model,
522
- messages=messages,
523
- max_tokens=max_new_tokens,
524
- temperature=temperature,
525
- stream=stream,
526
- base_url=self.base_url,
527
- api_key=self.api_key,
528
- **kwrs
529
- )
530
-
531
578
  if stream:
579
+ def _stream_generator():
580
+ response_stream = self.litellm.completion(
581
+ model=self.model,
582
+ messages=messages,
583
+ max_tokens=max_new_tokens,
584
+ temperature=temperature,
585
+ stream=True,
586
+ base_url=self.base_url,
587
+ api_key=self.api_key,
588
+ **kwrs
589
+ )
590
+
591
+ for chunk in response_stream:
592
+ chunk_content = chunk.get('choices')[0].get('delta').get('content')
593
+ if chunk_content:
594
+ yield chunk_content
595
+
596
+ return _stream_generator()
597
+
598
+ elif verbose:
599
+ response = self.litellm.completion(
600
+ model=self.model,
601
+ messages=messages,
602
+ max_tokens=max_new_tokens,
603
+ temperature=temperature,
604
+ stream=True,
605
+ base_url=self.base_url,
606
+ api_key=self.api_key,
607
+ **kwrs
608
+ )
609
+
532
610
  res = ''
533
611
  for chunk in response:
534
- if chunk.choices[0].delta.content is not None:
535
- res += chunk.choices[0].delta.content
536
- print(chunk.choices[0].delta.content, end="", flush=True)
612
+ chunk_content = chunk.get('choices')[0].get('delta').get('content')
613
+ if chunk_content:
614
+ res += chunk_content
615
+ print(chunk_content, end='', flush=True)
616
+
537
617
  return res
538
618
 
539
- return response.choices[0].message.content
619
+ else:
620
+ response = self.litellm.completion(
621
+ model=self.model,
622
+ messages=messages,
623
+ max_tokens=max_new_tokens,
624
+ temperature=temperature,
625
+ stream=False,
626
+ base_url=self.base_url,
627
+ api_key=self.api_key,
628
+ **kwrs
629
+ )
630
+ return response.choices[0].message.content
540
631
 
541
632
  async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
542
633
  """
@@ -553,4 +644,4 @@ class LiteLLMInferenceEngine(InferenceEngine):
553
644
  **kwrs
554
645
  )
555
646
 
556
- return response.choices[0].message.content
647
+ return response.get('choices')[0].get('message').get('content')