prompture 0.0.40.dev1__py3-none-any.whl → 0.0.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +311 -0
- prompture/drivers/async_openrouter_driver.py +190 -2
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +302 -0
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +341 -0
- prompture/drivers/openrouter_driver.py +235 -39
- prompture/drivers/zai_driver.py +317 -0
- prompture/model_rates.py +2 -0
- prompture/settings.py +15 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/METADATA +1 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/RECORD +19 -13
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/WHEEL +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
Requires the `requests` package. Uses OPENROUTER_API_KEY env var.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
7
|
+
from collections.abc import Iterator
|
|
6
8
|
from typing import Any
|
|
7
9
|
|
|
8
10
|
import requests
|
|
@@ -13,43 +15,52 @@ from ..driver import Driver
|
|
|
13
15
|
|
|
14
16
|
class OpenRouterDriver(CostMixin, Driver):
|
|
15
17
|
supports_json_mode = True
|
|
18
|
+
supports_json_schema = True
|
|
19
|
+
supports_tool_use = True
|
|
20
|
+
supports_streaming = True
|
|
16
21
|
supports_vision = True
|
|
17
22
|
|
|
18
23
|
# Approximate pricing per 1K tokens based on OpenRouter's pricing
|
|
19
24
|
# https://openrouter.ai/docs#pricing
|
|
20
25
|
MODEL_PRICING = {
|
|
21
|
-
"openai/gpt-
|
|
22
|
-
"prompt": 0.
|
|
23
|
-
"completion": 0.
|
|
26
|
+
"openai/gpt-4o": {
|
|
27
|
+
"prompt": 0.005,
|
|
28
|
+
"completion": 0.015,
|
|
24
29
|
"tokens_param": "max_tokens",
|
|
25
30
|
"supports_temperature": True,
|
|
26
31
|
},
|
|
27
|
-
"
|
|
28
|
-
"prompt": 0.
|
|
29
|
-
"completion": 0.
|
|
32
|
+
"openai/gpt-4o-mini": {
|
|
33
|
+
"prompt": 0.00015,
|
|
34
|
+
"completion": 0.0006,
|
|
30
35
|
"tokens_param": "max_tokens",
|
|
31
36
|
"supports_temperature": True,
|
|
32
37
|
},
|
|
33
|
-
"
|
|
34
|
-
"prompt": 0.
|
|
35
|
-
"completion": 0.
|
|
38
|
+
"anthropic/claude-sonnet-4-20250514": {
|
|
39
|
+
"prompt": 0.003,
|
|
40
|
+
"completion": 0.015,
|
|
36
41
|
"tokens_param": "max_tokens",
|
|
37
42
|
"supports_temperature": True,
|
|
38
43
|
},
|
|
39
|
-
"
|
|
40
|
-
"prompt": 0.
|
|
41
|
-
"completion": 0.
|
|
44
|
+
"google/gemini-2.0-flash-001": {
|
|
45
|
+
"prompt": 0.0001,
|
|
46
|
+
"completion": 0.0004,
|
|
47
|
+
"tokens_param": "max_tokens",
|
|
48
|
+
"supports_temperature": True,
|
|
49
|
+
},
|
|
50
|
+
"meta-llama/llama-3.1-70b-instruct": {
|
|
51
|
+
"prompt": 0.0004,
|
|
52
|
+
"completion": 0.0004,
|
|
42
53
|
"tokens_param": "max_tokens",
|
|
43
54
|
"supports_temperature": True,
|
|
44
55
|
},
|
|
45
56
|
}
|
|
46
57
|
|
|
47
|
-
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-
|
|
58
|
+
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-4o-mini"):
|
|
48
59
|
"""Initialize OpenRouter driver.
|
|
49
60
|
|
|
50
61
|
Args:
|
|
51
62
|
api_key: OpenRouter API key. If not provided, will look for OPENROUTER_API_KEY env var
|
|
52
|
-
model: Model to use. Defaults to openai/gpt-
|
|
63
|
+
model: Model to use. Defaults to openai/gpt-4o-mini
|
|
53
64
|
"""
|
|
54
65
|
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
55
66
|
if not self.api_key:
|
|
@@ -90,6 +101,13 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
90
101
|
tokens_param = model_config["tokens_param"]
|
|
91
102
|
supports_temperature = model_config["supports_temperature"]
|
|
92
103
|
|
|
104
|
+
# Validate capabilities against models.dev metadata
|
|
105
|
+
self._validate_model_capabilities(
|
|
106
|
+
"openrouter",
|
|
107
|
+
model,
|
|
108
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
109
|
+
)
|
|
110
|
+
|
|
93
111
|
# Defaults
|
|
94
112
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
95
113
|
|
|
@@ -108,45 +126,223 @@ class OpenRouterDriver(CostMixin, Driver):
|
|
|
108
126
|
|
|
109
127
|
# Native JSON mode support
|
|
110
128
|
if options.get("json_mode"):
|
|
111
|
-
|
|
129
|
+
json_schema = options.get("json_schema")
|
|
130
|
+
if json_schema:
|
|
131
|
+
data["response_format"] = {
|
|
132
|
+
"type": "json_schema",
|
|
133
|
+
"json_schema": {
|
|
134
|
+
"name": "extraction",
|
|
135
|
+
"strict": True,
|
|
136
|
+
"schema": json_schema,
|
|
137
|
+
},
|
|
138
|
+
}
|
|
139
|
+
else:
|
|
140
|
+
data["response_format"] = {"type": "json_object"}
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
response = requests.post(
|
|
144
|
+
f"{self.base_url}/chat/completions",
|
|
145
|
+
headers=self.headers,
|
|
146
|
+
json=data,
|
|
147
|
+
timeout=120,
|
|
148
|
+
)
|
|
149
|
+
response.raise_for_status()
|
|
150
|
+
resp = response.json()
|
|
151
|
+
except requests.exceptions.HTTPError as e:
|
|
152
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
153
|
+
raise RuntimeError(error_msg) from e
|
|
154
|
+
except requests.exceptions.RequestException as e:
|
|
155
|
+
raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
|
|
156
|
+
|
|
157
|
+
# Extract usage info
|
|
158
|
+
usage = resp.get("usage", {})
|
|
159
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
160
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
161
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
162
|
+
|
|
163
|
+
# Calculate cost via shared mixin
|
|
164
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
165
|
+
|
|
166
|
+
# Standardized meta object
|
|
167
|
+
meta = {
|
|
168
|
+
"prompt_tokens": prompt_tokens,
|
|
169
|
+
"completion_tokens": completion_tokens,
|
|
170
|
+
"total_tokens": total_tokens,
|
|
171
|
+
"cost": round(total_cost, 6),
|
|
172
|
+
"raw_response": resp,
|
|
173
|
+
"model_name": model,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
text = resp["choices"][0]["message"]["content"]
|
|
177
|
+
return {"text": text, "meta": meta}
|
|
178
|
+
|
|
179
|
+
# ------------------------------------------------------------------
|
|
180
|
+
# Tool use
|
|
181
|
+
# ------------------------------------------------------------------
|
|
182
|
+
|
|
183
|
+
def generate_messages_with_tools(
|
|
184
|
+
self,
|
|
185
|
+
messages: list[dict[str, Any]],
|
|
186
|
+
tools: list[dict[str, Any]],
|
|
187
|
+
options: dict[str, Any],
|
|
188
|
+
) -> dict[str, Any]:
|
|
189
|
+
"""Generate a response that may include tool calls."""
|
|
190
|
+
if not self.api_key:
|
|
191
|
+
raise RuntimeError("OpenRouter API key not found")
|
|
192
|
+
|
|
193
|
+
model = options.get("model", self.model)
|
|
194
|
+
model_config = self._get_model_config("openrouter", model)
|
|
195
|
+
tokens_param = model_config["tokens_param"]
|
|
196
|
+
supports_temperature = model_config["supports_temperature"]
|
|
197
|
+
|
|
198
|
+
self._validate_model_capabilities("openrouter", model, using_tool_use=True)
|
|
199
|
+
|
|
200
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
201
|
+
|
|
202
|
+
data: dict[str, Any] = {
|
|
203
|
+
"model": model,
|
|
204
|
+
"messages": messages,
|
|
205
|
+
"tools": tools,
|
|
206
|
+
}
|
|
207
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
208
|
+
|
|
209
|
+
if supports_temperature and "temperature" in opts:
|
|
210
|
+
data["temperature"] = opts["temperature"]
|
|
112
211
|
|
|
113
212
|
try:
|
|
114
213
|
response = requests.post(
|
|
115
214
|
f"{self.base_url}/chat/completions",
|
|
116
215
|
headers=self.headers,
|
|
117
216
|
json=data,
|
|
217
|
+
timeout=120,
|
|
118
218
|
)
|
|
119
219
|
response.raise_for_status()
|
|
120
220
|
resp = response.json()
|
|
221
|
+
except requests.exceptions.HTTPError as e:
|
|
222
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
223
|
+
raise RuntimeError(error_msg) from e
|
|
224
|
+
except requests.exceptions.RequestException as e:
|
|
225
|
+
raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
|
|
226
|
+
|
|
227
|
+
usage = resp.get("usage", {})
|
|
228
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
229
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
230
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
231
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
232
|
+
|
|
233
|
+
meta = {
|
|
234
|
+
"prompt_tokens": prompt_tokens,
|
|
235
|
+
"completion_tokens": completion_tokens,
|
|
236
|
+
"total_tokens": total_tokens,
|
|
237
|
+
"cost": round(total_cost, 6),
|
|
238
|
+
"raw_response": resp,
|
|
239
|
+
"model_name": model,
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
choice = resp["choices"][0]
|
|
243
|
+
text = choice["message"].get("content") or ""
|
|
244
|
+
stop_reason = choice.get("finish_reason")
|
|
245
|
+
|
|
246
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
247
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
248
|
+
try:
|
|
249
|
+
args = json.loads(tc["function"]["arguments"])
|
|
250
|
+
except (json.JSONDecodeError, TypeError):
|
|
251
|
+
args = {}
|
|
252
|
+
tool_calls_out.append({
|
|
253
|
+
"id": tc["id"],
|
|
254
|
+
"name": tc["function"]["name"],
|
|
255
|
+
"arguments": args,
|
|
256
|
+
})
|
|
257
|
+
|
|
258
|
+
return {
|
|
259
|
+
"text": text,
|
|
260
|
+
"meta": meta,
|
|
261
|
+
"tool_calls": tool_calls_out,
|
|
262
|
+
"stop_reason": stop_reason,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# ------------------------------------------------------------------
|
|
266
|
+
# Streaming
|
|
267
|
+
# ------------------------------------------------------------------
|
|
268
|
+
|
|
269
|
+
def generate_messages_stream(
|
|
270
|
+
self,
|
|
271
|
+
messages: list[dict[str, Any]],
|
|
272
|
+
options: dict[str, Any],
|
|
273
|
+
) -> Iterator[dict[str, Any]]:
|
|
274
|
+
"""Yield response chunks via OpenRouter streaming API."""
|
|
275
|
+
if not self.api_key:
|
|
276
|
+
raise RuntimeError("OpenRouter API key not found")
|
|
277
|
+
|
|
278
|
+
model = options.get("model", self.model)
|
|
279
|
+
model_config = self._get_model_config("openrouter", model)
|
|
280
|
+
tokens_param = model_config["tokens_param"]
|
|
281
|
+
supports_temperature = model_config["supports_temperature"]
|
|
121
282
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
283
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
284
|
+
|
|
285
|
+
data: dict[str, Any] = {
|
|
286
|
+
"model": model,
|
|
287
|
+
"messages": messages,
|
|
288
|
+
"stream": True,
|
|
289
|
+
"stream_options": {"include_usage": True},
|
|
290
|
+
}
|
|
291
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
292
|
+
|
|
293
|
+
if supports_temperature and "temperature" in opts:
|
|
294
|
+
data["temperature"] = opts["temperature"]
|
|
295
|
+
|
|
296
|
+
response = requests.post(
|
|
297
|
+
f"{self.base_url}/chat/completions",
|
|
298
|
+
headers=self.headers,
|
|
299
|
+
json=data,
|
|
300
|
+
stream=True,
|
|
301
|
+
timeout=120,
|
|
302
|
+
)
|
|
303
|
+
response.raise_for_status()
|
|
127
304
|
|
|
128
|
-
|
|
129
|
-
|
|
305
|
+
full_text = ""
|
|
306
|
+
prompt_tokens = 0
|
|
307
|
+
completion_tokens = 0
|
|
130
308
|
|
|
131
|
-
|
|
132
|
-
|
|
309
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
310
|
+
if not line or not line.startswith("data: "):
|
|
311
|
+
continue
|
|
312
|
+
payload = line[len("data: "):]
|
|
313
|
+
if payload.strip() == "[DONE]":
|
|
314
|
+
break
|
|
315
|
+
try:
|
|
316
|
+
chunk = json.loads(payload)
|
|
317
|
+
except json.JSONDecodeError:
|
|
318
|
+
continue
|
|
319
|
+
|
|
320
|
+
# Usage comes in the final chunk
|
|
321
|
+
usage = chunk.get("usage")
|
|
322
|
+
if usage:
|
|
323
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
324
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
325
|
+
|
|
326
|
+
choices = chunk.get("choices", [])
|
|
327
|
+
if choices:
|
|
328
|
+
delta = choices[0].get("delta", {})
|
|
329
|
+
content = delta.get("content", "")
|
|
330
|
+
if content:
|
|
331
|
+
full_text += content
|
|
332
|
+
yield {"type": "delta", "text": content}
|
|
333
|
+
|
|
334
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
335
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
336
|
+
|
|
337
|
+
yield {
|
|
338
|
+
"type": "done",
|
|
339
|
+
"text": full_text,
|
|
340
|
+
"meta": {
|
|
133
341
|
"prompt_tokens": prompt_tokens,
|
|
134
342
|
"completion_tokens": completion_tokens,
|
|
135
343
|
"total_tokens": total_tokens,
|
|
136
344
|
"cost": round(total_cost, 6),
|
|
137
|
-
"raw_response":
|
|
345
|
+
"raw_response": {},
|
|
138
346
|
"model_name": model,
|
|
139
|
-
}
|
|
140
|
-
|
|
141
|
-
text = resp["choices"][0]["message"]["content"]
|
|
142
|
-
return {"text": text, "meta": meta}
|
|
143
|
-
|
|
144
|
-
except requests.exceptions.RequestException as e:
|
|
145
|
-
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
146
|
-
if hasattr(e.response, "json"):
|
|
147
|
-
try:
|
|
148
|
-
error_details = e.response.json()
|
|
149
|
-
error_msg = f"{error_msg} - {error_details.get('error', {}).get('message', '')}"
|
|
150
|
-
except Exception:
|
|
151
|
-
pass
|
|
152
|
-
raise RuntimeError(error_msg) from e
|
|
347
|
+
},
|
|
348
|
+
}
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""Z.ai (Zhipu AI) driver implementation.
|
|
2
|
+
Requires the `requests` package. Uses ZHIPU_API_KEY env var.
|
|
3
|
+
|
|
4
|
+
The Z.ai API is fully OpenAI-compatible (/chat/completions).
|
|
5
|
+
All pricing comes from models.dev (provider: "zai") — no hardcoded pricing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import requests
|
|
14
|
+
|
|
15
|
+
from ..cost_mixin import CostMixin
|
|
16
|
+
from ..driver import Driver
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ZaiDriver(CostMixin, Driver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
supports_tool_use = True
|
|
23
|
+
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
25
|
+
|
|
26
|
+
# All pricing resolved live from models.dev (provider: "zai")
|
|
27
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str | None = None,
|
|
32
|
+
model: str = "glm-4.7",
|
|
33
|
+
endpoint: str = "https://api.z.ai/api/paas/v4",
|
|
34
|
+
):
|
|
35
|
+
"""Initialize Z.ai driver.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
api_key: Zhipu API key. If not provided, will look for ZHIPU_API_KEY env var.
|
|
39
|
+
model: Model to use. Defaults to glm-4.7.
|
|
40
|
+
endpoint: API base URL. Defaults to https://api.z.ai/api/paas/v4.
|
|
41
|
+
"""
|
|
42
|
+
self.api_key = api_key or os.getenv("ZHIPU_API_KEY")
|
|
43
|
+
if not self.api_key:
|
|
44
|
+
raise ValueError("Zhipu API key not found. Set ZHIPU_API_KEY env var.")
|
|
45
|
+
|
|
46
|
+
self.model = model
|
|
47
|
+
self.base_url = endpoint.rstrip("/")
|
|
48
|
+
|
|
49
|
+
self.headers = {
|
|
50
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
51
|
+
"Content-Type": "application/json",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
supports_messages = True
|
|
55
|
+
|
|
56
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
57
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
58
|
+
|
|
59
|
+
return _prepare_openai_vision_messages(messages)
|
|
60
|
+
|
|
61
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
messages = [{"role": "user", "content": prompt}]
|
|
63
|
+
return self._do_generate(messages, options)
|
|
64
|
+
|
|
65
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
67
|
+
|
|
68
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
69
|
+
if not self.api_key:
|
|
70
|
+
raise RuntimeError("Zhipu API key not found")
|
|
71
|
+
|
|
72
|
+
model = options.get("model", self.model)
|
|
73
|
+
|
|
74
|
+
model_config = self._get_model_config("zai", model)
|
|
75
|
+
tokens_param = model_config["tokens_param"]
|
|
76
|
+
supports_temperature = model_config["supports_temperature"]
|
|
77
|
+
|
|
78
|
+
self._validate_model_capabilities(
|
|
79
|
+
"zai",
|
|
80
|
+
model,
|
|
81
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
85
|
+
|
|
86
|
+
data: dict[str, Any] = {
|
|
87
|
+
"model": model,
|
|
88
|
+
"messages": messages,
|
|
89
|
+
}
|
|
90
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
91
|
+
|
|
92
|
+
if supports_temperature and "temperature" in opts:
|
|
93
|
+
data["temperature"] = opts["temperature"]
|
|
94
|
+
|
|
95
|
+
# Native JSON mode support
|
|
96
|
+
if options.get("json_mode"):
|
|
97
|
+
json_schema = options.get("json_schema")
|
|
98
|
+
if json_schema:
|
|
99
|
+
data["response_format"] = {
|
|
100
|
+
"type": "json_schema",
|
|
101
|
+
"json_schema": {
|
|
102
|
+
"name": "extraction",
|
|
103
|
+
"strict": True,
|
|
104
|
+
"schema": json_schema,
|
|
105
|
+
},
|
|
106
|
+
}
|
|
107
|
+
else:
|
|
108
|
+
data["response_format"] = {"type": "json_object"}
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
response = requests.post(
|
|
112
|
+
f"{self.base_url}/chat/completions",
|
|
113
|
+
headers=self.headers,
|
|
114
|
+
json=data,
|
|
115
|
+
timeout=120,
|
|
116
|
+
)
|
|
117
|
+
response.raise_for_status()
|
|
118
|
+
resp = response.json()
|
|
119
|
+
except requests.exceptions.HTTPError as e:
|
|
120
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
121
|
+
raise RuntimeError(error_msg) from e
|
|
122
|
+
except requests.exceptions.RequestException as e:
|
|
123
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
124
|
+
|
|
125
|
+
usage = resp.get("usage", {})
|
|
126
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
127
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
128
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
129
|
+
|
|
130
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
131
|
+
|
|
132
|
+
meta = {
|
|
133
|
+
"prompt_tokens": prompt_tokens,
|
|
134
|
+
"completion_tokens": completion_tokens,
|
|
135
|
+
"total_tokens": total_tokens,
|
|
136
|
+
"cost": round(total_cost, 6),
|
|
137
|
+
"raw_response": resp,
|
|
138
|
+
"model_name": model,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
text = resp["choices"][0]["message"]["content"]
|
|
142
|
+
return {"text": text, "meta": meta}
|
|
143
|
+
|
|
144
|
+
# ------------------------------------------------------------------
|
|
145
|
+
# Tool use
|
|
146
|
+
# ------------------------------------------------------------------
|
|
147
|
+
|
|
148
|
+
def generate_messages_with_tools(
|
|
149
|
+
self,
|
|
150
|
+
messages: list[dict[str, Any]],
|
|
151
|
+
tools: list[dict[str, Any]],
|
|
152
|
+
options: dict[str, Any],
|
|
153
|
+
) -> dict[str, Any]:
|
|
154
|
+
"""Generate a response that may include tool calls."""
|
|
155
|
+
if not self.api_key:
|
|
156
|
+
raise RuntimeError("Zhipu API key not found")
|
|
157
|
+
|
|
158
|
+
model = options.get("model", self.model)
|
|
159
|
+
model_config = self._get_model_config("zai", model)
|
|
160
|
+
tokens_param = model_config["tokens_param"]
|
|
161
|
+
supports_temperature = model_config["supports_temperature"]
|
|
162
|
+
|
|
163
|
+
self._validate_model_capabilities("zai", model, using_tool_use=True)
|
|
164
|
+
|
|
165
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
166
|
+
|
|
167
|
+
data: dict[str, Any] = {
|
|
168
|
+
"model": model,
|
|
169
|
+
"messages": messages,
|
|
170
|
+
"tools": tools,
|
|
171
|
+
}
|
|
172
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
173
|
+
|
|
174
|
+
if supports_temperature and "temperature" in opts:
|
|
175
|
+
data["temperature"] = opts["temperature"]
|
|
176
|
+
|
|
177
|
+
if "tool_choice" in options:
|
|
178
|
+
data["tool_choice"] = options["tool_choice"]
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
response = requests.post(
|
|
182
|
+
f"{self.base_url}/chat/completions",
|
|
183
|
+
headers=self.headers,
|
|
184
|
+
json=data,
|
|
185
|
+
timeout=120,
|
|
186
|
+
)
|
|
187
|
+
response.raise_for_status()
|
|
188
|
+
resp = response.json()
|
|
189
|
+
except requests.exceptions.HTTPError as e:
|
|
190
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
191
|
+
raise RuntimeError(error_msg) from e
|
|
192
|
+
except requests.exceptions.RequestException as e:
|
|
193
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
194
|
+
|
|
195
|
+
usage = resp.get("usage", {})
|
|
196
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
197
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
198
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
199
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
200
|
+
|
|
201
|
+
meta = {
|
|
202
|
+
"prompt_tokens": prompt_tokens,
|
|
203
|
+
"completion_tokens": completion_tokens,
|
|
204
|
+
"total_tokens": total_tokens,
|
|
205
|
+
"cost": round(total_cost, 6),
|
|
206
|
+
"raw_response": resp,
|
|
207
|
+
"model_name": model,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
choice = resp["choices"][0]
|
|
211
|
+
text = choice["message"].get("content") or ""
|
|
212
|
+
stop_reason = choice.get("finish_reason")
|
|
213
|
+
|
|
214
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
215
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
216
|
+
try:
|
|
217
|
+
args = json.loads(tc["function"]["arguments"])
|
|
218
|
+
except (json.JSONDecodeError, TypeError):
|
|
219
|
+
args = {}
|
|
220
|
+
tool_calls_out.append(
|
|
221
|
+
{
|
|
222
|
+
"id": tc["id"],
|
|
223
|
+
"name": tc["function"]["name"],
|
|
224
|
+
"arguments": args,
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return {
|
|
229
|
+
"text": text,
|
|
230
|
+
"meta": meta,
|
|
231
|
+
"tool_calls": tool_calls_out,
|
|
232
|
+
"stop_reason": stop_reason,
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
# ------------------------------------------------------------------
|
|
236
|
+
# Streaming
|
|
237
|
+
# ------------------------------------------------------------------
|
|
238
|
+
|
|
239
|
+
def generate_messages_stream(
|
|
240
|
+
self,
|
|
241
|
+
messages: list[dict[str, Any]],
|
|
242
|
+
options: dict[str, Any],
|
|
243
|
+
) -> Iterator[dict[str, Any]]:
|
|
244
|
+
"""Yield response chunks via Z.ai streaming API."""
|
|
245
|
+
if not self.api_key:
|
|
246
|
+
raise RuntimeError("Zhipu API key not found")
|
|
247
|
+
|
|
248
|
+
model = options.get("model", self.model)
|
|
249
|
+
model_config = self._get_model_config("zai", model)
|
|
250
|
+
tokens_param = model_config["tokens_param"]
|
|
251
|
+
supports_temperature = model_config["supports_temperature"]
|
|
252
|
+
|
|
253
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
254
|
+
|
|
255
|
+
data: dict[str, Any] = {
|
|
256
|
+
"model": model,
|
|
257
|
+
"messages": messages,
|
|
258
|
+
"stream": True,
|
|
259
|
+
"stream_options": {"include_usage": True},
|
|
260
|
+
}
|
|
261
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
262
|
+
|
|
263
|
+
if supports_temperature and "temperature" in opts:
|
|
264
|
+
data["temperature"] = opts["temperature"]
|
|
265
|
+
|
|
266
|
+
response = requests.post(
|
|
267
|
+
f"{self.base_url}/chat/completions",
|
|
268
|
+
headers=self.headers,
|
|
269
|
+
json=data,
|
|
270
|
+
stream=True,
|
|
271
|
+
timeout=120,
|
|
272
|
+
)
|
|
273
|
+
response.raise_for_status()
|
|
274
|
+
|
|
275
|
+
full_text = ""
|
|
276
|
+
prompt_tokens = 0
|
|
277
|
+
completion_tokens = 0
|
|
278
|
+
|
|
279
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
280
|
+
if not line or not line.startswith("data: "):
|
|
281
|
+
continue
|
|
282
|
+
payload = line[len("data: ") :]
|
|
283
|
+
if payload.strip() == "[DONE]":
|
|
284
|
+
break
|
|
285
|
+
try:
|
|
286
|
+
chunk = json.loads(payload)
|
|
287
|
+
except json.JSONDecodeError:
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
usage = chunk.get("usage")
|
|
291
|
+
if usage:
|
|
292
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
293
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
294
|
+
|
|
295
|
+
choices = chunk.get("choices", [])
|
|
296
|
+
if choices:
|
|
297
|
+
delta = choices[0].get("delta", {})
|
|
298
|
+
content = delta.get("content", "")
|
|
299
|
+
if content:
|
|
300
|
+
full_text += content
|
|
301
|
+
yield {"type": "delta", "text": content}
|
|
302
|
+
|
|
303
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
304
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
305
|
+
|
|
306
|
+
yield {
|
|
307
|
+
"type": "done",
|
|
308
|
+
"text": full_text,
|
|
309
|
+
"meta": {
|
|
310
|
+
"prompt_tokens": prompt_tokens,
|
|
311
|
+
"completion_tokens": completion_tokens,
|
|
312
|
+
"total_tokens": total_tokens,
|
|
313
|
+
"cost": round(total_cost, 6),
|
|
314
|
+
"raw_response": {},
|
|
315
|
+
"model_name": model,
|
|
316
|
+
},
|
|
317
|
+
}
|
prompture/model_rates.py
CHANGED
prompture/settings.py
CHANGED
|
@@ -51,6 +51,21 @@ class Settings(BaseSettings):
|
|
|
51
51
|
grok_api_key: Optional[str] = None
|
|
52
52
|
grok_model: str = "grok-4-fast-reasoning"
|
|
53
53
|
|
|
54
|
+
# Moonshot AI (Kimi)
|
|
55
|
+
moonshot_api_key: Optional[str] = None
|
|
56
|
+
moonshot_model: str = "kimi-k2-0905-preview"
|
|
57
|
+
moonshot_endpoint: str = "https://api.moonshot.ai/v1"
|
|
58
|
+
|
|
59
|
+
# Z.ai (Zhipu AI)
|
|
60
|
+
zhipu_api_key: Optional[str] = None
|
|
61
|
+
zhipu_model: str = "glm-4.7"
|
|
62
|
+
zhipu_endpoint: str = "https://api.z.ai/api/paas/v4"
|
|
63
|
+
|
|
64
|
+
# ModelScope (Alibaba Cloud)
|
|
65
|
+
modelscope_api_key: Optional[str] = None
|
|
66
|
+
modelscope_model: str = "Qwen/Qwen3-235B-A22B-Instruct-2507"
|
|
67
|
+
modelscope_endpoint: str = "https://api-inference.modelscope.cn/v1"
|
|
68
|
+
|
|
54
69
|
# AirLLM
|
|
55
70
|
airllm_model: str = "meta-llama/Llama-2-7b-hf"
|
|
56
71
|
airllm_compression: Optional[str] = None # "4bit" or "8bit"
|