prompture 0.0.38.dev2__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +12 -1
- prompture/_version.py +2 -2
- prompture/agent.py +11 -11
- prompture/async_agent.py +11 -11
- prompture/async_conversation.py +9 -0
- prompture/async_core.py +16 -0
- prompture/async_driver.py +39 -0
- prompture/async_groups.py +63 -0
- prompture/conversation.py +9 -0
- prompture/core.py +16 -0
- prompture/cost_mixin.py +62 -0
- prompture/discovery.py +108 -43
- prompture/driver.py +39 -0
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_azure_driver.py +7 -6
- prompture/drivers/async_claude_driver.py +177 -8
- prompture/drivers/async_google_driver.py +10 -0
- prompture/drivers/async_grok_driver.py +4 -4
- prompture/drivers/async_groq_driver.py +4 -4
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +312 -0
- prompture/drivers/async_openai_driver.py +158 -6
- prompture/drivers/async_openrouter_driver.py +196 -7
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +303 -0
- prompture/drivers/azure_driver.py +6 -5
- prompture/drivers/claude_driver.py +10 -0
- prompture/drivers/google_driver.py +10 -0
- prompture/drivers/grok_driver.py +4 -4
- prompture/drivers/groq_driver.py +4 -4
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +342 -0
- prompture/drivers/openai_driver.py +22 -12
- prompture/drivers/openrouter_driver.py +248 -44
- prompture/drivers/zai_driver.py +318 -0
- prompture/groups.py +42 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +114 -2
- prompture/settings.py +16 -1
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/METADATA +1 -1
- prompture-0.0.42.dist-info/RECORD +84 -0
- prompture-0.0.38.dev2.dist-info/RECORD +0 -77
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/WHEEL +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/top_level.txt +0 -0
|
@@ -49,9 +49,9 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
49
49
|
|
|
50
50
|
model = options.get("model", self.model)
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
tokens_param =
|
|
54
|
-
supports_temperature =
|
|
52
|
+
model_config = self._get_model_config("groq", model)
|
|
53
|
+
tokens_param = model_config["tokens_param"]
|
|
54
|
+
supports_temperature = model_config["supports_temperature"]
|
|
55
55
|
|
|
56
56
|
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
57
57
|
|
|
@@ -81,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
81
81
|
"prompt_tokens": prompt_tokens,
|
|
82
82
|
"completion_tokens": completion_tokens,
|
|
83
83
|
"total_tokens": total_tokens,
|
|
84
|
-
"cost": total_cost,
|
|
84
|
+
"cost": round(total_cost, 6),
|
|
85
85
|
"raw_response": resp.model_dump(),
|
|
86
86
|
"model_name": model,
|
|
87
87
|
}
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""Async ModelScope (Alibaba Cloud) driver using httpx.
|
|
2
|
+
|
|
3
|
+
No hardcoded pricing — ModelScope's free tier has no per-token cost.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import AsyncIterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from ..async_driver import AsyncDriver
|
|
16
|
+
from ..cost_mixin import CostMixin
|
|
17
|
+
from .modelscope_driver import ModelScopeDriver
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncModelScopeDriver(CostMixin, AsyncDriver):
|
|
21
|
+
supports_json_mode = True
|
|
22
|
+
supports_json_schema = False
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
25
|
+
supports_vision = False
|
|
26
|
+
|
|
27
|
+
MODEL_PRICING = ModelScopeDriver.MODEL_PRICING
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str | None = None,
|
|
32
|
+
model: str = "Qwen/Qwen3-235B-A22B-Instruct-2507",
|
|
33
|
+
endpoint: str = "https://api-inference.modelscope.cn/v1",
|
|
34
|
+
):
|
|
35
|
+
self.api_key = api_key or os.getenv("MODELSCOPE_API_KEY")
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
raise ValueError("ModelScope API key not found. Set MODELSCOPE_API_KEY env var.")
|
|
38
|
+
self.model = model
|
|
39
|
+
self.base_url = endpoint.rstrip("/")
|
|
40
|
+
self.headers = {
|
|
41
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
42
|
+
"Content-Type": "application/json",
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
supports_messages = True
|
|
46
|
+
|
|
47
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
+
messages = [{"role": "user", "content": prompt}]
|
|
49
|
+
return await self._do_generate(messages, options)
|
|
50
|
+
|
|
51
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
52
|
+
return await self._do_generate(messages, options)
|
|
53
|
+
|
|
54
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
55
|
+
model = options.get("model", self.model)
|
|
56
|
+
|
|
57
|
+
model_config = self._get_model_config("modelscope", model)
|
|
58
|
+
tokens_param = model_config["tokens_param"]
|
|
59
|
+
supports_temperature = model_config["supports_temperature"]
|
|
60
|
+
|
|
61
|
+
self._validate_model_capabilities(
|
|
62
|
+
"modelscope",
|
|
63
|
+
model,
|
|
64
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
68
|
+
|
|
69
|
+
data: dict[str, Any] = {
|
|
70
|
+
"model": model,
|
|
71
|
+
"messages": messages,
|
|
72
|
+
}
|
|
73
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
74
|
+
|
|
75
|
+
if supports_temperature and "temperature" in opts:
|
|
76
|
+
data["temperature"] = opts["temperature"]
|
|
77
|
+
|
|
78
|
+
if options.get("json_mode"):
|
|
79
|
+
data["response_format"] = {"type": "json_object"}
|
|
80
|
+
|
|
81
|
+
async with httpx.AsyncClient() as client:
|
|
82
|
+
try:
|
|
83
|
+
response = await client.post(
|
|
84
|
+
f"{self.base_url}/chat/completions",
|
|
85
|
+
headers=self.headers,
|
|
86
|
+
json=data,
|
|
87
|
+
timeout=120,
|
|
88
|
+
)
|
|
89
|
+
response.raise_for_status()
|
|
90
|
+
resp = response.json()
|
|
91
|
+
except httpx.HTTPStatusError as e:
|
|
92
|
+
error_msg = f"ModelScope API request failed: {e!s}"
|
|
93
|
+
raise RuntimeError(error_msg) from e
|
|
94
|
+
except Exception as e:
|
|
95
|
+
raise RuntimeError(f"ModelScope API request failed: {e!s}") from e
|
|
96
|
+
|
|
97
|
+
usage = resp.get("usage", {})
|
|
98
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
99
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
100
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
101
|
+
|
|
102
|
+
total_cost = self._calculate_cost("modelscope", model, prompt_tokens, completion_tokens)
|
|
103
|
+
|
|
104
|
+
meta = {
|
|
105
|
+
"prompt_tokens": prompt_tokens,
|
|
106
|
+
"completion_tokens": completion_tokens,
|
|
107
|
+
"total_tokens": total_tokens,
|
|
108
|
+
"cost": round(total_cost, 6),
|
|
109
|
+
"raw_response": resp,
|
|
110
|
+
"model_name": model,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
text = resp["choices"][0]["message"]["content"]
|
|
114
|
+
return {"text": text, "meta": meta}
|
|
115
|
+
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
# Tool use
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
async def generate_messages_with_tools(
|
|
121
|
+
self,
|
|
122
|
+
messages: list[dict[str, Any]],
|
|
123
|
+
tools: list[dict[str, Any]],
|
|
124
|
+
options: dict[str, Any],
|
|
125
|
+
) -> dict[str, Any]:
|
|
126
|
+
"""Generate a response that may include tool calls."""
|
|
127
|
+
model = options.get("model", self.model)
|
|
128
|
+
model_config = self._get_model_config("modelscope", model)
|
|
129
|
+
tokens_param = model_config["tokens_param"]
|
|
130
|
+
supports_temperature = model_config["supports_temperature"]
|
|
131
|
+
|
|
132
|
+
self._validate_model_capabilities("modelscope", model, using_tool_use=True)
|
|
133
|
+
|
|
134
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
135
|
+
|
|
136
|
+
data: dict[str, Any] = {
|
|
137
|
+
"model": model,
|
|
138
|
+
"messages": messages,
|
|
139
|
+
"tools": tools,
|
|
140
|
+
}
|
|
141
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
142
|
+
|
|
143
|
+
if supports_temperature and "temperature" in opts:
|
|
144
|
+
data["temperature"] = opts["temperature"]
|
|
145
|
+
|
|
146
|
+
if "tool_choice" in options:
|
|
147
|
+
data["tool_choice"] = options["tool_choice"]
|
|
148
|
+
|
|
149
|
+
async with httpx.AsyncClient() as client:
|
|
150
|
+
try:
|
|
151
|
+
response = await client.post(
|
|
152
|
+
f"{self.base_url}/chat/completions",
|
|
153
|
+
headers=self.headers,
|
|
154
|
+
json=data,
|
|
155
|
+
timeout=120,
|
|
156
|
+
)
|
|
157
|
+
response.raise_for_status()
|
|
158
|
+
resp = response.json()
|
|
159
|
+
except httpx.HTTPStatusError as e:
|
|
160
|
+
error_msg = f"ModelScope API request failed: {e!s}"
|
|
161
|
+
raise RuntimeError(error_msg) from e
|
|
162
|
+
except Exception as e:
|
|
163
|
+
raise RuntimeError(f"ModelScope API request failed: {e!s}") from e
|
|
164
|
+
|
|
165
|
+
usage = resp.get("usage", {})
|
|
166
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
167
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
168
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
169
|
+
total_cost = self._calculate_cost("modelscope", model, prompt_tokens, completion_tokens)
|
|
170
|
+
|
|
171
|
+
meta = {
|
|
172
|
+
"prompt_tokens": prompt_tokens,
|
|
173
|
+
"completion_tokens": completion_tokens,
|
|
174
|
+
"total_tokens": total_tokens,
|
|
175
|
+
"cost": round(total_cost, 6),
|
|
176
|
+
"raw_response": resp,
|
|
177
|
+
"model_name": model,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
choice = resp["choices"][0]
|
|
181
|
+
text = choice["message"].get("content") or ""
|
|
182
|
+
stop_reason = choice.get("finish_reason")
|
|
183
|
+
|
|
184
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
185
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
186
|
+
try:
|
|
187
|
+
args = json.loads(tc["function"]["arguments"])
|
|
188
|
+
except (json.JSONDecodeError, TypeError):
|
|
189
|
+
args = {}
|
|
190
|
+
tool_calls_out.append(
|
|
191
|
+
{
|
|
192
|
+
"id": tc["id"],
|
|
193
|
+
"name": tc["function"]["name"],
|
|
194
|
+
"arguments": args,
|
|
195
|
+
}
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return {
|
|
199
|
+
"text": text,
|
|
200
|
+
"meta": meta,
|
|
201
|
+
"tool_calls": tool_calls_out,
|
|
202
|
+
"stop_reason": stop_reason,
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# ------------------------------------------------------------------
|
|
206
|
+
# Streaming
|
|
207
|
+
# ------------------------------------------------------------------
|
|
208
|
+
|
|
209
|
+
async def generate_messages_stream(
|
|
210
|
+
self,
|
|
211
|
+
messages: list[dict[str, Any]],
|
|
212
|
+
options: dict[str, Any],
|
|
213
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
214
|
+
"""Yield response chunks via ModelScope streaming API."""
|
|
215
|
+
model = options.get("model", self.model)
|
|
216
|
+
model_config = self._get_model_config("modelscope", model)
|
|
217
|
+
tokens_param = model_config["tokens_param"]
|
|
218
|
+
supports_temperature = model_config["supports_temperature"]
|
|
219
|
+
|
|
220
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
221
|
+
|
|
222
|
+
data: dict[str, Any] = {
|
|
223
|
+
"model": model,
|
|
224
|
+
"messages": messages,
|
|
225
|
+
"stream": True,
|
|
226
|
+
"stream_options": {"include_usage": True},
|
|
227
|
+
}
|
|
228
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
229
|
+
|
|
230
|
+
if supports_temperature and "temperature" in opts:
|
|
231
|
+
data["temperature"] = opts["temperature"]
|
|
232
|
+
|
|
233
|
+
full_text = ""
|
|
234
|
+
prompt_tokens = 0
|
|
235
|
+
completion_tokens = 0
|
|
236
|
+
|
|
237
|
+
async with (
|
|
238
|
+
httpx.AsyncClient() as client,
|
|
239
|
+
client.stream(
|
|
240
|
+
"POST",
|
|
241
|
+
f"{self.base_url}/chat/completions",
|
|
242
|
+
headers=self.headers,
|
|
243
|
+
json=data,
|
|
244
|
+
timeout=120,
|
|
245
|
+
) as response,
|
|
246
|
+
):
|
|
247
|
+
response.raise_for_status()
|
|
248
|
+
async for line in response.aiter_lines():
|
|
249
|
+
if not line or not line.startswith("data: "):
|
|
250
|
+
continue
|
|
251
|
+
payload = line[len("data: ") :]
|
|
252
|
+
if payload.strip() == "[DONE]":
|
|
253
|
+
break
|
|
254
|
+
try:
|
|
255
|
+
chunk = json.loads(payload)
|
|
256
|
+
except json.JSONDecodeError:
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
usage = chunk.get("usage")
|
|
260
|
+
if usage:
|
|
261
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
262
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
263
|
+
|
|
264
|
+
choices = chunk.get("choices", [])
|
|
265
|
+
if choices:
|
|
266
|
+
delta = choices[0].get("delta", {})
|
|
267
|
+
content = delta.get("content", "")
|
|
268
|
+
if content:
|
|
269
|
+
full_text += content
|
|
270
|
+
yield {"type": "delta", "text": content}
|
|
271
|
+
|
|
272
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
273
|
+
total_cost = self._calculate_cost("modelscope", model, prompt_tokens, completion_tokens)
|
|
274
|
+
|
|
275
|
+
yield {
|
|
276
|
+
"type": "done",
|
|
277
|
+
"text": full_text,
|
|
278
|
+
"meta": {
|
|
279
|
+
"prompt_tokens": prompt_tokens,
|
|
280
|
+
"completion_tokens": completion_tokens,
|
|
281
|
+
"total_tokens": total_tokens,
|
|
282
|
+
"cost": round(total_cost, 6),
|
|
283
|
+
"raw_response": {},
|
|
284
|
+
"model_name": model,
|
|
285
|
+
},
|
|
286
|
+
}
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""Async Moonshot AI (Kimi) driver using httpx.
|
|
2
|
+
|
|
3
|
+
All pricing comes from models.dev (provider: "moonshotai") — no hardcoded pricing.
|
|
4
|
+
|
|
5
|
+
Moonshot-specific constraints:
|
|
6
|
+
- Temperature clamped to [0, 1] (OpenAI allows [0, 2])
|
|
7
|
+
- tool_choice: "required" not supported — only "auto" or "none"
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
from collections.abc import AsyncIterator
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import httpx
|
|
18
|
+
|
|
19
|
+
from ..async_driver import AsyncDriver
|
|
20
|
+
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
21
|
+
from .moonshot_driver import MoonshotDriver
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
25
|
+
supports_json_mode = True
|
|
26
|
+
supports_json_schema = True
|
|
27
|
+
supports_tool_use = True
|
|
28
|
+
supports_streaming = True
|
|
29
|
+
supports_vision = True
|
|
30
|
+
|
|
31
|
+
MODEL_PRICING = MoonshotDriver.MODEL_PRICING
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
api_key: str | None = None,
|
|
36
|
+
model: str = "kimi-k2-0905-preview",
|
|
37
|
+
endpoint: str = "https://api.moonshot.ai/v1",
|
|
38
|
+
):
|
|
39
|
+
self.api_key = api_key or os.getenv("MOONSHOT_API_KEY")
|
|
40
|
+
if not self.api_key:
|
|
41
|
+
raise ValueError("Moonshot API key not found. Set MOONSHOT_API_KEY env var.")
|
|
42
|
+
self.model = model
|
|
43
|
+
self.base_url = endpoint.rstrip("/")
|
|
44
|
+
self.headers = {
|
|
45
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
46
|
+
"Content-Type": "application/json",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
supports_messages = True
|
|
50
|
+
|
|
51
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
52
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
53
|
+
|
|
54
|
+
return _prepare_openai_vision_messages(messages)
|
|
55
|
+
|
|
56
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
57
|
+
messages = [{"role": "user", "content": prompt}]
|
|
58
|
+
return await self._do_generate(messages, options)
|
|
59
|
+
|
|
60
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
62
|
+
|
|
63
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
64
|
+
model = options.get("model", self.model)
|
|
65
|
+
|
|
66
|
+
model_config = self._get_model_config("moonshot", model)
|
|
67
|
+
tokens_param = model_config["tokens_param"]
|
|
68
|
+
supports_temperature = model_config["supports_temperature"]
|
|
69
|
+
|
|
70
|
+
self._validate_model_capabilities(
|
|
71
|
+
"moonshot",
|
|
72
|
+
model,
|
|
73
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
77
|
+
MoonshotDriver._clamp_temperature(opts)
|
|
78
|
+
|
|
79
|
+
data: dict[str, Any] = {
|
|
80
|
+
"model": model,
|
|
81
|
+
"messages": messages,
|
|
82
|
+
}
|
|
83
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
84
|
+
|
|
85
|
+
if supports_temperature and "temperature" in opts:
|
|
86
|
+
data["temperature"] = opts["temperature"]
|
|
87
|
+
|
|
88
|
+
if options.get("json_mode"):
|
|
89
|
+
json_schema = options.get("json_schema")
|
|
90
|
+
if json_schema:
|
|
91
|
+
schema_copy = prepare_strict_schema(json_schema)
|
|
92
|
+
data["response_format"] = {
|
|
93
|
+
"type": "json_schema",
|
|
94
|
+
"json_schema": {
|
|
95
|
+
"name": "extraction",
|
|
96
|
+
"strict": True,
|
|
97
|
+
"schema": schema_copy,
|
|
98
|
+
},
|
|
99
|
+
}
|
|
100
|
+
else:
|
|
101
|
+
data["response_format"] = {"type": "json_object"}
|
|
102
|
+
|
|
103
|
+
async with httpx.AsyncClient() as client:
|
|
104
|
+
try:
|
|
105
|
+
response = await client.post(
|
|
106
|
+
f"{self.base_url}/chat/completions",
|
|
107
|
+
headers=self.headers,
|
|
108
|
+
json=data,
|
|
109
|
+
timeout=120,
|
|
110
|
+
)
|
|
111
|
+
response.raise_for_status()
|
|
112
|
+
resp = response.json()
|
|
113
|
+
except httpx.HTTPStatusError as e:
|
|
114
|
+
error_msg = f"Moonshot API request failed: {e!s}"
|
|
115
|
+
raise RuntimeError(error_msg) from e
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise RuntimeError(f"Moonshot API request failed: {e!s}") from e
|
|
118
|
+
|
|
119
|
+
usage = resp.get("usage", {})
|
|
120
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
121
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
122
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
123
|
+
|
|
124
|
+
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
125
|
+
|
|
126
|
+
meta = {
|
|
127
|
+
"prompt_tokens": prompt_tokens,
|
|
128
|
+
"completion_tokens": completion_tokens,
|
|
129
|
+
"total_tokens": total_tokens,
|
|
130
|
+
"cost": round(total_cost, 6),
|
|
131
|
+
"raw_response": resp,
|
|
132
|
+
"model_name": model,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
text = resp["choices"][0]["message"]["content"]
|
|
136
|
+
return {"text": text, "meta": meta}
|
|
137
|
+
|
|
138
|
+
# ------------------------------------------------------------------
|
|
139
|
+
# Tool use
|
|
140
|
+
# ------------------------------------------------------------------
|
|
141
|
+
|
|
142
|
+
async def generate_messages_with_tools(
|
|
143
|
+
self,
|
|
144
|
+
messages: list[dict[str, Any]],
|
|
145
|
+
tools: list[dict[str, Any]],
|
|
146
|
+
options: dict[str, Any],
|
|
147
|
+
) -> dict[str, Any]:
|
|
148
|
+
"""Generate a response that may include tool calls."""
|
|
149
|
+
model = options.get("model", self.model)
|
|
150
|
+
model_config = self._get_model_config("moonshot", model)
|
|
151
|
+
tokens_param = model_config["tokens_param"]
|
|
152
|
+
supports_temperature = model_config["supports_temperature"]
|
|
153
|
+
|
|
154
|
+
self._validate_model_capabilities("moonshot", model, using_tool_use=True)
|
|
155
|
+
|
|
156
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
157
|
+
MoonshotDriver._clamp_temperature(opts)
|
|
158
|
+
|
|
159
|
+
data: dict[str, Any] = {
|
|
160
|
+
"model": model,
|
|
161
|
+
"messages": messages,
|
|
162
|
+
"tools": tools,
|
|
163
|
+
}
|
|
164
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
165
|
+
|
|
166
|
+
if supports_temperature and "temperature" in opts:
|
|
167
|
+
data["temperature"] = opts["temperature"]
|
|
168
|
+
|
|
169
|
+
if "tool_choice" in options:
|
|
170
|
+
data["tool_choice"] = options["tool_choice"]
|
|
171
|
+
|
|
172
|
+
MoonshotDriver._sanitize_tool_choice(data)
|
|
173
|
+
|
|
174
|
+
async with httpx.AsyncClient() as client:
|
|
175
|
+
try:
|
|
176
|
+
response = await client.post(
|
|
177
|
+
f"{self.base_url}/chat/completions",
|
|
178
|
+
headers=self.headers,
|
|
179
|
+
json=data,
|
|
180
|
+
timeout=120,
|
|
181
|
+
)
|
|
182
|
+
response.raise_for_status()
|
|
183
|
+
resp = response.json()
|
|
184
|
+
except httpx.HTTPStatusError as e:
|
|
185
|
+
error_msg = f"Moonshot API request failed: {e!s}"
|
|
186
|
+
raise RuntimeError(error_msg) from e
|
|
187
|
+
except Exception as e:
|
|
188
|
+
raise RuntimeError(f"Moonshot API request failed: {e!s}") from e
|
|
189
|
+
|
|
190
|
+
usage = resp.get("usage", {})
|
|
191
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
192
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
193
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
194
|
+
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
195
|
+
|
|
196
|
+
meta = {
|
|
197
|
+
"prompt_tokens": prompt_tokens,
|
|
198
|
+
"completion_tokens": completion_tokens,
|
|
199
|
+
"total_tokens": total_tokens,
|
|
200
|
+
"cost": round(total_cost, 6),
|
|
201
|
+
"raw_response": resp,
|
|
202
|
+
"model_name": model,
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
choice = resp["choices"][0]
|
|
206
|
+
text = choice["message"].get("content") or ""
|
|
207
|
+
stop_reason = choice.get("finish_reason")
|
|
208
|
+
|
|
209
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
210
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
211
|
+
try:
|
|
212
|
+
args = json.loads(tc["function"]["arguments"])
|
|
213
|
+
except (json.JSONDecodeError, TypeError):
|
|
214
|
+
args = {}
|
|
215
|
+
tool_calls_out.append(
|
|
216
|
+
{
|
|
217
|
+
"id": tc["id"],
|
|
218
|
+
"name": tc["function"]["name"],
|
|
219
|
+
"arguments": args,
|
|
220
|
+
}
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return {
|
|
224
|
+
"text": text,
|
|
225
|
+
"meta": meta,
|
|
226
|
+
"tool_calls": tool_calls_out,
|
|
227
|
+
"stop_reason": stop_reason,
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
# ------------------------------------------------------------------
|
|
231
|
+
# Streaming
|
|
232
|
+
# ------------------------------------------------------------------
|
|
233
|
+
|
|
234
|
+
async def generate_messages_stream(
|
|
235
|
+
self,
|
|
236
|
+
messages: list[dict[str, Any]],
|
|
237
|
+
options: dict[str, Any],
|
|
238
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
239
|
+
"""Yield response chunks via Moonshot streaming API."""
|
|
240
|
+
model = options.get("model", self.model)
|
|
241
|
+
model_config = self._get_model_config("moonshot", model)
|
|
242
|
+
tokens_param = model_config["tokens_param"]
|
|
243
|
+
supports_temperature = model_config["supports_temperature"]
|
|
244
|
+
|
|
245
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
246
|
+
MoonshotDriver._clamp_temperature(opts)
|
|
247
|
+
|
|
248
|
+
data: dict[str, Any] = {
|
|
249
|
+
"model": model,
|
|
250
|
+
"messages": messages,
|
|
251
|
+
"stream": True,
|
|
252
|
+
"stream_options": {"include_usage": True},
|
|
253
|
+
}
|
|
254
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
255
|
+
|
|
256
|
+
if supports_temperature and "temperature" in opts:
|
|
257
|
+
data["temperature"] = opts["temperature"]
|
|
258
|
+
|
|
259
|
+
full_text = ""
|
|
260
|
+
prompt_tokens = 0
|
|
261
|
+
completion_tokens = 0
|
|
262
|
+
|
|
263
|
+
async with (
|
|
264
|
+
httpx.AsyncClient() as client,
|
|
265
|
+
client.stream(
|
|
266
|
+
"POST",
|
|
267
|
+
f"{self.base_url}/chat/completions",
|
|
268
|
+
headers=self.headers,
|
|
269
|
+
json=data,
|
|
270
|
+
timeout=120,
|
|
271
|
+
) as response,
|
|
272
|
+
):
|
|
273
|
+
response.raise_for_status()
|
|
274
|
+
async for line in response.aiter_lines():
|
|
275
|
+
if not line or not line.startswith("data: "):
|
|
276
|
+
continue
|
|
277
|
+
payload = line[len("data: ") :]
|
|
278
|
+
if payload.strip() == "[DONE]":
|
|
279
|
+
break
|
|
280
|
+
try:
|
|
281
|
+
chunk = json.loads(payload)
|
|
282
|
+
except json.JSONDecodeError:
|
|
283
|
+
continue
|
|
284
|
+
|
|
285
|
+
usage = chunk.get("usage")
|
|
286
|
+
if usage:
|
|
287
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
288
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
289
|
+
|
|
290
|
+
choices = chunk.get("choices", [])
|
|
291
|
+
if choices:
|
|
292
|
+
delta = choices[0].get("delta", {})
|
|
293
|
+
content = delta.get("content", "")
|
|
294
|
+
if content:
|
|
295
|
+
full_text += content
|
|
296
|
+
yield {"type": "delta", "text": content}
|
|
297
|
+
|
|
298
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
299
|
+
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
300
|
+
|
|
301
|
+
yield {
|
|
302
|
+
"type": "done",
|
|
303
|
+
"text": full_text,
|
|
304
|
+
"meta": {
|
|
305
|
+
"prompt_tokens": prompt_tokens,
|
|
306
|
+
"completion_tokens": completion_tokens,
|
|
307
|
+
"total_tokens": total_tokens,
|
|
308
|
+
"cost": round(total_cost, 6),
|
|
309
|
+
"raw_response": {},
|
|
310
|
+
"model_name": model,
|
|
311
|
+
},
|
|
312
|
+
}
|