llm-ie 0.4.7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py CHANGED
@@ -1,21 +1,290 @@
1
1
  import abc
2
+ import re
2
3
  import warnings
3
- import importlib
4
- from typing import List, Dict, Union
4
+ import importlib.util
5
+ from typing import Any, Tuple, List, Dict, Union, Generator
6
+
7
+
8
+ class LLMConfig(abc.ABC):
9
+ def __init__(self, **kwargs):
10
+ """
11
+ This is an abstract class to provide interfaces for LLM configuration.
12
+ Children classes that inherts this class can be used in extrators and prompt editor.
13
+ Common LLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
14
+ """
15
+ self.params = kwargs.copy()
16
+
17
+
18
+ @abc.abstractmethod
19
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
20
+ """
21
+ This method preprocesses the input messages before passing them to the LLM.
22
+
23
+ Parameters:
24
+ ----------
25
+ messages : List[Dict[str,str]]
26
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
27
+
28
+ Returns:
29
+ -------
30
+ messages : List[Dict[str,str]]
31
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
32
+ """
33
+ return NotImplemented
34
+
35
+ @abc.abstractmethod
36
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
37
+ """
38
+ This method postprocesses the LLM response after it is generated.
39
+
40
+ Parameters:
41
+ ----------
42
+ response : Union[str, Generator[str, None, None]]
43
+ the LLM response. Can be a string or a generator.
44
+
45
+ Returns:
46
+ -------
47
+ response : str
48
+ the postprocessed LLM response
49
+ """
50
+ return NotImplemented
51
+
52
+
53
+ class BasicLLMConfig(LLMConfig):
54
+ def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
55
+ """
56
+ The basic LLM configuration for most non-reasoning models.
57
+ """
58
+ super().__init__(**kwargs)
59
+ self.max_new_tokens = max_new_tokens
60
+ self.temperature = temperature
61
+ self.params["max_new_tokens"] = self.max_new_tokens
62
+ self.params["temperature"] = self.temperature
63
+
64
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
65
+ """
66
+ This method preprocesses the input messages before passing them to the LLM.
67
+
68
+ Parameters:
69
+ ----------
70
+ messages : List[Dict[str,str]]
71
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
72
+
73
+ Returns:
74
+ -------
75
+ messages : List[Dict[str,str]]
76
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
77
+ """
78
+ return messages
79
+
80
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
81
+ """
82
+ This method postprocesses the LLM response after it is generated.
83
+
84
+ Parameters:
85
+ ----------
86
+ response : Union[str, Generator[str, None, None]]
87
+ the LLM response. Can be a string or a generator.
88
+
89
+ Returns: Union[str, Generator[Dict[str, str], None, None]]
90
+ the postprocessed LLM response.
91
+ if input is a generator, the output will be a generator {"data": <content>}.
92
+ """
93
+ if isinstance(response, str):
94
+ return response
95
+
96
+ def _process_stream():
97
+ for chunk in response:
98
+ yield {"type": "response", "data": chunk}
99
+
100
+ return _process_stream()
101
+
102
+ class Qwen3LLMConfig(LLMConfig):
103
+ def __init__(self, thinking_mode:bool=True, **kwargs):
104
+ """
105
+ The Qwen3 LLM configuration for reasoning models.
106
+
107
+ Parameters:
108
+ ----------
109
+ thinking_mode : bool, Optional
110
+ if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
111
+ """
112
+ super().__init__(**kwargs)
113
+ self.thinking_mode = thinking_mode
114
+
115
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
116
+ """
117
+ Append a special token to the system and user prompts.
118
+ The token is "/think" if thinking_mode is True, otherwise "/no_think".
119
+
120
+ Parameters:
121
+ ----------
122
+ messages : List[Dict[str,str]]
123
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
124
+
125
+ Returns:
126
+ -------
127
+ messages : List[Dict[str,str]]
128
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
129
+ """
130
+ thinking_token = "/think" if self.thinking_mode else "/no_think"
131
+ new_messages = []
132
+ for message in messages:
133
+ if message['role'] in ['system', 'user']:
134
+ new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
135
+ else:
136
+ new_message = {'role': message['role'], 'content': message['content']}
137
+
138
+ new_messages.append(new_message)
139
+
140
+ return new_messages
141
+
142
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str,str], None, None]]:
143
+ """
144
+ If input is a generator, tag contents in <think> and </think> as {"type": "reasoning", "data": <content>},
145
+ and the rest as {"type": "response", "data": <content>}.
146
+ If input is a string, drop contents in <think> and </think>.
147
+
148
+ Parameters:
149
+ ----------
150
+ response : Union[str, Generator[str, None, None]]
151
+ the LLM response. Can be a string or a generator.
152
+
153
+ Returns:
154
+ -------
155
+ response : Union[str, Generator[str, None, None]]
156
+ the postprocessed LLM response.
157
+ if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
158
+ """
159
+ if isinstance(response, str):
160
+ return re.sub(r"<think>.*?</think>\s*", "", response, flags=re.DOTALL).strip()
161
+
162
+ if isinstance(response, Generator):
163
+ def _process_stream():
164
+ think_flag = False
165
+ buffer = ""
166
+ for chunk in response:
167
+ if isinstance(chunk, str):
168
+ buffer += chunk
169
+ # switch between reasoning and response
170
+ if "<think>" in buffer:
171
+ think_flag = True
172
+ buffer = buffer.replace("<think>", "")
173
+ elif "</think>" in buffer:
174
+ think_flag = False
175
+ buffer = buffer.replace("</think>", "")
176
+
177
+ # if chunk is in thinking block, tag it as reasoning; else tag it as response
178
+ if chunk not in ["<think>", "</think>"]:
179
+ if think_flag:
180
+ yield {"type": "reasoning", "data": chunk}
181
+ else:
182
+ yield {"type": "response", "data": chunk}
183
+
184
+ return _process_stream()
185
+
186
+
187
+ class OpenAIReasoningLLMConfig(LLMConfig):
188
+ def __init__(self, reasoning_effort:str="low", **kwargs):
189
+ """
190
+ The OpenAI "o" series configuration.
191
+ 1. The reasoning effort is set to "low" by default.
192
+ 2. The temperature parameter is not supported and will be ignored.
193
+ 3. The system prompt is not supported and will be concatenated to the next user prompt.
194
+
195
+ Parameters:
196
+ ----------
197
+ reasoning_effort : str, Optional
198
+ the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
199
+ """
200
+ super().__init__(**kwargs)
201
+ if reasoning_effort not in ["low", "medium", "high"]:
202
+ raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
203
+
204
+ self.reasoning_effort = reasoning_effort
205
+ self.params["reasoning_effort"] = self.reasoning_effort
206
+
207
+ if "temperature" in self.params:
208
+ warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
209
+ self.params.pop("temperature")
210
+
211
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
212
+ """
213
+ Concatenate system prompts to the next user prompt.
214
+
215
+ Parameters:
216
+ ----------
217
+ messages : List[Dict[str,str]]
218
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
219
+
220
+ Returns:
221
+ -------
222
+ messages : List[Dict[str,str]]
223
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
224
+ """
225
+ system_prompt_holder = ""
226
+ new_messages = []
227
+ for i, message in enumerate(messages):
228
+ # if system prompt, store it in system_prompt_holder
229
+ if message['role'] == 'system':
230
+ system_prompt_holder = message['content']
231
+ # if user prompt, concatenate it with system_prompt_holder
232
+ elif message['role'] == 'user':
233
+ if system_prompt_holder:
234
+ new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
235
+ system_prompt_holder = ""
236
+ else:
237
+ new_message = {'role': message['role'], 'content': message['content']}
238
+
239
+ new_messages.append(new_message)
240
+ # if assistant/other prompt, do nothing
241
+ else:
242
+ new_message = {'role': message['role'], 'content': message['content']}
243
+ new_messages.append(new_message)
244
+
245
+ return new_messages
246
+
247
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
248
+ """
249
+ This method postprocesses the LLM response after it is generated.
250
+
251
+ Parameters:
252
+ ----------
253
+ response : Union[str, Generator[str, None, None]]
254
+ the LLM response. Can be a string or a generator.
255
+
256
+ Returns: Union[str, Generator[Dict[str, str], None, None]]
257
+ the postprocessed LLM response.
258
+ if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
259
+ """
260
+ if isinstance(response, str):
261
+ return response
262
+
263
+ def _process_stream():
264
+ for chunk in response:
265
+ yield {"type": "response", "data": chunk}
266
+
267
+ return _process_stream()
5
268
 
6
269
 
7
270
  class InferenceEngine:
8
271
  @abc.abstractmethod
9
- def __init__(self):
272
+ def __init__(self, config:LLMConfig, **kwrs):
10
273
  """
11
274
  This is an abstract class to provide interfaces for LLM inference engines.
12
275
  Children classes that inherts this class can be used in extrators. Must implement chat() method.
276
+
277
+ Parameters:
278
+ ----------
279
+ config : LLMConfig
280
+ the LLM configuration. Must be a child class of LLMConfig.
13
281
  """
14
282
  return NotImplemented
15
283
 
16
284
 
17
285
  @abc.abstractmethod
18
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
286
+ def chat(self, messages:List[Dict[str,str]],
287
+ verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
19
288
  """
20
289
  This method inputs chat messages and outputs LLM generated text.
21
290
 
@@ -23,18 +292,25 @@ class InferenceEngine:
23
292
  ----------
24
293
  messages : List[Dict[str,str]]
25
294
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
26
- max_new_tokens : str, Optional
27
- the max number of new tokens LLM can generate.
28
- temperature : float, Optional
29
- the temperature for token sampling.
295
+ verbose : bool, Optional
296
+ if True, LLM generated text will be printed in terminal in real-time.
30
297
  stream : bool, Optional
31
- if True, LLM generated text will be printed in terminal in real-time.
298
+ if True, returns a generator that yields the output in real-time.
299
+ """
300
+ return NotImplemented
301
+
302
+ def _format_config(self) -> Dict[str, Any]:
303
+ """
304
+ This method format the LLM configuration with the correct key for the inference engine.
305
+
306
+ Return : Dict[str, Any]
307
+ the config parameters.
32
308
  """
33
309
  return NotImplemented
34
310
 
35
311
 
36
312
  class LlamaCppInferenceEngine(InferenceEngine):
37
- def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, **kwrs):
313
+ def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, config:LLMConfig=None, **kwrs):
38
314
  """
39
315
  The Llama.cpp inference engine.
40
316
 
@@ -49,12 +325,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
49
325
  context length that LLM will evaluate.
50
326
  n_gpu_layers : int, Optional
51
327
  number of layers to offload to GPU. Default is all layers (-1).
328
+ config : LLMConfig
329
+ the LLM configuration.
52
330
  """
53
331
  from llama_cpp import Llama
54
332
  self.repo_id = repo_id
55
333
  self.gguf_filename = gguf_filename
56
334
  self.n_ctx = n_ctx
57
335
  self.n_gpu_layers = n_gpu_layers
336
+ self.config = config if config else BasicLLMConfig()
337
+ self.formatted_params = self._format_config()
58
338
 
59
339
  self.model = Llama.from_pretrained(
60
340
  repo_id=self.repo_id,
@@ -70,8 +350,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
70
350
  """
71
351
  del self.model
72
352
 
353
+ def _format_config(self) -> Dict[str, Any]:
354
+ """
355
+ This method format the LLM configuration with the correct key for the inference engine.
356
+ """
357
+ formatted_params = self.config.params.copy()
358
+ if "max_new_tokens" in formatted_params:
359
+ formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
360
+ formatted_params.pop("max_new_tokens")
361
+
362
+ return formatted_params
73
363
 
74
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
364
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> str:
75
365
  """
76
366
  This method inputs chat messages and outputs LLM generated text.
77
367
 
@@ -79,22 +369,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
79
369
  ----------
80
370
  messages : List[Dict[str,str]]
81
371
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
82
- max_new_tokens : str, Optional
83
- the max number of new tokens LLM can generate.
84
- temperature : float, Optional
85
- the temperature for token sampling.
86
- stream : bool, Optional
372
+ verbose : bool, Optional
87
373
  if True, LLM generated text will be printed in terminal in real-time.
88
374
  """
375
+ processed_messages = self.config.preprocess_messages(messages)
376
+
89
377
  response = self.model.create_chat_completion(
90
- messages=messages,
91
- max_tokens=max_new_tokens,
92
- temperature=temperature,
93
- stream=stream,
94
- **kwrs
378
+ messages=processed_messages,
379
+ stream=verbose,
380
+ **self.formatted_params
95
381
  )
96
382
 
97
- if stream:
383
+ if verbose:
98
384
  res = ''
99
385
  for chunk in response:
100
386
  out_dict = chunk['choices'][0]['delta']
@@ -102,16 +388,14 @@ class LlamaCppInferenceEngine(InferenceEngine):
102
388
  res += out_dict['content']
103
389
  print(out_dict['content'], end='', flush=True)
104
390
  print('\n')
105
- return res
391
+ return self.config.postprocess_response(res)
106
392
 
107
- return response['choices'][0]['message']['content']
108
-
109
-
110
-
393
+ res = response['choices'][0]['message']['content']
394
+ return self.config.postprocess_response(res)
111
395
 
112
396
 
113
397
  class OllamaInferenceEngine(InferenceEngine):
114
- def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
398
+ def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, config:LLMConfig=None, **kwrs):
115
399
  """
116
400
  The Ollama inference engine.
117
401
 
@@ -123,6 +407,8 @@ class OllamaInferenceEngine(InferenceEngine):
123
407
  context length that LLM will evaluate.
124
408
  keep_alive : int, Optional
125
409
  seconds to hold the LLM after the last API call.
410
+ config : LLMConfig
411
+ the LLM configuration.
126
412
  """
127
413
  if importlib.util.find_spec("ollama") is None:
128
414
  raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
@@ -133,69 +419,144 @@ class OllamaInferenceEngine(InferenceEngine):
133
419
  self.model_name = model_name
134
420
  self.num_ctx = num_ctx
135
421
  self.keep_alive = keep_alive
422
+ self.config = config if config else BasicLLMConfig()
423
+ self.formatted_params = self._format_config()
424
+
425
+ def _format_config(self) -> Dict[str, Any]:
426
+ """
427
+ This method format the LLM configuration with the correct key for the inference engine.
428
+ """
429
+ formatted_params = self.config.params.copy()
430
+ if "max_new_tokens" in formatted_params:
431
+ formatted_params["num_predict"] = formatted_params["max_new_tokens"]
432
+ formatted_params.pop("max_new_tokens")
136
433
 
137
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
434
+ return formatted_params
435
+
436
+ def chat(self, messages:List[Dict[str,str]],
437
+ verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
138
438
  """
139
- This method inputs chat messages and outputs LLM generated text.
439
+ This method inputs chat messages and outputs VLM generated text.
140
440
 
141
441
  Parameters:
142
442
  ----------
143
443
  messages : List[Dict[str,str]]
144
444
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
145
- max_new_tokens : str, Optional
146
- the max number of new tokens LLM can generate.
147
- temperature : float, Optional
148
- the temperature for token sampling.
445
+ verbose : bool, Optional
446
+ if True, VLM generated text will be printed in terminal in real-time.
149
447
  stream : bool, Optional
150
- if True, LLM generated text will be printed in terminal in real-time.
448
+ if True, returns a generator that yields the output in real-time.
151
449
  """
152
- response = self.client.chat(
450
+ processed_messages = self.config.preprocess_messages(messages)
451
+
452
+ options={'num_ctx': self.num_ctx, **self.formatted_params}
453
+ if stream:
454
+ def _stream_generator():
455
+ response_stream = self.client.chat(
456
+ model=self.model_name,
457
+ messages=processed_messages,
458
+ options=options,
459
+ stream=True,
460
+ keep_alive=self.keep_alive
461
+ )
462
+ for chunk in response_stream:
463
+ content_chunk = chunk.get('message', {}).get('content')
464
+ if content_chunk:
465
+ yield content_chunk
466
+
467
+ return self.config.postprocess_response(_stream_generator())
468
+
469
+ elif verbose:
470
+ response = self.client.chat(
153
471
  model=self.model_name,
154
- messages=messages,
155
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
156
- stream=stream,
472
+ messages=processed_messages,
473
+ options=options,
474
+ stream=True,
157
475
  keep_alive=self.keep_alive
158
476
  )
159
- if stream:
477
+
160
478
  res = ''
161
479
  for chunk in response:
162
- res += chunk['message']['content']
163
- print(chunk['message']['content'], end='', flush=True)
480
+ content_chunk = chunk.get('message', {}).get('content')
481
+ print(content_chunk, end='', flush=True)
482
+ res += content_chunk
164
483
  print('\n')
165
- return res
484
+ return self.config.postprocess_response(res)
485
+
486
+ else:
487
+ response = self.client.chat(
488
+ model=self.model_name,
489
+ messages=processed_messages,
490
+ options=options,
491
+ stream=False,
492
+ keep_alive=self.keep_alive
493
+ )
494
+ res = response.get('message', {}).get('content')
495
+ return self.config.postprocess_response(res)
166
496
 
167
- return response['message']['content']
168
-
169
497
 
170
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
498
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
171
499
  """
172
500
  Async version of chat method. Streaming is not supported.
173
501
  """
502
+ processed_messages = self.config.preprocess_messages(messages)
503
+
174
504
  response = await self.async_client.chat(
175
505
  model=self.model_name,
176
- messages=messages,
177
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
506
+ messages=processed_messages,
507
+ options={'num_ctx': self.num_ctx, **self.formatted_params},
178
508
  stream=False,
179
509
  keep_alive=self.keep_alive
180
510
  )
181
511
 
182
- return response['message']['content']
512
+ res = response['message']['content']
513
+ return self.config.postprocess_response(res)
183
514
 
184
515
 
185
516
  class HuggingFaceHubInferenceEngine(InferenceEngine):
186
- def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, **kwrs):
517
+ def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, config:LLMConfig=None, **kwrs):
187
518
  """
188
519
  The Huggingface_hub InferenceClient inference engine.
189
520
  For parameters and documentation, refer to https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client
521
+
522
+ Parameters:
523
+ ----------
524
+ model : str
525
+ the model name exactly as shown in Huggingface repo
526
+ token : str, Optional
527
+ the Huggingface token. If None, will use the token in os.environ['HF_TOKEN'].
528
+ base_url : str, Optional
529
+ the base url for the LLM server. If None, will use the default Huggingface Hub URL.
530
+ api_key : str, Optional
531
+ the API key for the LLM server.
532
+ config : LLMConfig
533
+ the LLM configuration.
190
534
  """
191
535
  if importlib.util.find_spec("huggingface_hub") is None:
192
536
  raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
193
537
 
194
538
  from huggingface_hub import InferenceClient, AsyncInferenceClient
539
+ self.model = model
540
+ self.base_url = base_url
195
541
  self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
196
542
  self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
543
+ self.config = config if config else BasicLLMConfig()
544
+ self.formatted_params = self._format_config()
197
545
 
198
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
546
+ def _format_config(self) -> Dict[str, Any]:
547
+ """
548
+ This method format the LLM configuration with the correct key for the inference engine.
549
+ """
550
+ formatted_params = self.config.params.copy()
551
+ if "max_new_tokens" in formatted_params:
552
+ formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
553
+ formatted_params.pop("max_new_tokens")
554
+
555
+ return formatted_params
556
+
557
+
558
+ def chat(self, messages:List[Dict[str,str]],
559
+ verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
199
560
  """
200
561
  This method inputs chat messages and outputs LLM generated text.
201
562
 
@@ -203,47 +564,69 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
203
564
  ----------
204
565
  messages : List[Dict[str,str]]
205
566
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
206
- max_new_tokens : str, Optional
207
- the max number of new tokens LLM can generate.
208
- temperature : float, Optional
209
- the temperature for token sampling.
567
+ verbose : bool, Optional
568
+ if True, VLM generated text will be printed in terminal in real-time.
210
569
  stream : bool, Optional
211
- if True, LLM generated text will be printed in terminal in real-time.
570
+ if True, returns a generator that yields the output in real-time.
212
571
  """
213
- response = self.client.chat.completions.create(
214
- messages=messages,
215
- max_tokens=max_new_tokens,
216
- temperature=temperature,
217
- stream=stream,
218
- **kwrs
219
- )
220
-
572
+ processed_messages = self.config.preprocess_messages(messages)
573
+
221
574
  if stream:
575
+ def _stream_generator():
576
+ response_stream = self.client.chat.completions.create(
577
+ messages=processed_messages,
578
+ stream=True,
579
+ **self.formatted_params
580
+ )
581
+ for chunk in response_stream:
582
+ content_chunk = chunk.get('choices')[0].get('delta').get('content')
583
+ if content_chunk:
584
+ yield content_chunk
585
+
586
+ return self.config.postprocess_response(_stream_generator())
587
+
588
+ elif verbose:
589
+ response = self.client.chat.completions.create(
590
+ messages=processed_messages,
591
+ stream=True,
592
+ **self.formatted_params
593
+ )
594
+
222
595
  res = ''
223
596
  for chunk in response:
224
- res += chunk.choices[0].delta.content
225
- print(chunk.choices[0].delta.content, end='', flush=True)
226
- return res
597
+ content_chunk = chunk.get('choices')[0].get('delta').get('content')
598
+ if content_chunk:
599
+ res += content_chunk
600
+ print(content_chunk, end='', flush=True)
601
+ return self.config.postprocess_response(res)
227
602
 
228
- return response.choices[0].message.content
603
+ else:
604
+ response = self.client.chat.completions.create(
605
+ messages=processed_messages,
606
+ stream=False,
607
+ **self.formatted_params
608
+ )
609
+ res = response.choices[0].message.content
610
+ return self.config.postprocess_response(res)
229
611
 
230
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
612
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
231
613
  """
232
614
  Async version of chat method. Streaming is not supported.
233
615
  """
616
+ processed_messages = self.config.preprocess_messages(messages)
617
+
234
618
  response = await self.client_async.chat.completions.create(
235
- messages=messages,
236
- max_tokens=max_new_tokens,
237
- temperature=temperature,
619
+ messages=processed_messages,
238
620
  stream=False,
239
- **kwrs
621
+ **self.formatted_params
240
622
  )
241
623
 
242
- return response.choices[0].message.content
624
+ res = response.choices[0].message.content
625
+ return self.config.postprocess_response(res)
243
626
 
244
627
 
245
628
  class OpenAIInferenceEngine(InferenceEngine):
246
- def __init__(self, model:str, reasoning_model:bool=False, **kwrs):
629
+ def __init__(self, model:str, config:LLMConfig=None, **kwrs):
247
630
  """
248
631
  The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
249
632
  - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
@@ -255,8 +638,6 @@ class OpenAIInferenceEngine(InferenceEngine):
255
638
  ----------
256
639
  model_name : str
257
640
  model name as described in https://platform.openai.com/docs/models
258
- reasoning_model : bool, Optional
259
- indicator for OpenAI reasoning models ("o" series).
260
641
  """
261
642
  if importlib.util.find_spec("openai") is None:
262
643
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
@@ -265,9 +646,21 @@ class OpenAIInferenceEngine(InferenceEngine):
265
646
  self.client = OpenAI(**kwrs)
266
647
  self.async_client = AsyncOpenAI(**kwrs)
267
648
  self.model = model
268
- self.reasoning_model = reasoning_model
649
+ self.config = config if config else BasicLLMConfig()
650
+ self.formatted_params = self._format_config()
269
651
 
270
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
652
+ def _format_config(self) -> Dict[str, Any]:
653
+ """
654
+ This method format the LLM configuration with the correct key for the inference engine.
655
+ """
656
+ formatted_params = self.config.params.copy()
657
+ if "max_new_tokens" in formatted_params:
658
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
659
+ formatted_params.pop("max_new_tokens")
660
+
661
+ return formatted_params
662
+
663
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
271
664
  """
272
665
  This method inputs chat messages and outputs LLM generated text.
273
666
 
@@ -275,36 +668,37 @@ class OpenAIInferenceEngine(InferenceEngine):
275
668
  ----------
276
669
  messages : List[Dict[str,str]]
277
670
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
278
- max_new_tokens : str, Optional
279
- the max number of new tokens LLM can generate.
280
- temperature : float, Optional
281
- the temperature for token sampling.
671
+ verbose : bool, Optional
672
+ if True, VLM generated text will be printed in terminal in real-time.
282
673
  stream : bool, Optional
283
- if True, LLM generated text will be printed in terminal in real-time.
674
+ if True, returns a generator that yields the output in real-time.
284
675
  """
285
- if self.reasoning_model:
286
- if temperature != 0.0:
287
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
288
-
289
- response = self.client.chat.completions.create(
290
- model=self.model,
291
- messages=messages,
292
- max_completion_tokens=max_new_tokens,
293
- stream=stream,
294
- **kwrs
295
- )
676
+ processed_messages = self.config.preprocess_messages(messages)
296
677
 
297
- else:
678
+ if stream:
679
+ def _stream_generator():
680
+ response_stream = self.client.chat.completions.create(
681
+ model=self.model,
682
+ messages=processed_messages,
683
+ stream=True,
684
+ **self.formatted_params
685
+ )
686
+ for chunk in response_stream:
687
+ if len(chunk.choices) > 0:
688
+ if chunk.choices[0].delta.content is not None:
689
+ yield chunk.choices[0].delta.content
690
+ if chunk.choices[0].finish_reason == "length":
691
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
692
+
693
+ return self.config.postprocess_response(_stream_generator())
694
+
695
+ elif verbose:
298
696
  response = self.client.chat.completions.create(
299
697
  model=self.model,
300
- messages=messages,
301
- max_tokens=max_new_tokens,
302
- temperature=temperature,
303
- stream=stream,
304
- **kwrs
698
+ messages=processed_messages,
699
+ stream=True,
700
+ **self.formatted_params
305
701
  )
306
-
307
- if stream:
308
702
  res = ''
309
703
  for chunk in response:
310
704
  if len(chunk.choices) > 0:
@@ -313,53 +707,42 @@ class OpenAIInferenceEngine(InferenceEngine):
313
707
  print(chunk.choices[0].delta.content, end="", flush=True)
314
708
  if chunk.choices[0].finish_reason == "length":
315
709
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
316
- if self.reasoning_model:
317
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
318
- return res
319
-
320
- if response.choices[0].finish_reason == "length":
321
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
322
- if self.reasoning_model:
323
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
324
-
325
- return response.choices[0].message.content
326
-
327
710
 
328
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
329
- """
330
- Async version of chat method. Streaming is not supported.
331
- """
332
- if self.reasoning_model:
333
- if temperature != 0.0:
334
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
335
-
336
- response = await self.async_client.chat.completions.create(
337
- model=self.model,
338
- messages=messages,
339
- max_completion_tokens=max_new_tokens,
340
- stream=False,
341
- **kwrs
342
- )
711
+ print('\n')
712
+ return self.config.postprocess_response(res)
343
713
  else:
344
- response = await self.async_client.chat.completions.create(
714
+ response = self.client.chat.completions.create(
345
715
  model=self.model,
346
- messages=messages,
347
- max_tokens=max_new_tokens,
348
- temperature=temperature,
716
+ messages=processed_messages,
349
717
  stream=False,
350
- **kwrs
718
+ **self.formatted_params
351
719
  )
720
+ res = response.choices[0].message.content
721
+ return self.config.postprocess_response(res)
722
+
723
+
724
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
725
+ """
726
+ Async version of chat method. Streaming is not supported.
727
+ """
728
+ processed_messages = self.config.preprocess_messages(messages)
729
+
730
+ response = await self.async_client.chat.completions.create(
731
+ model=self.model,
732
+ messages=processed_messages,
733
+ stream=False,
734
+ **self.formatted_params
735
+ )
352
736
 
353
737
  if response.choices[0].finish_reason == "length":
354
738
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
355
- if self.reasoning_model:
356
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
357
739
 
358
- return response.choices[0].message.content
740
+ res = response.choices[0].message.content
741
+ return self.config.postprocess_response(res)
359
742
 
360
743
 
361
- class AzureOpenAIInferenceEngine(InferenceEngine):
362
- def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
744
+ class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
745
+ def __init__(self, model:str, api_version:str, config:LLMConfig=None, **kwrs):
363
746
  """
364
747
  The Azure OpenAI API inference engine.
365
748
  For parameters and documentation, refer to
@@ -372,8 +755,8 @@ class AzureOpenAIInferenceEngine(InferenceEngine):
372
755
  model name as described in https://platform.openai.com/docs/models
373
756
  api_version : str
374
757
  the Azure OpenAI API version
375
- reasoning_model : bool, Optional
376
- indicator for OpenAI reasoning models ("o" series).
758
+ config : LLMConfig
759
+ the LLM configuration.
377
760
  """
378
761
  if importlib.util.find_spec("openai") is None:
379
762
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
@@ -385,101 +768,12 @@ class AzureOpenAIInferenceEngine(InferenceEngine):
385
768
  **kwrs)
386
769
  self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
387
770
  **kwrs)
388
- self.reasoning_model = reasoning_model
389
-
390
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
391
- """
392
- This method inputs chat messages and outputs LLM generated text.
393
-
394
- Parameters:
395
- ----------
396
- messages : List[Dict[str,str]]
397
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
398
- max_new_tokens : str, Optional
399
- the max number of new tokens LLM can generate.
400
- temperature : float, Optional
401
- the temperature for token sampling.
402
- stream : bool, Optional
403
- if True, LLM generated text will be printed in terminal in real-time.
404
- """
405
- if self.reasoning_model:
406
- if temperature != 0.0:
407
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
408
-
409
- response = self.client.chat.completions.create(
410
- model=self.model,
411
- messages=messages,
412
- max_completion_tokens=max_new_tokens,
413
- stream=stream,
414
- **kwrs
415
- )
416
-
417
- else:
418
- response = self.client.chat.completions.create(
419
- model=self.model,
420
- messages=messages,
421
- max_tokens=max_new_tokens,
422
- temperature=temperature,
423
- stream=stream,
424
- **kwrs
425
- )
426
-
427
- if stream:
428
- res = ''
429
- for chunk in response:
430
- if len(chunk.choices) > 0:
431
- if chunk.choices[0].delta.content is not None:
432
- res += chunk.choices[0].delta.content
433
- print(chunk.choices[0].delta.content, end="", flush=True)
434
- if chunk.choices[0].finish_reason == "length":
435
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
436
- if self.reasoning_model:
437
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
438
- return res
439
-
440
- if response.choices[0].finish_reason == "length":
441
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
442
- if self.reasoning_model:
443
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
444
-
445
- return response.choices[0].message.content
446
-
447
-
448
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
449
- """
450
- Async version of chat method. Streaming is not supported.
451
- """
452
- if self.reasoning_model:
453
- if temperature != 0.0:
454
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
455
-
456
- response = await self.async_client.chat.completions.create(
457
- model=self.model,
458
- messages=messages,
459
- max_completion_tokens=max_new_tokens,
460
- stream=False,
461
- **kwrs
462
- )
463
- else:
464
- response = await self.async_client.chat.completions.create(
465
- model=self.model,
466
- messages=messages,
467
- max_tokens=max_new_tokens,
468
- temperature=temperature,
469
- stream=False,
470
- **kwrs
471
- )
472
-
473
- if response.choices[0].finish_reason == "length":
474
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
475
- if self.reasoning_model:
476
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
477
-
478
- return response.choices[0].message.content
771
+ self.config = config if config else BasicLLMConfig()
772
+ self.formatted_params = self._format_config()
479
773
 
480
774
 
481
775
  class LiteLLMInferenceEngine(InferenceEngine):
482
- def __init__(self, model:str=None, base_url:str=None, api_key:str=None):
776
+ def __init__(self, model:str=None, base_url:str=None, api_key:str=None, config:LLMConfig=None):
483
777
  """
484
778
  The LiteLLM inference engine.
485
779
  For parameters and documentation, refer to https://github.com/BerriAI/litellm?tab=readme-ov-file
@@ -492,6 +786,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
492
786
  the base url for the LLM server
493
787
  api_key : str, Optional
494
788
  the API key for the LLM server
789
+ config : LLMConfig
790
+ the LLM configuration.
495
791
  """
496
792
  if importlib.util.find_spec("litellm") is None:
497
793
  raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
@@ -501,56 +797,98 @@ class LiteLLMInferenceEngine(InferenceEngine):
501
797
  self.model = model
502
798
  self.base_url = base_url
503
799
  self.api_key = api_key
800
+ self.config = config if config else BasicLLMConfig()
801
+ self.formatted_params = self._format_config()
504
802
 
505
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, stream:bool=False, **kwrs) -> str:
803
+ def _format_config(self) -> Dict[str, Any]:
804
+ """
805
+ This method format the LLM configuration with the correct key for the inference engine.
806
+ """
807
+ formatted_params = self.config.params.copy()
808
+ if "max_new_tokens" in formatted_params:
809
+ formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
810
+ formatted_params.pop("max_new_tokens")
811
+
812
+ return formatted_params
813
+
814
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
506
815
  """
507
816
  This method inputs chat messages and outputs LLM generated text.
508
817
 
509
818
  Parameters:
510
819
  ----------
511
820
  messages : List[Dict[str,str]]
512
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
513
- max_new_tokens : str, Optional
514
- the max number of new tokens LLM can generate.
515
- temperature : float, Optional
516
- the temperature for token sampling.
821
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
822
+ verbose : bool, Optional
823
+ if True, VLM generated text will be printed in terminal in real-time.
517
824
  stream : bool, Optional
518
- if True, LLM generated text will be printed in terminal in real-time.
825
+ if True, returns a generator that yields the output in real-time.
519
826
  """
520
- response = self.litellm.completion(
521
- model=self.model,
522
- messages=messages,
523
- max_tokens=max_new_tokens,
524
- temperature=temperature,
525
- stream=stream,
526
- base_url=self.base_url,
527
- api_key=self.api_key,
528
- **kwrs
529
- )
530
-
827
+ processed_messages = self.config.preprocess_messages(messages)
828
+
531
829
  if stream:
830
+ def _stream_generator():
831
+ response_stream = self.litellm.completion(
832
+ model=self.model,
833
+ messages=processed_messages,
834
+ stream=True,
835
+ base_url=self.base_url,
836
+ api_key=self.api_key,
837
+ **self.formatted_params
838
+ )
839
+
840
+ for chunk in response_stream:
841
+ chunk_content = chunk.get('choices')[0].get('delta').get('content')
842
+ if chunk_content:
843
+ yield chunk_content
844
+
845
+ return self.config.postprocess_response(_stream_generator())
846
+
847
+ elif verbose:
848
+ response = self.litellm.completion(
849
+ model=self.model,
850
+ messages=processed_messages,
851
+ stream=True,
852
+ base_url=self.base_url,
853
+ api_key=self.api_key,
854
+ **self.formatted_params
855
+ )
856
+
532
857
  res = ''
533
858
  for chunk in response:
534
- if chunk.choices[0].delta.content is not None:
535
- res += chunk.choices[0].delta.content
536
- print(chunk.choices[0].delta.content, end="", flush=True)
537
- return res
859
+ chunk_content = chunk.get('choices')[0].get('delta').get('content')
860
+ if chunk_content:
861
+ res += chunk_content
862
+ print(chunk_content, end='', flush=True)
863
+
864
+ return self.config.postprocess_response(res)
538
865
 
539
- return response.choices[0].message.content
866
+ else:
867
+ response = self.litellm.completion(
868
+ model=self.model,
869
+ messages=processed_messages,
870
+ stream=False,
871
+ base_url=self.base_url,
872
+ api_key=self.api_key,
873
+ **self.formatted_params
874
+ )
875
+ res = response.choices[0].message.content
876
+ return self.config.postprocess_response(res)
540
877
 
541
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
878
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
542
879
  """
543
880
  Async version of chat method. Streaming is not supported.
544
881
  """
882
+ processed_messages = self.config.preprocess_messages(messages)
883
+
545
884
  response = await self.litellm.acompletion(
546
885
  model=self.model,
547
- messages=messages,
548
- max_tokens=max_new_tokens,
549
- temperature=temperature,
886
+ messages=processed_messages,
550
887
  stream=False,
551
888
  base_url=self.base_url,
552
889
  api_key=self.api_key,
553
- **kwrs
890
+ **self.formatted_params
554
891
  )
555
892
 
556
- return response.choices[0].message.content
893
+ res = response.get('choices')[0].get('message').get('content')
894
+ return self.config.postprocess_response(res)