prompture 0.0.40.dev1__py3-none-any.whl → 0.0.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +311 -0
- prompture/drivers/async_openrouter_driver.py +190 -2
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +302 -0
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +341 -0
- prompture/drivers/openrouter_driver.py +235 -39
- prompture/drivers/zai_driver.py +317 -0
- prompture/model_rates.py +2 -0
- prompture/settings.py +15 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/METADATA +1 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/RECORD +19 -13
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/WHEEL +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
6
8
|
from typing import Any
|
|
7
9
|
|
|
8
10
|
import httpx
|
|
@@ -14,11 +16,14 @@ from .openrouter_driver import OpenRouterDriver
|
|
|
14
16
|
|
|
15
17
|
class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
16
18
|
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
supports_tool_use = True
|
|
21
|
+
supports_streaming = True
|
|
17
22
|
supports_vision = True
|
|
18
23
|
|
|
19
24
|
MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
|
|
20
25
|
|
|
21
|
-
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-
|
|
26
|
+
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-4o-mini"):
|
|
22
27
|
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
23
28
|
if not self.api_key:
|
|
24
29
|
raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
|
|
@@ -51,6 +56,13 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
51
56
|
tokens_param = model_config["tokens_param"]
|
|
52
57
|
supports_temperature = model_config["supports_temperature"]
|
|
53
58
|
|
|
59
|
+
# Validate capabilities against models.dev metadata
|
|
60
|
+
self._validate_model_capabilities(
|
|
61
|
+
"openrouter",
|
|
62
|
+
model,
|
|
63
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
64
|
+
)
|
|
65
|
+
|
|
54
66
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
55
67
|
|
|
56
68
|
data = {
|
|
@@ -64,7 +76,18 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
64
76
|
|
|
65
77
|
# Native JSON mode support
|
|
66
78
|
if options.get("json_mode"):
|
|
67
|
-
|
|
79
|
+
json_schema = options.get("json_schema")
|
|
80
|
+
if json_schema:
|
|
81
|
+
data["response_format"] = {
|
|
82
|
+
"type": "json_schema",
|
|
83
|
+
"json_schema": {
|
|
84
|
+
"name": "extraction",
|
|
85
|
+
"strict": True,
|
|
86
|
+
"schema": json_schema,
|
|
87
|
+
},
|
|
88
|
+
}
|
|
89
|
+
else:
|
|
90
|
+
data["response_format"] = {"type": "json_object"}
|
|
68
91
|
|
|
69
92
|
async with httpx.AsyncClient() as client:
|
|
70
93
|
try:
|
|
@@ -100,3 +123,168 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
100
123
|
|
|
101
124
|
text = resp["choices"][0]["message"]["content"]
|
|
102
125
|
return {"text": text, "meta": meta}
|
|
126
|
+
|
|
127
|
+
# ------------------------------------------------------------------
|
|
128
|
+
# Tool use
|
|
129
|
+
# ------------------------------------------------------------------
|
|
130
|
+
|
|
131
|
+
async def generate_messages_with_tools(
|
|
132
|
+
self,
|
|
133
|
+
messages: list[dict[str, Any]],
|
|
134
|
+
tools: list[dict[str, Any]],
|
|
135
|
+
options: dict[str, Any],
|
|
136
|
+
) -> dict[str, Any]:
|
|
137
|
+
"""Generate a response that may include tool calls."""
|
|
138
|
+
model = options.get("model", self.model)
|
|
139
|
+
model_config = self._get_model_config("openrouter", model)
|
|
140
|
+
tokens_param = model_config["tokens_param"]
|
|
141
|
+
supports_temperature = model_config["supports_temperature"]
|
|
142
|
+
|
|
143
|
+
self._validate_model_capabilities("openrouter", model, using_tool_use=True)
|
|
144
|
+
|
|
145
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
146
|
+
|
|
147
|
+
data: dict[str, Any] = {
|
|
148
|
+
"model": model,
|
|
149
|
+
"messages": messages,
|
|
150
|
+
"tools": tools,
|
|
151
|
+
}
|
|
152
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
153
|
+
|
|
154
|
+
if supports_temperature and "temperature" in opts:
|
|
155
|
+
data["temperature"] = opts["temperature"]
|
|
156
|
+
|
|
157
|
+
async with httpx.AsyncClient() as client:
|
|
158
|
+
try:
|
|
159
|
+
response = await client.post(
|
|
160
|
+
f"{self.base_url}/chat/completions",
|
|
161
|
+
headers=self.headers,
|
|
162
|
+
json=data,
|
|
163
|
+
timeout=120,
|
|
164
|
+
)
|
|
165
|
+
response.raise_for_status()
|
|
166
|
+
resp = response.json()
|
|
167
|
+
except httpx.HTTPStatusError as e:
|
|
168
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
169
|
+
raise RuntimeError(error_msg) from e
|
|
170
|
+
except Exception as e:
|
|
171
|
+
raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
|
|
172
|
+
|
|
173
|
+
usage = resp.get("usage", {})
|
|
174
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
175
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
176
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
177
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
178
|
+
|
|
179
|
+
meta = {
|
|
180
|
+
"prompt_tokens": prompt_tokens,
|
|
181
|
+
"completion_tokens": completion_tokens,
|
|
182
|
+
"total_tokens": total_tokens,
|
|
183
|
+
"cost": round(total_cost, 6),
|
|
184
|
+
"raw_response": resp,
|
|
185
|
+
"model_name": model,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
choice = resp["choices"][0]
|
|
189
|
+
text = choice["message"].get("content") or ""
|
|
190
|
+
stop_reason = choice.get("finish_reason")
|
|
191
|
+
|
|
192
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
193
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
194
|
+
try:
|
|
195
|
+
args = json.loads(tc["function"]["arguments"])
|
|
196
|
+
except (json.JSONDecodeError, TypeError):
|
|
197
|
+
args = {}
|
|
198
|
+
tool_calls_out.append({
|
|
199
|
+
"id": tc["id"],
|
|
200
|
+
"name": tc["function"]["name"],
|
|
201
|
+
"arguments": args,
|
|
202
|
+
})
|
|
203
|
+
|
|
204
|
+
return {
|
|
205
|
+
"text": text,
|
|
206
|
+
"meta": meta,
|
|
207
|
+
"tool_calls": tool_calls_out,
|
|
208
|
+
"stop_reason": stop_reason,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
# ------------------------------------------------------------------
|
|
212
|
+
# Streaming
|
|
213
|
+
# ------------------------------------------------------------------
|
|
214
|
+
|
|
215
|
+
async def generate_messages_stream(
|
|
216
|
+
self,
|
|
217
|
+
messages: list[dict[str, Any]],
|
|
218
|
+
options: dict[str, Any],
|
|
219
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
220
|
+
"""Yield response chunks via OpenRouter streaming API."""
|
|
221
|
+
model = options.get("model", self.model)
|
|
222
|
+
model_config = self._get_model_config("openrouter", model)
|
|
223
|
+
tokens_param = model_config["tokens_param"]
|
|
224
|
+
supports_temperature = model_config["supports_temperature"]
|
|
225
|
+
|
|
226
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
227
|
+
|
|
228
|
+
data: dict[str, Any] = {
|
|
229
|
+
"model": model,
|
|
230
|
+
"messages": messages,
|
|
231
|
+
"stream": True,
|
|
232
|
+
"stream_options": {"include_usage": True},
|
|
233
|
+
}
|
|
234
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
235
|
+
|
|
236
|
+
if supports_temperature and "temperature" in opts:
|
|
237
|
+
data["temperature"] = opts["temperature"]
|
|
238
|
+
|
|
239
|
+
full_text = ""
|
|
240
|
+
prompt_tokens = 0
|
|
241
|
+
completion_tokens = 0
|
|
242
|
+
|
|
243
|
+
async with httpx.AsyncClient() as client, client.stream(
|
|
244
|
+
"POST",
|
|
245
|
+
f"{self.base_url}/chat/completions",
|
|
246
|
+
headers=self.headers,
|
|
247
|
+
json=data,
|
|
248
|
+
timeout=120,
|
|
249
|
+
) as response:
|
|
250
|
+
response.raise_for_status()
|
|
251
|
+
async for line in response.aiter_lines():
|
|
252
|
+
if not line or not line.startswith("data: "):
|
|
253
|
+
continue
|
|
254
|
+
payload = line[len("data: "):]
|
|
255
|
+
if payload.strip() == "[DONE]":
|
|
256
|
+
break
|
|
257
|
+
try:
|
|
258
|
+
chunk = json.loads(payload)
|
|
259
|
+
except json.JSONDecodeError:
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
# Usage comes in the final chunk
|
|
263
|
+
usage = chunk.get("usage")
|
|
264
|
+
if usage:
|
|
265
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
266
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
267
|
+
|
|
268
|
+
choices = chunk.get("choices", [])
|
|
269
|
+
if choices:
|
|
270
|
+
delta = choices[0].get("delta", {})
|
|
271
|
+
content = delta.get("content", "")
|
|
272
|
+
if content:
|
|
273
|
+
full_text += content
|
|
274
|
+
yield {"type": "delta", "text": content}
|
|
275
|
+
|
|
276
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
277
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
278
|
+
|
|
279
|
+
yield {
|
|
280
|
+
"type": "done",
|
|
281
|
+
"text": full_text,
|
|
282
|
+
"meta": {
|
|
283
|
+
"prompt_tokens": prompt_tokens,
|
|
284
|
+
"completion_tokens": completion_tokens,
|
|
285
|
+
"total_tokens": total_tokens,
|
|
286
|
+
"cost": round(total_cost, 6),
|
|
287
|
+
"raw_response": {},
|
|
288
|
+
"model_name": model,
|
|
289
|
+
},
|
|
290
|
+
}
|
|
@@ -22,9 +22,12 @@ from .async_grok_driver import AsyncGrokDriver
|
|
|
22
22
|
from .async_groq_driver import AsyncGroqDriver
|
|
23
23
|
from .async_lmstudio_driver import AsyncLMStudioDriver
|
|
24
24
|
from .async_local_http_driver import AsyncLocalHTTPDriver
|
|
25
|
+
from .async_modelscope_driver import AsyncModelScopeDriver
|
|
26
|
+
from .async_moonshot_driver import AsyncMoonshotDriver
|
|
25
27
|
from .async_ollama_driver import AsyncOllamaDriver
|
|
26
28
|
from .async_openai_driver import AsyncOpenAIDriver
|
|
27
29
|
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
30
|
+
from .async_zai_driver import AsyncZaiDriver
|
|
28
31
|
from .registry import (
|
|
29
32
|
_get_async_registry,
|
|
30
33
|
get_async_driver_factory,
|
|
@@ -90,6 +93,33 @@ register_async_driver(
|
|
|
90
93
|
lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
91
94
|
overwrite=True,
|
|
92
95
|
)
|
|
96
|
+
register_async_driver(
|
|
97
|
+
"moonshot",
|
|
98
|
+
lambda model=None: AsyncMoonshotDriver(
|
|
99
|
+
api_key=settings.moonshot_api_key,
|
|
100
|
+
model=model or settings.moonshot_model,
|
|
101
|
+
endpoint=settings.moonshot_endpoint,
|
|
102
|
+
),
|
|
103
|
+
overwrite=True,
|
|
104
|
+
)
|
|
105
|
+
register_async_driver(
|
|
106
|
+
"modelscope",
|
|
107
|
+
lambda model=None: AsyncModelScopeDriver(
|
|
108
|
+
api_key=settings.modelscope_api_key,
|
|
109
|
+
model=model or settings.modelscope_model,
|
|
110
|
+
endpoint=settings.modelscope_endpoint,
|
|
111
|
+
),
|
|
112
|
+
overwrite=True,
|
|
113
|
+
)
|
|
114
|
+
register_async_driver(
|
|
115
|
+
"zai",
|
|
116
|
+
lambda model=None: AsyncZaiDriver(
|
|
117
|
+
api_key=settings.zhipu_api_key,
|
|
118
|
+
model=model or settings.zhipu_model,
|
|
119
|
+
endpoint=settings.zhipu_endpoint,
|
|
120
|
+
),
|
|
121
|
+
overwrite=True,
|
|
122
|
+
)
|
|
93
123
|
register_async_driver(
|
|
94
124
|
"airllm",
|
|
95
125
|
lambda model=None: AsyncAirLLMDriver(
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""Async Z.ai (Zhipu AI) driver using httpx.
|
|
2
|
+
|
|
3
|
+
All pricing comes from models.dev (provider: "zai") — no hardcoded pricing.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import AsyncIterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from ..async_driver import AsyncDriver
|
|
16
|
+
from ..cost_mixin import CostMixin
|
|
17
|
+
from .zai_driver import ZaiDriver
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncZaiDriver(CostMixin, AsyncDriver):
|
|
21
|
+
supports_json_mode = True
|
|
22
|
+
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
25
|
+
supports_vision = True
|
|
26
|
+
|
|
27
|
+
MODEL_PRICING = ZaiDriver.MODEL_PRICING
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str | None = None,
|
|
32
|
+
model: str = "glm-4.7",
|
|
33
|
+
endpoint: str = "https://api.z.ai/api/paas/v4",
|
|
34
|
+
):
|
|
35
|
+
self.api_key = api_key or os.getenv("ZHIPU_API_KEY")
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
raise ValueError("Zhipu API key not found. Set ZHIPU_API_KEY env var.")
|
|
38
|
+
self.model = model
|
|
39
|
+
self.base_url = endpoint.rstrip("/")
|
|
40
|
+
self.headers = {
|
|
41
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
42
|
+
"Content-Type": "application/json",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
supports_messages = True
|
|
46
|
+
|
|
47
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
48
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
49
|
+
|
|
50
|
+
return _prepare_openai_vision_messages(messages)
|
|
51
|
+
|
|
52
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
53
|
+
messages = [{"role": "user", "content": prompt}]
|
|
54
|
+
return await self._do_generate(messages, options)
|
|
55
|
+
|
|
56
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
57
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
58
|
+
|
|
59
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
+
model = options.get("model", self.model)
|
|
61
|
+
|
|
62
|
+
model_config = self._get_model_config("zai", model)
|
|
63
|
+
tokens_param = model_config["tokens_param"]
|
|
64
|
+
supports_temperature = model_config["supports_temperature"]
|
|
65
|
+
|
|
66
|
+
self._validate_model_capabilities(
|
|
67
|
+
"zai",
|
|
68
|
+
model,
|
|
69
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
73
|
+
|
|
74
|
+
data: dict[str, Any] = {
|
|
75
|
+
"model": model,
|
|
76
|
+
"messages": messages,
|
|
77
|
+
}
|
|
78
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
79
|
+
|
|
80
|
+
if supports_temperature and "temperature" in opts:
|
|
81
|
+
data["temperature"] = opts["temperature"]
|
|
82
|
+
|
|
83
|
+
if options.get("json_mode"):
|
|
84
|
+
json_schema = options.get("json_schema")
|
|
85
|
+
if json_schema:
|
|
86
|
+
data["response_format"] = {
|
|
87
|
+
"type": "json_schema",
|
|
88
|
+
"json_schema": {
|
|
89
|
+
"name": "extraction",
|
|
90
|
+
"strict": True,
|
|
91
|
+
"schema": json_schema,
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
else:
|
|
95
|
+
data["response_format"] = {"type": "json_object"}
|
|
96
|
+
|
|
97
|
+
async with httpx.AsyncClient() as client:
|
|
98
|
+
try:
|
|
99
|
+
response = await client.post(
|
|
100
|
+
f"{self.base_url}/chat/completions",
|
|
101
|
+
headers=self.headers,
|
|
102
|
+
json=data,
|
|
103
|
+
timeout=120,
|
|
104
|
+
)
|
|
105
|
+
response.raise_for_status()
|
|
106
|
+
resp = response.json()
|
|
107
|
+
except httpx.HTTPStatusError as e:
|
|
108
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
109
|
+
raise RuntimeError(error_msg) from e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
112
|
+
|
|
113
|
+
usage = resp.get("usage", {})
|
|
114
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
115
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
116
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
117
|
+
|
|
118
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
119
|
+
|
|
120
|
+
meta = {
|
|
121
|
+
"prompt_tokens": prompt_tokens,
|
|
122
|
+
"completion_tokens": completion_tokens,
|
|
123
|
+
"total_tokens": total_tokens,
|
|
124
|
+
"cost": round(total_cost, 6),
|
|
125
|
+
"raw_response": resp,
|
|
126
|
+
"model_name": model,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
text = resp["choices"][0]["message"]["content"]
|
|
130
|
+
return {"text": text, "meta": meta}
|
|
131
|
+
|
|
132
|
+
# ------------------------------------------------------------------
|
|
133
|
+
# Tool use
|
|
134
|
+
# ------------------------------------------------------------------
|
|
135
|
+
|
|
136
|
+
async def generate_messages_with_tools(
|
|
137
|
+
self,
|
|
138
|
+
messages: list[dict[str, Any]],
|
|
139
|
+
tools: list[dict[str, Any]],
|
|
140
|
+
options: dict[str, Any],
|
|
141
|
+
) -> dict[str, Any]:
|
|
142
|
+
"""Generate a response that may include tool calls."""
|
|
143
|
+
model = options.get("model", self.model)
|
|
144
|
+
model_config = self._get_model_config("zai", model)
|
|
145
|
+
tokens_param = model_config["tokens_param"]
|
|
146
|
+
supports_temperature = model_config["supports_temperature"]
|
|
147
|
+
|
|
148
|
+
self._validate_model_capabilities("zai", model, using_tool_use=True)
|
|
149
|
+
|
|
150
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
151
|
+
|
|
152
|
+
data: dict[str, Any] = {
|
|
153
|
+
"model": model,
|
|
154
|
+
"messages": messages,
|
|
155
|
+
"tools": tools,
|
|
156
|
+
}
|
|
157
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
158
|
+
|
|
159
|
+
if supports_temperature and "temperature" in opts:
|
|
160
|
+
data["temperature"] = opts["temperature"]
|
|
161
|
+
|
|
162
|
+
if "tool_choice" in options:
|
|
163
|
+
data["tool_choice"] = options["tool_choice"]
|
|
164
|
+
|
|
165
|
+
async with httpx.AsyncClient() as client:
|
|
166
|
+
try:
|
|
167
|
+
response = await client.post(
|
|
168
|
+
f"{self.base_url}/chat/completions",
|
|
169
|
+
headers=self.headers,
|
|
170
|
+
json=data,
|
|
171
|
+
timeout=120,
|
|
172
|
+
)
|
|
173
|
+
response.raise_for_status()
|
|
174
|
+
resp = response.json()
|
|
175
|
+
except httpx.HTTPStatusError as e:
|
|
176
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
177
|
+
raise RuntimeError(error_msg) from e
|
|
178
|
+
except Exception as e:
|
|
179
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
180
|
+
|
|
181
|
+
usage = resp.get("usage", {})
|
|
182
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
183
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
184
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
185
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
186
|
+
|
|
187
|
+
meta = {
|
|
188
|
+
"prompt_tokens": prompt_tokens,
|
|
189
|
+
"completion_tokens": completion_tokens,
|
|
190
|
+
"total_tokens": total_tokens,
|
|
191
|
+
"cost": round(total_cost, 6),
|
|
192
|
+
"raw_response": resp,
|
|
193
|
+
"model_name": model,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
choice = resp["choices"][0]
|
|
197
|
+
text = choice["message"].get("content") or ""
|
|
198
|
+
stop_reason = choice.get("finish_reason")
|
|
199
|
+
|
|
200
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
201
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
202
|
+
try:
|
|
203
|
+
args = json.loads(tc["function"]["arguments"])
|
|
204
|
+
except (json.JSONDecodeError, TypeError):
|
|
205
|
+
args = {}
|
|
206
|
+
tool_calls_out.append(
|
|
207
|
+
{
|
|
208
|
+
"id": tc["id"],
|
|
209
|
+
"name": tc["function"]["name"],
|
|
210
|
+
"arguments": args,
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
"text": text,
|
|
216
|
+
"meta": meta,
|
|
217
|
+
"tool_calls": tool_calls_out,
|
|
218
|
+
"stop_reason": stop_reason,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
# Streaming
|
|
223
|
+
# ------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
async def generate_messages_stream(
|
|
226
|
+
self,
|
|
227
|
+
messages: list[dict[str, Any]],
|
|
228
|
+
options: dict[str, Any],
|
|
229
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
230
|
+
"""Yield response chunks via Z.ai streaming API."""
|
|
231
|
+
model = options.get("model", self.model)
|
|
232
|
+
model_config = self._get_model_config("zai", model)
|
|
233
|
+
tokens_param = model_config["tokens_param"]
|
|
234
|
+
supports_temperature = model_config["supports_temperature"]
|
|
235
|
+
|
|
236
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
237
|
+
|
|
238
|
+
data: dict[str, Any] = {
|
|
239
|
+
"model": model,
|
|
240
|
+
"messages": messages,
|
|
241
|
+
"stream": True,
|
|
242
|
+
"stream_options": {"include_usage": True},
|
|
243
|
+
}
|
|
244
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
245
|
+
|
|
246
|
+
if supports_temperature and "temperature" in opts:
|
|
247
|
+
data["temperature"] = opts["temperature"]
|
|
248
|
+
|
|
249
|
+
full_text = ""
|
|
250
|
+
prompt_tokens = 0
|
|
251
|
+
completion_tokens = 0
|
|
252
|
+
|
|
253
|
+
async with (
|
|
254
|
+
httpx.AsyncClient() as client,
|
|
255
|
+
client.stream(
|
|
256
|
+
"POST",
|
|
257
|
+
f"{self.base_url}/chat/completions",
|
|
258
|
+
headers=self.headers,
|
|
259
|
+
json=data,
|
|
260
|
+
timeout=120,
|
|
261
|
+
) as response,
|
|
262
|
+
):
|
|
263
|
+
response.raise_for_status()
|
|
264
|
+
async for line in response.aiter_lines():
|
|
265
|
+
if not line or not line.startswith("data: "):
|
|
266
|
+
continue
|
|
267
|
+
payload = line[len("data: ") :]
|
|
268
|
+
if payload.strip() == "[DONE]":
|
|
269
|
+
break
|
|
270
|
+
try:
|
|
271
|
+
chunk = json.loads(payload)
|
|
272
|
+
except json.JSONDecodeError:
|
|
273
|
+
continue
|
|
274
|
+
|
|
275
|
+
usage = chunk.get("usage")
|
|
276
|
+
if usage:
|
|
277
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
278
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
279
|
+
|
|
280
|
+
choices = chunk.get("choices", [])
|
|
281
|
+
if choices:
|
|
282
|
+
delta = choices[0].get("delta", {})
|
|
283
|
+
content = delta.get("content", "")
|
|
284
|
+
if content:
|
|
285
|
+
full_text += content
|
|
286
|
+
yield {"type": "delta", "text": content}
|
|
287
|
+
|
|
288
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
289
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
290
|
+
|
|
291
|
+
yield {
|
|
292
|
+
"type": "done",
|
|
293
|
+
"text": full_text,
|
|
294
|
+
"meta": {
|
|
295
|
+
"prompt_tokens": prompt_tokens,
|
|
296
|
+
"completion_tokens": completion_tokens,
|
|
297
|
+
"total_tokens": total_tokens,
|
|
298
|
+
"cost": round(total_cost, 6),
|
|
299
|
+
"raw_response": {},
|
|
300
|
+
"model_name": model,
|
|
301
|
+
},
|
|
302
|
+
}
|