prompture 0.0.40.dev1__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/agent.py +11 -11
- prompture/async_agent.py +11 -11
- prompture/async_groups.py +63 -0
- prompture/cost_mixin.py +25 -0
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_azure_driver.py +3 -2
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +312 -0
- prompture/drivers/async_openai_driver.py +3 -2
- prompture/drivers/async_openrouter_driver.py +192 -3
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +303 -0
- prompture/drivers/azure_driver.py +3 -2
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +342 -0
- prompture/drivers/openai_driver.py +3 -2
- prompture/drivers/openrouter_driver.py +244 -40
- prompture/drivers/zai_driver.py +318 -0
- prompture/groups.py +42 -0
- prompture/model_rates.py +2 -0
- prompture/settings.py +16 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.42.dist-info}/METADATA +1 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.42.dist-info}/RECORD +28 -22
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.42.dist-info}/WHEEL +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.42.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.42.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.42.dist-info}/top_level.txt +0 -0
|
@@ -2,54 +2,66 @@
|
|
|
2
2
|
Requires the `requests` package. Uses OPENROUTER_API_KEY env var.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import contextlib
|
|
6
|
+
import json
|
|
5
7
|
import os
|
|
8
|
+
from collections.abc import Iterator
|
|
6
9
|
from typing import Any
|
|
7
10
|
|
|
8
11
|
import requests
|
|
9
12
|
|
|
10
|
-
from ..cost_mixin import CostMixin
|
|
13
|
+
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
11
14
|
from ..driver import Driver
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class OpenRouterDriver(CostMixin, Driver):
|
|
15
18
|
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
supports_tool_use = True
|
|
21
|
+
supports_streaming = True
|
|
16
22
|
supports_vision = True
|
|
17
23
|
|
|
18
24
|
# Approximate pricing per 1K tokens based on OpenRouter's pricing
|
|
19
25
|
# https://openrouter.ai/docs#pricing
|
|
20
26
|
MODEL_PRICING = {
|
|
21
|
-
"openai/gpt-
|
|
22
|
-
"prompt": 0.
|
|
23
|
-
"completion": 0.
|
|
27
|
+
"openai/gpt-4o": {
|
|
28
|
+
"prompt": 0.005,
|
|
29
|
+
"completion": 0.015,
|
|
24
30
|
"tokens_param": "max_tokens",
|
|
25
31
|
"supports_temperature": True,
|
|
26
32
|
},
|
|
27
|
-
"
|
|
28
|
-
"prompt": 0.
|
|
29
|
-
"completion": 0.
|
|
33
|
+
"openai/gpt-4o-mini": {
|
|
34
|
+
"prompt": 0.00015,
|
|
35
|
+
"completion": 0.0006,
|
|
30
36
|
"tokens_param": "max_tokens",
|
|
31
37
|
"supports_temperature": True,
|
|
32
38
|
},
|
|
33
|
-
"
|
|
34
|
-
"prompt": 0.
|
|
35
|
-
"completion": 0.
|
|
39
|
+
"anthropic/claude-sonnet-4-20250514": {
|
|
40
|
+
"prompt": 0.003,
|
|
41
|
+
"completion": 0.015,
|
|
36
42
|
"tokens_param": "max_tokens",
|
|
37
43
|
"supports_temperature": True,
|
|
38
44
|
},
|
|
39
|
-
"
|
|
40
|
-
"prompt": 0.
|
|
41
|
-
"completion": 0.
|
|
45
|
+
"google/gemini-2.0-flash-001": {
|
|
46
|
+
"prompt": 0.0001,
|
|
47
|
+
"completion": 0.0004,
|
|
48
|
+
"tokens_param": "max_tokens",
|
|
49
|
+
"supports_temperature": True,
|
|
50
|
+
},
|
|
51
|
+
"meta-llama/llama-3.1-70b-instruct": {
|
|
52
|
+
"prompt": 0.0004,
|
|
53
|
+
"completion": 0.0004,
|
|
42
54
|
"tokens_param": "max_tokens",
|
|
43
55
|
"supports_temperature": True,
|
|
44
56
|
},
|
|
45
57
|
}
|
|
46
58
|
|
|
47
|
-
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-
|
|
59
|
+
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-4o-mini"):
|
|
48
60
|
"""Initialize OpenRouter driver.
|
|
49
61
|
|
|
50
62
|
Args:
|
|
51
63
|
api_key: OpenRouter API key. If not provided, will look for OPENROUTER_API_KEY env var
|
|
52
|
-
model: Model to use. Defaults to openai/gpt-
|
|
64
|
+
model: Model to use. Defaults to openai/gpt-4o-mini
|
|
53
65
|
"""
|
|
54
66
|
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
55
67
|
if not self.api_key:
|
|
@@ -90,6 +102,13 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
90
102
|
tokens_param = model_config["tokens_param"]
|
|
91
103
|
supports_temperature = model_config["supports_temperature"]
|
|
92
104
|
|
|
105
|
+
# Validate capabilities against models.dev metadata
|
|
106
|
+
self._validate_model_capabilities(
|
|
107
|
+
"openrouter",
|
|
108
|
+
model,
|
|
109
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
110
|
+
)
|
|
111
|
+
|
|
93
112
|
# Defaults
|
|
94
113
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
95
114
|
|
|
@@ -108,45 +127,230 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
108
127
|
|
|
109
128
|
# Native JSON mode support
|
|
110
129
|
if options.get("json_mode"):
|
|
111
|
-
|
|
130
|
+
json_schema = options.get("json_schema")
|
|
131
|
+
if json_schema:
|
|
132
|
+
schema_copy = prepare_strict_schema(json_schema)
|
|
133
|
+
data["response_format"] = {
|
|
134
|
+
"type": "json_schema",
|
|
135
|
+
"json_schema": {
|
|
136
|
+
"name": "extraction",
|
|
137
|
+
"strict": True,
|
|
138
|
+
"schema": schema_copy,
|
|
139
|
+
},
|
|
140
|
+
}
|
|
141
|
+
else:
|
|
142
|
+
data["response_format"] = {"type": "json_object"}
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
response = requests.post(
|
|
146
|
+
f"{self.base_url}/chat/completions",
|
|
147
|
+
headers=self.headers,
|
|
148
|
+
json=data,
|
|
149
|
+
timeout=120,
|
|
150
|
+
)
|
|
151
|
+
response.raise_for_status()
|
|
152
|
+
resp = response.json()
|
|
153
|
+
except requests.exceptions.HTTPError as e:
|
|
154
|
+
body = ""
|
|
155
|
+
if e.response is not None:
|
|
156
|
+
with contextlib.suppress(Exception):
|
|
157
|
+
body = e.response.text
|
|
158
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
159
|
+
if body:
|
|
160
|
+
error_msg += f"\nResponse: {body}"
|
|
161
|
+
raise RuntimeError(error_msg) from e
|
|
162
|
+
except requests.exceptions.RequestException as e:
|
|
163
|
+
raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
|
|
164
|
+
|
|
165
|
+
# Extract usage info
|
|
166
|
+
usage = resp.get("usage", {})
|
|
167
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
168
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
169
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
170
|
+
|
|
171
|
+
# Calculate cost via shared mixin
|
|
172
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
173
|
+
|
|
174
|
+
# Standardized meta object
|
|
175
|
+
meta = {
|
|
176
|
+
"prompt_tokens": prompt_tokens,
|
|
177
|
+
"completion_tokens": completion_tokens,
|
|
178
|
+
"total_tokens": total_tokens,
|
|
179
|
+
"cost": round(total_cost, 6),
|
|
180
|
+
"raw_response": resp,
|
|
181
|
+
"model_name": model,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
text = resp["choices"][0]["message"]["content"]
|
|
185
|
+
return {"text": text, "meta": meta}
|
|
186
|
+
|
|
187
|
+
# ------------------------------------------------------------------
|
|
188
|
+
# Tool use
|
|
189
|
+
# ------------------------------------------------------------------
|
|
190
|
+
|
|
191
|
+
def generate_messages_with_tools(
|
|
192
|
+
self,
|
|
193
|
+
messages: list[dict[str, Any]],
|
|
194
|
+
tools: list[dict[str, Any]],
|
|
195
|
+
options: dict[str, Any],
|
|
196
|
+
) -> dict[str, Any]:
|
|
197
|
+
"""Generate a response that may include tool calls."""
|
|
198
|
+
if not self.api_key:
|
|
199
|
+
raise RuntimeError("OpenRouter API key not found")
|
|
200
|
+
|
|
201
|
+
model = options.get("model", self.model)
|
|
202
|
+
model_config = self._get_model_config("openrouter", model)
|
|
203
|
+
tokens_param = model_config["tokens_param"]
|
|
204
|
+
supports_temperature = model_config["supports_temperature"]
|
|
205
|
+
|
|
206
|
+
self._validate_model_capabilities("openrouter", model, using_tool_use=True)
|
|
207
|
+
|
|
208
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
209
|
+
|
|
210
|
+
data: dict[str, Any] = {
|
|
211
|
+
"model": model,
|
|
212
|
+
"messages": messages,
|
|
213
|
+
"tools": tools,
|
|
214
|
+
}
|
|
215
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
216
|
+
|
|
217
|
+
if supports_temperature and "temperature" in opts:
|
|
218
|
+
data["temperature"] = opts["temperature"]
|
|
112
219
|
|
|
113
220
|
try:
|
|
114
221
|
response = requests.post(
|
|
115
222
|
f"{self.base_url}/chat/completions",
|
|
116
223
|
headers=self.headers,
|
|
117
224
|
json=data,
|
|
225
|
+
timeout=120,
|
|
118
226
|
)
|
|
119
227
|
response.raise_for_status()
|
|
120
228
|
resp = response.json()
|
|
229
|
+
except requests.exceptions.HTTPError as e:
|
|
230
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
231
|
+
raise RuntimeError(error_msg) from e
|
|
232
|
+
except requests.exceptions.RequestException as e:
|
|
233
|
+
raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
|
|
234
|
+
|
|
235
|
+
usage = resp.get("usage", {})
|
|
236
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
237
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
238
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
239
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
240
|
+
|
|
241
|
+
meta = {
|
|
242
|
+
"prompt_tokens": prompt_tokens,
|
|
243
|
+
"completion_tokens": completion_tokens,
|
|
244
|
+
"total_tokens": total_tokens,
|
|
245
|
+
"cost": round(total_cost, 6),
|
|
246
|
+
"raw_response": resp,
|
|
247
|
+
"model_name": model,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
choice = resp["choices"][0]
|
|
251
|
+
text = choice["message"].get("content") or ""
|
|
252
|
+
stop_reason = choice.get("finish_reason")
|
|
253
|
+
|
|
254
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
255
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
256
|
+
try:
|
|
257
|
+
args = json.loads(tc["function"]["arguments"])
|
|
258
|
+
except (json.JSONDecodeError, TypeError):
|
|
259
|
+
args = {}
|
|
260
|
+
tool_calls_out.append({
|
|
261
|
+
"id": tc["id"],
|
|
262
|
+
"name": tc["function"]["name"],
|
|
263
|
+
"arguments": args,
|
|
264
|
+
})
|
|
265
|
+
|
|
266
|
+
return {
|
|
267
|
+
"text": text,
|
|
268
|
+
"meta": meta,
|
|
269
|
+
"tool_calls": tool_calls_out,
|
|
270
|
+
"stop_reason": stop_reason,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
# ------------------------------------------------------------------
|
|
274
|
+
# Streaming
|
|
275
|
+
# ------------------------------------------------------------------
|
|
276
|
+
|
|
277
|
+
def generate_messages_stream(
|
|
278
|
+
self,
|
|
279
|
+
messages: list[dict[str, Any]],
|
|
280
|
+
options: dict[str, Any],
|
|
281
|
+
) -> Iterator[dict[str, Any]]:
|
|
282
|
+
"""Yield response chunks via OpenRouter streaming API."""
|
|
283
|
+
if not self.api_key:
|
|
284
|
+
raise RuntimeError("OpenRouter API key not found")
|
|
285
|
+
|
|
286
|
+
model = options.get("model", self.model)
|
|
287
|
+
model_config = self._get_model_config("openrouter", model)
|
|
288
|
+
tokens_param = model_config["tokens_param"]
|
|
289
|
+
supports_temperature = model_config["supports_temperature"]
|
|
121
290
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
291
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
292
|
+
|
|
293
|
+
data: dict[str, Any] = {
|
|
294
|
+
"model": model,
|
|
295
|
+
"messages": messages,
|
|
296
|
+
"stream": True,
|
|
297
|
+
"stream_options": {"include_usage": True},
|
|
298
|
+
}
|
|
299
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
300
|
+
|
|
301
|
+
if supports_temperature and "temperature" in opts:
|
|
302
|
+
data["temperature"] = opts["temperature"]
|
|
303
|
+
|
|
304
|
+
response = requests.post(
|
|
305
|
+
f"{self.base_url}/chat/completions",
|
|
306
|
+
headers=self.headers,
|
|
307
|
+
json=data,
|
|
308
|
+
stream=True,
|
|
309
|
+
timeout=120,
|
|
310
|
+
)
|
|
311
|
+
response.raise_for_status()
|
|
127
312
|
|
|
128
|
-
|
|
129
|
-
|
|
313
|
+
full_text = ""
|
|
314
|
+
prompt_tokens = 0
|
|
315
|
+
completion_tokens = 0
|
|
130
316
|
|
|
131
|
-
|
|
132
|
-
|
|
317
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
318
|
+
if not line or not line.startswith("data: "):
|
|
319
|
+
continue
|
|
320
|
+
payload = line[len("data: "):]
|
|
321
|
+
if payload.strip() == "[DONE]":
|
|
322
|
+
break
|
|
323
|
+
try:
|
|
324
|
+
chunk = json.loads(payload)
|
|
325
|
+
except json.JSONDecodeError:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
# Usage comes in the final chunk
|
|
329
|
+
usage = chunk.get("usage")
|
|
330
|
+
if usage:
|
|
331
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
332
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
333
|
+
|
|
334
|
+
choices = chunk.get("choices", [])
|
|
335
|
+
if choices:
|
|
336
|
+
delta = choices[0].get("delta", {})
|
|
337
|
+
content = delta.get("content", "")
|
|
338
|
+
if content:
|
|
339
|
+
full_text += content
|
|
340
|
+
yield {"type": "delta", "text": content}
|
|
341
|
+
|
|
342
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
343
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
344
|
+
|
|
345
|
+
yield {
|
|
346
|
+
"type": "done",
|
|
347
|
+
"text": full_text,
|
|
348
|
+
"meta": {
|
|
133
349
|
"prompt_tokens": prompt_tokens,
|
|
134
350
|
"completion_tokens": completion_tokens,
|
|
135
351
|
"total_tokens": total_tokens,
|
|
136
352
|
"cost": round(total_cost, 6),
|
|
137
|
-
"raw_response":
|
|
353
|
+
"raw_response": {},
|
|
138
354
|
"model_name": model,
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
text = resp["choices"][0]["message"]["content"]
|
|
142
|
-
return {"text": text, "meta": meta}
|
|
143
|
-
|
|
144
|
-
except requests.exceptions.RequestException as e:
|
|
145
|
-
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
146
|
-
if hasattr(e.response, "json"):
|
|
147
|
-
try:
|
|
148
|
-
error_details = e.response.json()
|
|
149
|
-
error_msg = f"{error_msg} - {error_details.get('error', {}).get('message', '')}"
|
|
150
|
-
except Exception:
|
|
151
|
-
pass
|
|
152
|
-
raise RuntimeError(error_msg) from e
|
|
355
|
+
},
|
|
356
|
+
}
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""Z.ai (Zhipu AI) driver implementation.
|
|
2
|
+
Requires the `requests` package. Uses ZHIPU_API_KEY env var.
|
|
3
|
+
|
|
4
|
+
The Z.ai API is fully OpenAI-compatible (/chat/completions).
|
|
5
|
+
All pricing comes from models.dev (provider: "zai") — no hardcoded pricing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import requests
|
|
14
|
+
|
|
15
|
+
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
16
|
+
from ..driver import Driver
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ZaiDriver(CostMixin, Driver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
supports_tool_use = True
|
|
23
|
+
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
25
|
+
|
|
26
|
+
# All pricing resolved live from models.dev (provider: "zai")
|
|
27
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str | None = None,
|
|
32
|
+
model: str = "glm-4.7",
|
|
33
|
+
endpoint: str = "https://api.z.ai/api/paas/v4",
|
|
34
|
+
):
|
|
35
|
+
"""Initialize Z.ai driver.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
api_key: Zhipu API key. If not provided, will look for ZHIPU_API_KEY env var.
|
|
39
|
+
model: Model to use. Defaults to glm-4.7.
|
|
40
|
+
endpoint: API base URL. Defaults to https://api.z.ai/api/paas/v4.
|
|
41
|
+
"""
|
|
42
|
+
self.api_key = api_key or os.getenv("ZHIPU_API_KEY")
|
|
43
|
+
if not self.api_key:
|
|
44
|
+
raise ValueError("Zhipu API key not found. Set ZHIPU_API_KEY env var.")
|
|
45
|
+
|
|
46
|
+
self.model = model
|
|
47
|
+
self.base_url = endpoint.rstrip("/")
|
|
48
|
+
|
|
49
|
+
self.headers = {
|
|
50
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
51
|
+
"Content-Type": "application/json",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
supports_messages = True
|
|
55
|
+
|
|
56
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
57
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
58
|
+
|
|
59
|
+
return _prepare_openai_vision_messages(messages)
|
|
60
|
+
|
|
61
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
messages = [{"role": "user", "content": prompt}]
|
|
63
|
+
return self._do_generate(messages, options)
|
|
64
|
+
|
|
65
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
67
|
+
|
|
68
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
69
|
+
if not self.api_key:
|
|
70
|
+
raise RuntimeError("Zhipu API key not found")
|
|
71
|
+
|
|
72
|
+
model = options.get("model", self.model)
|
|
73
|
+
|
|
74
|
+
model_config = self._get_model_config("zai", model)
|
|
75
|
+
tokens_param = model_config["tokens_param"]
|
|
76
|
+
supports_temperature = model_config["supports_temperature"]
|
|
77
|
+
|
|
78
|
+
self._validate_model_capabilities(
|
|
79
|
+
"zai",
|
|
80
|
+
model,
|
|
81
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
85
|
+
|
|
86
|
+
data: dict[str, Any] = {
|
|
87
|
+
"model": model,
|
|
88
|
+
"messages": messages,
|
|
89
|
+
}
|
|
90
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
91
|
+
|
|
92
|
+
if supports_temperature and "temperature" in opts:
|
|
93
|
+
data["temperature"] = opts["temperature"]
|
|
94
|
+
|
|
95
|
+
# Native JSON mode support
|
|
96
|
+
if options.get("json_mode"):
|
|
97
|
+
json_schema = options.get("json_schema")
|
|
98
|
+
if json_schema:
|
|
99
|
+
schema_copy = prepare_strict_schema(json_schema)
|
|
100
|
+
data["response_format"] = {
|
|
101
|
+
"type": "json_schema",
|
|
102
|
+
"json_schema": {
|
|
103
|
+
"name": "extraction",
|
|
104
|
+
"strict": True,
|
|
105
|
+
"schema": schema_copy,
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
else:
|
|
109
|
+
data["response_format"] = {"type": "json_object"}
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
response = requests.post(
|
|
113
|
+
f"{self.base_url}/chat/completions",
|
|
114
|
+
headers=self.headers,
|
|
115
|
+
json=data,
|
|
116
|
+
timeout=120,
|
|
117
|
+
)
|
|
118
|
+
response.raise_for_status()
|
|
119
|
+
resp = response.json()
|
|
120
|
+
except requests.exceptions.HTTPError as e:
|
|
121
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
122
|
+
raise RuntimeError(error_msg) from e
|
|
123
|
+
except requests.exceptions.RequestException as e:
|
|
124
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
125
|
+
|
|
126
|
+
usage = resp.get("usage", {})
|
|
127
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
128
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
129
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
130
|
+
|
|
131
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
132
|
+
|
|
133
|
+
meta = {
|
|
134
|
+
"prompt_tokens": prompt_tokens,
|
|
135
|
+
"completion_tokens": completion_tokens,
|
|
136
|
+
"total_tokens": total_tokens,
|
|
137
|
+
"cost": round(total_cost, 6),
|
|
138
|
+
"raw_response": resp,
|
|
139
|
+
"model_name": model,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
text = resp["choices"][0]["message"]["content"]
|
|
143
|
+
return {"text": text, "meta": meta}
|
|
144
|
+
|
|
145
|
+
# ------------------------------------------------------------------
|
|
146
|
+
# Tool use
|
|
147
|
+
# ------------------------------------------------------------------
|
|
148
|
+
|
|
149
|
+
def generate_messages_with_tools(
|
|
150
|
+
self,
|
|
151
|
+
messages: list[dict[str, Any]],
|
|
152
|
+
tools: list[dict[str, Any]],
|
|
153
|
+
options: dict[str, Any],
|
|
154
|
+
) -> dict[str, Any]:
|
|
155
|
+
"""Generate a response that may include tool calls."""
|
|
156
|
+
if not self.api_key:
|
|
157
|
+
raise RuntimeError("Zhipu API key not found")
|
|
158
|
+
|
|
159
|
+
model = options.get("model", self.model)
|
|
160
|
+
model_config = self._get_model_config("zai", model)
|
|
161
|
+
tokens_param = model_config["tokens_param"]
|
|
162
|
+
supports_temperature = model_config["supports_temperature"]
|
|
163
|
+
|
|
164
|
+
self._validate_model_capabilities("zai", model, using_tool_use=True)
|
|
165
|
+
|
|
166
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
167
|
+
|
|
168
|
+
data: dict[str, Any] = {
|
|
169
|
+
"model": model,
|
|
170
|
+
"messages": messages,
|
|
171
|
+
"tools": tools,
|
|
172
|
+
}
|
|
173
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
174
|
+
|
|
175
|
+
if supports_temperature and "temperature" in opts:
|
|
176
|
+
data["temperature"] = opts["temperature"]
|
|
177
|
+
|
|
178
|
+
if "tool_choice" in options:
|
|
179
|
+
data["tool_choice"] = options["tool_choice"]
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
response = requests.post(
|
|
183
|
+
f"{self.base_url}/chat/completions",
|
|
184
|
+
headers=self.headers,
|
|
185
|
+
json=data,
|
|
186
|
+
timeout=120,
|
|
187
|
+
)
|
|
188
|
+
response.raise_for_status()
|
|
189
|
+
resp = response.json()
|
|
190
|
+
except requests.exceptions.HTTPError as e:
|
|
191
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
192
|
+
raise RuntimeError(error_msg) from e
|
|
193
|
+
except requests.exceptions.RequestException as e:
|
|
194
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
195
|
+
|
|
196
|
+
usage = resp.get("usage", {})
|
|
197
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
198
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
199
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
200
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
201
|
+
|
|
202
|
+
meta = {
|
|
203
|
+
"prompt_tokens": prompt_tokens,
|
|
204
|
+
"completion_tokens": completion_tokens,
|
|
205
|
+
"total_tokens": total_tokens,
|
|
206
|
+
"cost": round(total_cost, 6),
|
|
207
|
+
"raw_response": resp,
|
|
208
|
+
"model_name": model,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
choice = resp["choices"][0]
|
|
212
|
+
text = choice["message"].get("content") or ""
|
|
213
|
+
stop_reason = choice.get("finish_reason")
|
|
214
|
+
|
|
215
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
216
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
217
|
+
try:
|
|
218
|
+
args = json.loads(tc["function"]["arguments"])
|
|
219
|
+
except (json.JSONDecodeError, TypeError):
|
|
220
|
+
args = {}
|
|
221
|
+
tool_calls_out.append(
|
|
222
|
+
{
|
|
223
|
+
"id": tc["id"],
|
|
224
|
+
"name": tc["function"]["name"],
|
|
225
|
+
"arguments": args,
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return {
|
|
230
|
+
"text": text,
|
|
231
|
+
"meta": meta,
|
|
232
|
+
"tool_calls": tool_calls_out,
|
|
233
|
+
"stop_reason": stop_reason,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# ------------------------------------------------------------------
|
|
237
|
+
# Streaming
|
|
238
|
+
# ------------------------------------------------------------------
|
|
239
|
+
|
|
240
|
+
def generate_messages_stream(
|
|
241
|
+
self,
|
|
242
|
+
messages: list[dict[str, Any]],
|
|
243
|
+
options: dict[str, Any],
|
|
244
|
+
) -> Iterator[dict[str, Any]]:
|
|
245
|
+
"""Yield response chunks via Z.ai streaming API."""
|
|
246
|
+
if not self.api_key:
|
|
247
|
+
raise RuntimeError("Zhipu API key not found")
|
|
248
|
+
|
|
249
|
+
model = options.get("model", self.model)
|
|
250
|
+
model_config = self._get_model_config("zai", model)
|
|
251
|
+
tokens_param = model_config["tokens_param"]
|
|
252
|
+
supports_temperature = model_config["supports_temperature"]
|
|
253
|
+
|
|
254
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
255
|
+
|
|
256
|
+
data: dict[str, Any] = {
|
|
257
|
+
"model": model,
|
|
258
|
+
"messages": messages,
|
|
259
|
+
"stream": True,
|
|
260
|
+
"stream_options": {"include_usage": True},
|
|
261
|
+
}
|
|
262
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
263
|
+
|
|
264
|
+
if supports_temperature and "temperature" in opts:
|
|
265
|
+
data["temperature"] = opts["temperature"]
|
|
266
|
+
|
|
267
|
+
response = requests.post(
|
|
268
|
+
f"{self.base_url}/chat/completions",
|
|
269
|
+
headers=self.headers,
|
|
270
|
+
json=data,
|
|
271
|
+
stream=True,
|
|
272
|
+
timeout=120,
|
|
273
|
+
)
|
|
274
|
+
response.raise_for_status()
|
|
275
|
+
|
|
276
|
+
full_text = ""
|
|
277
|
+
prompt_tokens = 0
|
|
278
|
+
completion_tokens = 0
|
|
279
|
+
|
|
280
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
281
|
+
if not line or not line.startswith("data: "):
|
|
282
|
+
continue
|
|
283
|
+
payload = line[len("data: ") :]
|
|
284
|
+
if payload.strip() == "[DONE]":
|
|
285
|
+
break
|
|
286
|
+
try:
|
|
287
|
+
chunk = json.loads(payload)
|
|
288
|
+
except json.JSONDecodeError:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
usage = chunk.get("usage")
|
|
292
|
+
if usage:
|
|
293
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
294
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
295
|
+
|
|
296
|
+
choices = chunk.get("choices", [])
|
|
297
|
+
if choices:
|
|
298
|
+
delta = choices[0].get("delta", {})
|
|
299
|
+
content = delta.get("content", "")
|
|
300
|
+
if content:
|
|
301
|
+
full_text += content
|
|
302
|
+
yield {"type": "delta", "text": content}
|
|
303
|
+
|
|
304
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
305
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
306
|
+
|
|
307
|
+
yield {
|
|
308
|
+
"type": "done",
|
|
309
|
+
"text": full_text,
|
|
310
|
+
"meta": {
|
|
311
|
+
"prompt_tokens": prompt_tokens,
|
|
312
|
+
"completion_tokens": completion_tokens,
|
|
313
|
+
"total_tokens": total_tokens,
|
|
314
|
+
"cost": round(total_cost, 6),
|
|
315
|
+
"raw_response": {},
|
|
316
|
+
"model_name": model,
|
|
317
|
+
},
|
|
318
|
+
}
|