llm-ie 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_ie/engines.py CHANGED
@@ -1,1491 +1,37 @@
1
- import abc
2
- import os
3
- import re
4
- import warnings
5
- import importlib.util
6
- from typing import Any, Tuple, List, Dict, Union, Generator
7
-
8
-
9
- class LLMConfig(abc.ABC):
10
- def __init__(self, **kwargs):
11
- """
12
- This is an abstract class to provide interfaces for LLM configuration.
13
- Children classes that inherts this class can be used in extrators and prompt editor.
14
- Common LLM parameters: max_new_tokens, temperature, top_p, top_k, min_p.
15
- """
16
- self.params = kwargs.copy()
17
-
18
-
19
- @abc.abstractmethod
20
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
21
- """
22
- This method preprocesses the input messages before passing them to the LLM.
23
-
24
- Parameters:
25
- ----------
26
- messages : List[Dict[str,str]]
27
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
28
-
29
- Returns:
30
- -------
31
- messages : List[Dict[str,str]]
32
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
33
- """
34
- return NotImplemented
35
-
36
- @abc.abstractmethod
37
- def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
38
- """
39
- This method postprocesses the LLM response after it is generated.
40
-
41
- Parameters:
42
- ----------
43
- response : Union[str, Dict[str, str], Generator[Dict[str, str], None, None]]
44
- the LLM response. Can be a dict or a generator.
45
-
46
- Returns:
47
- -------
48
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
49
- the postprocessed LLM response
50
- """
51
- return NotImplemented
52
-
53
-
54
- class BasicLLMConfig(LLMConfig):
55
- def __init__(self, max_new_tokens:int=2048, temperature:float=0.0, **kwargs):
56
- """
57
- The basic LLM configuration for most non-reasoning models.
58
- """
59
- super().__init__(**kwargs)
60
- self.max_new_tokens = max_new_tokens
61
- self.temperature = temperature
62
- self.params["max_new_tokens"] = self.max_new_tokens
63
- self.params["temperature"] = self.temperature
64
-
65
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
66
- """
67
- This method preprocesses the input messages before passing them to the LLM.
68
-
69
- Parameters:
70
- ----------
71
- messages : List[Dict[str,str]]
72
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
73
-
74
- Returns:
75
- -------
76
- messages : List[Dict[str,str]]
77
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
78
- """
79
- return messages.copy()
80
-
81
- def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
82
- """
83
- This method postprocesses the LLM response after it is generated.
84
-
85
- Parameters:
86
- ----------
87
- response : Union[str, Dict[str, str], Generator[str, None, None]]
88
- the LLM response. Can be a string or a generator.
89
-
90
- Returns: Union[Dict[str,str], Generator[Dict[str, str], None, None]]
91
- the postprocessed LLM response.
92
- If input is a string, the output will be a dict {"response": <response>}.
93
- if input is a generator, the output will be a generator {"type": "response", "data": <content>}.
94
- """
95
- if isinstance(response, str):
96
- return {"response": response}
97
-
98
- elif isinstance(response, dict):
99
- if "response" in response:
100
- return response
101
- else:
102
- warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
103
- return {"response": ""}
104
-
105
- elif isinstance(response, Generator):
106
- def _process_stream():
107
- for chunk in response:
108
- if isinstance(chunk, dict):
109
- yield chunk
110
- elif isinstance(chunk, str):
111
- yield {"type": "response", "data": chunk}
112
-
113
- return _process_stream()
114
-
115
- else:
116
- warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
117
- return {"response": ""}
118
-
119
- class ReasoningLLMConfig(LLMConfig):
120
- def __init__(self, thinking_token_start="<think>", thinking_token_end="</think>", **kwargs):
121
- """
122
- The general LLM configuration for reasoning models.
123
- """
124
- super().__init__(**kwargs)
125
- self.thinking_token_start = thinking_token_start
126
- self.thinking_token_end = thinking_token_end
127
-
128
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
129
- """
130
- This method preprocesses the input messages before passing them to the LLM.
131
-
132
- Parameters:
133
- ----------
134
- messages : List[Dict[str,str]]
135
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
136
-
137
- Returns:
138
- -------
139
- messages : List[Dict[str,str]]
140
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
141
- """
142
- return messages.copy()
143
-
144
- def postprocess_response(self, response:Union[str, Dict[str, str], Generator[str, None, None]]) -> Union[Dict[str,str], Generator[Dict[str,str], None, None]]:
145
- """
146
- This method postprocesses the LLM response after it is generated.
147
- 1. If input is a string, it will extract the reasoning and response based on the thinking tokens.
148
- 2. If input is a dict, it should contain keys "reasoning" and "response". This is for inference engines that already parse reasoning and response.
149
- 3. If input is a generator,
150
- a. if the chunk is a dict, it should contain keys "type" and "data". This is for inference engines that already parse reasoning and response.
151
- b. if the chunk is a string, it will yield dicts with keys "type" and "data" based on the thinking tokens.
152
-
153
- Parameters:
154
- ----------
155
- response : Union[str, Generator[str, None, None]]
156
- the LLM response. Can be a string or a generator.
157
-
158
- Returns:
159
- -------
160
- response : Union[str, Generator[str, None, None]]
161
- the postprocessed LLM response as a dict {"reasoning": <reasoning>, "response": <content>}
162
- if input is a generator, the output will be a generator {"type": <reasoning or response>, "data": <content>}.
163
- """
164
- if isinstance(response, str):
165
- # get contents between thinking_token_start and thinking_token_end
166
- pattern = f"{re.escape(self.thinking_token_start)}(.*?){re.escape(self.thinking_token_end)}"
167
- match = re.search(pattern, response, re.DOTALL)
168
- reasoning = match.group(1) if match else ""
169
- # get response AFTER thinking_token_end
170
- response = re.sub(f".*?{self.thinking_token_end}", "", response, flags=re.DOTALL).strip()
171
- return {"reasoning": reasoning, "response": response}
172
-
173
- elif isinstance(response, dict):
174
- if "reasoning" in response and "response" in response:
175
- return response
176
- else:
177
- warnings.warn(f"Invalid response dict keys: {response.keys()}. Returning default empty dict.", UserWarning)
178
- return {"reasoning": "", "response": ""}
179
-
180
- elif isinstance(response, Generator):
181
- def _process_stream():
182
- think_flag = False
183
- buffer = ""
184
- for chunk in response:
185
- if isinstance(chunk, dict):
186
- yield chunk
187
-
188
- elif isinstance(chunk, str):
189
- buffer += chunk
190
- # switch between reasoning and response
191
- if self.thinking_token_start in buffer:
192
- think_flag = True
193
- buffer = buffer.replace(self.thinking_token_start, "")
194
- elif self.thinking_token_end in buffer:
195
- think_flag = False
196
- buffer = buffer.replace(self.thinking_token_end, "")
197
-
198
- # if chunk is in thinking block, tag it as reasoning; else tag it as response
199
- if chunk not in [self.thinking_token_start, self.thinking_token_end]:
200
- if think_flag:
201
- yield {"type": "reasoning", "data": chunk}
202
- else:
203
- yield {"type": "response", "data": chunk}
204
-
205
- return _process_stream()
206
-
207
- else:
208
- warnings.warn(f"Invalid response type: {type(response)}. Returning default empty dict.", UserWarning)
209
- return {"reasoning": "", "response": ""}
210
-
211
- class Qwen3LLMConfig(ReasoningLLMConfig):
212
- def __init__(self, thinking_mode:bool=True, **kwargs):
213
- """
214
- The Qwen3 **hybrid thinking** LLM configuration.
215
- For Qwen3 thinking 2507, use ReasoningLLMConfig instead; for Qwen3 Instruct, use BasicLLMConfig instead.
216
-
217
- Parameters:
218
- ----------
219
- thinking_mode : bool, Optional
220
- if True, a special token "/think" will be placed after each system and user prompt. Otherwise, "/no_think" will be placed.
221
- """
222
- super().__init__(**kwargs)
223
- self.thinking_mode = thinking_mode
224
-
225
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
226
- """
227
- Append a special token to the system and user prompts.
228
- The token is "/think" if thinking_mode is True, otherwise "/no_think".
229
-
230
- Parameters:
231
- ----------
232
- messages : List[Dict[str,str]]
233
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
234
-
235
- Returns:
236
- -------
237
- messages : List[Dict[str,str]]
238
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
239
- """
240
- thinking_token = "/think" if self.thinking_mode else "/no_think"
241
- new_messages = []
242
- for message in messages:
243
- if message['role'] in ['system', 'user']:
244
- new_message = {'role': message['role'], 'content': f"{message['content']} {thinking_token}"}
245
- else:
246
- new_message = {'role': message['role'], 'content': message['content']}
247
-
248
- new_messages.append(new_message)
249
-
250
- return new_messages
251
-
252
-
253
- class OpenAIReasoningLLMConfig(ReasoningLLMConfig):
254
- def __init__(self, reasoning_effort:str=None, **kwargs):
255
- """
256
- The OpenAI "o" series configuration.
257
- 1. The reasoning effort as one of {"low", "medium", "high"}.
258
- For models that do not support setting reasoning effort (e.g., o1-mini, o1-preview), set to None.
259
- 2. The temperature parameter is not supported and will be ignored.
260
- 3. The system prompt is not supported and will be concatenated to the next user prompt.
261
-
262
- Parameters:
263
- ----------
264
- reasoning_effort : str, Optional
265
- the reasoning effort. Must be one of {"low", "medium", "high"}. Default is "low".
266
- """
267
- super().__init__(**kwargs)
268
- if reasoning_effort is not None:
269
- if reasoning_effort not in ["low", "medium", "high"]:
270
- raise ValueError("reasoning_effort must be one of {'low', 'medium', 'high'}.")
271
-
272
- self.reasoning_effort = reasoning_effort
273
- self.params["reasoning_effort"] = self.reasoning_effort
274
-
275
- if "temperature" in self.params:
276
- warnings.warn("Reasoning models do not support temperature parameter. Will be ignored.", UserWarning)
277
- self.params.pop("temperature")
278
-
279
- def preprocess_messages(self, messages:List[Dict[str,str]]) -> List[Dict[str,str]]:
280
- """
281
- Concatenate system prompts to the next user prompt.
282
-
283
- Parameters:
284
- ----------
285
- messages : List[Dict[str,str]]
286
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
287
-
288
- Returns:
289
- -------
290
- messages : List[Dict[str,str]]
291
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
292
- """
293
- system_prompt_holder = ""
294
- new_messages = []
295
- for i, message in enumerate(messages):
296
- # if system prompt, store it in system_prompt_holder
297
- if message['role'] == 'system':
298
- system_prompt_holder = message['content']
299
- # if user prompt, concatenate it with system_prompt_holder
300
- elif message['role'] == 'user':
301
- if system_prompt_holder:
302
- new_message = {'role': message['role'], 'content': f"{system_prompt_holder} {message['content']}"}
303
- system_prompt_holder = ""
304
- else:
305
- new_message = {'role': message['role'], 'content': message['content']}
306
-
307
- new_messages.append(new_message)
308
- # if assistant/other prompt, do nothing
309
- else:
310
- new_message = {'role': message['role'], 'content': message['content']}
311
- new_messages.append(new_message)
312
-
313
- return new_messages
314
-
315
-
316
- class MessagesLogger:
317
- def __init__(self):
318
- """
319
- This class is used to log the messages for InferenceEngine.chat().
320
- """
321
- self.messages_log = []
322
-
323
- def log_messages(self, messages : List[Dict[str,str]]):
324
- """
325
- This method logs the messages to a list.
326
- """
327
- self.messages_log.append(messages)
328
-
329
- def get_messages_log(self) -> List[List[Dict[str,str]]]:
330
- """
331
- This method returns a copy of the current messages log
332
- """
333
- return self.messages_log.copy()
334
-
335
- def clear_messages_log(self):
336
- """
337
- This method clears the current messages log
338
- """
339
- self.messages_log.clear()
340
-
341
-
342
- class InferenceEngine:
343
- @abc.abstractmethod
344
- def __init__(self, config:LLMConfig, **kwrs):
345
- """
346
- This is an abstract class to provide interfaces for LLM inference engines.
347
- Children classes that inherts this class can be used in extrators. Must implement chat() method.
348
-
349
- Parameters:
350
- ----------
351
- config : LLMConfig
352
- the LLM configuration. Must be a child class of LLMConfig.
353
- """
354
- return NotImplemented
355
-
356
- def get_messages_log(self) -> List[List[Dict[str,str]]]:
357
- return self.messages_log.copy()
358
-
359
- def clear_messages_log(self):
360
- self.messages_log = []
361
-
362
-
363
- @abc.abstractmethod
364
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
365
- messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
366
- """
367
- This method inputs chat messages and outputs LLM generated text.
368
-
369
- Parameters:
370
- ----------
371
- messages : List[Dict[str,str]]
372
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
373
- verbose : bool, Optional
374
- if True, LLM generated text will be printed in terminal in real-time.
375
- stream : bool, Optional
376
- if True, returns a generator that yields the output in real-time.
377
- Messages_logger : MessagesLogger, Optional
378
- the message logger that logs the chat messages.
379
-
380
- Returns:
381
- -------
382
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
383
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
384
- """
385
- return NotImplemented
386
-
387
- def _format_config(self) -> Dict[str, Any]:
388
- """
389
- This method format the LLM configuration with the correct key for the inference engine.
390
-
391
- Return : Dict[str, Any]
392
- the config parameters.
393
- """
394
- return NotImplemented
395
-
1
+ from llm_inference_engine import (
2
+ # Configs
3
+ LLMConfig,
4
+ BasicLLMConfig,
5
+ ReasoningLLMConfig,
6
+ Qwen3LLMConfig,
7
+ OpenAIReasoningLLMConfig,
8
+
9
+ # Base Engine
10
+ InferenceEngine,
11
+
12
+ # Concrete Engines
13
+ OllamaInferenceEngine,
14
+ OpenAIInferenceEngine,
15
+ HuggingFaceHubInferenceEngine,
16
+ AzureOpenAIInferenceEngine,
17
+ LiteLLMInferenceEngine,
18
+ OpenAICompatibleInferenceEngine,
19
+ VLLMInferenceEngine,
20
+ SGLangInferenceEngine,
21
+ OpenRouterInferenceEngine
22
+ )
23
+
24
+ from llm_inference_engine.utils import MessagesLogger
396
25
 
397
26
  class LlamaCppInferenceEngine(InferenceEngine):
398
- def __init__(self, repo_id:str, gguf_filename:str, n_ctx:int=4096, n_gpu_layers:int=-1, config:LLMConfig=None, **kwrs):
399
- """
400
- The Llama.cpp inference engine.
401
-
402
- Parameters:
403
- ----------
404
- repo_id : str
405
- the exact name as shown on Huggingface repo
406
- gguf_filename : str
407
- the exact name as shown in Huggingface repo -> Files and versions.
408
- If multiple gguf files are needed, use the first.
409
- n_ctx : int, Optional
410
- context length that LLM will evaluate.
411
- n_gpu_layers : int, Optional
412
- number of layers to offload to GPU. Default is all layers (-1).
413
- config : LLMConfig
414
- the LLM configuration.
415
- """
416
- from llama_cpp import Llama
417
- super().__init__(config)
418
- self.repo_id = repo_id
419
- self.gguf_filename = gguf_filename
420
- self.n_ctx = n_ctx
421
- self.n_gpu_layers = n_gpu_layers
422
- self.config = config if config else BasicLLMConfig()
423
- self.formatted_params = self._format_config()
424
-
425
- self.model = Llama.from_pretrained(
426
- repo_id=self.repo_id,
427
- filename=self.gguf_filename,
428
- n_gpu_layers=n_gpu_layers,
429
- n_ctx=n_ctx,
430
- **kwrs
431
- )
432
-
433
- def __del__(self):
434
- """
435
- When the inference engine is deleted, release memory for model.
436
- """
437
- del self.model
438
-
439
- def _format_config(self) -> Dict[str, Any]:
440
- """
441
- This method format the LLM configuration with the correct key for the inference engine.
442
- """
443
- formatted_params = self.config.params.copy()
444
- if "max_new_tokens" in formatted_params:
445
- formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
446
- formatted_params.pop("max_new_tokens")
447
-
448
- return formatted_params
449
-
450
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, messages_logger:MessagesLogger=None) -> Dict[str,str]:
451
- """
452
- This method inputs chat messages and outputs LLM generated text.
453
-
454
- Parameters:
455
- ----------
456
- messages : List[Dict[str,str]]
457
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
458
- verbose : bool, Optional
459
- if True, LLM generated text will be printed in terminal in real-time.
460
- messages_logger : MessagesLogger, Optional
461
- the message logger that logs the chat messages.
462
- """
463
- # Preprocess messages
464
- processed_messages = self.config.preprocess_messages(messages)
465
- # Generate response
466
- response = self.model.create_chat_completion(
467
- messages=processed_messages,
468
- stream=verbose,
469
- **self.formatted_params
470
- )
471
-
472
- if verbose:
473
- res = ''
474
- for chunk in response:
475
- out_dict = chunk['choices'][0]['delta']
476
- if 'content' in out_dict:
477
- res += out_dict['content']
478
- print(out_dict['content'], end='', flush=True)
479
- print('\n')
480
- return self.config.postprocess_response(res)
481
-
482
- res = response['choices'][0]['message']['content']
483
- # Postprocess response
484
- res_dict = self.config.postprocess_response(res)
485
- # Write to messages log
486
- if messages_logger:
487
- processed_messages.append({"role": "assistant",
488
- "content": res_dict.get("response", ""),
489
- "reasoning": res_dict.get("reasoning", "")})
490
- messages_logger.log_messages(processed_messages)
491
-
492
- return res_dict
493
-
494
-
495
- class OllamaInferenceEngine(InferenceEngine):
496
- def __init__(self, model_name:str, num_ctx:int=4096, keep_alive:int=300, config:LLMConfig=None, **kwrs):
497
- """
498
- The Ollama inference engine.
499
-
500
- Parameters:
501
- ----------
502
- model_name : str
503
- the model name exactly as shown in >> ollama ls
504
- num_ctx : int, Optional
505
- context length that LLM will evaluate.
506
- keep_alive : int, Optional
507
- seconds to hold the LLM after the last API call.
508
- config : LLMConfig
509
- the LLM configuration.
510
- """
511
- if importlib.util.find_spec("ollama") is None:
512
- raise ImportError("ollama-python not found. Please install ollama-python (```pip install ollama```).")
513
-
514
- from ollama import Client, AsyncClient
515
- super().__init__(config)
516
- self.client = Client(**kwrs)
517
- self.async_client = AsyncClient(**kwrs)
518
- self.model_name = model_name
519
- self.num_ctx = num_ctx
520
- self.keep_alive = keep_alive
521
- self.config = config if config else BasicLLMConfig()
522
- self.formatted_params = self._format_config()
523
-
524
- def _format_config(self) -> Dict[str, Any]:
525
- """
526
- This method format the LLM configuration with the correct key for the inference engine.
527
- """
528
- formatted_params = self.config.params.copy()
529
- if "max_new_tokens" in formatted_params:
530
- formatted_params["num_predict"] = formatted_params["max_new_tokens"]
531
- formatted_params.pop("max_new_tokens")
532
-
533
- return formatted_params
534
-
535
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
536
- messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
537
- """
538
- This method inputs chat messages and outputs VLM generated text.
539
-
540
- Parameters:
541
- ----------
542
- messages : List[Dict[str,str]]
543
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
544
- verbose : bool, Optional
545
- if True, VLM generated text will be printed in terminal in real-time.
546
- stream : bool, Optional
547
- if True, returns a generator that yields the output in real-time.
548
- Messages_logger : MessagesLogger, Optional
549
- the message logger that logs the chat messages.
550
-
551
- Returns:
552
- -------
553
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
554
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
555
- """
556
- processed_messages = self.config.preprocess_messages(messages)
557
-
558
- options={'num_ctx': self.num_ctx, **self.formatted_params}
559
- if stream:
560
- def _stream_generator():
561
- response_stream = self.client.chat(
562
- model=self.model_name,
563
- messages=processed_messages,
564
- options=options,
565
- stream=True,
566
- keep_alive=self.keep_alive
567
- )
568
- res = {"reasoning": "", "response": ""}
569
- for chunk in response_stream:
570
- if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
571
- content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
572
- res["reasoning"] += content_chunk
573
- yield {"type": "reasoning", "data": content_chunk}
574
- else:
575
- content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
576
- res["response"] += content_chunk
577
- yield {"type": "response", "data": content_chunk}
578
-
579
- if chunk.done_reason == "length":
580
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
581
-
582
- # Postprocess response
583
- res_dict = self.config.postprocess_response(res)
584
- # Write to messages log
585
- if messages_logger:
586
- processed_messages.append({"role": "assistant",
587
- "content": res_dict.get("response", ""),
588
- "reasoning": res_dict.get("reasoning", "")})
589
- messages_logger.log_messages(processed_messages)
590
-
591
- return self.config.postprocess_response(_stream_generator())
592
-
593
- elif verbose:
594
- response = self.client.chat(
595
- model=self.model_name,
596
- messages=processed_messages,
597
- options=options,
598
- stream=True,
599
- keep_alive=self.keep_alive
600
- )
601
-
602
- res = {"reasoning": "", "response": ""}
603
- phase = ""
604
- for chunk in response:
605
- if hasattr(chunk.message, 'thinking') and chunk.message.thinking:
606
- if phase != "reasoning":
607
- print("\n--- Reasoning ---")
608
- phase = "reasoning"
609
-
610
- content_chunk = getattr(getattr(chunk, 'message', {}), 'thinking', '')
611
- res["reasoning"] += content_chunk
612
- else:
613
- if phase != "response":
614
- print("\n--- Response ---")
615
- phase = "response"
616
- content_chunk = getattr(getattr(chunk, 'message', {}), 'content', '')
617
- res["response"] += content_chunk
618
-
619
- print(content_chunk, end='', flush=True)
620
-
621
- if chunk.done_reason == "length":
622
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
623
- print('\n')
624
-
625
- else:
626
- response = self.client.chat(
627
- model=self.model_name,
628
- messages=processed_messages,
629
- options=options,
630
- stream=False,
631
- keep_alive=self.keep_alive
632
- )
633
- res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
634
- "response": getattr(getattr(response, 'message', {}), 'content', '')}
635
-
636
- if response.done_reason == "length":
637
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
638
-
639
- # Postprocess response
640
- res_dict = self.config.postprocess_response(res)
641
- # Write to messages log
642
- if messages_logger:
643
- processed_messages.append({"role": "assistant",
644
- "content": res_dict.get("response", ""),
645
- "reasoning": res_dict.get("reasoning", "")})
646
- messages_logger.log_messages(processed_messages)
647
-
648
- return res_dict
649
-
650
-
651
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
652
- """
653
- Async version of chat method. Streaming is not supported.
654
- """
655
- processed_messages = self.config.preprocess_messages(messages)
656
-
657
- response = await self.async_client.chat(
658
- model=self.model_name,
659
- messages=processed_messages,
660
- options={'num_ctx': self.num_ctx, **self.formatted_params},
661
- stream=False,
662
- keep_alive=self.keep_alive
663
- )
664
-
665
- res = {"reasoning": getattr(getattr(response, 'message', {}), 'thinking', ''),
666
- "response": getattr(getattr(response, 'message', {}), 'content', '')}
667
-
668
- if response.done_reason == "length":
669
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
670
- # Postprocess response
671
- res_dict = self.config.postprocess_response(res)
672
- # Write to messages log
673
- if messages_logger:
674
- processed_messages.append({"role": "assistant",
675
- "content": res_dict.get("response", ""),
676
- "reasoning": res_dict.get("reasoning", "")})
677
- messages_logger.log_messages(processed_messages)
678
-
679
- return res_dict
680
-
681
-
682
- class HuggingFaceHubInferenceEngine(InferenceEngine):
683
- def __init__(self, model:str=None, token:Union[str, bool]=None, base_url:str=None, api_key:str=None, config:LLMConfig=None, **kwrs):
684
- """
685
- The Huggingface_hub InferenceClient inference engine.
686
- For parameters and documentation, refer to https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client
687
-
688
- Parameters:
689
- ----------
690
- model : str
691
- the model name exactly as shown in Huggingface repo
692
- token : str, Optional
693
- the Huggingface token. If None, will use the token in os.environ['HF_TOKEN'].
694
- base_url : str, Optional
695
- the base url for the LLM server. If None, will use the default Huggingface Hub URL.
696
- api_key : str, Optional
697
- the API key for the LLM server.
698
- config : LLMConfig
699
- the LLM configuration.
700
- """
701
- if importlib.util.find_spec("huggingface_hub") is None:
702
- raise ImportError("huggingface-hub not found. Please install huggingface-hub (```pip install huggingface-hub```).")
703
-
704
- from huggingface_hub import InferenceClient, AsyncInferenceClient
705
- super().__init__(config)
706
- self.model = model
707
- self.base_url = base_url
708
- self.client = InferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
709
- self.client_async = AsyncInferenceClient(model=model, token=token, base_url=base_url, api_key=api_key, **kwrs)
710
- self.config = config if config else BasicLLMConfig()
711
- self.formatted_params = self._format_config()
712
-
713
- def _format_config(self) -> Dict[str, Any]:
714
- """
715
- This method format the LLM configuration with the correct key for the inference engine.
716
- """
717
- formatted_params = self.config.params.copy()
718
- if "max_new_tokens" in formatted_params:
719
- formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
720
- formatted_params.pop("max_new_tokens")
721
-
722
- return formatted_params
723
-
724
-
725
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False,
726
- messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
727
- """
728
- This method inputs chat messages and outputs LLM generated text.
729
-
730
- Parameters:
731
- ----------
732
- messages : List[Dict[str,str]]
733
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
734
- verbose : bool, Optional
735
- if True, VLM generated text will be printed in terminal in real-time.
736
- stream : bool, Optional
737
- if True, returns a generator that yields the output in real-time.
738
- messages_logger : MessagesLogger, Optional
739
- the message logger that logs the chat messages.
740
-
741
- Returns:
742
- -------
743
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
744
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
745
- """
746
- processed_messages = self.config.preprocess_messages(messages)
747
-
748
- if stream:
749
- def _stream_generator():
750
- response_stream = self.client.chat.completions.create(
751
- messages=processed_messages,
752
- stream=True,
753
- **self.formatted_params
754
- )
755
- res_text = ""
756
- for chunk in response_stream:
757
- content_chunk = chunk.get('choices')[0].get('delta').get('content')
758
- if content_chunk:
759
- res_text += content_chunk
760
- yield content_chunk
761
-
762
- # Postprocess response
763
- res_dict = self.config.postprocess_response(res_text)
764
- # Write to messages log
765
- if messages_logger:
766
- processed_messages.append({"role": "assistant",
767
- "content": res_dict.get("response", ""),
768
- "reasoning": res_dict.get("reasoning", "")})
769
- messages_logger.log_messages(processed_messages)
770
-
771
- return self.config.postprocess_response(_stream_generator())
772
-
773
- elif verbose:
774
- response = self.client.chat.completions.create(
775
- messages=processed_messages,
776
- stream=True,
777
- **self.formatted_params
778
- )
779
-
780
- res = ''
781
- for chunk in response:
782
- content_chunk = chunk.get('choices')[0].get('delta').get('content')
783
- if content_chunk:
784
- res += content_chunk
785
- print(content_chunk, end='', flush=True)
786
-
787
-
788
- else:
789
- response = self.client.chat.completions.create(
790
- messages=processed_messages,
791
- stream=False,
792
- **self.formatted_params
793
- )
794
- res = response.choices[0].message.content
795
-
796
- # Postprocess response
797
- res_dict = self.config.postprocess_response(res)
798
- # Write to messages log
799
- if messages_logger:
800
- processed_messages.append({"role": "assistant",
801
- "content": res_dict.get("response", ""),
802
- "reasoning": res_dict.get("reasoning", "")})
803
- messages_logger.log_messages(processed_messages)
804
-
805
- return res_dict
806
-
807
-
808
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
809
- """
810
- Async version of chat method. Streaming is not supported.
811
- """
812
- processed_messages = self.config.preprocess_messages(messages)
813
-
814
- response = await self.client_async.chat.completions.create(
815
- messages=processed_messages,
816
- stream=False,
817
- **self.formatted_params
818
- )
819
-
820
- res = response.choices[0].message.content
821
- # Postprocess response
822
- res_dict = self.config.postprocess_response(res)
823
- # Write to messages log
824
- if messages_logger:
825
- processed_messages.append({"role": "assistant",
826
- "content": res_dict.get("response", ""),
827
- "reasoning": res_dict.get("reasoning", "")})
828
- messages_logger.log_messages(processed_messages)
829
-
830
- return res_dict
831
-
832
-
833
- class OpenAICompatibleInferenceEngine(InferenceEngine):
834
- def __init__(self, model:str, api_key:str, base_url:str, config:LLMConfig=None, **kwrs):
835
- """
836
- General OpenAI-compatible server inference engine.
837
- https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
838
-
839
- For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
840
-
841
- Parameters:
842
- ----------
843
- model_name : str
844
- model name as shown in the vLLM server
845
- api_key : str
846
- the API key for the vLLM server.
847
- base_url : str
848
- the base url for the vLLM server.
849
- config : LLMConfig
850
- the LLM configuration.
851
- """
852
- if importlib.util.find_spec("openai") is None:
853
- raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
854
-
855
- from openai import OpenAI, AsyncOpenAI
856
- from openai.types.chat import ChatCompletionChunk
857
- self.ChatCompletionChunk = ChatCompletionChunk
858
- super().__init__(config)
859
- self.client = OpenAI(api_key=api_key, base_url=base_url, **kwrs)
860
- self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url, **kwrs)
861
- self.model = model
862
- self.config = config if config else BasicLLMConfig()
863
- self.formatted_params = self._format_config()
864
-
865
- def _format_config(self) -> Dict[str, Any]:
866
- """
867
- This method format the LLM configuration with the correct key for the inference engine.
868
- """
869
- formatted_params = self.config.params.copy()
870
- if "max_new_tokens" in formatted_params:
871
- formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
872
- formatted_params.pop("max_new_tokens")
873
-
874
- return formatted_params
875
-
876
- @abc.abstractmethod
877
- def _format_response(self, response: Any) -> Dict[str, str]:
878
- """
879
- This method format the response from OpenAI API to a dict with keys "type" and "data".
880
-
881
- Parameters:
882
- ----------
883
- response : Any
884
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
885
- """
886
- return NotImplemented
887
-
888
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
889
- """
890
- This method inputs chat messages and outputs LLM generated text.
891
-
892
- Parameters:
893
- ----------
894
- messages : List[Dict[str,str]]
895
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
896
- verbose : bool, Optional
897
- if True, VLM generated text will be printed in terminal in real-time.
898
- stream : bool, Optional
899
- if True, returns a generator that yields the output in real-time.
900
- messages_logger : MessagesLogger, Optional
901
- the message logger that logs the chat messages.
902
-
903
- Returns:
904
- -------
905
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
906
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
907
- """
908
- processed_messages = self.config.preprocess_messages(messages)
909
-
910
- if stream:
911
- def _stream_generator():
912
- response_stream = self.client.chat.completions.create(
913
- model=self.model,
914
- messages=processed_messages,
915
- stream=True,
916
- **self.formatted_params
917
- )
918
- res_text = ""
919
- for chunk in response_stream:
920
- if len(chunk.choices) > 0:
921
- chunk_dict = self._format_response(chunk)
922
- yield chunk_dict
923
-
924
- res_text += chunk_dict["data"]
925
- if chunk.choices[0].finish_reason == "length":
926
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
927
-
928
- # Postprocess response
929
- res_dict = self.config.postprocess_response(res_text)
930
- # Write to messages log
931
- if messages_logger:
932
- processed_messages.append({"role": "assistant",
933
- "content": res_dict.get("response", ""),
934
- "reasoning": res_dict.get("reasoning", "")})
935
- messages_logger.log_messages(processed_messages)
936
-
937
- return self.config.postprocess_response(_stream_generator())
938
-
939
- elif verbose:
940
- response = self.client.chat.completions.create(
941
- model=self.model,
942
- messages=processed_messages,
943
- stream=True,
944
- **self.formatted_params
945
- )
946
- res = {"reasoning": "", "response": ""}
947
- phase = ""
948
- for chunk in response:
949
- if len(chunk.choices) > 0:
950
- chunk_dict = self._format_response(chunk)
951
- chunk_text = chunk_dict["data"]
952
- res[chunk_dict["type"]] += chunk_text
953
- if phase != chunk_dict["type"] and chunk_text != "":
954
- print(f"\n--- {chunk_dict['type'].capitalize()} ---")
955
- phase = chunk_dict["type"]
956
-
957
- print(chunk_text, end="", flush=True)
958
- if chunk.choices[0].finish_reason == "length":
959
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
960
-
961
- print('\n')
962
-
963
- else:
964
- response = self.client.chat.completions.create(
965
- model=self.model,
966
- messages=processed_messages,
967
- stream=False,
968
- **self.formatted_params
969
- )
970
- res = self._format_response(response)
971
-
972
- if response.choices[0].finish_reason == "length":
973
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
974
-
975
- # Postprocess response
976
- res_dict = self.config.postprocess_response(res)
977
- # Write to messages log
978
- if messages_logger:
979
- processed_messages.append({"role": "assistant",
980
- "content": res_dict.get("response", ""),
981
- "reasoning": res_dict.get("reasoning", "")})
982
- messages_logger.log_messages(processed_messages)
983
-
984
- return res_dict
985
-
986
-
987
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
988
- """
989
- Async version of chat method. Streaming is not supported.
990
- """
991
- processed_messages = self.config.preprocess_messages(messages)
992
-
993
- response = await self.async_client.chat.completions.create(
994
- model=self.model,
995
- messages=processed_messages,
996
- stream=False,
997
- **self.formatted_params
998
- )
999
-
1000
- if response.choices[0].finish_reason == "length":
1001
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1002
-
1003
- res = self._format_response(response)
1004
-
1005
- # Postprocess response
1006
- res_dict = self.config.postprocess_response(res)
1007
- # Write to messages log
1008
- if messages_logger:
1009
- processed_messages.append({"role": "assistant",
1010
- "content": res_dict.get("response", ""),
1011
- "reasoning": res_dict.get("reasoning", "")})
1012
- messages_logger.log_messages(processed_messages)
1013
-
1014
- return res_dict
1015
-
1016
-
1017
- class VLLMInferenceEngine(OpenAICompatibleInferenceEngine):
1018
- def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:8000/v1", config:LLMConfig=None, **kwrs):
1019
- """
1020
- vLLM OpenAI compatible server inference engine.
1021
- https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
1022
-
1023
- For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
1024
-
1025
- Parameters:
1026
- ----------
1027
- model_name : str
1028
- model name as shown in the vLLM server
1029
- api_key : str, Optional
1030
- the API key for the vLLM server.
1031
- base_url : str, Optional
1032
- the base url for the vLLM server.
1033
- config : LLMConfig
1034
- the LLM configuration.
1035
- """
1036
- super().__init__(model, api_key, base_url, config, **kwrs)
1037
-
1038
-
1039
- def _format_response(self, response: Any) -> Dict[str, str]:
1040
- """
1041
- This method format the response from OpenAI API to a dict with keys "type" and "data".
1042
-
1043
- Parameters:
1044
- ----------
1045
- response : Any
1046
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
1047
- """
1048
- if isinstance(response, self.ChatCompletionChunk):
1049
- if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
1050
- chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
1051
- if chunk_text is None:
1052
- chunk_text = ""
1053
- return {"type": "reasoning", "data": chunk_text}
1054
- else:
1055
- chunk_text = getattr(response.choices[0].delta, "content", "")
1056
- if chunk_text is None:
1057
- chunk_text = ""
1058
- return {"type": "response", "data": chunk_text}
1059
-
1060
- return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
1061
- "response": getattr(response.choices[0].message, "content", "")}
1062
-
1063
- class SGLangInferenceEngine(OpenAICompatibleInferenceEngine):
1064
- def __init__(self, model:str, api_key:str="", base_url:str="http://localhost:30000/v1", config:LLMConfig=None, **kwrs):
1065
- """
1066
- SGLang OpenAI compatible API inference engine.
1067
- https://docs.sglang.ai/basic_usage/openai_api.html
1068
-
1069
- Parameters:
1070
- ----------
1071
- model_name : str
1072
- model name as shown in the vLLM server
1073
- api_key : str, Optional
1074
- the API key for the vLLM server.
1075
- base_url : str, Optional
1076
- the base url for the vLLM server.
1077
- config : LLMConfig
1078
- the LLM configuration.
1079
- """
1080
- super().__init__(model, api_key, base_url, config, **kwrs)
1081
-
1082
-
1083
- def _format_response(self, response: Any) -> Dict[str, str]:
1084
- """
1085
- This method format the response from OpenAI API to a dict with keys "type" and "data".
1086
-
1087
- Parameters:
1088
- ----------
1089
- response : Any
1090
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
1091
- """
1092
- if isinstance(response, self.ChatCompletionChunk):
1093
- if hasattr(response.choices[0].delta, "reasoning_content") and getattr(response.choices[0].delta, "reasoning_content") is not None:
1094
- chunk_text = getattr(response.choices[0].delta, "reasoning_content", "")
1095
- if chunk_text is None:
1096
- chunk_text = ""
1097
- return {"type": "reasoning", "data": chunk_text}
1098
- else:
1099
- chunk_text = getattr(response.choices[0].delta, "content", "")
1100
- if chunk_text is None:
1101
- chunk_text = ""
1102
- return {"type": "response", "data": chunk_text}
1103
-
1104
- return {"reasoning": getattr(response.choices[0].message, "reasoning_content", ""),
1105
- "response": getattr(response.choices[0].message, "content", "")}
1106
-
1107
-
1108
- class OpenRouterInferenceEngine(OpenAICompatibleInferenceEngine):
1109
- def __init__(self, model:str, api_key:str=None, base_url:str="https://openrouter.ai/api/v1", config:LLMConfig=None, **kwrs):
1110
- """
1111
- OpenRouter OpenAI-compatible server inference engine.
1112
-
1113
- Parameters:
1114
- ----------
1115
- model_name : str
1116
- model name as shown in the vLLM server
1117
- api_key : str, Optional
1118
- the API key for the vLLM server. If None, will use the key in os.environ['OPENROUTER_API_KEY'].
1119
- base_url : str, Optional
1120
- the base url for the vLLM server.
1121
- config : LLMConfig
1122
- the LLM configuration.
1123
- """
1124
- self.api_key = api_key
1125
- if self.api_key is None:
1126
- self.api_key = os.getenv("OPENROUTER_API_KEY")
1127
- super().__init__(model, self.api_key, base_url, config, **kwrs)
1128
-
1129
- def _format_response(self, response: Any) -> Dict[str, str]:
1130
- """
1131
- This method format the response from OpenAI API to a dict with keys "type" and "data".
1132
-
1133
- Parameters:
1134
- ----------
1135
- response : Any
1136
- the response from OpenAI-compatible API. Could be a dict, generator, or object.
1137
- """
1138
- if isinstance(response, self.ChatCompletionChunk):
1139
- if hasattr(response.choices[0].delta, "reasoning") and getattr(response.choices[0].delta, "reasoning") is not None:
1140
- chunk_text = getattr(response.choices[0].delta, "reasoning", "")
1141
- if chunk_text is None:
1142
- chunk_text = ""
1143
- return {"type": "reasoning", "data": chunk_text}
1144
- else:
1145
- chunk_text = getattr(response.choices[0].delta, "content", "")
1146
- if chunk_text is None:
1147
- chunk_text = ""
1148
- return {"type": "response", "data": chunk_text}
1149
-
1150
- return {"reasoning": getattr(response.choices[0].message, "reasoning", ""),
1151
- "response": getattr(response.choices[0].message, "content", "")}
1152
-
1153
-
1154
- class OpenAIInferenceEngine(InferenceEngine):
1155
- def __init__(self, model:str, config:LLMConfig=None, **kwrs):
1156
- """
1157
- The OpenAI API inference engine.
1158
- For parameters and documentation, refer to https://platform.openai.com/docs/api-reference/introduction
1159
-
1160
- Parameters:
1161
- ----------
1162
- model_name : str
1163
- model name as described in https://platform.openai.com/docs/models
1164
- """
1165
- if importlib.util.find_spec("openai") is None:
1166
- raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
1167
-
1168
- from openai import OpenAI, AsyncOpenAI
1169
- super().__init__(config)
1170
- self.client = OpenAI(**kwrs)
1171
- self.async_client = AsyncOpenAI(**kwrs)
1172
- self.model = model
1173
- self.config = config if config else BasicLLMConfig()
1174
- self.formatted_params = self._format_config()
1175
-
1176
- def _format_config(self) -> Dict[str, Any]:
1177
- """
1178
- This method format the LLM configuration with the correct key for the inference engine.
1179
- """
1180
- formatted_params = self.config.params.copy()
1181
- if "max_new_tokens" in formatted_params:
1182
- formatted_params["max_completion_tokens"] = formatted_params["max_new_tokens"]
1183
- formatted_params.pop("max_new_tokens")
1184
-
1185
- return formatted_params
1186
-
1187
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str, str], Generator[Dict[str, str], None, None]]:
1188
- """
1189
- This method inputs chat messages and outputs LLM generated text.
1190
-
1191
- Parameters:
1192
- ----------
1193
- messages : List[Dict[str,str]]
1194
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
1195
- verbose : bool, Optional
1196
- if True, VLM generated text will be printed in terminal in real-time.
1197
- stream : bool, Optional
1198
- if True, returns a generator that yields the output in real-time.
1199
- messages_logger : MessagesLogger, Optional
1200
- the message logger that logs the chat messages.
1201
-
1202
- Returns:
1203
- -------
1204
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
1205
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
1206
- """
1207
- processed_messages = self.config.preprocess_messages(messages)
1208
-
1209
- if stream:
1210
- def _stream_generator():
1211
- response_stream = self.client.chat.completions.create(
1212
- model=self.model,
1213
- messages=processed_messages,
1214
- stream=True,
1215
- **self.formatted_params
1216
- )
1217
- res_text = ""
1218
- for chunk in response_stream:
1219
- if len(chunk.choices) > 0:
1220
- chunk_text = chunk.choices[0].delta.content
1221
- if chunk_text is not None:
1222
- res_text += chunk_text
1223
- yield chunk_text
1224
- if chunk.choices[0].finish_reason == "length":
1225
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1226
-
1227
- # Postprocess response
1228
- res_dict = self.config.postprocess_response(res_text)
1229
- # Write to messages log
1230
- if messages_logger:
1231
- processed_messages.append({"role": "assistant",
1232
- "content": res_dict.get("response", ""),
1233
- "reasoning": res_dict.get("reasoning", "")})
1234
- messages_logger.log_messages(processed_messages)
1235
-
1236
- return self.config.postprocess_response(_stream_generator())
1237
-
1238
- elif verbose:
1239
- response = self.client.chat.completions.create(
1240
- model=self.model,
1241
- messages=processed_messages,
1242
- stream=True,
1243
- **self.formatted_params
1244
- )
1245
- res = ''
1246
- for chunk in response:
1247
- if len(chunk.choices) > 0:
1248
- if chunk.choices[0].delta.content is not None:
1249
- res += chunk.choices[0].delta.content
1250
- print(chunk.choices[0].delta.content, end="", flush=True)
1251
- if chunk.choices[0].finish_reason == "length":
1252
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1253
-
1254
- print('\n')
1255
-
1256
- else:
1257
- response = self.client.chat.completions.create(
1258
- model=self.model,
1259
- messages=processed_messages,
1260
- stream=False,
1261
- **self.formatted_params
1262
- )
1263
- res = response.choices[0].message.content
1264
-
1265
- # Postprocess response
1266
- res_dict = self.config.postprocess_response(res)
1267
- # Write to messages log
1268
- if messages_logger:
1269
- processed_messages.append({"role": "assistant",
1270
- "content": res_dict.get("response", ""),
1271
- "reasoning": res_dict.get("reasoning", "")})
1272
- messages_logger.log_messages(processed_messages)
1273
-
1274
- return res_dict
1275
-
1276
-
1277
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
1278
- """
1279
- Async version of chat method. Streaming is not supported.
1280
- """
1281
- processed_messages = self.config.preprocess_messages(messages)
1282
-
1283
- response = await self.async_client.chat.completions.create(
1284
- model=self.model,
1285
- messages=processed_messages,
1286
- stream=False,
1287
- **self.formatted_params
1288
- )
1289
-
1290
- if response.choices[0].finish_reason == "length":
1291
- warnings.warn("Model stopped generating due to context length limit.", RuntimeWarning)
1292
-
1293
- res = response.choices[0].message.content
1294
- # Postprocess response
1295
- res_dict = self.config.postprocess_response(res)
1296
- # Write to messages log
1297
- if messages_logger:
1298
- processed_messages.append({"role": "assistant",
1299
- "content": res_dict.get("response", ""),
1300
- "reasoning": res_dict.get("reasoning", "")})
1301
- messages_logger.log_messages(processed_messages)
1302
-
1303
- return res_dict
1304
-
1305
-
1306
- class AzureOpenAIInferenceEngine(OpenAIInferenceEngine):
1307
- def __init__(self, model:str, api_version:str, config:LLMConfig=None, **kwrs):
1308
- """
1309
- The Azure OpenAI API inference engine.
1310
- For parameters and documentation, refer to
1311
- - https://azure.microsoft.com/en-us/products/ai-services/openai-service
1312
- - https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart
1313
-
1314
- Parameters:
1315
- ----------
1316
- model : str
1317
- model name as described in https://platform.openai.com/docs/models
1318
- api_version : str
1319
- the Azure OpenAI API version
1320
- config : LLMConfig
1321
- the LLM configuration.
1322
- """
1323
- if importlib.util.find_spec("openai") is None:
1324
- raise ImportError("OpenAI Python API library not found. Please install OpanAI (```pip install openai```).")
1325
-
1326
- from openai import AzureOpenAI, AsyncAzureOpenAI
1327
- self.model = model
1328
- self.api_version = api_version
1329
- self.client = AzureOpenAI(api_version=self.api_version,
1330
- **kwrs)
1331
- self.async_client = AsyncAzureOpenAI(api_version=self.api_version,
1332
- **kwrs)
1333
- self.config = config if config else BasicLLMConfig()
1334
- self.formatted_params = self._format_config()
1335
-
1336
-
1337
- class LiteLLMInferenceEngine(InferenceEngine):
1338
- def __init__(self, model:str=None, base_url:str=None, api_key:str=None, config:LLMConfig=None):
1339
- """
1340
- The LiteLLM inference engine.
1341
- For parameters and documentation, refer to https://github.com/BerriAI/litellm?tab=readme-ov-file
1342
-
1343
- Parameters:
1344
- ----------
1345
- model : str
1346
- the model name
1347
- base_url : str, Optional
1348
- the base url for the LLM server
1349
- api_key : str, Optional
1350
- the API key for the LLM server
1351
- config : LLMConfig
1352
- the LLM configuration.
1353
- """
1354
- if importlib.util.find_spec("litellm") is None:
1355
- raise ImportError("litellm not found. Please install litellm (```pip install litellm```).")
1356
-
1357
- import litellm
1358
- super().__init__(config)
1359
- self.litellm = litellm
1360
- self.model = model
1361
- self.base_url = base_url
1362
- self.api_key = api_key
1363
- self.config = config if config else BasicLLMConfig()
1364
- self.formatted_params = self._format_config()
1365
-
1366
- def _format_config(self) -> Dict[str, Any]:
1367
- """
1368
- This method format the LLM configuration with the correct key for the inference engine.
1369
- """
1370
- formatted_params = self.config.params.copy()
1371
- if "max_new_tokens" in formatted_params:
1372
- formatted_params["max_tokens"] = formatted_params["max_new_tokens"]
1373
- formatted_params.pop("max_new_tokens")
1374
-
1375
- return formatted_params
1376
-
1377
- def chat(self, messages:List[Dict[str,str]], verbose:bool=False, stream:bool=False, messages_logger:MessagesLogger=None) -> Union[Dict[str,str], Generator[Dict[str, str], None, None]]:
1378
- """
1379
- This method inputs chat messages and outputs LLM generated text.
1380
-
1381
- Parameters:
1382
- ----------
1383
- messages : List[Dict[str,str]]
1384
- a list of dict with role and content. role must be one of {"system", "user", "assistant"}
1385
- verbose : bool, Optional
1386
- if True, VLM generated text will be printed in terminal in real-time.
1387
- stream : bool, Optional
1388
- if True, returns a generator that yields the output in real-time.
1389
- messages_logger: MessagesLogger, Optional
1390
- a messages logger that logs the messages.
1391
-
1392
- Returns:
1393
- -------
1394
- response : Union[Dict[str,str], Generator[Dict[str, str], None, None]]
1395
- a dict {"reasoning": <reasoning>, "response": <response>} or Generator {"type": <reasoning or response>, "data": <content>}
1396
- """
1397
- processed_messages = self.config.preprocess_messages(messages)
1398
-
1399
- if stream:
1400
- def _stream_generator():
1401
- response_stream = self.litellm.completion(
1402
- model=self.model,
1403
- messages=processed_messages,
1404
- stream=True,
1405
- base_url=self.base_url,
1406
- api_key=self.api_key,
1407
- **self.formatted_params
1408
- )
1409
- res_text = ""
1410
- for chunk in response_stream:
1411
- chunk_content = chunk.get('choices')[0].get('delta').get('content')
1412
- if chunk_content:
1413
- res_text += chunk_content
1414
- yield chunk_content
1415
-
1416
- # Postprocess response
1417
- res_dict = self.config.postprocess_response(res_text)
1418
- # Write to messages log
1419
- if messages_logger:
1420
- processed_messages.append({"role": "assistant",
1421
- "content": res_dict.get("response", ""),
1422
- "reasoning": res_dict.get("reasoning", "")})
1423
- messages_logger.log_messages(processed_messages)
1424
-
1425
- return self.config.postprocess_response(_stream_generator())
1426
-
1427
- elif verbose:
1428
- response = self.litellm.completion(
1429
- model=self.model,
1430
- messages=processed_messages,
1431
- stream=True,
1432
- base_url=self.base_url,
1433
- api_key=self.api_key,
1434
- **self.formatted_params
1435
- )
1436
-
1437
- res = ''
1438
- for chunk in response:
1439
- chunk_content = chunk.get('choices')[0].get('delta').get('content')
1440
- if chunk_content:
1441
- res += chunk_content
1442
- print(chunk_content, end='', flush=True)
1443
-
1444
- else:
1445
- response = self.litellm.completion(
1446
- model=self.model,
1447
- messages=processed_messages,
1448
- stream=False,
1449
- base_url=self.base_url,
1450
- api_key=self.api_key,
1451
- **self.formatted_params
1452
- )
1453
- res = response.choices[0].message.content
1454
-
1455
- # Postprocess response
1456
- res_dict = self.config.postprocess_response(res)
1457
- # Write to messages log
1458
- if messages_logger:
1459
- processed_messages.append({"role": "assistant",
1460
- "content": res_dict.get("response", ""),
1461
- "reasoning": res_dict.get("reasoning", "")})
1462
- messages_logger.log_messages(processed_messages)
1463
-
1464
- return res_dict
1465
-
1466
- async def chat_async(self, messages:List[Dict[str,str]], messages_logger:MessagesLogger=None) -> Dict[str,str]:
1467
- """
1468
- Async version of chat method. Streaming is not supported.
1469
- """
1470
- processed_messages = self.config.preprocess_messages(messages)
1471
-
1472
- response = await self.litellm.acompletion(
1473
- model=self.model,
1474
- messages=processed_messages,
1475
- stream=False,
1476
- base_url=self.base_url,
1477
- api_key=self.api_key,
1478
- **self.formatted_params
27
+ """
28
+ Deprecated: This engine is no longer supported. Please run llama.cpp as a server and use OpenAICompatibleInferenceEngine instead.
29
+ """
30
+ def __init__(self, *args, **kwargs):
31
+ raise NotImplementedError(
32
+ "LlamaCppInferenceEngine has been deprecated. "
33
+ "Please run llama.cpp as a server and use OpenAICompatibleInferenceEngine."
1479
34
  )
1480
-
1481
- res = response.get('choices')[0].get('message').get('content')
1482
35
 
1483
- # Postprocess response
1484
- res_dict = self.config.postprocess_response(res)
1485
- # Write to messages log
1486
- if messages_logger:
1487
- processed_messages.append({"role": "assistant",
1488
- "content": res_dict.get("response", ""),
1489
- "reasoning": res_dict.get("reasoning", "")})
1490
- messages_logger.log_messages(processed_messages)
1491
- return res_dict
36
+ def chat(self, *args, **kwargs):
37
+ raise NotImplementedError("This engine is deprecated.")