prompture 0.0.38.dev2__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. prompture/__init__.py +12 -1
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +11 -11
  4. prompture/async_agent.py +11 -11
  5. prompture/async_conversation.py +9 -0
  6. prompture/async_core.py +16 -0
  7. prompture/async_driver.py +39 -0
  8. prompture/async_groups.py +63 -0
  9. prompture/conversation.py +9 -0
  10. prompture/core.py +16 -0
  11. prompture/cost_mixin.py +62 -0
  12. prompture/discovery.py +108 -43
  13. prompture/driver.py +39 -0
  14. prompture/drivers/__init__.py +39 -0
  15. prompture/drivers/async_azure_driver.py +7 -6
  16. prompture/drivers/async_claude_driver.py +177 -8
  17. prompture/drivers/async_google_driver.py +10 -0
  18. prompture/drivers/async_grok_driver.py +4 -4
  19. prompture/drivers/async_groq_driver.py +4 -4
  20. prompture/drivers/async_modelscope_driver.py +286 -0
  21. prompture/drivers/async_moonshot_driver.py +312 -0
  22. prompture/drivers/async_openai_driver.py +158 -6
  23. prompture/drivers/async_openrouter_driver.py +196 -7
  24. prompture/drivers/async_registry.py +30 -0
  25. prompture/drivers/async_zai_driver.py +303 -0
  26. prompture/drivers/azure_driver.py +6 -5
  27. prompture/drivers/claude_driver.py +10 -0
  28. prompture/drivers/google_driver.py +10 -0
  29. prompture/drivers/grok_driver.py +4 -4
  30. prompture/drivers/groq_driver.py +4 -4
  31. prompture/drivers/modelscope_driver.py +303 -0
  32. prompture/drivers/moonshot_driver.py +342 -0
  33. prompture/drivers/openai_driver.py +22 -12
  34. prompture/drivers/openrouter_driver.py +248 -44
  35. prompture/drivers/zai_driver.py +318 -0
  36. prompture/groups.py +42 -0
  37. prompture/ledger.py +252 -0
  38. prompture/model_rates.py +114 -2
  39. prompture/settings.py +16 -1
  40. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/METADATA +1 -1
  41. prompture-0.0.42.dist-info/RECORD +84 -0
  42. prompture-0.0.38.dev2.dist-info/RECORD +0 -77
  43. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/WHEEL +0 -0
  44. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/entry_points.txt +0 -0
  45. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/licenses/LICENSE +0 -0
  46. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,303 @@
1
+ """Async Z.ai (Zhipu AI) driver using httpx.
2
+
3
+ All pricing comes from models.dev (provider: "zai") — no hardcoded pricing.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import os
10
+ from collections.abc import AsyncIterator
11
+ from typing import Any
12
+
13
+ import httpx
14
+
15
+ from ..async_driver import AsyncDriver
16
+ from ..cost_mixin import CostMixin, prepare_strict_schema
17
+ from .zai_driver import ZaiDriver
18
+
19
+
20
+ class AsyncZaiDriver(CostMixin, AsyncDriver):
21
+ supports_json_mode = True
22
+ supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
25
+ supports_vision = True
26
+
27
+ MODEL_PRICING = ZaiDriver.MODEL_PRICING
28
+
29
+ def __init__(
30
+ self,
31
+ api_key: str | None = None,
32
+ model: str = "glm-4.7",
33
+ endpoint: str = "https://api.z.ai/api/paas/v4",
34
+ ):
35
+ self.api_key = api_key or os.getenv("ZHIPU_API_KEY")
36
+ if not self.api_key:
37
+ raise ValueError("Zhipu API key not found. Set ZHIPU_API_KEY env var.")
38
+ self.model = model
39
+ self.base_url = endpoint.rstrip("/")
40
+ self.headers = {
41
+ "Authorization": f"Bearer {self.api_key}",
42
+ "Content-Type": "application/json",
43
+ }
44
+
45
+ supports_messages = True
46
+
47
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
48
+ from .vision_helpers import _prepare_openai_vision_messages
49
+
50
+ return _prepare_openai_vision_messages(messages)
51
+
52
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
53
+ messages = [{"role": "user", "content": prompt}]
54
+ return await self._do_generate(messages, options)
55
+
56
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
57
+ return await self._do_generate(self._prepare_messages(messages), options)
58
+
59
+ async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
60
+ model = options.get("model", self.model)
61
+
62
+ model_config = self._get_model_config("zai", model)
63
+ tokens_param = model_config["tokens_param"]
64
+ supports_temperature = model_config["supports_temperature"]
65
+
66
+ self._validate_model_capabilities(
67
+ "zai",
68
+ model,
69
+ using_json_schema=bool(options.get("json_schema")),
70
+ )
71
+
72
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
73
+
74
+ data: dict[str, Any] = {
75
+ "model": model,
76
+ "messages": messages,
77
+ }
78
+ data[tokens_param] = opts.get("max_tokens", 512)
79
+
80
+ if supports_temperature and "temperature" in opts:
81
+ data["temperature"] = opts["temperature"]
82
+
83
+ if options.get("json_mode"):
84
+ json_schema = options.get("json_schema")
85
+ if json_schema:
86
+ schema_copy = prepare_strict_schema(json_schema)
87
+ data["response_format"] = {
88
+ "type": "json_schema",
89
+ "json_schema": {
90
+ "name": "extraction",
91
+ "strict": True,
92
+ "schema": schema_copy,
93
+ },
94
+ }
95
+ else:
96
+ data["response_format"] = {"type": "json_object"}
97
+
98
+ async with httpx.AsyncClient() as client:
99
+ try:
100
+ response = await client.post(
101
+ f"{self.base_url}/chat/completions",
102
+ headers=self.headers,
103
+ json=data,
104
+ timeout=120,
105
+ )
106
+ response.raise_for_status()
107
+ resp = response.json()
108
+ except httpx.HTTPStatusError as e:
109
+ error_msg = f"Z.ai API request failed: {e!s}"
110
+ raise RuntimeError(error_msg) from e
111
+ except Exception as e:
112
+ raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
113
+
114
+ usage = resp.get("usage", {})
115
+ prompt_tokens = usage.get("prompt_tokens", 0)
116
+ completion_tokens = usage.get("completion_tokens", 0)
117
+ total_tokens = usage.get("total_tokens", 0)
118
+
119
+ total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
120
+
121
+ meta = {
122
+ "prompt_tokens": prompt_tokens,
123
+ "completion_tokens": completion_tokens,
124
+ "total_tokens": total_tokens,
125
+ "cost": round(total_cost, 6),
126
+ "raw_response": resp,
127
+ "model_name": model,
128
+ }
129
+
130
+ text = resp["choices"][0]["message"]["content"]
131
+ return {"text": text, "meta": meta}
132
+
133
+ # ------------------------------------------------------------------
134
+ # Tool use
135
+ # ------------------------------------------------------------------
136
+
137
+ async def generate_messages_with_tools(
138
+ self,
139
+ messages: list[dict[str, Any]],
140
+ tools: list[dict[str, Any]],
141
+ options: dict[str, Any],
142
+ ) -> dict[str, Any]:
143
+ """Generate a response that may include tool calls."""
144
+ model = options.get("model", self.model)
145
+ model_config = self._get_model_config("zai", model)
146
+ tokens_param = model_config["tokens_param"]
147
+ supports_temperature = model_config["supports_temperature"]
148
+
149
+ self._validate_model_capabilities("zai", model, using_tool_use=True)
150
+
151
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
152
+
153
+ data: dict[str, Any] = {
154
+ "model": model,
155
+ "messages": messages,
156
+ "tools": tools,
157
+ }
158
+ data[tokens_param] = opts.get("max_tokens", 512)
159
+
160
+ if supports_temperature and "temperature" in opts:
161
+ data["temperature"] = opts["temperature"]
162
+
163
+ if "tool_choice" in options:
164
+ data["tool_choice"] = options["tool_choice"]
165
+
166
+ async with httpx.AsyncClient() as client:
167
+ try:
168
+ response = await client.post(
169
+ f"{self.base_url}/chat/completions",
170
+ headers=self.headers,
171
+ json=data,
172
+ timeout=120,
173
+ )
174
+ response.raise_for_status()
175
+ resp = response.json()
176
+ except httpx.HTTPStatusError as e:
177
+ error_msg = f"Z.ai API request failed: {e!s}"
178
+ raise RuntimeError(error_msg) from e
179
+ except Exception as e:
180
+ raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
181
+
182
+ usage = resp.get("usage", {})
183
+ prompt_tokens = usage.get("prompt_tokens", 0)
184
+ completion_tokens = usage.get("completion_tokens", 0)
185
+ total_tokens = usage.get("total_tokens", 0)
186
+ total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
187
+
188
+ meta = {
189
+ "prompt_tokens": prompt_tokens,
190
+ "completion_tokens": completion_tokens,
191
+ "total_tokens": total_tokens,
192
+ "cost": round(total_cost, 6),
193
+ "raw_response": resp,
194
+ "model_name": model,
195
+ }
196
+
197
+ choice = resp["choices"][0]
198
+ text = choice["message"].get("content") or ""
199
+ stop_reason = choice.get("finish_reason")
200
+
201
+ tool_calls_out: list[dict[str, Any]] = []
202
+ for tc in choice["message"].get("tool_calls", []):
203
+ try:
204
+ args = json.loads(tc["function"]["arguments"])
205
+ except (json.JSONDecodeError, TypeError):
206
+ args = {}
207
+ tool_calls_out.append(
208
+ {
209
+ "id": tc["id"],
210
+ "name": tc["function"]["name"],
211
+ "arguments": args,
212
+ }
213
+ )
214
+
215
+ return {
216
+ "text": text,
217
+ "meta": meta,
218
+ "tool_calls": tool_calls_out,
219
+ "stop_reason": stop_reason,
220
+ }
221
+
222
+ # ------------------------------------------------------------------
223
+ # Streaming
224
+ # ------------------------------------------------------------------
225
+
226
+ async def generate_messages_stream(
227
+ self,
228
+ messages: list[dict[str, Any]],
229
+ options: dict[str, Any],
230
+ ) -> AsyncIterator[dict[str, Any]]:
231
+ """Yield response chunks via Z.ai streaming API."""
232
+ model = options.get("model", self.model)
233
+ model_config = self._get_model_config("zai", model)
234
+ tokens_param = model_config["tokens_param"]
235
+ supports_temperature = model_config["supports_temperature"]
236
+
237
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
238
+
239
+ data: dict[str, Any] = {
240
+ "model": model,
241
+ "messages": messages,
242
+ "stream": True,
243
+ "stream_options": {"include_usage": True},
244
+ }
245
+ data[tokens_param] = opts.get("max_tokens", 512)
246
+
247
+ if supports_temperature and "temperature" in opts:
248
+ data["temperature"] = opts["temperature"]
249
+
250
+ full_text = ""
251
+ prompt_tokens = 0
252
+ completion_tokens = 0
253
+
254
+ async with (
255
+ httpx.AsyncClient() as client,
256
+ client.stream(
257
+ "POST",
258
+ f"{self.base_url}/chat/completions",
259
+ headers=self.headers,
260
+ json=data,
261
+ timeout=120,
262
+ ) as response,
263
+ ):
264
+ response.raise_for_status()
265
+ async for line in response.aiter_lines():
266
+ if not line or not line.startswith("data: "):
267
+ continue
268
+ payload = line[len("data: ") :]
269
+ if payload.strip() == "[DONE]":
270
+ break
271
+ try:
272
+ chunk = json.loads(payload)
273
+ except json.JSONDecodeError:
274
+ continue
275
+
276
+ usage = chunk.get("usage")
277
+ if usage:
278
+ prompt_tokens = usage.get("prompt_tokens", 0)
279
+ completion_tokens = usage.get("completion_tokens", 0)
280
+
281
+ choices = chunk.get("choices", [])
282
+ if choices:
283
+ delta = choices[0].get("delta", {})
284
+ content = delta.get("content", "")
285
+ if content:
286
+ full_text += content
287
+ yield {"type": "delta", "text": content}
288
+
289
+ total_tokens = prompt_tokens + completion_tokens
290
+ total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
291
+
292
+ yield {
293
+ "type": "done",
294
+ "text": full_text,
295
+ "meta": {
296
+ "prompt_tokens": prompt_tokens,
297
+ "completion_tokens": completion_tokens,
298
+ "total_tokens": total_tokens,
299
+ "cost": round(total_cost, 6),
300
+ "raw_response": {},
301
+ "model_name": model,
302
+ },
303
+ }
@@ -10,7 +10,7 @@ try:
10
10
  except Exception:
11
11
  AzureOpenAI = None
12
12
 
13
- from ..cost_mixin import CostMixin
13
+ from ..cost_mixin import CostMixin, prepare_strict_schema
14
14
  from ..driver import Driver
15
15
 
16
16
 
@@ -108,9 +108,9 @@ class AzureDriver(CostMixin, Driver):
108
108
  raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
109
109
 
110
110
  model = options.get("model", self.model)
111
- model_info = self.MODEL_PRICING.get(model, {})
112
- tokens_param = model_info.get("tokens_param", "max_tokens")
113
- supports_temperature = model_info.get("supports_temperature", True)
111
+ model_config = self._get_model_config("azure", model)
112
+ tokens_param = model_config["tokens_param"]
113
+ supports_temperature = model_config["supports_temperature"]
114
114
 
115
115
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
116
116
 
@@ -128,12 +128,13 @@ class AzureDriver(CostMixin, Driver):
128
128
  if options.get("json_mode"):
129
129
  json_schema = options.get("json_schema")
130
130
  if json_schema:
131
+ schema_copy = prepare_strict_schema(json_schema)
131
132
  kwargs["response_format"] = {
132
133
  "type": "json_schema",
133
134
  "json_schema": {
134
135
  "name": "extraction",
135
136
  "strict": True,
136
- "schema": json_schema,
137
+ "schema": schema_copy,
137
138
  },
138
139
  }
139
140
  else:
@@ -77,6 +77,13 @@ class ClaudeDriver(CostMixin, Driver):
77
77
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
78
78
  model = options.get("model", self.model)
79
79
 
80
+ # Validate capabilities against models.dev metadata
81
+ self._validate_model_capabilities(
82
+ "claude",
83
+ model,
84
+ using_json_schema=bool(options.get("json_schema")),
85
+ )
86
+
80
87
  client = anthropic.Anthropic(api_key=self.api_key)
81
88
 
82
89
  # Anthropic requires system messages as a top-level parameter
@@ -177,6 +184,9 @@ class ClaudeDriver(CostMixin, Driver):
177
184
 
178
185
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
179
186
  model = options.get("model", self.model)
187
+
188
+ self._validate_model_capabilities("claude", model, using_tool_use=True)
189
+
180
190
  client = anthropic.Anthropic(api_key=self.api_key)
181
191
 
182
192
  system_content, api_messages = self._extract_system_and_messages(messages)
@@ -228,6 +228,13 @@ class GoogleDriver(CostMixin, Driver):
228
228
  def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
229
229
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
230
230
 
231
+ # Validate capabilities against models.dev metadata
232
+ self._validate_model_capabilities(
233
+ "google",
234
+ self.model,
235
+ using_json_schema=bool((options or {}).get("json_schema")),
236
+ )
237
+
231
238
  try:
232
239
  logger.debug(f"Initializing {self.model} for generation")
233
240
  model = genai.GenerativeModel(self.model, **model_kwargs)
@@ -263,6 +270,9 @@ class GoogleDriver(CostMixin, Driver):
263
270
  options: dict[str, Any],
264
271
  ) -> dict[str, Any]:
265
272
  """Generate a response that may include tool/function calls."""
273
+ model = options.get("model", self.model)
274
+ self._validate_model_capabilities("google", model, using_tool_use=True)
275
+
266
276
  gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
267
277
  self._prepare_messages(messages), options
268
278
  )
@@ -99,10 +99,10 @@ class GrokDriver(CostMixin, Driver):
99
99
 
100
100
  model = options.get("model", self.model)
101
101
 
102
- # Lookup model-specific config
103
- model_info = self.MODEL_PRICING.get(model, {})
104
- tokens_param = model_info.get("tokens_param", "max_tokens")
105
- supports_temperature = model_info.get("supports_temperature", True)
102
+ # Lookup model-specific config (live models.dev data + hardcoded fallback)
103
+ model_config = self._get_model_config("grok", model)
104
+ tokens_param = model_config["tokens_param"]
105
+ supports_temperature = model_config["supports_temperature"]
106
106
 
107
107
  # Defaults
108
108
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
@@ -69,10 +69,10 @@ class GroqDriver(CostMixin, Driver):
69
69
 
70
70
  model = options.get("model", self.model)
71
71
 
72
- # Lookup model-specific config
73
- model_info = self.MODEL_PRICING.get(model, {})
74
- tokens_param = model_info.get("tokens_param", "max_tokens")
75
- supports_temperature = model_info.get("supports_temperature", True)
72
+ # Lookup model-specific config (live models.dev data + hardcoded fallback)
73
+ model_config = self._get_model_config("groq", model)
74
+ tokens_param = model_config["tokens_param"]
75
+ supports_temperature = model_config["supports_temperature"]
76
76
 
77
77
  # Base configuration
78
78
  opts = {"temperature": 0.7, "max_tokens": 512, **options}