prompture 0.0.38.dev2__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. prompture/__init__.py +12 -1
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +11 -11
  4. prompture/async_agent.py +11 -11
  5. prompture/async_conversation.py +9 -0
  6. prompture/async_core.py +16 -0
  7. prompture/async_driver.py +39 -0
  8. prompture/async_groups.py +63 -0
  9. prompture/conversation.py +9 -0
  10. prompture/core.py +16 -0
  11. prompture/cost_mixin.py +62 -0
  12. prompture/discovery.py +108 -43
  13. prompture/driver.py +39 -0
  14. prompture/drivers/__init__.py +39 -0
  15. prompture/drivers/async_azure_driver.py +7 -6
  16. prompture/drivers/async_claude_driver.py +177 -8
  17. prompture/drivers/async_google_driver.py +10 -0
  18. prompture/drivers/async_grok_driver.py +4 -4
  19. prompture/drivers/async_groq_driver.py +4 -4
  20. prompture/drivers/async_modelscope_driver.py +286 -0
  21. prompture/drivers/async_moonshot_driver.py +312 -0
  22. prompture/drivers/async_openai_driver.py +158 -6
  23. prompture/drivers/async_openrouter_driver.py +196 -7
  24. prompture/drivers/async_registry.py +30 -0
  25. prompture/drivers/async_zai_driver.py +303 -0
  26. prompture/drivers/azure_driver.py +6 -5
  27. prompture/drivers/claude_driver.py +10 -0
  28. prompture/drivers/google_driver.py +10 -0
  29. prompture/drivers/grok_driver.py +4 -4
  30. prompture/drivers/groq_driver.py +4 -4
  31. prompture/drivers/modelscope_driver.py +303 -0
  32. prompture/drivers/moonshot_driver.py +342 -0
  33. prompture/drivers/openai_driver.py +22 -12
  34. prompture/drivers/openrouter_driver.py +248 -44
  35. prompture/drivers/zai_driver.py +318 -0
  36. prompture/groups.py +42 -0
  37. prompture/ledger.py +252 -0
  38. prompture/model_rates.py +114 -2
  39. prompture/settings.py +16 -1
  40. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/METADATA +1 -1
  41. prompture-0.0.42.dist-info/RECORD +84 -0
  42. prompture-0.0.38.dev2.dist-info/RECORD +0 -77
  43. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/WHEEL +0 -0
  44. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/entry_points.txt +0 -0
  45. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/licenses/LICENSE +0 -0
  46. {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
7
+ from collections.abc import AsyncIterator
6
8
  from typing import Any
7
9
 
8
10
  try:
@@ -11,13 +13,15 @@ except Exception:
11
13
  AsyncOpenAI = None
12
14
 
13
15
  from ..async_driver import AsyncDriver
14
- from ..cost_mixin import CostMixin
16
+ from ..cost_mixin import CostMixin, prepare_strict_schema
15
17
  from .openai_driver import OpenAIDriver
16
18
 
17
19
 
18
20
  class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
21
  supports_json_mode = True
20
22
  supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
21
25
  supports_vision = True
22
26
 
23
27
  MODEL_PRICING = OpenAIDriver.MODEL_PRICING
@@ -50,9 +54,16 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
50
54
 
51
55
  model = options.get("model", self.model)
52
56
 
53
- model_info = self.MODEL_PRICING.get(model, {})
54
- tokens_param = model_info.get("tokens_param", "max_tokens")
55
- supports_temperature = model_info.get("supports_temperature", True)
57
+ model_config = self._get_model_config("openai", model)
58
+ tokens_param = model_config["tokens_param"]
59
+ supports_temperature = model_config["supports_temperature"]
60
+
61
+ # Validate capabilities against models.dev metadata
62
+ self._validate_model_capabilities(
63
+ "openai",
64
+ model,
65
+ using_json_schema=bool(options.get("json_schema")),
66
+ )
56
67
 
57
68
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
58
69
 
@@ -69,12 +80,13 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
69
80
  if options.get("json_mode"):
70
81
  json_schema = options.get("json_schema")
71
82
  if json_schema:
83
+ schema_copy = prepare_strict_schema(json_schema)
72
84
  kwargs["response_format"] = {
73
85
  "type": "json_schema",
74
86
  "json_schema": {
75
87
  "name": "extraction",
76
88
  "strict": True,
77
- "schema": json_schema,
89
+ "schema": schema_copy,
78
90
  },
79
91
  }
80
92
  else:
@@ -93,10 +105,150 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
93
105
  "prompt_tokens": prompt_tokens,
94
106
  "completion_tokens": completion_tokens,
95
107
  "total_tokens": total_tokens,
96
- "cost": total_cost,
108
+ "cost": round(total_cost, 6),
97
109
  "raw_response": resp.model_dump(),
98
110
  "model_name": model,
99
111
  }
100
112
 
101
113
  text = resp.choices[0].message.content
102
114
  return {"text": text, "meta": meta}
115
+
116
+ # ------------------------------------------------------------------
117
+ # Tool use
118
+ # ------------------------------------------------------------------
119
+
120
+ async def generate_messages_with_tools(
121
+ self,
122
+ messages: list[dict[str, Any]],
123
+ tools: list[dict[str, Any]],
124
+ options: dict[str, Any],
125
+ ) -> dict[str, Any]:
126
+ """Generate a response that may include tool calls."""
127
+ if self.client is None:
128
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
129
+
130
+ model = options.get("model", self.model)
131
+ model_config = self._get_model_config("openai", model)
132
+ tokens_param = model_config["tokens_param"]
133
+ supports_temperature = model_config["supports_temperature"]
134
+
135
+ self._validate_model_capabilities("openai", model, using_tool_use=True)
136
+
137
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
138
+
139
+ kwargs: dict[str, Any] = {
140
+ "model": model,
141
+ "messages": messages,
142
+ "tools": tools,
143
+ }
144
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
145
+
146
+ if supports_temperature and "temperature" in opts:
147
+ kwargs["temperature"] = opts["temperature"]
148
+
149
+ resp = await self.client.chat.completions.create(**kwargs)
150
+
151
+ usage = getattr(resp, "usage", None)
152
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
153
+ completion_tokens = getattr(usage, "completion_tokens", 0)
154
+ total_tokens = getattr(usage, "total_tokens", 0)
155
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
156
+
157
+ meta = {
158
+ "prompt_tokens": prompt_tokens,
159
+ "completion_tokens": completion_tokens,
160
+ "total_tokens": total_tokens,
161
+ "cost": round(total_cost, 6),
162
+ "raw_response": resp.model_dump(),
163
+ "model_name": model,
164
+ }
165
+
166
+ choice = resp.choices[0]
167
+ text = choice.message.content or ""
168
+ stop_reason = choice.finish_reason
169
+
170
+ tool_calls_out: list[dict[str, Any]] = []
171
+ if choice.message.tool_calls:
172
+ for tc in choice.message.tool_calls:
173
+ try:
174
+ args = json.loads(tc.function.arguments)
175
+ except (json.JSONDecodeError, TypeError):
176
+ args = {}
177
+ tool_calls_out.append({
178
+ "id": tc.id,
179
+ "name": tc.function.name,
180
+ "arguments": args,
181
+ })
182
+
183
+ return {
184
+ "text": text,
185
+ "meta": meta,
186
+ "tool_calls": tool_calls_out,
187
+ "stop_reason": stop_reason,
188
+ }
189
+
190
+ # ------------------------------------------------------------------
191
+ # Streaming
192
+ # ------------------------------------------------------------------
193
+
194
+ async def generate_messages_stream(
195
+ self,
196
+ messages: list[dict[str, Any]],
197
+ options: dict[str, Any],
198
+ ) -> AsyncIterator[dict[str, Any]]:
199
+ """Yield response chunks via OpenAI streaming API."""
200
+ if self.client is None:
201
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
202
+
203
+ model = options.get("model", self.model)
204
+ model_config = self._get_model_config("openai", model)
205
+ tokens_param = model_config["tokens_param"]
206
+ supports_temperature = model_config["supports_temperature"]
207
+
208
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
209
+
210
+ kwargs: dict[str, Any] = {
211
+ "model": model,
212
+ "messages": messages,
213
+ "stream": True,
214
+ "stream_options": {"include_usage": True},
215
+ }
216
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
217
+
218
+ if supports_temperature and "temperature" in opts:
219
+ kwargs["temperature"] = opts["temperature"]
220
+
221
+ stream = await self.client.chat.completions.create(**kwargs)
222
+
223
+ full_text = ""
224
+ prompt_tokens = 0
225
+ completion_tokens = 0
226
+
227
+ async for chunk in stream:
228
+ # Usage comes in the final chunk
229
+ if getattr(chunk, "usage", None):
230
+ prompt_tokens = chunk.usage.prompt_tokens or 0
231
+ completion_tokens = chunk.usage.completion_tokens or 0
232
+
233
+ if chunk.choices:
234
+ delta = chunk.choices[0].delta
235
+ content = getattr(delta, "content", None) or ""
236
+ if content:
237
+ full_text += content
238
+ yield {"type": "delta", "text": content}
239
+
240
+ total_tokens = prompt_tokens + completion_tokens
241
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
242
+
243
+ yield {
244
+ "type": "done",
245
+ "text": full_text,
246
+ "meta": {
247
+ "prompt_tokens": prompt_tokens,
248
+ "completion_tokens": completion_tokens,
249
+ "total_tokens": total_tokens,
250
+ "cost": round(total_cost, 6),
251
+ "raw_response": {},
252
+ "model_name": model,
253
+ },
254
+ }
@@ -2,23 +2,28 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
7
+ from collections.abc import AsyncIterator
6
8
  from typing import Any
7
9
 
8
10
  import httpx
9
11
 
10
12
  from ..async_driver import AsyncDriver
11
- from ..cost_mixin import CostMixin
13
+ from ..cost_mixin import CostMixin, prepare_strict_schema
12
14
  from .openrouter_driver import OpenRouterDriver
13
15
 
14
16
 
15
17
  class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
16
18
  supports_json_mode = True
19
+ supports_json_schema = True
20
+ supports_tool_use = True
21
+ supports_streaming = True
17
22
  supports_vision = True
18
23
 
19
24
  MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
20
25
 
21
- def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
26
+ def __init__(self, api_key: str | None = None, model: str = "openai/gpt-4o-mini"):
22
27
  self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
23
28
  if not self.api_key:
24
29
  raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
@@ -47,9 +52,16 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
47
52
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
53
  model = options.get("model", self.model)
49
54
 
50
- model_info = self.MODEL_PRICING.get(model, {})
51
- tokens_param = model_info.get("tokens_param", "max_tokens")
52
- supports_temperature = model_info.get("supports_temperature", True)
55
+ model_config = self._get_model_config("openrouter", model)
56
+ tokens_param = model_config["tokens_param"]
57
+ supports_temperature = model_config["supports_temperature"]
58
+
59
+ # Validate capabilities against models.dev metadata
60
+ self._validate_model_capabilities(
61
+ "openrouter",
62
+ model,
63
+ using_json_schema=bool(options.get("json_schema")),
64
+ )
53
65
 
54
66
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
55
67
 
@@ -64,7 +76,19 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
64
76
 
65
77
  # Native JSON mode support
66
78
  if options.get("json_mode"):
67
- data["response_format"] = {"type": "json_object"}
79
+ json_schema = options.get("json_schema")
80
+ if json_schema:
81
+ schema_copy = prepare_strict_schema(json_schema)
82
+ data["response_format"] = {
83
+ "type": "json_schema",
84
+ "json_schema": {
85
+ "name": "extraction",
86
+ "strict": True,
87
+ "schema": schema_copy,
88
+ },
89
+ }
90
+ else:
91
+ data["response_format"] = {"type": "json_object"}
68
92
 
69
93
  async with httpx.AsyncClient() as client:
70
94
  try:
@@ -93,10 +117,175 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
93
117
  "prompt_tokens": prompt_tokens,
94
118
  "completion_tokens": completion_tokens,
95
119
  "total_tokens": total_tokens,
96
- "cost": total_cost,
120
+ "cost": round(total_cost, 6),
97
121
  "raw_response": resp,
98
122
  "model_name": model,
99
123
  }
100
124
 
101
125
  text = resp["choices"][0]["message"]["content"]
102
126
  return {"text": text, "meta": meta}
127
+
128
+ # ------------------------------------------------------------------
129
+ # Tool use
130
+ # ------------------------------------------------------------------
131
+
132
+ async def generate_messages_with_tools(
133
+ self,
134
+ messages: list[dict[str, Any]],
135
+ tools: list[dict[str, Any]],
136
+ options: dict[str, Any],
137
+ ) -> dict[str, Any]:
138
+ """Generate a response that may include tool calls."""
139
+ model = options.get("model", self.model)
140
+ model_config = self._get_model_config("openrouter", model)
141
+ tokens_param = model_config["tokens_param"]
142
+ supports_temperature = model_config["supports_temperature"]
143
+
144
+ self._validate_model_capabilities("openrouter", model, using_tool_use=True)
145
+
146
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
147
+
148
+ data: dict[str, Any] = {
149
+ "model": model,
150
+ "messages": messages,
151
+ "tools": tools,
152
+ }
153
+ data[tokens_param] = opts.get("max_tokens", 512)
154
+
155
+ if supports_temperature and "temperature" in opts:
156
+ data["temperature"] = opts["temperature"]
157
+
158
+ async with httpx.AsyncClient() as client:
159
+ try:
160
+ response = await client.post(
161
+ f"{self.base_url}/chat/completions",
162
+ headers=self.headers,
163
+ json=data,
164
+ timeout=120,
165
+ )
166
+ response.raise_for_status()
167
+ resp = response.json()
168
+ except httpx.HTTPStatusError as e:
169
+ error_msg = f"OpenRouter API request failed: {e!s}"
170
+ raise RuntimeError(error_msg) from e
171
+ except Exception as e:
172
+ raise RuntimeError(f"OpenRouter API request failed: {e!s}") from e
173
+
174
+ usage = resp.get("usage", {})
175
+ prompt_tokens = usage.get("prompt_tokens", 0)
176
+ completion_tokens = usage.get("completion_tokens", 0)
177
+ total_tokens = usage.get("total_tokens", 0)
178
+ total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
179
+
180
+ meta = {
181
+ "prompt_tokens": prompt_tokens,
182
+ "completion_tokens": completion_tokens,
183
+ "total_tokens": total_tokens,
184
+ "cost": round(total_cost, 6),
185
+ "raw_response": resp,
186
+ "model_name": model,
187
+ }
188
+
189
+ choice = resp["choices"][0]
190
+ text = choice["message"].get("content") or ""
191
+ stop_reason = choice.get("finish_reason")
192
+
193
+ tool_calls_out: list[dict[str, Any]] = []
194
+ for tc in choice["message"].get("tool_calls", []):
195
+ try:
196
+ args = json.loads(tc["function"]["arguments"])
197
+ except (json.JSONDecodeError, TypeError):
198
+ args = {}
199
+ tool_calls_out.append({
200
+ "id": tc["id"],
201
+ "name": tc["function"]["name"],
202
+ "arguments": args,
203
+ })
204
+
205
+ return {
206
+ "text": text,
207
+ "meta": meta,
208
+ "tool_calls": tool_calls_out,
209
+ "stop_reason": stop_reason,
210
+ }
211
+
212
+ # ------------------------------------------------------------------
213
+ # Streaming
214
+ # ------------------------------------------------------------------
215
+
216
+ async def generate_messages_stream(
217
+ self,
218
+ messages: list[dict[str, Any]],
219
+ options: dict[str, Any],
220
+ ) -> AsyncIterator[dict[str, Any]]:
221
+ """Yield response chunks via OpenRouter streaming API."""
222
+ model = options.get("model", self.model)
223
+ model_config = self._get_model_config("openrouter", model)
224
+ tokens_param = model_config["tokens_param"]
225
+ supports_temperature = model_config["supports_temperature"]
226
+
227
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
228
+
229
+ data: dict[str, Any] = {
230
+ "model": model,
231
+ "messages": messages,
232
+ "stream": True,
233
+ "stream_options": {"include_usage": True},
234
+ }
235
+ data[tokens_param] = opts.get("max_tokens", 512)
236
+
237
+ if supports_temperature and "temperature" in opts:
238
+ data["temperature"] = opts["temperature"]
239
+
240
+ full_text = ""
241
+ prompt_tokens = 0
242
+ completion_tokens = 0
243
+
244
+ async with httpx.AsyncClient() as client, client.stream(
245
+ "POST",
246
+ f"{self.base_url}/chat/completions",
247
+ headers=self.headers,
248
+ json=data,
249
+ timeout=120,
250
+ ) as response:
251
+ response.raise_for_status()
252
+ async for line in response.aiter_lines():
253
+ if not line or not line.startswith("data: "):
254
+ continue
255
+ payload = line[len("data: "):]
256
+ if payload.strip() == "[DONE]":
257
+ break
258
+ try:
259
+ chunk = json.loads(payload)
260
+ except json.JSONDecodeError:
261
+ continue
262
+
263
+ # Usage comes in the final chunk
264
+ usage = chunk.get("usage")
265
+ if usage:
266
+ prompt_tokens = usage.get("prompt_tokens", 0)
267
+ completion_tokens = usage.get("completion_tokens", 0)
268
+
269
+ choices = chunk.get("choices", [])
270
+ if choices:
271
+ delta = choices[0].get("delta", {})
272
+ content = delta.get("content", "")
273
+ if content:
274
+ full_text += content
275
+ yield {"type": "delta", "text": content}
276
+
277
+ total_tokens = prompt_tokens + completion_tokens
278
+ total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
279
+
280
+ yield {
281
+ "type": "done",
282
+ "text": full_text,
283
+ "meta": {
284
+ "prompt_tokens": prompt_tokens,
285
+ "completion_tokens": completion_tokens,
286
+ "total_tokens": total_tokens,
287
+ "cost": round(total_cost, 6),
288
+ "raw_response": {},
289
+ "model_name": model,
290
+ },
291
+ }
@@ -22,9 +22,12 @@ from .async_grok_driver import AsyncGrokDriver
22
22
  from .async_groq_driver import AsyncGroqDriver
23
23
  from .async_lmstudio_driver import AsyncLMStudioDriver
24
24
  from .async_local_http_driver import AsyncLocalHTTPDriver
25
+ from .async_modelscope_driver import AsyncModelScopeDriver
26
+ from .async_moonshot_driver import AsyncMoonshotDriver
25
27
  from .async_ollama_driver import AsyncOllamaDriver
26
28
  from .async_openai_driver import AsyncOpenAIDriver
27
29
  from .async_openrouter_driver import AsyncOpenRouterDriver
30
+ from .async_zai_driver import AsyncZaiDriver
28
31
  from .registry import (
29
32
  _get_async_registry,
30
33
  get_async_driver_factory,
@@ -90,6 +93,33 @@ register_async_driver(
90
93
  lambda model=None: AsyncGrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
91
94
  overwrite=True,
92
95
  )
96
+ register_async_driver(
97
+ "moonshot",
98
+ lambda model=None: AsyncMoonshotDriver(
99
+ api_key=settings.moonshot_api_key,
100
+ model=model or settings.moonshot_model,
101
+ endpoint=settings.moonshot_endpoint,
102
+ ),
103
+ overwrite=True,
104
+ )
105
+ register_async_driver(
106
+ "modelscope",
107
+ lambda model=None: AsyncModelScopeDriver(
108
+ api_key=settings.modelscope_api_key,
109
+ model=model or settings.modelscope_model,
110
+ endpoint=settings.modelscope_endpoint,
111
+ ),
112
+ overwrite=True,
113
+ )
114
+ register_async_driver(
115
+ "zai",
116
+ lambda model=None: AsyncZaiDriver(
117
+ api_key=settings.zhipu_api_key,
118
+ model=model or settings.zhipu_model,
119
+ endpoint=settings.zhipu_endpoint,
120
+ ),
121
+ overwrite=True,
122
+ )
93
123
  register_async_driver(
94
124
  "airllm",
95
125
  lambda model=None: AsyncAirLLMDriver(