yaicli 0.6.0__tar.gz → 0.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {yaicli-0.6.0 → yaicli-0.6.2}/PKG-INFO +11 -3
  2. {yaicli-0.6.0 → yaicli-0.6.2}/README.md +10 -2
  3. {yaicli-0.6.0 → yaicli-0.6.2}/pyproject.toml +1 -1
  4. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/provider.py +12 -9
  5. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/ai21_provider.py +13 -4
  6. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/chatglm_provider.py +6 -1
  7. yaicli-0.6.2/yaicli/llms/providers/chutes_provider.py +24 -0
  8. yaicli-0.6.2/yaicli/llms/providers/deepseek_provider.py +14 -0
  9. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/doubao_provider.py +24 -22
  10. yaicli-0.6.2/yaicli/llms/providers/groq_provider.py +36 -0
  11. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/infiniai_provider.py +7 -1
  12. yaicli-0.6.2/yaicli/llms/providers/minimax_provider.py +21 -0
  13. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/modelscope_provider.py +6 -3
  14. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/openai_provider.py +84 -26
  15. yaicli-0.6.2/yaicli/llms/providers/openrouter_provider.py +22 -0
  16. yaicli-0.6.2/yaicli/llms/providers/sambanova_provider.py +59 -0
  17. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/siliconflow_provider.py +6 -3
  18. yaicli-0.6.2/yaicli/llms/providers/targon_provider.py +22 -0
  19. yaicli-0.6.2/yaicli/llms/providers/xai_provider.py +7 -0
  20. yaicli-0.6.2/yaicli/llms/providers/yi_provider.py +22 -0
  21. yaicli-0.6.0/yaicli/llms/providers/chutes_provider.py +0 -7
  22. yaicli-0.6.0/yaicli/llms/providers/deepseek_provider.py +0 -11
  23. yaicli-0.6.0/yaicli/llms/providers/groq_provider.py +0 -14
  24. yaicli-0.6.0/yaicli/llms/providers/openrouter_provider.py +0 -11
  25. yaicli-0.6.0/yaicli/llms/providers/sambanova_provider.py +0 -28
  26. yaicli-0.6.0/yaicli/llms/providers/yi_provider.py +0 -7
  27. {yaicli-0.6.0 → yaicli-0.6.2}/.gitignore +0 -0
  28. {yaicli-0.6.0 → yaicli-0.6.2}/LICENSE +0 -0
  29. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/__init__.py +0 -0
  30. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/chat.py +0 -0
  31. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/cli.py +0 -0
  32. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/config.py +0 -0
  33. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/console.py +0 -0
  34. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/const.py +0 -0
  35. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/entry.py +0 -0
  36. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/exceptions.py +0 -0
  37. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/functions/__init__.py +0 -0
  38. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/functions/buildin/execute_shell_command.py +0 -0
  39. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/history.py +0 -0
  40. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/__init__.py +0 -0
  41. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/client.py +0 -0
  42. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/cohere_provider.py +0 -0
  43. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/llms/providers/ollama_provider.py +0 -0
  44. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/printer.py +0 -0
  45. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/render.py +0 -0
  46. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/role.py +0 -0
  47. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/schemas.py +0 -0
  48. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/tools.py +0 -0
  49. {yaicli-0.6.0 → yaicli-0.6.2}/yaicli/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yaicli
3
- Version: 0.6.0
3
+ Version: 0.6.2
4
4
  Summary: A simple CLI tool to interact with LLM
5
5
  Project-URL: Homepage, https://github.com/belingud/yaicli
6
6
  Project-URL: Repository, https://github.com/belingud/yaicli
@@ -375,7 +375,7 @@ settings, just as below:
375
375
  ```ini
376
376
  [core]
377
377
  PROVIDER=openai
378
- BASE_URL=https://api.openai.com/v1
378
+ BASE_URL=
379
379
  API_KEY=
380
380
  MODEL=gpt-4o
381
381
 
@@ -387,7 +387,7 @@ OS_NAME=auto
387
387
  STREAM=true
388
388
 
389
389
  # LLM parameters
390
- TEMPERATURE=0.5
390
+ TEMPERATURE=0.3
391
391
  TOP_P=1.0
392
392
  MAX_TOKENS=1024
393
393
  TIMEOUT=60
@@ -518,6 +518,14 @@ API_KEY=
518
518
  MODEL=llama-3.3-70b-versatile
519
519
  ```
520
520
 
521
+ #### XAI
522
+
523
+ ```ini
524
+ PROVIDER=xai
525
+ API_KEY=
526
+ MODEL=grok-3
527
+ ```
528
+
521
529
  #### Chatglm
522
530
 
523
531
  ```ini
@@ -138,7 +138,7 @@ settings, just as below:
138
138
  ```ini
139
139
  [core]
140
140
  PROVIDER=openai
141
- BASE_URL=https://api.openai.com/v1
141
+ BASE_URL=
142
142
  API_KEY=
143
143
  MODEL=gpt-4o
144
144
 
@@ -150,7 +150,7 @@ OS_NAME=auto
150
150
  STREAM=true
151
151
 
152
152
  # LLM parameters
153
- TEMPERATURE=0.5
153
+ TEMPERATURE=0.3
154
154
  TOP_P=1.0
155
155
  MAX_TOKENS=1024
156
156
  TIMEOUT=60
@@ -281,6 +281,14 @@ API_KEY=
281
281
  MODEL=llama-3.3-70b-versatile
282
282
  ```
283
283
 
284
+ #### XAI
285
+
286
+ ```ini
287
+ PROVIDER=xai
288
+ API_KEY=
289
+ MODEL=grok-3
290
+ ```
291
+
284
292
  #### Chatglm
285
293
 
286
294
  ```ini
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "yaicli"
3
- version = "0.6.0"
3
+ version = "0.6.2"
4
4
  description = "A simple CLI tool to interact with LLM"
5
5
  authors = [{ name = "belingud", email = "im.victor@qq.com" }]
6
6
  readme = "README.md"
@@ -9,7 +9,7 @@ class Provider(ABC):
9
9
  """Base abstract class for LLM providers"""
10
10
 
11
11
  APP_NAME = "yaicli"
12
- APPA_REFERER = "https://github.com/halfrost/yaicli"
12
+ APP_REFERER = "https://github.com/halfrost/yaicli"
13
13
 
14
14
  @abstractmethod
15
15
  def completion(
@@ -39,21 +39,24 @@ class ProviderFactory:
39
39
  """Factory to create LLM provider instances"""
40
40
 
41
41
  providers_map = {
42
- "openai": (".providers.openai_provider", "OpenAIProvider"),
43
- "modelscope": (".providers.modelscope_provider", "ModelScopeProvider"),
42
+ "ai21": (".providers.ai21_provider", "AI21Provider"),
44
43
  "chatglm": (".providers.chatglm_provider", "ChatglmProvider"),
45
- "openrouter": (".providers.openrouter_provider", "OpenRouterProvider"),
46
- "siliconflow": (".providers.siliconflow_provider", "SiliconFlowProvider"),
47
44
  "chutes": (".providers.chutes_provider", "ChutesProvider"),
48
- "infini-ai": (".providers.infiniai_provider", "InfiniAIProvider"),
49
- "yi": (".providers.yi_provider", "YiProvider"),
45
+ "cohere": (".providers.cohere_provider", "CohereProvider"),
50
46
  "deepseek": (".providers.deepseek_provider", "DeepSeekProvider"),
51
47
  "doubao": (".providers.doubao_provider", "DoubaoProvider"),
52
48
  "groq": (".providers.groq_provider", "GroqProvider"),
53
- "ai21": (".providers.ai21_provider", "AI21Provider"),
49
+ "infini-ai": (".providers.infiniai_provider", "InfiniAIProvider"),
50
+ "minimax": (".providers.minimax_provider", "MinimaxProvider"),
51
+ "modelscope": (".providers.modelscope_provider", "ModelScopeProvider"),
54
52
  "ollama": (".providers.ollama_provider", "OllamaProvider"),
55
- "cohere": (".providers.cohere_provider", "CohereProvider"),
53
+ "openai": (".providers.openai_provider", "OpenAIProvider"),
54
+ "openrouter": (".providers.openrouter_provider", "OpenRouterProvider"),
56
55
  "sambanova": (".providers.sambanova_provider", "SambanovaProvider"),
56
+ "siliconflow": (".providers.siliconflow_provider", "SiliconFlowProvider"),
57
+ "targon": (".providers.targon_provider", "TargonProvider"),
58
+ "xai": (".providers.xai_provider", "XaiProvider"),
59
+ "yi": (".providers.yi_provider", "YiProvider"),
57
60
  }
58
61
 
59
62
  @classmethod
@@ -1,4 +1,4 @@
1
- from typing import Generator, Optional
1
+ from typing import Dict, Generator, Optional
2
2
 
3
3
  from openai._streaming import Stream
4
4
  from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@@ -12,9 +12,18 @@ class AI21Provider(OpenAIProvider):
12
12
 
13
13
  DEFAULT_BASE_URL = "https://api.ai21.com/studio/v1"
14
14
 
15
- def __init__(self, config: dict = ..., **kwargs):
16
- super().__init__(config, **kwargs)
17
- self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
15
+ def get_completion_params_keys(self) -> Dict[str, str]:
16
+ """
17
+ Customize completion parameter keys for AI21 API.
18
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility.
19
+
20
+ Returns:
21
+ Dict[str, str]: Modified parameter mapping dictionary
22
+ """
23
+ keys = super().get_completion_params_keys()
24
+ if "max_completion_tokens" in keys:
25
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
26
+ return keys
18
27
 
19
28
  def _handle_stream_response(self, response: Stream[ChatCompletionChunk]) -> Generator[LLMResponse, None, None]:
20
29
  """Handle streaming response from AI21 models
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Generator, Optional
2
+ from typing import Any, Dict, Generator, Optional
3
3
 
4
4
  from openai._streaming import Stream
5
5
  from openai.types.chat.chat_completion import ChatCompletion, Choice
@@ -14,6 +14,11 @@ class ChatglmProvider(OpenAIProvider):
14
14
 
15
15
  DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4/"
16
16
 
17
+ def get_completion_params(self) -> Dict[str, Any]:
18
+ params = super().get_completion_params()
19
+ params["max_tokens"] = params.pop("max_completion_tokens")
20
+ return params
21
+
17
22
  def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
18
23
  """Handle normal (non-streaming) response
19
24
  Support both openai capabilities and chatglm
@@ -0,0 +1,24 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class ChutesProvider(OpenAIProvider):
5
+ """Chutes provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://llm.chutes.ai/v1"
8
+
9
+ def get_completion_params_keys(self) -> dict:
10
+ """
11
+ Customize completion parameter keys for Chutes API.
12
+ Maps 'max_completion_tokens' to 'max_tokens' and removes 'reasoning_effort'
13
+ which is not supported by this provider.
14
+
15
+ Returns:
16
+ dict: Modified parameter mapping dictionary
17
+ """
18
+ keys = super().get_completion_params_keys()
19
+ # Replace max_completion_tokens with max_tokens in the API
20
+ if "max_completion_tokens" in keys:
21
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
22
+ # Remove unsupported parameters
23
+ keys.pop("reasoning_effort", None)
24
+ return keys
@@ -0,0 +1,14 @@
1
+ from typing import Any, Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class DeepSeekProvider(OpenAIProvider):
7
+ """DeepSeek provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
10
+
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
@@ -1,3 +1,5 @@
1
+ from typing import Any, Dict
2
+
1
3
  from volcenginesdkarkruntime import Ark
2
4
 
3
5
  from ...config import cfg
@@ -9,43 +11,43 @@ class DoubaoProvider(OpenAIProvider):
9
11
  """Doubao provider implementation based on openai-compatible API"""
10
12
 
11
13
  DEFAULT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
14
+ CLIENT_CLS = Ark
12
15
 
13
16
  def __init__(self, config: dict = cfg, **kwargs):
14
17
  self.config = config
15
18
  self.enable_function = self.config["ENABLE_FUNCTIONS"]
19
+ self.client_params = self.get_client_params()
20
+
21
+ # Initialize client
22
+ self.client = self.CLIENT_CLS(**self.client_params)
23
+ self.console = get_console()
24
+
25
+ # Store completion params
26
+ self.completion_params = self.get_completion_params()
27
+
28
+ def get_client_params(self) -> Dict[str, Any]:
16
29
  # Initialize client params
17
- self.client_params = {"base_url": self.DEFAULT_BASE_URL}
30
+ client_params = {"base_url": self.DEFAULT_BASE_URL}
18
31
  if self.config.get("API_KEY", None):
19
- self.client_params["api_key"] = self.config["API_KEY"]
32
+ client_params["api_key"] = self.config["API_KEY"]
20
33
  if self.config.get("BASE_URL", None):
21
- self.client_params["base_url"] = self.config["BASE_URL"]
34
+ client_params["base_url"] = self.config["BASE_URL"]
22
35
  if self.config.get("AK", None):
23
- self.client_params["ak"] = self.config["AK"]
36
+ client_params["ak"] = self.config["AK"]
24
37
  if self.config.get("SK", None):
25
- self.client_params["sk"] = self.config["SK"]
38
+ client_params["sk"] = self.config["SK"]
26
39
  if self.config.get("REGION", None):
27
- self.client_params["region"] = self.config["REGION"]
40
+ client_params["region"] = self.config["REGION"]
41
+ return client_params
28
42
 
29
- # Initialize client
30
- self.client = Ark(**self.client_params)
31
- self.console = get_console()
32
-
33
- # Store completion params
34
- self.completion_params = {
43
+ def get_completion_params(self) -> Dict[str, Any]:
44
+ params = {
35
45
  "model": self.config["MODEL"],
36
46
  "temperature": self.config["TEMPERATURE"],
37
47
  "top_p": self.config["TOP_P"],
38
48
  "max_tokens": self.config["MAX_TOKENS"],
39
49
  "timeout": self.config["TIMEOUT"],
40
50
  }
41
- # Add extra headers if set
42
- if self.config.get("EXTRA_HEADERS", None):
43
- self.completion_params["extra_headers"] = {
44
- **self.config["EXTRA_HEADERS"],
45
- "X-Title": self.APP_NAME,
46
- "HTTP-Referer": self.APPA_REFERER,
47
- }
48
-
49
- # Add extra body params if set
50
51
  if self.config.get("EXTRA_BODY", None):
51
- self.completion_params["extra_body"] = self.config["EXTRA_BODY"]
52
+ params["extra_body"] = self.config["EXTRA_BODY"]
53
+ return params
@@ -0,0 +1,36 @@
1
+ from typing import Any, Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class GroqProvider(OpenAIProvider):
7
+ """Groq provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
10
+
11
+ def get_completion_params_keys(self) -> Dict[str, str]:
12
+ """
13
+ Customize completion parameter keys for Groq API.
14
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility.
15
+
16
+ Returns:
17
+ Dict[str, str]: Modified parameter mapping dictionary
18
+ """
19
+ keys = super().get_completion_params_keys()
20
+ if "max_completion_tokens" in keys:
21
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
22
+ return keys
23
+
24
+ def get_completion_params(self) -> Dict[str, Any]:
25
+ """
26
+ Get completion parameters with Groq-specific adjustments.
27
+ Enforce N=1 as Groq doesn't support multiple completions.
28
+
29
+ Returns:
30
+ Dict[str, Any]: Parameters for completion API call
31
+ """
32
+ params = super().get_completion_params()
33
+ if self.config["EXTRA_BODY"] and "N" in self.config["EXTRA_BODY"] and self.config["EXTRA_BODY"]["N"] != 1:
34
+ self.console.print("Groq does not support N parameter, setting N to 1 as Groq default", style="yellow")
35
+ params["extra_body"]["N"] = 1
36
+ return params
@@ -1,3 +1,5 @@
1
+ from typing import Any, Dict
2
+
1
3
  from .openai_provider import OpenAIProvider
2
4
 
3
5
 
@@ -11,4 +13,8 @@ class InfiniAIProvider(OpenAIProvider):
11
13
  if self.enable_function:
12
14
  self.console.print("InfiniAI does not support functions, disabled", style="yellow")
13
15
  self.enable_function = False
14
- self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
16
+
17
+ def get_completion_params(self) -> Dict[str, Any]:
18
+ params = super().get_completion_params()
19
+ params["max_tokens"] = params.pop("max_completion_tokens")
20
+ return params
@@ -0,0 +1,21 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class MinimaxProvider(OpenAIProvider):
5
+ """Minimax provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api.minimaxi.com/v1"
8
+
9
+ def get_completion_params_keys(self) -> dict:
10
+ """
11
+ Customize completion parameter keys for Minimax API.
12
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility.
13
+
14
+ Returns:
15
+ dict: Modified parameter mapping dictionary
16
+ """
17
+ keys = super().get_completion_params_keys()
18
+ # Replace max_completion_tokens with max_tokens in the API
19
+ if "max_completion_tokens" in keys:
20
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
21
+ return keys
@@ -1,3 +1,5 @@
1
+ from typing import Any, Dict
2
+
1
3
  from .openai_provider import OpenAIProvider
2
4
 
3
5
 
@@ -6,6 +8,7 @@ class ModelScopeProvider(OpenAIProvider):
6
8
 
7
9
  DEFAULT_BASE_URL = "https://api-inference.modelscope.cn/v1/"
8
10
 
9
- def __init__(self, config: dict = ..., **kwargs):
10
- super().__init__(config, **kwargs)
11
- self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from typing import Any, Dict, Generator, List, Optional
2
3
 
3
4
  import openai
@@ -16,41 +17,73 @@ class OpenAIProvider(Provider):
16
17
  """OpenAI provider implementation based on openai library"""
17
18
 
18
19
  DEFAULT_BASE_URL = "https://api.openai.com/v1"
20
+ CLIENT_CLS = openai.OpenAI
21
+ # Base mapping between config keys and API parameter names
22
+ _BASE_COMPLETION_PARAMS_KEYS = {
23
+ "model": "MODEL",
24
+ "temperature": "TEMPERATURE",
25
+ "top_p": "TOP_P",
26
+ "max_completion_tokens": "MAX_TOKENS",
27
+ "timeout": "TIMEOUT",
28
+ "extra_body": "EXTRA_BODY",
29
+ "reasoning_effort": "REASONING_EFFORT",
30
+ }
19
31
 
20
32
  def __init__(self, config: dict = cfg, verbose: bool = False, **kwargs):
21
33
  self.config = config
34
+ if not self.config.get("API_KEY"):
35
+ raise ValueError("API_KEY is required")
22
36
  self.enable_function = self.config["ENABLE_FUNCTIONS"]
23
37
  self.verbose = verbose
38
+
39
+ # Initialize client
40
+ self.client_params = self.get_client_params()
41
+ self.client = self.CLIENT_CLS(**self.client_params)
42
+ self.console = get_console()
43
+
44
+ # Store completion params
45
+ self.completion_params = self.get_completion_params()
46
+
47
+ def get_client_params(self) -> Dict[str, Any]:
48
+ """Get the client parameters"""
24
49
  # Initialize client params
25
- self.client_params = {
50
+ client_params = {
26
51
  "api_key": self.config["API_KEY"],
27
52
  "base_url": self.config["BASE_URL"] or self.DEFAULT_BASE_URL,
28
53
  }
29
54
 
30
55
  # Add extra headers if set
31
56
  if self.config["EXTRA_HEADERS"]:
32
- self.client_params["default_headers"] = {
57
+ client_params["default_headers"] = {
33
58
  **self.config["EXTRA_HEADERS"],
34
59
  "X-Title": self.APP_NAME,
35
- "HTTP-Referer": self.APPA_REFERER,
60
+ "HTTP-Referer": self.APP_REFERER,
36
61
  }
37
-
38
- # Initialize client
39
- self.client = openai.OpenAI(**self.client_params)
40
- self.console = get_console()
41
-
42
- # Store completion params
43
- self.completion_params = {
44
- "model": self.config["MODEL"],
45
- "temperature": self.config["TEMPERATURE"],
46
- "top_p": self.config["TOP_P"],
47
- "max_completion_tokens": self.config["MAX_TOKENS"],
48
- "timeout": self.config["TIMEOUT"],
49
- }
50
-
51
- # Add extra body params if set
52
- if self.config["EXTRA_BODY"]:
53
- self.completion_params["extra_body"] = self.config["EXTRA_BODY"]
62
+ return client_params
63
+
64
+ def get_completion_params_keys(self) -> Dict[str, str]:
65
+ """
66
+ Get the mapping between completion parameter keys and config keys.
67
+ Subclasses can override this method to customize parameter mapping.
68
+
69
+ Returns:
70
+ Dict[str, str]: Mapping from API parameter names to config keys
71
+ """
72
+ return self._BASE_COMPLETION_PARAMS_KEYS.copy()
73
+
74
+ def get_completion_params(self) -> Dict[str, Any]:
75
+ """
76
+ Get the completion parameters based on config and parameter mapping.
77
+
78
+ Returns:
79
+ Dict[str, Any]: Parameters for completion API call
80
+ """
81
+ completion_params = {}
82
+ params_keys = self.get_completion_params_keys()
83
+ for api_key, config_key in params_keys.items():
84
+ if self.config.get(config_key, None) is not None:
85
+ completion_params[api_key] = self.config[config_key]
86
+ return completion_params
54
87
 
55
88
  def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
56
89
  """Convert a list of ChatMessage objects to a list of OpenAI message dicts."""
@@ -79,7 +112,20 @@ class OpenAIProvider(Provider):
79
112
  messages: List[ChatMessage],
80
113
  stream: bool = False,
81
114
  ) -> Generator[LLMResponse, None, None]:
82
- """Send completion request to OpenAI and return responses"""
115
+ """
116
+ Send completion request to OpenAI and return responses.
117
+
118
+ Args:
119
+ messages: List of chat messages to send
120
+ stream: Whether to stream the response
121
+
122
+ Yields:
123
+ LLMResponse: Response objects containing content, tool calls, etc.
124
+
125
+ Raises:
126
+ ValueError: If messages is empty or invalid
127
+ openai.APIError: If API request fails
128
+ """
83
129
  openai_messages = self._convert_messages(messages)
84
130
  if self.verbose:
85
131
  self.console.print("Messages:")
@@ -103,6 +149,11 @@ class OpenAIProvider(Provider):
103
149
 
104
150
  def _handle_normal_response(self, response: ChatCompletion) -> Generator[LLMResponse, None, None]:
105
151
  """Handle normal (non-streaming) response"""
152
+ if not response.choices:
153
+ yield LLMResponse(
154
+ content=json.dumps(getattr(response, "base_resp", None) or response.to_dict()), finish_reason="stop"
155
+ )
156
+ return
106
157
  choice = response.choices[0]
107
158
  content = choice.message.content or "" # type: ignore
108
159
  reasoning = choice.message.reasoning_content # type: ignore
@@ -124,15 +175,22 @@ class OpenAIProvider(Provider):
124
175
  """Handle streaming response from OpenAI API"""
125
176
  # Initialize tool call object to accumulate tool call data across chunks
126
177
  tool_call: Optional[ToolCall] = None
127
-
178
+ started = False
128
179
  # Process each chunk in the response stream
129
180
  for chunk in response:
130
- if not chunk.choices:
181
+ if not chunk.choices and not started:
182
+ # Some api could return error message in the first chunk, no choices to handle, return raw response to show the message
183
+ yield LLMResponse(
184
+ content=json.dumps(getattr(chunk, "base_resp", None) or chunk.to_dict()), finish_reason="stop"
185
+ )
186
+ started = True
131
187
  continue
132
188
 
133
- choice = chunk.choices[0]
134
- delta = choice.delta
135
- finish_reason = choice.finish_reason
189
+ if not chunk.choices:
190
+ continue
191
+ started = True
192
+ delta = chunk.choices[0].delta
193
+ finish_reason = chunk.choices[0].finish_reason
136
194
 
137
195
  # Extract content from current chunk
138
196
  content = delta.content or ""
@@ -0,0 +1,22 @@
1
+ from typing import Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class OpenRouterProvider(OpenAIProvider):
7
+ """OpenRouter provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
10
+
11
+ def get_completion_params_keys(self) -> Dict[str, str]:
12
+ """
13
+ Customize completion parameter keys for OpenRouter API.
14
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility.
15
+
16
+ Returns:
17
+ Dict[str, str]: Modified parameter mapping dictionary
18
+ """
19
+ keys = super().get_completion_params_keys()
20
+ if "max_completion_tokens" in keys:
21
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
22
+ return keys
@@ -0,0 +1,59 @@
1
+ from typing import Any, Dict
2
+
3
+ from ...const import DEFAULT_TEMPERATURE
4
+ from .openai_provider import OpenAIProvider
5
+
6
+
7
+ class SambanovaProvider(OpenAIProvider):
8
+ """Sambanova provider implementation based on OpenAI API"""
9
+
10
+ DEFAULT_BASE_URL = "https://api.sambanova.ai/v1"
11
+ SUPPORT_FUNCTION_CALL_MOELS = (
12
+ "Meta-Llama-3.1-8B-Instruct",
13
+ "Meta-Llama-3.1-405B-Instruct",
14
+ "Meta-Llama-3.3-70B-Instruct",
15
+ "Llama-4-Scout-17B-16E-Instruct",
16
+ "DeepSeek-V3-0324",
17
+ )
18
+
19
+ def get_completion_params_keys(self) -> Dict[str, str]:
20
+ """
21
+ Customize completion parameter keys for Sambanova API.
22
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility
23
+ and removes parameters not supported by Sambanova.
24
+
25
+ Returns:
26
+ Dict[str, str]: Modified parameter mapping dictionary
27
+ """
28
+ keys = super().get_completion_params_keys()
29
+ # Replace max_completion_tokens with max_tokens
30
+ if "max_completion_tokens" in keys:
31
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
32
+ # Remove unsupported parameters
33
+ keys.pop("presence_penalty", None)
34
+ keys.pop("frequency_penalty", None)
35
+ return keys
36
+
37
+ def get_completion_params(self) -> Dict[str, Any]:
38
+ """
39
+ Get completion parameters with Sambanova-specific adjustments.
40
+ Validate temperature range and check for function call compatibility.
41
+
42
+ Returns:
43
+ Dict[str, Any]: Parameters for completion API call
44
+ """
45
+ params = super().get_completion_params()
46
+
47
+ # Validate temperature
48
+ if params.get("temperature") is not None and (params["temperature"] < 0 or params["temperature"] > 1):
49
+ self.console.print("Sambanova temperature must be between 0 and 1, setting to 0.4", style="yellow")
50
+ params["temperature"] = DEFAULT_TEMPERATURE
51
+
52
+ # Check function call compatibility
53
+ if self.enable_function and self.config["MODEL"] not in self.SUPPORT_FUNCTION_CALL_MOELS:
54
+ self.console.print(
55
+ f"Sambanova supports function call models: {', '.join(self.SUPPORT_FUNCTION_CALL_MOELS)}",
56
+ style="yellow",
57
+ )
58
+
59
+ return params
@@ -1,3 +1,5 @@
1
+ from typing import Any, Dict
2
+
1
3
  from .openai_provider import OpenAIProvider
2
4
 
3
5
 
@@ -6,6 +8,7 @@ class SiliconFlowProvider(OpenAIProvider):
6
8
 
7
9
  DEFAULT_BASE_URL = "https://api.siliconflow.cn/v1"
8
10
 
9
- def __init__(self, config: dict = ..., **kwargs):
10
- super().__init__(config, **kwargs)
11
- self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
11
+ def get_completion_params(self) -> Dict[str, Any]:
12
+ params = super().get_completion_params()
13
+ params["max_tokens"] = params.pop("max_completion_tokens")
14
+ return params
@@ -0,0 +1,22 @@
1
+ from typing import Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class TargonProvider(OpenAIProvider):
7
+ """Targon provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.targon.com/v1"
10
+
11
+ def get_completion_params_keys(self) -> Dict[str, str]:
12
+ """
13
+ Customize completion parameter keys for Targon API.
14
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility.
15
+
16
+ Returns:
17
+ Dict[str, str]: Modified parameter mapping dictionary
18
+ """
19
+ keys = super().get_completion_params_keys()
20
+ if "max_completion_tokens" in keys:
21
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
22
+ return keys
@@ -0,0 +1,7 @@
1
+ from .openai_provider import OpenAIProvider
2
+
3
+
4
+ class XaiProvider(OpenAIProvider):
5
+ """Xai provider implementation based on openai-compatible API"""
6
+
7
+ DEFAULT_BASE_URL = "https://api.xai.com/v1"
@@ -0,0 +1,22 @@
1
+ from typing import Dict
2
+
3
+ from .openai_provider import OpenAIProvider
4
+
5
+
6
+ class YiProvider(OpenAIProvider):
7
+ """Lingyiwanwu provider implementation based on openai-compatible API"""
8
+
9
+ DEFAULT_BASE_URL = "https://api.lingyiwanwu.com/v1"
10
+
11
+ def get_completion_params_keys(self) -> Dict[str, str]:
12
+ """
13
+ Customize completion parameter keys for Yi API.
14
+ Maps 'max_completion_tokens' to 'max_tokens' for compatibility.
15
+
16
+ Returns:
17
+ Dict[str, str]: Modified parameter mapping dictionary
18
+ """
19
+ keys = super().get_completion_params_keys()
20
+ if "max_completion_tokens" in keys:
21
+ keys["max_tokens"] = keys.pop("max_completion_tokens")
22
+ return keys
@@ -1,7 +0,0 @@
1
- from .openai_provider import OpenAIProvider
2
-
3
-
4
- class ChutesProvider(OpenAIProvider):
5
- """Chutes provider implementation based on openai-compatible API"""
6
-
7
- DEFAULT_BASE_URL = "https://llm.chutes.ai/v1"
@@ -1,11 +0,0 @@
1
- from .openai_provider import OpenAIProvider
2
-
3
-
4
- class DeepSeekProvider(OpenAIProvider):
5
- """DeepSeek provider implementation based on openai-compatible API"""
6
-
7
- DEFAULT_BASE_URL = "https://api.deepseek.com/v1"
8
-
9
- def __init__(self, config: dict = ..., **kwargs):
10
- super().__init__(config, **kwargs)
11
- self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -1,14 +0,0 @@
1
- from .openai_provider import OpenAIProvider
2
-
3
-
4
- class GroqProvider(OpenAIProvider):
5
- """Groq provider implementation based on openai-compatible API"""
6
-
7
- DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
8
-
9
- def __init__(self, config: dict = ..., **kwargs):
10
- super().__init__(config, **kwargs)
11
- if self.config.get("EXTRA_BODY") and "N" in self.config["EXTRA_BODY"] and self.config["EXTRA_BODY"]["N"] != 1:
12
- self.console.print("Groq does not support N parameter, setting N to 1 as Groq default", style="yellow")
13
- if "extra_body" in self.completion_params:
14
- self.completion_params["extra_body"]["N"] = 1
@@ -1,11 +0,0 @@
1
- from .openai_provider import OpenAIProvider
2
-
3
-
4
- class OpenRouterProvider(OpenAIProvider):
5
- """OpenRouter provider implementation based on openai-compatible API"""
6
-
7
- DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
8
-
9
- def __init__(self, config: dict = ..., **kwargs):
10
- super().__init__(config, **kwargs)
11
- self.completion_params["max_tokens"] = self.completion_params.pop("max_completion_tokens")
@@ -1,28 +0,0 @@
1
- from ...const import DEFAULT_TEMPERATURE
2
- from .openai_provider import OpenAIProvider
3
-
4
-
5
- class SambanovaProvider(OpenAIProvider):
6
- """Sambanova provider implementation based on OpenAI API"""
7
-
8
- DEFAULT_BASE_URL = "https://api.sambanova.ai/v1"
9
- SUPPORT_FUNCTION_CALL_MOELS = (
10
- "Meta-Llama-3.1-8B-Instruct",
11
- "Meta-Llama-3.1-405B-Instruct",
12
- "Meta-Llama-3.3-70B-Instruct",
13
- "Llama-4-Scout-17B-16E-Instruct",
14
- "DeepSeek-V3-0324",
15
- )
16
-
17
- def __init__(self, config: dict = ..., verbose: bool = False, **kwargs):
18
- super().__init__(config, verbose, **kwargs)
19
- self.completion_params.pop("presence_penalty", None)
20
- self.completion_params.pop("frequency_penalty", None)
21
- if self.completion_params.get("temperature") < 0 or self.completion_params.get("temperature") > 1:
22
- self.console.print("Sambanova temperature must be between 0 and 1, setting to 0.4", style="yellow")
23
- self.completion_params["temperature"] = DEFAULT_TEMPERATURE
24
- if self.enable_function and self.config["MODEL"] not in self.SUPPORT_FUNCTION_CALL_MOELS:
25
- self.console.print(
26
- f"Sambanova supports function call models: {', '.join(self.SUPPORT_FUNCTION_CALL_MOELS)}",
27
- style="yellow",
28
- )
@@ -1,7 +0,0 @@
1
- from .openai_provider import OpenAIProvider
2
-
3
-
4
- class YiProvider(OpenAIProvider):
5
- """Yi provider implementation based on openai-compatible API"""
6
-
7
- DEFAULT_BASE_URL = "https://api.lingyiwanwu.com/v1"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes