llm-ie 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py CHANGED
@@ -1,22 +1,290 @@
1
1
  import abc
2
+ import re
2
3
  import warnings
3
- import importlib
4
- from typing import List, Dict, Union, Generator
4
+ import importlib.util
5
+ from typing import Any, Tuple, List, Dict, Union, Generator
6
+
7
+
8
+ class LLMConfig(abc.ABC):
9
+ def __init__(self, **kwargs):
10
+ """
11
+ This is an abstract class to provide interfaces for LLM configuration.
12
+ Children classes that inherts this class can be used in extrators and prompt editor.
13
+ Common LLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
14
+ """
15
+ self.params = kwargs.copy()
16
+
17
+
18
+ @abc.abstractmethod
19
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
20
+ """
21
+ This method preprocesses the input messages before passing them to the LLM.
22
+
23
+ Parameters:
24
+ ----------
25
+ messages : List[Dict[str,str]]
26
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
27
+
28
+ Returns:
29
+ -------
30
+ messages : List[Dict[str,str]]
31
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
32
+ """
33
+ return NotImplemented
34
+
35
+ @abc.abstractmethod
36
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[str, None, None]]:
37
+ """
38
+ This method postprocesses the LLM response after it is generated.
39
+
40
+ Parameters:
41
+ ----------
42
+ response : Union[str, Generator[str, None, None]]
43
+ the LLM response. Can be a string or a generator.
44
+
45
+ Returns:
46
+ -------
47
+ response : str
48
+ the postprocessed LLM response
49
+ """
50
+ return NotImplemented
51
+
52
+
53
+ class BasicLLMConfig(LLMConfig):
54
+ def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
55
+ """
56
+ The basic LLM configuration for most non-reasoning models.
57
+ """
58
+ super().__init__(**kwargs)
59
+ self.max_new_tokens = max_new_tokens
60
+ self.temperature = temperature
61
+ self.params["max_new_tokens"] = self.max_new_tokens
62
+ self.params["temperature"] = self.temperature
63
+
64
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
65
+ """
66
+ This method preprocesses the input messages before passing them to the LLM.
67
+
68
+ Parameters:
69
+ ----------
70
+ messages : List[Dict[str,str]]
71
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
72
+
73
+ Returns:
74
+ -------
75
+ messages : List[Dict[str,str]]
76
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
77
+ """
78
+ return messages
79
+
80
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
81
+ """
82
+ This method postprocesses the LLM response after it is generated.
83
+
84
+ Parameters:
85
+ ----------
86
+ response : Union[str, Generator[str, None, None]]
87
+ the LLM response. Can be a string or a generator.
88
+
89
+ Returns: Union[str, Generator[Dict[str, str], None, None]]
90
+ the postprocessed LLM response.
91
+ if input is a generator, the output will be a generator {"data": <content>}.
92
+ """
93
+ if isinstance(response, str):
94
+ return response
95
+
96
+ def _process_stream():
97
+ for chunk in response:
98
+ yield {"type": "response", "data": chunk}
99
+
100
+ return _process_stream()
101
+
102
+ class Qwen3LLMConfig(LLMConfig):
103
+ def __init__(self, thinking_mode:bool=True, **kwargs):
104
+ """
105
+ The Qwen3 LLM configuration for reasoning models.
106
+
107
+ Parameters:
108
+ ----------
109
+ thinking_mode : bool, Optional
110
+ if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
111
+ """
112
+ super().__init__(**kwargs)
113
+ self.thinking_mode = thinking_mode
114
+
115
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
116
+ """
117
+ Append a special token to the system and user prompts.
118
+ The token is "/think" if thinking_mode is True, otherwise "/no_think".
119
+
120
+ Parameters:
121
+ ----------
122
+ messages : List[Dict[str,str]]
123
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
124
+
125
+ Returns:
126
+ -------
127
+ messages : List[Dict[str,str]]
128
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
129
+ """
130
+ thinking_token = "/think" if self.thinking_mode else "/no_think"
131
+ new_messages = []
132
+ for message in messages:
133
+ if message['role'] in ['system', 'user']:
134
+ new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
135
+ else:
136
+ new_message = {'role': message['role'], 'content': message['content']}
137
+
138
+ new_messages.append(new_message)
139
+
140
+ return new_messages
141
+
142
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str,str], None, None]]:
143
+ """
144
+ If input is a generator, tag contents in <think> and </think> as {"type": "reasoning", "data": <content>},
145
+ and the rest as {"type": "response", "data": <content>}.
146
+ If input is a string, drop contents in <think> and </think>.
147
+
148
+ Parameters:
149
+ ----------
150
+ response : Union[str, Generator[str, None, None]]
151
+ the LLM response. Can be a string or a generator.
152
+
153
+ Returns:
154
+ -------
155
+ response : Union[str, Generator[str, None, None]]
156
+ the postprocessed LLM response.
157
+ if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
158
+ """
159
+ if isinstance(response, str):
160
+ return re.sub(r"<think>.*?</think>\s*", "", response, flags=re.DOTALL).strip()
161
+
162
+ if isinstance(response, Generator):
163
+ def _process_stream():
164
+ think_flag = False
165
+ buffer = ""
166
+ for chunk in response:
167
+ if isinstance(chunk, str):
168
+ buffer += chunk
169
+ # switch between reasoning and response
170
+ if "<think>" in buffer:
171
+ think_flag = True
172
+ buffer = buffer.replace("<think>", "")
173
+ elif "</think>" in buffer:
174
+ think_flag = False
175
+ buffer = buffer.replace("</think>", "")
176
+
177
+ # if chunk is in thinking block, tag it as reasoning; else tag it as response
178
+ if chunk not in ["<think>", "</think>"]:
179
+ if think_flag:
180
+ yield {"type": "reasoning", "data": chunk}
181
+ else:
182
+ yield {"type": "response", "data": chunk}
183
+
184
+ return _process_stream()
185
+
186
+
187
+ class OpenAIReasoningLLMConfig(LLMConfig):
188
+ def __init__(self, reasoning_effort:str="low", **kwargs):
189
+ """
190
+ The OpenAI "o" series configuration.
191
+ 1. The reasoning effort is set to "low" by default.
192
+ 2. The temperature parameter is not supported and will be ignored.
193
+ 3. The system prompt is not supported and will be concatenated to the next user prompt.
194
+
195
+ Parameters:
196
+ ----------
197
+ reasoning_effort : str, Optional
198
+ the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
199
+ """
200
+ super().__init__(**kwargs)
201
+ if reasoning_effort not in ["low", "medium", "high"]:
202
+ raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
203
+
204
+ self.reasoning_effort = reasoning_effort
205
+ self.params["reasoning_effort"] = self.reasoning_effort
206
+
207
+ if "temperature" in self.params:
208
+ warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
209
+ self.params.pop("temperature")
210
+
211
+ def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
212
+ """
213
+ Concatenate system prompts to the next user prompt.
214
+
215
+ Parameters:
216
+ ----------
217
+ messages : List[Dict[str,str]]
218
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
219
+
220
+ Returns:
221
+ -------
222
+ messages : List[Dict[str,str]]
223
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
224
+ """
225
+ system_prompt_holder = ""
226
+ new_messages = []
227
+ for i, message in enumerate(messages):
228
+ # if system prompt, store it in system_prompt_holder
229
+ if message['role'] == 'system':
230
+ system_prompt_holder = message['content']
231
+ # if user prompt, concatenate it with system_prompt_holder
232
+ elif message['role'] == 'user':
233
+ if system_prompt_holder:
234
+ new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
235
+ system_prompt_holder = ""
236
+ else:
237
+ new_message = {'role': message['role'], 'content': message['content']}
238
+
239
+ new_messages.append(new_message)
240
+ # if assistant/other prompt, do nothing
241
+ else:
242
+ new_message = {'role': message['role'], 'content': message['content']}
243
+ new_messages.append(new_message)
244
+
245
+ return new_messages
246
+
247
+ def postprocess_response(self, response:Union[str, Generator[str, None, None]]) -> Union[str, Generator[Dict[str, str], None, None]]:
248
+ """
249
+ This method postprocesses the LLM response after it is generated.
250
+
251
+ Parameters:
252
+ ----------
253
+ response : Union[str, Generator[str, None, None]]
254
+ the LLM response. Can be a string or a generator.
255
+
256
+ Returns: Union[str, Generator[Dict[str, str], None, None]]
257
+ the postprocessed LLM response.
258
+ if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
259
+ """
260
+ if isinstance(response, str):
261
+ return response
262
+
263
+ def _process_stream():
264
+ for chunk in response:
265
+ yield {"type": "response", "data": chunk}
266
+
267
+ return _process_stream()
5
268
 
6
269
 
7
270
  class InferenceEngine:
8
271
  @abc.abstractmethod
9
- def __init__(self):
272
+ def __init__(self, config:LLMConfig, **kwrs):
10
273
  """
11
274
  This is an abstract class to provide interfaces for LLM inference engines.
12
275
  Children classes that inherts this class can be used in extrators. Must implement chat() method.
276
+
277
+ Parameters:
278
+ ----------
279
+ config : LLMConfig
280
+ the LLM configuration. Must be a child class of LLMConfig.
13
281
  """
14
282
  return NotImplemented
15
283
 
16
284
 
17
285
  @abc.abstractmethod
18
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
19
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
286
+ def chat(self, messages:List[Dict[str,str]],
287
+ verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
20
288
  """
21
289
  This method inputs chat messages and outputs LLM generated text.
22
290
 
@@ -24,10 +292,6 @@ class InferenceEngine:
24
292
  ----------
25
293
  messages : List[Dict[str,str]]
26
294
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
27
- max_new_tokens : str, Optional
28
- the max number of new tokens LLM can generate.
29
- temperature : float, Optional
30
- the temperature for token sampling.
31
295
  verbose : bool, Optional
32
296
  if True, LLM generated text will be printed in terminal in real-time.
33
297
  stream : bool, Optional
@@ -35,9 +299,18 @@ class InferenceEngine:
35
299
  """
36
300
  return NotImplemented
37
301
 
302
+ def _format_config(self) -> Dict[str, Any]:
303
+ """
304
+ This method format the LLM configuration with the correct key for the inference engine.
305
+
306
+ Return : Dict[str, Any]
307
+ the config parameters.
308
+ """
309
+ return NotImplemented
310
+
38
311
 
39
312
  class LlamaCppInferenceEngine(InferenceEngine):
40
- def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, **kwrs):
313
+ def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, config:LLMConfig=None, **kwrs):
41
314
  """
42
315
  The Llama.cpp inference engine.
43
316
 
@@ -52,12 +325,16 @@ class LlamaCppInferenceEngine(InferenceEngine):
52
325
  context length that LLM will evaluate.
53
326
  n_gpu_layers : int, Optional
54
327
  number of layers to offload to GPU. Default is all layers (-1).
328
+ config : LLMConfig
329
+ the LLM configuration.
55
330
  """
56
331
  from llama_cpp import Llama
57
332
  self.repo_id = repo_id
58
333
  self.gguf_filename = gguf_filename
59
334
  self.n_ctx = n_ctx
60
335
  self.n_gpu_layers = n_gpu_layers
336
+ self.config = config if config else BasicLLMConfig()
337
+ self.formatted_params = self._format_config()
61
338
 
62
339
  self.model = Llama.from_pretrained(
63
340
  repo_id=self.repo_id,
@@ -73,8 +350,18 @@ class LlamaCppInferenceEngine(InferenceEngine):
73
350
  """
74
351
  del self.model
75
352
 
353
+ def _format_config(self) -> Dict[str, Any]:
354
+ """
355
+ This method format the LLM configuration with the correct key for the inference engine.
356
+ """
357
+ formatted_params = self.config.params.copy()
358
+ if "max_new_tokens" in formatted_params:
359
+ formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
360
+ formatted_params.pop("max_new_tokens")
361
+
362
+ return formatted_params
76
363
 
77
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, verbose:bool=False, **kwrs) -> str:
364
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False) -> str:
78
365
  """
79
366
  This method inputs chat messages and outputs LLM generated text.
80
367
 
@@ -82,19 +369,15 @@ class LlamaCppInferenceEngine(InferenceEngine):
82
369
  ----------
83
370
  messages : List[Dict[str,str]]
84
371
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
85
- max_new_tokens : str, Optional
86
- the max number of new tokens LLM can generate.
87
- temperature : float, Optional
88
- the temperature for token sampling.
89
372
  verbose : bool, Optional
90
373
  if True, LLM generated text will be printed in terminal in real-time.
91
374
  """
375
+ processed_messages = self.config.preprocess_messages(messages)
376
+
92
377
  response = self.model.create_chat_completion(
93
- messages=messages,
94
- max_tokens=max_new_tokens,
95
- temperature=temperature,
378
+ messages=processed_messages,
96
379
  stream=verbose,
97
- **kwrs
380
+ **self.formatted_params
98
381
  )
99
382
 
100
383
  if verbose:
@@ -105,13 +388,14 @@ class LlamaCppInferenceEngine(InferenceEngine):
105
388
  res += out_dict['content']
106
389
  print(out_dict['content'], end='', flush=True)
107
390
  print('\n')
108
- return res
391
+ return self.config.postprocess_response(res)
109
392
 
110
- return response['choices'][0]['message']['content']
393
+ res = response['choices'][0]['message']['content']
394
+ return self.config.postprocess_response(res)
111
395
 
112
396
 
113
397
  class OllamaInferenceEngine(InferenceEngine):
114
- def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, **kwrs):
398
+ def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, config:LLMConfig=None, **kwrs):
115
399
  """
116
400
  The Ollama inference engine.
117
401
 
@@ -123,6 +407,8 @@ class OllamaInferenceEngine(InferenceEngine):
123
407
  context length that LLM will evaluate.
124
408
  keep_alive : int, Optional
125
409
  seconds to hold the LLM after the last API call.
410
+ config : LLMConfig
411
+ the LLM configuration.
126
412
  """
127
413
  if importlib.util.find_spec("ollama") is None:
128
414
  raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
@@ -133,9 +419,22 @@ class OllamaInferenceEngine(InferenceEngine):
133
419
  self.model_name = model_name
134
420
  self.num_ctx = num_ctx
135
421
  self.keep_alive = keep_alive
422
+ self.config = config if config else BasicLLMConfig()
423
+ self.formatted_params = self._format_config()
136
424
 
137
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
138
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
425
+ def _format_config(self) -> Dict[str, Any]:
426
+ """
427
+ This method format the LLM configuration with the correct key for the inference engine.
428
+ """
429
+ formatted_params = self.config.params.copy()
430
+ if "max_new_tokens" in formatted_params:
431
+ formatted_params["num_predict"] = formatted_params["max_new_tokens"]
432
+ formatted_params.pop("max_new_tokens")
433
+
434
+ return formatted_params
435
+
436
+ def chat(self, messages:List[Dict[str,str]],
437
+ verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
139
438
  """
140
439
  This method inputs chat messages and outputs VLM generated text.
141
440
 
@@ -143,21 +442,19 @@ class OllamaInferenceEngine(InferenceEngine):
143
442
  ----------
144
443
  messages : List[Dict[str,str]]
145
444
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
146
- max_new_tokens : str, Optional
147
- the max number of new tokens VLM can generate.
148
- temperature : float, Optional
149
- the temperature for token sampling.
150
445
  verbose : bool, Optional
151
446
  if True, VLM generated text will be printed in terminal in real-time.
152
447
  stream : bool, Optional
153
448
  if True, returns a generator that yields the output in real-time.
154
449
  """
155
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs}
450
+ processed_messages = self.config.preprocess_messages(messages)
451
+
452
+ options={'num_ctx': self.num_ctx, **self.formatted_params}
156
453
  if stream:
157
454
  def _stream_generator():
158
455
  response_stream = self.client.chat(
159
456
  model=self.model_name,
160
- messages=messages,
457
+ messages=processed_messages,
161
458
  options=options,
162
459
  stream=True,
163
460
  keep_alive=self.keep_alive
@@ -167,12 +464,12 @@ class OllamaInferenceEngine(InferenceEngine):
167
464
  if content_chunk:
168
465
  yield content_chunk
169
466
 
170
- return _stream_generator()
467
+ return self.config.postprocess_response(_stream_generator())
171
468
 
172
469
  elif verbose:
173
470
  response = self.client.chat(
174
471
  model=self.model_name,
175
- messages=messages,
472
+ messages=processed_messages,
176
473
  options=options,
177
474
  stream=True,
178
475
  keep_alive=self.keep_alive
@@ -184,48 +481,82 @@ class OllamaInferenceEngine(InferenceEngine):
184
481
  print(content_chunk, end='', flush=True)
185
482
  res += content_chunk
186
483
  print('\n')
187
- return res
484
+ return self.config.postprocess_response(res)
188
485
 
189
486
  else:
190
487
  response = self.client.chat(
191
488
  model=self.model_name,
192
- messages=messages,
489
+ messages=processed_messages,
193
490
  options=options,
194
491
  stream=False,
195
492
  keep_alive=self.keep_alive
196
493
  )
197
- return response.get('message', {}).get('content')
494
+ res = response.get('message', {}).get('content')
495
+ return self.config.postprocess_response(res)
496
+
198
497
 
199
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
498
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
200
499
  """
201
500
  Async version of chat method. Streaming is not supported.
202
501
  """
502
+ processed_messages = self.config.preprocess_messages(messages)
503
+
203
504
  response = await self.async_client.chat(
204
505
  model=self.model_name,
205
- messages=messages,
206
- options={'temperature':temperature, 'num_ctx': self.num_ctx, 'num_predict': max_new_tokens, **kwrs},
506
+ messages=processed_messages,
507
+ options={'num_ctx': self.num_ctx, **self.formatted_params},
207
508
  stream=False,
208
509
  keep_alive=self.keep_alive
209
510
  )
210
511
 
211
- return response['message']['content']
512
+ res = response['message']['content']
513
+ return self.config.postprocess_response(res)
212
514
 
213
515
 
214
516
  class HuggingFaceHubInferenceEngine(InferenceEngine):
215
- def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, **kwrs):
517
+ def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, config:LLMConfig=None, **kwrs):
216
518
  """
217
519
  The Huggingface_hub InferenceClient inference engine.
218
520
  For parameters and documentation, refer to https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client
521
+
522
+ Parameters:
523
+ ----------
524
+ model : str
525
+ the model name exactly as shown in Huggingface repo
526
+ token : str, Optional
527
+ the Huggingface token. If None, will use the token in os.environ['HF_TOKEN'].
528
+ base_url : str, Optional
529
+ the base url for the LLM server. If None, will use the default Huggingface Hub URL.
530
+ api_key : str, Optional
531
+ the API key for the LLM server.
532
+ config : LLMConfig
533
+ the LLM configuration.
219
534
  """
220
535
  if importlib.util.find_spec("huggingface_hub") is None:
221
536
  raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
222
537
 
223
538
  from huggingface_hub import InferenceClient, AsyncInferenceClient
539
+ self.model = model
540
+ self.base_url = base_url
224
541
  self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
225
542
  self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
543
+ self.config = config if config else BasicLLMConfig()
544
+ self.formatted_params = self._format_config()
226
545
 
227
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
228
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
546
+ def _format_config(self) -> Dict[str, Any]:
547
+ """
548
+ This method format the LLM configuration with the correct key for the inference engine.
549
+ """
550
+ formatted_params = self.config.params.copy()
551
+ if "max_new_tokens" in formatted_params:
552
+ formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
553
+ formatted_params.pop("max_new_tokens")
554
+
555
+ return formatted_params
556
+
557
+
558
+ def chat(self, messages:List[Dict[str,str]],
559
+ verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
229
560
  """
230
561
  This method inputs chat messages and outputs LLM generated text.
231
562
 
@@ -233,38 +564,32 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
233
564
  ----------
234
565
  messages : List[Dict[str,str]]
235
566
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
236
- max_new_tokens : str, Optional
237
- the max number of new tokens LLM can generate.
238
- temperature : float, Optional
239
- the temperature for token sampling.
240
567
  verbose : bool, Optional
241
568
  if True, VLM generated text will be printed in terminal in real-time.
242
569
  stream : bool, Optional
243
570
  if True, returns a generator that yields the output in real-time.
244
571
  """
572
+ processed_messages = self.config.preprocess_messages(messages)
573
+
245
574
  if stream:
246
575
  def _stream_generator():
247
576
  response_stream = self.client.chat.completions.create(
248
- messages=messages,
249
- max_tokens=max_new_tokens,
250
- temperature=temperature,
577
+ messages=processed_messages,
251
578
  stream=True,
252
- **kwrs
579
+ **self.formatted_params
253
580
  )
254
581
  for chunk in response_stream:
255
582
  content_chunk = chunk.get('choices')[0].get('delta').get('content')
256
583
  if content_chunk:
257
584
  yield content_chunk
258
585
 
259
- return _stream_generator()
586
+ return self.config.postprocess_response(_stream_generator())
260
587
 
261
588
  elif verbose:
262
589
  response = self.client.chat.completions.create(
263
- messages=messages,
264
- max_tokens=max_new_tokens,
265
- temperature=temperature,
590
+ messages=processed_messages,
266
591
  stream=True,
267
- **kwrs
592
+ **self.formatted_params
268
593
  )
269
594
 
270
595
  res = ''
@@ -273,35 +598,35 @@ class HuggingFaceHubInferenceEngine(InferenceEngine):
273
598
  if content_chunk:
274
599
  res += content_chunk
275
600
  print(content_chunk, end='', flush=True)
276
- return res
601
+ return self.config.postprocess_response(res)
277
602
 
278
603
  else:
279
604
  response = self.client.chat.completions.create(
280
- messages=messages,
281
- max_tokens=max_new_tokens,
282
- temperature=temperature,
605
+ messages=processed_messages,
283
606
  stream=False,
284
- **kwrs
607
+ **self.formatted_params
285
608
  )
286
- return response.choices[0].message.content
609
+ res = response.choices[0].message.content
610
+ return self.config.postprocess_response(res)
287
611
 
288
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
612
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
289
613
  """
290
614
  Async version of chat method. Streaming is not supported.
291
615
  """
616
+ processed_messages = self.config.preprocess_messages(messages)
617
+
292
618
  response = await self.client_async.chat.completions.create(
293
- messages=messages,
294
- max_tokens=max_new_tokens,
295
- temperature=temperature,
619
+ messages=processed_messages,
296
620
  stream=False,
297
- **kwrs
621
+ **self.formatted_params
298
622
  )
299
623
 
300
- return response.choices[0].message.content
624
+ res = response.choices[0].message.content
625
+ return self.config.postprocess_response(res)
301
626
 
302
627
 
303
628
  class OpenAIInferenceEngine(InferenceEngine):
304
- def __init__(self, model:str, reasoning_model:bool=False, **kwrs):
629
+ def __init__(self, model:str, config:LLMConfig=None, **kwrs):
305
630
  """
306
631
  The OpenAI API inference engine. Supports OpenAI models and OpenAI compatible servers:
307
632
  - vLLM OpenAI compatible server (https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html)
@@ -313,8 +638,6 @@ class OpenAIInferenceEngine(InferenceEngine):
313
638
  ----------
314
639
  model_name : str
315
640
  model name as described in https://platform.openai.com/docs/models
316
- reasoning_model : bool, Optional
317
- indicator for OpenAI reasoning models ("o" series).
318
641
  """
319
642
  if importlib.util.find_spec("openai") is None:
320
643
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
@@ -323,10 +646,21 @@ class OpenAIInferenceEngine(InferenceEngine):
323
646
  self.client = OpenAI(**kwrs)
324
647
  self.async_client = AsyncOpenAI(**kwrs)
325
648
  self.model = model
326
- self.reasoning_model = reasoning_model
649
+ self.config = config if config else BasicLLMConfig()
650
+ self.formatted_params = self._format_config()
651
+
652
+ def _format_config(self) -> Dict[str, Any]:
653
+ """
654
+ This method format the LLM configuration with the correct key for the inference engine.
655
+ """
656
+ formatted_params = self.config.params.copy()
657
+ if "max_new_tokens" in formatted_params:
658
+ formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
659
+ formatted_params.pop("max_new_tokens")
660
+
661
+ return formatted_params
327
662
 
328
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
329
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
663
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
330
664
  """
331
665
  This method inputs chat messages and outputs LLM generated text.
332
666
 
@@ -334,177 +668,81 @@ class OpenAIInferenceEngine(InferenceEngine):
334
668
  ----------
335
669
  messages : List[Dict[str,str]]
336
670
  a list of dict with role and content. role must be one of {"system", "user", "assistant"}
337
- max_new_tokens : str, Optional
338
- the max number of new tokens LLM can generate.
339
- temperature : float, Optional
340
- the temperature for token sampling.
341
671
  verbose : bool, Optional
342
672
  if True, VLM generated text will be printed in terminal in real-time.
343
673
  stream : bool, Optional
344
674
  if True, returns a generator that yields the output in real-time.
345
675
  """
346
- # For reasoning models
347
- if self.reasoning_model:
348
- # Reasoning models do not support temperature parameter
349
- if temperature != 0.0:
350
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
351
-
352
- # Reasoning models do not support system prompts
353
- if any(msg['role'] == 'system' for msg in messages):
354
- warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
355
- messages = [msg for msg in messages if msg['role'] != 'system']
356
-
357
-
358
- if stream:
359
- def _stream_generator():
360
- response_stream = self.client.chat.completions.create(
361
- model=self.model,
362
- messages=messages,
363
- max_completion_tokens=max_new_tokens,
364
- stream=True,
365
- **kwrs
366
- )
367
- for chunk in response_stream:
368
- if len(chunk.choices) > 0:
369
- if chunk.choices[0].delta.content is not None:
370
- yield chunk.choices[0].delta.content
371
- if chunk.choices[0].finish_reason == "length":
372
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
373
- if self.reasoning_model:
374
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
375
- return _stream_generator()
376
-
377
- elif verbose:
378
- response = self.client.chat.completions.create(
379
- model=self.model,
380
- messages=messages,
381
- max_completion_tokens=max_new_tokens,
382
- stream=True,
383
- **kwrs
384
- )
385
- res = ''
386
- for chunk in response:
387
- if len(chunk.choices) > 0:
388
- if chunk.choices[0].delta.content is not None:
389
- res += chunk.choices[0].delta.content
390
- print(chunk.choices[0].delta.content, end="", flush=True)
391
- if chunk.choices[0].finish_reason == "length":
392
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
393
- if self.reasoning_model:
394
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
676
+ processed_messages = self.config.preprocess_messages(messages)
395
677
 
396
- print('\n')
397
- return res
398
- else:
399
- response = self.client.chat.completions.create(
400
- model=self.model,
401
- messages=messages,
402
- max_completion_tokens=max_new_tokens,
403
- stream=False,
404
- **kwrs
405
- )
406
- return response.choices[0].message.content
407
-
408
- # For non-reasoning models
409
- else:
410
- if stream:
411
- def _stream_generator():
412
- response_stream = self.client.chat.completions.create(
413
- model=self.model,
414
- messages=messages,
415
- max_tokens=max_new_tokens,
416
- temperature=temperature,
417
- stream=True,
418
- **kwrs
419
- )
420
- for chunk in response_stream:
421
- if len(chunk.choices) > 0:
422
- if chunk.choices[0].delta.content is not None:
423
- yield chunk.choices[0].delta.content
424
- if chunk.choices[0].finish_reason == "length":
425
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
426
- if self.reasoning_model:
427
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
428
- return _stream_generator()
429
-
430
- elif verbose:
431
- response = self.client.chat.completions.create(
432
- model=self.model,
433
- messages=messages,
434
- max_tokens=max_new_tokens,
435
- temperature=temperature,
436
- stream=True,
437
- **kwrs
438
- )
439
- res = ''
440
- for chunk in response:
678
+ if stream:
679
+ def _stream_generator():
680
+ response_stream = self.client.chat.completions.create(
681
+ model=self.model,
682
+ messages=processed_messages,
683
+ stream=True,
684
+ **self.formatted_params
685
+ )
686
+ for chunk in response_stream:
441
687
  if len(chunk.choices) > 0:
442
688
  if chunk.choices[0].delta.content is not None:
443
- res += chunk.choices[0].delta.content
444
- print(chunk.choices[0].delta.content, end="", flush=True)
689
+ yield chunk.choices[0].delta.content
445
690
  if chunk.choices[0].finish_reason == "length":
446
691
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
447
- if self.reasoning_model:
448
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
449
-
450
- print('\n')
451
- return res
452
-
453
- else:
454
- response = self.client.chat.completions.create(
455
- model=self.model,
456
- messages=messages,
457
- max_tokens=max_new_tokens,
458
- temperature=temperature,
459
- stream=False,
460
- **kwrs
461
- )
462
-
463
- return response.choices[0].message.content
464
-
465
-
466
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=4096, temperature:float=0.0, **kwrs) -> str:
467
- """
468
- Async version of chat method. Streaming is not supported.
469
- """
470
- if self.reasoning_model:
471
- # Reasoning models do not support temperature parameter
472
- if temperature != 0.0:
473
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
474
692
 
475
- # Reasoning models do not support system prompts
476
- if any(msg['role'] == 'system' for msg in messages):
477
- warnings.warn("Reasoning models do not support system prompts. Will be ignored.", UserWarning)
478
- messages = [msg for msg in messages if msg['role'] != 'system']
693
+ return self.config.postprocess_response(_stream_generator())
479
694
 
480
- response = await self.async_client.chat.completions.create(
695
+ elif verbose:
696
+ response = self.client.chat.completions.create(
481
697
  model=self.model,
482
- messages=messages,
483
- max_completion_tokens=max_new_tokens,
484
- stream=False,
485
- **kwrs
698
+ messages=processed_messages,
699
+ stream=True,
700
+ **self.formatted_params
486
701
  )
702
+ res = ''
703
+ for chunk in response:
704
+ if len(chunk.choices) > 0:
705
+ if chunk.choices[0].delta.content is not None:
706
+ res += chunk.choices[0].delta.content
707
+ print(chunk.choices[0].delta.content, end="", flush=True)
708
+ if chunk.choices[0].finish_reason == "length":
709
+ warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
487
710
 
711
+ print('\n')
712
+ return self.config.postprocess_response(res)
488
713
  else:
489
- response = await self.async_client.chat.completions.create(
714
+ response = self.client.chat.completions.create(
490
715
  model=self.model,
491
- messages=messages,
492
- max_tokens=max_new_tokens,
493
- temperature=temperature,
716
+ messages=processed_messages,
494
717
  stream=False,
495
- **kwrs
718
+ **self.formatted_params
496
719
  )
720
+ res = response.choices[0].message.content
721
+ return self.config.postprocess_response(res)
722
+
723
+
724
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
725
+ """
726
+ Async version of chat method. Streaming is not supported.
727
+ """
728
+ processed_messages = self.config.preprocess_messages(messages)
729
+
730
+ response = await self.async_client.chat.completions.create(
731
+ model=self.model,
732
+ messages=processed_messages,
733
+ stream=False,
734
+ **self.formatted_params
735
+ )
497
736
 
498
737
  if response.choices[0].finish_reason == "length":
499
738
  warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
500
- if self.reasoning_model:
501
- warnings.warn("max_new_tokens includes reasoning tokens and output tokens.", UserWarning)
502
739
 
503
- return response.choices[0].message.content
740
+ res = response.choices[0].message.content
741
+ return self.config.postprocess_response(res)
504
742
 
505
743
 
506
744
  class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
507
- def __init__(self, model:str, api_version:str, reasoning_model:bool=False, **kwrs):
745
+ def __init__(self, model:str, api_version:str, config:LLMConfig=None, **kwrs):
508
746
  """
509
747
  The Azure OpenAI API inference engine.
510
748
  For parameters and documentation, refer to
@@ -517,8 +755,8 @@ class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
517
755
  model name as described in https://platform.openai.com/docs/models
518
756
  api_version : str
519
757
  the Azure OpenAI API version
520
- reasoning_model : bool, Optional
521
- indicator for OpenAI reasoning models ("o" series).
758
+ config : LLMConfig
759
+ the LLM configuration.
522
760
  """
523
761
  if importlib.util.find_spec("openai") is None:
524
762
  raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
@@ -530,11 +768,12 @@ class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
530
768
  **kwrs)
531
769
  self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
532
770
  **kwrs)
533
- self.reasoning_model = reasoning_model
771
+ self.config = config if config else BasicLLMConfig()
772
+ self.formatted_params = self._format_config()
534
773
 
535
774
 
536
775
  class LiteLLMInferenceEngine(InferenceEngine):
537
- def __init__(self, model:str=None, base_url:str=None, api_key:str=None):
776
+ def __init__(self, model:str=None, base_url:str=None, api_key:str=None, config:LLMConfig=None):
538
777
  """
539
778
  The LiteLLM inference engine.
540
779
  For parameters and documentation, refer to https://github.com/BerriAI/litellm?tab=readme-ov-file
@@ -547,6 +786,8 @@ class LiteLLMInferenceEngine(InferenceEngine):
547
786
  the base url for the LLM server
548
787
  api_key : str, Optional
549
788
  the API key for the LLM server
789
+ config : LLMConfig
790
+ the LLM configuration.
550
791
  """
551
792
  if importlib.util.find_spec("litellm") is None:
552
793
  raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
@@ -556,36 +797,44 @@ class LiteLLMInferenceEngine(InferenceEngine):
556
797
  self.model = model
557
798
  self.base_url = base_url
558
799
  self.api_key = api_key
800
+ self.config = config if config else BasicLLMConfig()
801
+ self.formatted_params = self._format_config()
559
802
 
560
- def chat(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0,
561
- verbose:bool=False, stream:bool=False, **kwrs) -> Union[str, Generator[str, None, None]]:
803
+ def _format_config(self) -> Dict[str, Any]:
804
+ """
805
+ This method format the LLM configuration with the correct key for the inference engine.
806
+ """
807
+ formatted_params = self.config.params.copy()
808
+ if "max_new_tokens" in formatted_params:
809
+ formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
810
+ formatted_params.pop("max_new_tokens")
811
+
812
+ return formatted_params
813
+
814
+ def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False) -> Union[str, Generator[Dict[str, str], None, None]]:
562
815
  """
563
816
  This method inputs chat messages and outputs LLM generated text.
564
817
 
565
818
  Parameters:
566
819
  ----------
567
820
  messages : List[Dict[str,str]]
568
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
569
- max_new_tokens : str, Optional
570
- the max number of new tokens LLM can generate.
571
- temperature : float, Optional
572
- the temperature for token sampling.
821
+ a list of dict with role and content. role must be one of {"system", "user", "assistant"}
573
822
  verbose : bool, Optional
574
823
  if True, VLM generated text will be printed in terminal in real-time.
575
824
  stream : bool, Optional
576
825
  if True, returns a generator that yields the output in real-time.
577
826
  """
827
+ processed_messages = self.config.preprocess_messages(messages)
828
+
578
829
  if stream:
579
830
  def _stream_generator():
580
831
  response_stream = self.litellm.completion(
581
832
  model=self.model,
582
- messages=messages,
583
- max_tokens=max_new_tokens,
584
- temperature=temperature,
833
+ messages=processed_messages,
585
834
  stream=True,
586
835
  base_url=self.base_url,
587
836
  api_key=self.api_key,
588
- **kwrs
837
+ **self.formatted_params
589
838
  )
590
839
 
591
840
  for chunk in response_stream:
@@ -593,18 +842,16 @@ class LiteLLMInferenceEngine(InferenceEngine):
593
842
  if chunk_content:
594
843
  yield chunk_content
595
844
 
596
- return _stream_generator()
845
+ return self.config.postprocess_response(_stream_generator())
597
846
 
598
847
  elif verbose:
599
848
  response = self.litellm.completion(
600
849
  model=self.model,
601
- messages=messages,
602
- max_tokens=max_new_tokens,
603
- temperature=temperature,
850
+ messages=processed_messages,
604
851
  stream=True,
605
852
  base_url=self.base_url,
606
853
  api_key=self.api_key,
607
- **kwrs
854
+ **self.formatted_params
608
855
  )
609
856
 
610
857
  res = ''
@@ -614,34 +861,34 @@ class LiteLLMInferenceEngine(InferenceEngine):
614
861
  res += chunk_content
615
862
  print(chunk_content, end='', flush=True)
616
863
 
617
- return res
864
+ return self.config.postprocess_response(res)
618
865
 
619
866
  else:
620
867
  response = self.litellm.completion(
621
868
  model=self.model,
622
- messages=messages,
623
- max_tokens=max_new_tokens,
624
- temperature=temperature,
869
+ messages=processed_messages,
625
870
  stream=False,
626
871
  base_url=self.base_url,
627
872
  api_key=self.api_key,
628
- **kwrs
873
+ **self.formatted_params
629
874
  )
630
- return response.choices[0].message.content
875
+ res = response.choices[0].message.content
876
+ return self.config.postprocess_response(res)
631
877
 
632
- async def chat_async(self, messages:List[Dict[str,str]], max_new_tokens:int=2048, temperature:float=0.0, **kwrs) -> str:
878
+ async def chat_async(self, messages:List[Dict[str,str]]) -> str:
633
879
  """
634
880
  Async version of chat method. Streaming is not supported.
635
881
  """
882
+ processed_messages = self.config.preprocess_messages(messages)
883
+
636
884
  response = await self.litellm.acompletion(
637
885
  model=self.model,
638
- messages=messages,
639
- max_tokens=max_new_tokens,
640
- temperature=temperature,
886
+ messages=processed_messages,
641
887
  stream=False,
642
888
  base_url=self.base_url,
643
889
  api_key=self.api_key,
644
- **kwrs
890
+ **self.formatted_params
645
891
  )
646
892
 
647
- return response.get('choices')[0].get('message').get('content')
893
+ res = response.get('choices')[0].get('message').get('content')
894
+ return self.config.postprocess_response(res)