zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from .base import BaseProvider, Message, Response, ToolCall
|
|
2
|
+
from ..config import get_api_key
|
|
3
|
+
from ..core.errors import (
|
|
4
|
+
NoAPIKeyError,
|
|
5
|
+
AuthenticationError as ZaiAuthenticationError,
|
|
6
|
+
RateLimitError,
|
|
7
|
+
NetworkError,
|
|
8
|
+
classify_provider_error,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AnthropicProvider(BaseProvider):
|
|
13
|
+
name = "anthropic"
|
|
14
|
+
model_key = "claude"
|
|
15
|
+
model_id = "claude-sonnet-4-6"
|
|
16
|
+
context_window = 200000
|
|
17
|
+
supports_streaming = True
|
|
18
|
+
supports_native_tools = True
|
|
19
|
+
|
|
20
|
+
def is_available(self) -> bool:
|
|
21
|
+
return bool(get_api_key("anthropic"))
|
|
22
|
+
|
|
23
|
+
def stream_chat(self, messages: list[Message], system: str = "") -> str:
|
|
24
|
+
from ..core.streaming import stream_anthropic
|
|
25
|
+
key = get_api_key("anthropic")
|
|
26
|
+
if not key:
|
|
27
|
+
raise NoAPIKeyError("anthropic")
|
|
28
|
+
return stream_anthropic(
|
|
29
|
+
messages, system or "You are zai, a helpful AI assistant.",
|
|
30
|
+
self.model_id, key, self.timeout, self.retries,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def chat(
|
|
34
|
+
self,
|
|
35
|
+
messages: list[Message],
|
|
36
|
+
system: str = "",
|
|
37
|
+
tools: list[dict] | None = None,
|
|
38
|
+
) -> Response:
|
|
39
|
+
import anthropic
|
|
40
|
+
|
|
41
|
+
key = get_api_key("anthropic")
|
|
42
|
+
if not key:
|
|
43
|
+
raise NoAPIKeyError("anthropic")
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
client = anthropic.Anthropic(
|
|
47
|
+
api_key=key,
|
|
48
|
+
timeout=self.timeout,
|
|
49
|
+
max_retries=self.retries,
|
|
50
|
+
)
|
|
51
|
+
formatted = []
|
|
52
|
+
for message in messages:
|
|
53
|
+
if message.role == "assistant" and message.tool_calls:
|
|
54
|
+
content = []
|
|
55
|
+
if message.content:
|
|
56
|
+
content.append({"type": "text", "text": message.content})
|
|
57
|
+
content.extend({
|
|
58
|
+
"type": "tool_use",
|
|
59
|
+
"id": call.id,
|
|
60
|
+
"name": call.name,
|
|
61
|
+
"input": call.arguments,
|
|
62
|
+
} for call in message.tool_calls)
|
|
63
|
+
formatted.append({"role": "assistant", "content": content})
|
|
64
|
+
elif message.role == "tool":
|
|
65
|
+
result_block = {
|
|
66
|
+
"type": "tool_result",
|
|
67
|
+
"tool_use_id": message.tool_call_id,
|
|
68
|
+
"content": message.content,
|
|
69
|
+
}
|
|
70
|
+
if (
|
|
71
|
+
formatted
|
|
72
|
+
and formatted[-1]["role"] == "user"
|
|
73
|
+
and isinstance(formatted[-1]["content"], list)
|
|
74
|
+
and all(
|
|
75
|
+
block.get("type") == "tool_result"
|
|
76
|
+
for block in formatted[-1]["content"]
|
|
77
|
+
)
|
|
78
|
+
):
|
|
79
|
+
formatted[-1]["content"].append(result_block)
|
|
80
|
+
else:
|
|
81
|
+
formatted.append({
|
|
82
|
+
"role": "user",
|
|
83
|
+
"content": [result_block],
|
|
84
|
+
})
|
|
85
|
+
else:
|
|
86
|
+
formatted.append({
|
|
87
|
+
"role": message.role,
|
|
88
|
+
"content": message.content,
|
|
89
|
+
})
|
|
90
|
+
request = {
|
|
91
|
+
"model": self.model_id,
|
|
92
|
+
"max_tokens": 8096,
|
|
93
|
+
"system": system or "You are zai, a helpful AI assistant.",
|
|
94
|
+
"messages": formatted,
|
|
95
|
+
}
|
|
96
|
+
if tools:
|
|
97
|
+
request["tools"] = [{
|
|
98
|
+
"name": tool["name"],
|
|
99
|
+
"description": tool["description"],
|
|
100
|
+
"input_schema": tool["parameters"],
|
|
101
|
+
} for tool in tools]
|
|
102
|
+
result = client.messages.create(
|
|
103
|
+
**request,
|
|
104
|
+
)
|
|
105
|
+
text = "".join(
|
|
106
|
+
block.text for block in result.content
|
|
107
|
+
if (
|
|
108
|
+
getattr(block, "type", None) == "text"
|
|
109
|
+
or not isinstance(getattr(block, "type", None), str)
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
tool_calls = [
|
|
113
|
+
ToolCall(id=block.id, name=block.name, arguments=block.input)
|
|
114
|
+
for block in result.content
|
|
115
|
+
if getattr(block, "type", "") == "tool_use"
|
|
116
|
+
]
|
|
117
|
+
return Response(
|
|
118
|
+
content=text,
|
|
119
|
+
model=self.model_id,
|
|
120
|
+
tokens_used=result.usage.input_tokens + result.usage.output_tokens,
|
|
121
|
+
tool_calls=tool_calls,
|
|
122
|
+
)
|
|
123
|
+
except anthropic.RateLimitError:
|
|
124
|
+
raise RateLimitError("anthropic")
|
|
125
|
+
except anthropic.AuthenticationError:
|
|
126
|
+
raise ZaiAuthenticationError("anthropic", "invalid API credentials")
|
|
127
|
+
except Exception as e:
|
|
128
|
+
err = str(e).lower()
|
|
129
|
+
if "connect" in err or "timeout" in err or "network" in err:
|
|
130
|
+
raise NetworkError(str(e))
|
|
131
|
+
raise classify_provider_error("anthropic", e) from e
|
zai/providers/base.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ToolCall:
|
|
8
|
+
id: str
|
|
9
|
+
name: str
|
|
10
|
+
arguments: dict[str, Any]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class Message:
|
|
15
|
+
role: str # "user", "assistant", or "tool"
|
|
16
|
+
content: str = ""
|
|
17
|
+
tool_calls: list[ToolCall] = field(default_factory=list)
|
|
18
|
+
tool_call_id: str = ""
|
|
19
|
+
tool_name: str = ""
|
|
20
|
+
pinned: bool = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class Response:
|
|
25
|
+
content: str
|
|
26
|
+
model: str
|
|
27
|
+
tokens_used: int = 0
|
|
28
|
+
tokens_remaining: int = -1
|
|
29
|
+
tool_calls: list[ToolCall] = field(default_factory=list)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BaseProvider(ABC):
|
|
33
|
+
name: str = ""
|
|
34
|
+
model_key: str = ""
|
|
35
|
+
model_id: str = ""
|
|
36
|
+
context_window: int = 128000
|
|
37
|
+
supports_streaming: bool = False
|
|
38
|
+
supports_native_tools: bool = False
|
|
39
|
+
timeout: float = 60
|
|
40
|
+
retries: int = 2
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
from ..config import get_model_config
|
|
44
|
+
|
|
45
|
+
model = get_model_config(self.model_key or self.name)
|
|
46
|
+
self.model_id = model["model_id"]
|
|
47
|
+
self.context_window = model["context_window"]
|
|
48
|
+
self.timeout = model["timeout"]
|
|
49
|
+
self.retries = model["retries"]
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def chat(
|
|
53
|
+
self,
|
|
54
|
+
messages: list[Message],
|
|
55
|
+
system: str = "",
|
|
56
|
+
tools: list[dict] | None = None,
|
|
57
|
+
) -> Response:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def stream_chat(self, messages: list[Message], system: str = "") -> str:
|
|
61
|
+
"""Stream response — override in providers that support it."""
|
|
62
|
+
response = self.chat(messages, system=system)
|
|
63
|
+
return response.content
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def is_available(self) -> bool:
|
|
67
|
+
pass
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from .base import BaseProvider, Message, Response
|
|
2
|
+
from ..config import get_api_key
|
|
3
|
+
from ..core.errors import (
|
|
4
|
+
NoAPIKeyError,
|
|
5
|
+
AuthenticationError,
|
|
6
|
+
RateLimitError,
|
|
7
|
+
NetworkError,
|
|
8
|
+
classify_provider_error,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CerebrasProvider(BaseProvider):
|
|
13
|
+
name = "cerebras"
|
|
14
|
+
model_id = "llama-3.3-70b"
|
|
15
|
+
context_window = 128000
|
|
16
|
+
|
|
17
|
+
def is_available(self) -> bool:
|
|
18
|
+
return bool(get_api_key("cerebras"))
|
|
19
|
+
|
|
20
|
+
def chat(
|
|
21
|
+
self,
|
|
22
|
+
messages: list[Message],
|
|
23
|
+
system: str = "",
|
|
24
|
+
tools: list[dict] | None = None,
|
|
25
|
+
) -> Response:
|
|
26
|
+
from cerebras.cloud.sdk import Cerebras
|
|
27
|
+
|
|
28
|
+
key = get_api_key("cerebras")
|
|
29
|
+
if not key:
|
|
30
|
+
raise NoAPIKeyError("cerebras")
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
client = Cerebras(
|
|
34
|
+
api_key=key,
|
|
35
|
+
timeout=self.timeout,
|
|
36
|
+
max_retries=self.retries,
|
|
37
|
+
)
|
|
38
|
+
formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
|
|
39
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
40
|
+
result = client.chat.completions.create(
|
|
41
|
+
model=self.model_id,
|
|
42
|
+
messages=formatted,
|
|
43
|
+
)
|
|
44
|
+
return Response(
|
|
45
|
+
content=result.choices[0].message.content,
|
|
46
|
+
model=self.model_id,
|
|
47
|
+
tokens_used=result.usage.total_tokens if result.usage else 0,
|
|
48
|
+
)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
err = str(e).lower()
|
|
51
|
+
if "429" in err or "rate" in err or "quota" in err:
|
|
52
|
+
raise RateLimitError("cerebras")
|
|
53
|
+
if "401" in err or "api key" in err or "auth" in err:
|
|
54
|
+
raise AuthenticationError("cerebras", str(e))
|
|
55
|
+
if "connect" in err or "timeout" in err:
|
|
56
|
+
raise NetworkError(str(e))
|
|
57
|
+
raise classify_provider_error("cerebras", e) from e
|
zai/providers/gemini.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from .base import BaseProvider, Message, Response, ToolCall
|
|
2
|
+
from ..config import MODELS, get_api_key
|
|
3
|
+
from ..core.errors import (
|
|
4
|
+
NoAPIKeyError,
|
|
5
|
+
AuthenticationError,
|
|
6
|
+
RateLimitError,
|
|
7
|
+
NetworkError,
|
|
8
|
+
classify_provider_error,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GeminiProvider(BaseProvider):
|
|
13
|
+
name = "gemini"
|
|
14
|
+
model_id = MODELS["gemini"]["model_id"]
|
|
15
|
+
context_window = MODELS["gemini"]["context_window"]
|
|
16
|
+
supports_streaming = True
|
|
17
|
+
supports_native_tools = True
|
|
18
|
+
|
|
19
|
+
def is_available(self) -> bool:
|
|
20
|
+
return bool(get_api_key("gemini"))
|
|
21
|
+
|
|
22
|
+
def stream_chat(self, messages: list[Message], system: str = "") -> str:
|
|
23
|
+
from ..core.streaming import stream_gemini
|
|
24
|
+
key = get_api_key("gemini")
|
|
25
|
+
if not key:
|
|
26
|
+
raise NoAPIKeyError("gemini")
|
|
27
|
+
return stream_gemini(
|
|
28
|
+
messages, system or "You are zai, a helpful AI assistant.",
|
|
29
|
+
self.model_id, key, self.timeout,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def chat(
|
|
33
|
+
self,
|
|
34
|
+
messages: list[Message],
|
|
35
|
+
system: str = "",
|
|
36
|
+
tools: list[dict] | None = None,
|
|
37
|
+
) -> Response:
|
|
38
|
+
key = get_api_key("gemini")
|
|
39
|
+
if not key:
|
|
40
|
+
raise NoAPIKeyError("gemini")
|
|
41
|
+
|
|
42
|
+
from google import genai
|
|
43
|
+
from google.genai import types
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
client = genai.Client(
|
|
47
|
+
api_key=key,
|
|
48
|
+
http_options=types.HttpOptions(
|
|
49
|
+
timeout=int(self.timeout * 1000),
|
|
50
|
+
),
|
|
51
|
+
)
|
|
52
|
+
contents = []
|
|
53
|
+
for message in messages:
|
|
54
|
+
if message.role == "assistant" and message.tool_calls:
|
|
55
|
+
parts = []
|
|
56
|
+
if message.content:
|
|
57
|
+
parts.append(types.Part.from_text(text=message.content))
|
|
58
|
+
parts.extend(
|
|
59
|
+
types.Part.from_function_call(
|
|
60
|
+
name=call.name,
|
|
61
|
+
args=call.arguments,
|
|
62
|
+
)
|
|
63
|
+
for call in message.tool_calls
|
|
64
|
+
)
|
|
65
|
+
contents.append(types.Content(role="model", parts=parts))
|
|
66
|
+
elif message.role == "tool":
|
|
67
|
+
contents.append(types.Content(
|
|
68
|
+
role="user",
|
|
69
|
+
parts=[types.Part.from_function_response(
|
|
70
|
+
name=message.tool_name,
|
|
71
|
+
response={"result": message.content},
|
|
72
|
+
)],
|
|
73
|
+
))
|
|
74
|
+
else:
|
|
75
|
+
contents.append(types.Content(
|
|
76
|
+
role=message.role if message.role == "user" else "model",
|
|
77
|
+
parts=[types.Part.from_text(text=message.content)],
|
|
78
|
+
))
|
|
79
|
+
config_args = {
|
|
80
|
+
"system_instruction": system or "You are zai, a helpful AI assistant.",
|
|
81
|
+
}
|
|
82
|
+
if tools:
|
|
83
|
+
declarations = [
|
|
84
|
+
types.FunctionDeclaration(
|
|
85
|
+
name=tool["name"],
|
|
86
|
+
description=tool["description"],
|
|
87
|
+
parameters_json_schema=tool["parameters"],
|
|
88
|
+
)
|
|
89
|
+
for tool in tools
|
|
90
|
+
]
|
|
91
|
+
config_args["tools"] = [types.Tool(function_declarations=declarations)]
|
|
92
|
+
result = client.models.generate_content(
|
|
93
|
+
model=self.model_id,
|
|
94
|
+
contents=contents,
|
|
95
|
+
config=types.GenerateContentConfig(**config_args),
|
|
96
|
+
)
|
|
97
|
+
usage = getattr(result, "usage_metadata", None)
|
|
98
|
+
tool_calls = []
|
|
99
|
+
for index, call in enumerate(getattr(result, "function_calls", None) or []):
|
|
100
|
+
tool_calls.append(ToolCall(
|
|
101
|
+
id=f"gemini-{index}-{call.name}",
|
|
102
|
+
name=call.name,
|
|
103
|
+
arguments=dict(call.args or {}),
|
|
104
|
+
))
|
|
105
|
+
return Response(
|
|
106
|
+
content=result.text or "",
|
|
107
|
+
model=self.model_id,
|
|
108
|
+
tokens_used=getattr(usage, "total_token_count", 0) or 0,
|
|
109
|
+
tool_calls=tool_calls,
|
|
110
|
+
)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
err = str(e).lower()
|
|
113
|
+
if "quota" in err or "429" in err or "rate" in err:
|
|
114
|
+
raise RateLimitError("gemini")
|
|
115
|
+
if "api key" in err or "invalid" in err or "401" in err:
|
|
116
|
+
raise AuthenticationError("gemini", str(e))
|
|
117
|
+
if "network" in err or "connect" in err or "timeout" in err:
|
|
118
|
+
raise NetworkError(str(e))
|
|
119
|
+
raise classify_provider_error("gemini", e) from e
|
zai/providers/groq.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from .base import BaseProvider, Message, Response, ToolCall
|
|
4
|
+
from ..config import MODELS, get_api_key
|
|
5
|
+
from ..core.errors import (
|
|
6
|
+
NoAPIKeyError,
|
|
7
|
+
AuthenticationError as ZaiAuthenticationError,
|
|
8
|
+
RateLimitError,
|
|
9
|
+
NetworkError,
|
|
10
|
+
classify_provider_error,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GroqProvider(BaseProvider):
|
|
15
|
+
name = "groq"
|
|
16
|
+
model_id = MODELS["groq"]["model_id"]
|
|
17
|
+
context_window = MODELS["groq"]["context_window"]
|
|
18
|
+
supports_streaming = True
|
|
19
|
+
supports_native_tools = True
|
|
20
|
+
|
|
21
|
+
def is_available(self) -> bool:
|
|
22
|
+
return bool(get_api_key("groq"))
|
|
23
|
+
|
|
24
|
+
def stream_chat(self, messages: list[Message], system: str = "") -> str:
|
|
25
|
+
from ..core.streaming import stream_groq
|
|
26
|
+
key = get_api_key("groq")
|
|
27
|
+
if not key:
|
|
28
|
+
raise NoAPIKeyError("groq")
|
|
29
|
+
return stream_groq(
|
|
30
|
+
messages, system or "You are zai, a helpful AI assistant.",
|
|
31
|
+
self.model_id, key, self.timeout, self.retries,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def chat(
|
|
35
|
+
self,
|
|
36
|
+
messages: list[Message],
|
|
37
|
+
system: str = "",
|
|
38
|
+
tools: list[dict] | None = None,
|
|
39
|
+
) -> Response:
|
|
40
|
+
from groq import Groq, RateLimitError as GroqRateLimit, AuthenticationError
|
|
41
|
+
|
|
42
|
+
key = get_api_key("groq")
|
|
43
|
+
if not key:
|
|
44
|
+
raise NoAPIKeyError("groq")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
client = Groq(
|
|
48
|
+
api_key=key,
|
|
49
|
+
timeout=self.timeout,
|
|
50
|
+
max_retries=self.retries,
|
|
51
|
+
)
|
|
52
|
+
formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
|
|
53
|
+
for message in messages:
|
|
54
|
+
if message.role == "assistant" and message.tool_calls:
|
|
55
|
+
formatted.append({
|
|
56
|
+
"role": "assistant",
|
|
57
|
+
"content": message.content or None,
|
|
58
|
+
"tool_calls": [{
|
|
59
|
+
"id": call.id,
|
|
60
|
+
"type": "function",
|
|
61
|
+
"function": {
|
|
62
|
+
"name": call.name,
|
|
63
|
+
"arguments": json.dumps(call.arguments),
|
|
64
|
+
},
|
|
65
|
+
} for call in message.tool_calls],
|
|
66
|
+
})
|
|
67
|
+
elif message.role == "tool":
|
|
68
|
+
formatted.append({
|
|
69
|
+
"role": "tool",
|
|
70
|
+
"tool_call_id": message.tool_call_id,
|
|
71
|
+
"name": message.tool_name,
|
|
72
|
+
"content": message.content,
|
|
73
|
+
})
|
|
74
|
+
else:
|
|
75
|
+
formatted.append({
|
|
76
|
+
"role": message.role,
|
|
77
|
+
"content": message.content,
|
|
78
|
+
})
|
|
79
|
+
request = {"model": self.model_id, "messages": formatted}
|
|
80
|
+
if tools:
|
|
81
|
+
request["tools"] = [{
|
|
82
|
+
"type": "function",
|
|
83
|
+
"function": {
|
|
84
|
+
"name": tool["name"],
|
|
85
|
+
"description": tool["description"],
|
|
86
|
+
"parameters": tool["parameters"],
|
|
87
|
+
},
|
|
88
|
+
} for tool in tools]
|
|
89
|
+
result = client.chat.completions.create(**request)
|
|
90
|
+
message = result.choices[0].message
|
|
91
|
+
tool_calls = []
|
|
92
|
+
for call in message.tool_calls or []:
|
|
93
|
+
try:
|
|
94
|
+
arguments = json.loads(call.function.arguments)
|
|
95
|
+
except (TypeError, json.JSONDecodeError):
|
|
96
|
+
arguments = {}
|
|
97
|
+
tool_calls.append(ToolCall(
|
|
98
|
+
id=call.id,
|
|
99
|
+
name=call.function.name,
|
|
100
|
+
arguments=arguments,
|
|
101
|
+
))
|
|
102
|
+
return Response(
|
|
103
|
+
content=message.content or "",
|
|
104
|
+
model=self.model_id,
|
|
105
|
+
tokens_used=result.usage.total_tokens,
|
|
106
|
+
tool_calls=tool_calls,
|
|
107
|
+
)
|
|
108
|
+
except GroqRateLimit:
|
|
109
|
+
raise RateLimitError("groq")
|
|
110
|
+
except AuthenticationError:
|
|
111
|
+
raise ZaiAuthenticationError("groq", "invalid API credentials")
|
|
112
|
+
except Exception as e:
|
|
113
|
+
err = str(e).lower()
|
|
114
|
+
if "connect" in err or "timeout" in err or "network" in err:
|
|
115
|
+
raise NetworkError(str(e))
|
|
116
|
+
raise classify_provider_error("groq", e) from e
|
zai/providers/ollama.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from .base import BaseProvider, Message, Response
|
|
2
|
+
from ..core.errors import NetworkError
|
|
3
|
+
|
|
4
|
+
OLLAMA_URL = "http://localhost:11434"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class OllamaProvider(BaseProvider):
|
|
8
|
+
name = "ollama"
|
|
9
|
+
model_id = "llama3.2"
|
|
10
|
+
context_window = 128000
|
|
11
|
+
supports_streaming = True
|
|
12
|
+
|
|
13
|
+
def is_available(self) -> bool:
|
|
14
|
+
import httpx
|
|
15
|
+
try:
|
|
16
|
+
r = httpx.get(f"{OLLAMA_URL}/api/tags", timeout=2)
|
|
17
|
+
return r.status_code == 200
|
|
18
|
+
except Exception:
|
|
19
|
+
return False
|
|
20
|
+
|
|
21
|
+
def stream_chat(self, messages: list[Message], system: str = "") -> str:
|
|
22
|
+
from ..core.streaming import stream_ollama
|
|
23
|
+
return stream_ollama(
|
|
24
|
+
messages, system or "You are zai, a helpful AI assistant.",
|
|
25
|
+
self.model_id, self.timeout,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
def chat(
|
|
29
|
+
self,
|
|
30
|
+
messages: list[Message],
|
|
31
|
+
system: str = "",
|
|
32
|
+
tools: list[dict] | None = None,
|
|
33
|
+
) -> Response:
|
|
34
|
+
import httpx
|
|
35
|
+
import json
|
|
36
|
+
|
|
37
|
+
formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
|
|
38
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
39
|
+
|
|
40
|
+
for attempt in range(self.retries + 1):
|
|
41
|
+
try:
|
|
42
|
+
r = httpx.post(
|
|
43
|
+
f"{OLLAMA_URL}/api/chat",
|
|
44
|
+
json={"model": self.model_id, "messages": formatted, "stream": False},
|
|
45
|
+
timeout=self.timeout,
|
|
46
|
+
)
|
|
47
|
+
r.raise_for_status()
|
|
48
|
+
data = r.json()
|
|
49
|
+
content = data.get("message", {}).get("content", "")
|
|
50
|
+
return Response(
|
|
51
|
+
content=content,
|
|
52
|
+
model=self.model_id,
|
|
53
|
+
tokens_used=data.get("eval_count", 0),
|
|
54
|
+
)
|
|
55
|
+
except (httpx.ConnectError, httpx.TimeoutException) as error:
|
|
56
|
+
if attempt == self.retries:
|
|
57
|
+
if isinstance(error, httpx.ConnectError):
|
|
58
|
+
raise NetworkError(
|
|
59
|
+
"Ollama not running. Start it with: ollama serve"
|
|
60
|
+
)
|
|
61
|
+
raise NetworkError(str(error))
|
|
62
|
+
raise NetworkError("Ollama request failed")
|
zai/providers/openai.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from .base import BaseProvider, Message, Response, ToolCall
|
|
4
|
+
from ..config import get_api_key
|
|
5
|
+
from ..core.errors import (
|
|
6
|
+
NoAPIKeyError,
|
|
7
|
+
AuthenticationError as ZaiAuthenticationError,
|
|
8
|
+
RateLimitError,
|
|
9
|
+
NetworkError,
|
|
10
|
+
classify_provider_error,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OpenAIProvider(BaseProvider):
|
|
15
|
+
name = "openai"
|
|
16
|
+
model_key = "gpt4o"
|
|
17
|
+
model_id = "gpt-4o-mini"
|
|
18
|
+
context_window = 128000
|
|
19
|
+
supports_streaming = True
|
|
20
|
+
supports_native_tools = True
|
|
21
|
+
|
|
22
|
+
def is_available(self) -> bool:
|
|
23
|
+
return bool(get_api_key("openai"))
|
|
24
|
+
|
|
25
|
+
def stream_chat(self, messages: list[Message], system: str = "") -> str:
|
|
26
|
+
from ..core.streaming import stream_openai
|
|
27
|
+
key = get_api_key("openai")
|
|
28
|
+
if not key:
|
|
29
|
+
raise NoAPIKeyError("openai")
|
|
30
|
+
return stream_openai(
|
|
31
|
+
messages, system or "You are zai, a helpful AI assistant.",
|
|
32
|
+
self.model_id, key, self.timeout, self.retries,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def chat(
|
|
36
|
+
self,
|
|
37
|
+
messages: list[Message],
|
|
38
|
+
system: str = "",
|
|
39
|
+
tools: list[dict] | None = None,
|
|
40
|
+
) -> Response:
|
|
41
|
+
from openai import OpenAI, RateLimitError as OAIRateLimit, AuthenticationError
|
|
42
|
+
|
|
43
|
+
key = get_api_key("openai")
|
|
44
|
+
if not key:
|
|
45
|
+
raise NoAPIKeyError("openai")
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
client = OpenAI(
|
|
49
|
+
api_key=key,
|
|
50
|
+
timeout=self.timeout,
|
|
51
|
+
max_retries=self.retries,
|
|
52
|
+
)
|
|
53
|
+
formatted = [{
|
|
54
|
+
"role": "system",
|
|
55
|
+
"content": system or "You are zai, a helpful AI assistant.",
|
|
56
|
+
}]
|
|
57
|
+
for message in messages:
|
|
58
|
+
if message.role == "assistant" and message.tool_calls:
|
|
59
|
+
formatted.append({
|
|
60
|
+
"role": "assistant",
|
|
61
|
+
"content": message.content or None,
|
|
62
|
+
"tool_calls": [{
|
|
63
|
+
"id": call.id,
|
|
64
|
+
"type": "function",
|
|
65
|
+
"function": {
|
|
66
|
+
"name": call.name,
|
|
67
|
+
"arguments": json.dumps(call.arguments),
|
|
68
|
+
},
|
|
69
|
+
} for call in message.tool_calls],
|
|
70
|
+
})
|
|
71
|
+
elif message.role == "tool":
|
|
72
|
+
formatted.append({
|
|
73
|
+
"role": "tool",
|
|
74
|
+
"tool_call_id": message.tool_call_id,
|
|
75
|
+
"content": message.content,
|
|
76
|
+
})
|
|
77
|
+
else:
|
|
78
|
+
formatted.append({"role": message.role, "content": message.content})
|
|
79
|
+
|
|
80
|
+
request = {
|
|
81
|
+
"model": self.model_id,
|
|
82
|
+
"messages": formatted,
|
|
83
|
+
}
|
|
84
|
+
if tools:
|
|
85
|
+
request["tools"] = [{
|
|
86
|
+
"type": "function",
|
|
87
|
+
"function": {
|
|
88
|
+
"name": tool["name"],
|
|
89
|
+
"description": tool["description"],
|
|
90
|
+
"parameters": tool["parameters"],
|
|
91
|
+
"strict": tool["name"] not in {"mcp_call", "plugin_call"},
|
|
92
|
+
},
|
|
93
|
+
} for tool in tools]
|
|
94
|
+
request["parallel_tool_calls"] = True
|
|
95
|
+
result = client.chat.completions.create(
|
|
96
|
+
**request,
|
|
97
|
+
)
|
|
98
|
+
message = result.choices[0].message
|
|
99
|
+
tool_calls = []
|
|
100
|
+
for call in message.tool_calls or []:
|
|
101
|
+
try:
|
|
102
|
+
arguments = json.loads(call.function.arguments)
|
|
103
|
+
except (TypeError, json.JSONDecodeError):
|
|
104
|
+
arguments = {}
|
|
105
|
+
tool_calls.append(ToolCall(
|
|
106
|
+
id=call.id,
|
|
107
|
+
name=call.function.name,
|
|
108
|
+
arguments=arguments,
|
|
109
|
+
))
|
|
110
|
+
return Response(
|
|
111
|
+
content=message.content or "",
|
|
112
|
+
model=self.model_id,
|
|
113
|
+
tokens_used=result.usage.total_tokens if result.usage else 0,
|
|
114
|
+
tool_calls=tool_calls,
|
|
115
|
+
)
|
|
116
|
+
except OAIRateLimit:
|
|
117
|
+
raise RateLimitError("openai")
|
|
118
|
+
except AuthenticationError:
|
|
119
|
+
raise ZaiAuthenticationError("openai", "invalid API credentials")
|
|
120
|
+
except Exception as e:
|
|
121
|
+
err = str(e).lower()
|
|
122
|
+
if "connect" in err or "timeout" in err or "network" in err:
|
|
123
|
+
raise NetworkError(str(e))
|
|
124
|
+
raise classify_provider_error("openai", e) from e
|