prompture 0.0.40.dev1__py3-none-any.whl → 0.0.41.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +311 -0
- prompture/drivers/async_openrouter_driver.py +190 -2
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +302 -0
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +341 -0
- prompture/drivers/openrouter_driver.py +235 -39
- prompture/drivers/zai_driver.py +317 -0
- prompture/model_rates.py +2 -0
- prompture/settings.py +15 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dev1.dist-info}/METADATA +1 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dev1.dist-info}/RECORD +19 -13
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.41.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""ModelScope (Alibaba Cloud) driver implementation.
|
|
2
|
+
Requires the `requests` package. Uses MODELSCOPE_API_KEY env var.
|
|
3
|
+
|
|
4
|
+
The ModelScope API-Inference endpoint is fully OpenAI-compatible (/v1/chat/completions).
|
|
5
|
+
No hardcoded pricing — ModelScope's free tier has no per-token cost.
|
|
6
|
+
|
|
7
|
+
Model IDs are namespace-prefixed (e.g. Qwen/Qwen3-235B-A22B-Instruct-2507).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
from collections.abc import Iterator
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import requests
|
|
16
|
+
|
|
17
|
+
from ..cost_mixin import CostMixin
|
|
18
|
+
from ..driver import Driver
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelScopeDriver(CostMixin, Driver):
|
|
22
|
+
supports_json_mode = True
|
|
23
|
+
supports_json_schema = False
|
|
24
|
+
supports_tool_use = True
|
|
25
|
+
supports_streaming = True
|
|
26
|
+
supports_vision = False
|
|
27
|
+
|
|
28
|
+
# No pricing data available — ModelScope free tier has no per-token cost
|
|
29
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
api_key: str | None = None,
|
|
34
|
+
model: str = "Qwen/Qwen3-235B-A22B-Instruct-2507",
|
|
35
|
+
endpoint: str = "https://api-inference.modelscope.cn/v1",
|
|
36
|
+
):
|
|
37
|
+
"""Initialize ModelScope driver.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
api_key: ModelScope API key. If not provided, will look for MODELSCOPE_API_KEY env var.
|
|
41
|
+
model: Model to use. Defaults to Qwen/Qwen3-235B-A22B-Instruct-2507.
|
|
42
|
+
endpoint: API base URL. Defaults to https://api-inference.modelscope.cn/v1.
|
|
43
|
+
"""
|
|
44
|
+
self.api_key = api_key or os.getenv("MODELSCOPE_API_KEY")
|
|
45
|
+
if not self.api_key:
|
|
46
|
+
raise ValueError("ModelScope API key not found. Set MODELSCOPE_API_KEY env var.")
|
|
47
|
+
|
|
48
|
+
self.model = model
|
|
49
|
+
self.base_url = endpoint.rstrip("/")
|
|
50
|
+
|
|
51
|
+
self.headers = {
|
|
52
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
53
|
+
"Content-Type": "application/json",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
supports_messages = True
|
|
57
|
+
|
|
58
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
59
|
+
messages = [{"role": "user", "content": prompt}]
|
|
60
|
+
return self._do_generate(messages, options)
|
|
61
|
+
|
|
62
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
63
|
+
return self._do_generate(messages, options)
|
|
64
|
+
|
|
65
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
if not self.api_key:
|
|
67
|
+
raise RuntimeError("ModelScope API key not found")
|
|
68
|
+
|
|
69
|
+
model = options.get("model", self.model)
|
|
70
|
+
|
|
71
|
+
model_config = self._get_model_config("modelscope", model)
|
|
72
|
+
tokens_param = model_config["tokens_param"]
|
|
73
|
+
supports_temperature = model_config["supports_temperature"]
|
|
74
|
+
|
|
75
|
+
self._validate_model_capabilities(
|
|
76
|
+
"modelscope",
|
|
77
|
+
model,
|
|
78
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
82
|
+
|
|
83
|
+
data: dict[str, Any] = {
|
|
84
|
+
"model": model,
|
|
85
|
+
"messages": messages,
|
|
86
|
+
}
|
|
87
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
88
|
+
|
|
89
|
+
if supports_temperature and "temperature" in opts:
|
|
90
|
+
data["temperature"] = opts["temperature"]
|
|
91
|
+
|
|
92
|
+
# Native JSON mode support
|
|
93
|
+
if options.get("json_mode"):
|
|
94
|
+
data["response_format"] = {"type": "json_object"}
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
response = requests.post(
|
|
98
|
+
f"{self.base_url}/chat/completions",
|
|
99
|
+
headers=self.headers,
|
|
100
|
+
json=data,
|
|
101
|
+
timeout=120,
|
|
102
|
+
)
|
|
103
|
+
response.raise_for_status()
|
|
104
|
+
resp = response.json()
|
|
105
|
+
except requests.exceptions.HTTPError as e:
|
|
106
|
+
error_msg = f"ModelScope API request failed: {e!s}"
|
|
107
|
+
raise RuntimeError(error_msg) from e
|
|
108
|
+
except requests.exceptions.RequestException as e:
|
|
109
|
+
raise RuntimeError(f"ModelScope API request failed: {e!s}") from e
|
|
110
|
+
|
|
111
|
+
usage = resp.get("usage", {})
|
|
112
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
113
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
114
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
115
|
+
|
|
116
|
+
total_cost = self._calculate_cost("modelscope", model, prompt_tokens, completion_tokens)
|
|
117
|
+
|
|
118
|
+
meta = {
|
|
119
|
+
"prompt_tokens": prompt_tokens,
|
|
120
|
+
"completion_tokens": completion_tokens,
|
|
121
|
+
"total_tokens": total_tokens,
|
|
122
|
+
"cost": round(total_cost, 6),
|
|
123
|
+
"raw_response": resp,
|
|
124
|
+
"model_name": model,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
text = resp["choices"][0]["message"]["content"]
|
|
128
|
+
return {"text": text, "meta": meta}
|
|
129
|
+
|
|
130
|
+
# ------------------------------------------------------------------
|
|
131
|
+
# Tool use
|
|
132
|
+
# ------------------------------------------------------------------
|
|
133
|
+
|
|
134
|
+
def generate_messages_with_tools(
|
|
135
|
+
self,
|
|
136
|
+
messages: list[dict[str, Any]],
|
|
137
|
+
tools: list[dict[str, Any]],
|
|
138
|
+
options: dict[str, Any],
|
|
139
|
+
) -> dict[str, Any]:
|
|
140
|
+
"""Generate a response that may include tool calls."""
|
|
141
|
+
if not self.api_key:
|
|
142
|
+
raise RuntimeError("ModelScope API key not found")
|
|
143
|
+
|
|
144
|
+
model = options.get("model", self.model)
|
|
145
|
+
model_config = self._get_model_config("modelscope", model)
|
|
146
|
+
tokens_param = model_config["tokens_param"]
|
|
147
|
+
supports_temperature = model_config["supports_temperature"]
|
|
148
|
+
|
|
149
|
+
self._validate_model_capabilities("modelscope", model, using_tool_use=True)
|
|
150
|
+
|
|
151
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
152
|
+
|
|
153
|
+
data: dict[str, Any] = {
|
|
154
|
+
"model": model,
|
|
155
|
+
"messages": messages,
|
|
156
|
+
"tools": tools,
|
|
157
|
+
}
|
|
158
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
159
|
+
|
|
160
|
+
if supports_temperature and "temperature" in opts:
|
|
161
|
+
data["temperature"] = opts["temperature"]
|
|
162
|
+
|
|
163
|
+
if "tool_choice" in options:
|
|
164
|
+
data["tool_choice"] = options["tool_choice"]
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
response = requests.post(
|
|
168
|
+
f"{self.base_url}/chat/completions",
|
|
169
|
+
headers=self.headers,
|
|
170
|
+
json=data,
|
|
171
|
+
timeout=120,
|
|
172
|
+
)
|
|
173
|
+
response.raise_for_status()
|
|
174
|
+
resp = response.json()
|
|
175
|
+
except requests.exceptions.HTTPError as e:
|
|
176
|
+
error_msg = f"ModelScope API request failed: {e!s}"
|
|
177
|
+
raise RuntimeError(error_msg) from e
|
|
178
|
+
except requests.exceptions.RequestException as e:
|
|
179
|
+
raise RuntimeError(f"ModelScope API request failed: {e!s}") from e
|
|
180
|
+
|
|
181
|
+
usage = resp.get("usage", {})
|
|
182
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
183
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
184
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
185
|
+
total_cost = self._calculate_cost("modelscope", model, prompt_tokens, completion_tokens)
|
|
186
|
+
|
|
187
|
+
meta = {
|
|
188
|
+
"prompt_tokens": prompt_tokens,
|
|
189
|
+
"completion_tokens": completion_tokens,
|
|
190
|
+
"total_tokens": total_tokens,
|
|
191
|
+
"cost": round(total_cost, 6),
|
|
192
|
+
"raw_response": resp,
|
|
193
|
+
"model_name": model,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
choice = resp["choices"][0]
|
|
197
|
+
text = choice["message"].get("content") or ""
|
|
198
|
+
stop_reason = choice.get("finish_reason")
|
|
199
|
+
|
|
200
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
201
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
202
|
+
try:
|
|
203
|
+
args = json.loads(tc["function"]["arguments"])
|
|
204
|
+
except (json.JSONDecodeError, TypeError):
|
|
205
|
+
args = {}
|
|
206
|
+
tool_calls_out.append(
|
|
207
|
+
{
|
|
208
|
+
"id": tc["id"],
|
|
209
|
+
"name": tc["function"]["name"],
|
|
210
|
+
"arguments": args,
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return {
|
|
215
|
+
"text": text,
|
|
216
|
+
"meta": meta,
|
|
217
|
+
"tool_calls": tool_calls_out,
|
|
218
|
+
"stop_reason": stop_reason,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
# Streaming
|
|
223
|
+
# ------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
def generate_messages_stream(
|
|
226
|
+
self,
|
|
227
|
+
messages: list[dict[str, Any]],
|
|
228
|
+
options: dict[str, Any],
|
|
229
|
+
) -> Iterator[dict[str, Any]]:
|
|
230
|
+
"""Yield response chunks via ModelScope streaming API."""
|
|
231
|
+
if not self.api_key:
|
|
232
|
+
raise RuntimeError("ModelScope API key not found")
|
|
233
|
+
|
|
234
|
+
model = options.get("model", self.model)
|
|
235
|
+
model_config = self._get_model_config("modelscope", model)
|
|
236
|
+
tokens_param = model_config["tokens_param"]
|
|
237
|
+
supports_temperature = model_config["supports_temperature"]
|
|
238
|
+
|
|
239
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
240
|
+
|
|
241
|
+
data: dict[str, Any] = {
|
|
242
|
+
"model": model,
|
|
243
|
+
"messages": messages,
|
|
244
|
+
"stream": True,
|
|
245
|
+
"stream_options": {"include_usage": True},
|
|
246
|
+
}
|
|
247
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
248
|
+
|
|
249
|
+
if supports_temperature and "temperature" in opts:
|
|
250
|
+
data["temperature"] = opts["temperature"]
|
|
251
|
+
|
|
252
|
+
response = requests.post(
|
|
253
|
+
f"{self.base_url}/chat/completions",
|
|
254
|
+
headers=self.headers,
|
|
255
|
+
json=data,
|
|
256
|
+
stream=True,
|
|
257
|
+
timeout=120,
|
|
258
|
+
)
|
|
259
|
+
response.raise_for_status()
|
|
260
|
+
|
|
261
|
+
full_text = ""
|
|
262
|
+
prompt_tokens = 0
|
|
263
|
+
completion_tokens = 0
|
|
264
|
+
|
|
265
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
266
|
+
if not line or not line.startswith("data: "):
|
|
267
|
+
continue
|
|
268
|
+
payload = line[len("data: ") :]
|
|
269
|
+
if payload.strip() == "[DONE]":
|
|
270
|
+
break
|
|
271
|
+
try:
|
|
272
|
+
chunk = json.loads(payload)
|
|
273
|
+
except json.JSONDecodeError:
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
usage = chunk.get("usage")
|
|
277
|
+
if usage:
|
|
278
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
279
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
280
|
+
|
|
281
|
+
choices = chunk.get("choices", [])
|
|
282
|
+
if choices:
|
|
283
|
+
delta = choices[0].get("delta", {})
|
|
284
|
+
content = delta.get("content", "")
|
|
285
|
+
if content:
|
|
286
|
+
full_text += content
|
|
287
|
+
yield {"type": "delta", "text": content}
|
|
288
|
+
|
|
289
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
290
|
+
total_cost = self._calculate_cost("modelscope", model, prompt_tokens, completion_tokens)
|
|
291
|
+
|
|
292
|
+
yield {
|
|
293
|
+
"type": "done",
|
|
294
|
+
"text": full_text,
|
|
295
|
+
"meta": {
|
|
296
|
+
"prompt_tokens": prompt_tokens,
|
|
297
|
+
"completion_tokens": completion_tokens,
|
|
298
|
+
"total_tokens": total_tokens,
|
|
299
|
+
"cost": round(total_cost, 6),
|
|
300
|
+
"raw_response": {},
|
|
301
|
+
"model_name": model,
|
|
302
|
+
},
|
|
303
|
+
}
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""Moonshot AI (Kimi) driver implementation.
|
|
2
|
+
Requires the `requests` package. Uses MOONSHOT_API_KEY env var.
|
|
3
|
+
|
|
4
|
+
The Moonshot API is fully OpenAI-compatible (/v1/chat/completions).
|
|
5
|
+
All pricing comes from models.dev (provider: "moonshotai") — no hardcoded pricing.
|
|
6
|
+
|
|
7
|
+
Moonshot-specific constraints:
|
|
8
|
+
- Temperature clamped to [0, 1] (OpenAI allows [0, 2])
|
|
9
|
+
- tool_choice: "required" not supported — only "auto" or "none"
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
from collections.abc import Iterator
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import requests
|
|
18
|
+
|
|
19
|
+
from ..cost_mixin import CostMixin
|
|
20
|
+
from ..driver import Driver
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MoonshotDriver(CostMixin, Driver):
|
|
24
|
+
supports_json_mode = True
|
|
25
|
+
supports_json_schema = True
|
|
26
|
+
supports_tool_use = True
|
|
27
|
+
supports_streaming = True
|
|
28
|
+
supports_vision = True
|
|
29
|
+
|
|
30
|
+
# All pricing resolved live from models.dev (provider: "moonshotai")
|
|
31
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
api_key: str | None = None,
|
|
36
|
+
model: str = "kimi-k2-0905-preview",
|
|
37
|
+
endpoint: str = "https://api.moonshot.ai/v1",
|
|
38
|
+
):
|
|
39
|
+
"""Initialize Moonshot driver.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
api_key: Moonshot API key. If not provided, will look for MOONSHOT_API_KEY env var.
|
|
43
|
+
model: Model to use. Defaults to kimi-k2-0905-preview.
|
|
44
|
+
endpoint: API base URL. Defaults to https://api.moonshot.ai/v1.
|
|
45
|
+
Use https://api.moonshot.cn/v1 for the China endpoint.
|
|
46
|
+
"""
|
|
47
|
+
self.api_key = api_key or os.getenv("MOONSHOT_API_KEY")
|
|
48
|
+
if not self.api_key:
|
|
49
|
+
raise ValueError("Moonshot API key not found. Set MOONSHOT_API_KEY env var.")
|
|
50
|
+
|
|
51
|
+
self.model = model
|
|
52
|
+
self.base_url = endpoint.rstrip("/")
|
|
53
|
+
|
|
54
|
+
self.headers = {
|
|
55
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
56
|
+
"Content-Type": "application/json",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
supports_messages = True
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _clamp_temperature(opts: dict[str, Any]) -> dict[str, Any]:
|
|
63
|
+
"""Clamp temperature to Moonshot's supported range [0, 1]."""
|
|
64
|
+
if "temperature" in opts:
|
|
65
|
+
opts["temperature"] = max(0.0, min(1.0, float(opts["temperature"])))
|
|
66
|
+
return opts
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _sanitize_tool_choice(data: dict[str, Any]) -> dict[str, Any]:
|
|
70
|
+
"""Downgrade tool_choice='required' to 'auto' (unsupported by Moonshot)."""
|
|
71
|
+
if data.get("tool_choice") == "required":
|
|
72
|
+
data["tool_choice"] = "auto"
|
|
73
|
+
return data
|
|
74
|
+
|
|
75
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
76
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
77
|
+
|
|
78
|
+
return _prepare_openai_vision_messages(messages)
|
|
79
|
+
|
|
80
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
81
|
+
messages = [{"role": "user", "content": prompt}]
|
|
82
|
+
return self._do_generate(messages, options)
|
|
83
|
+
|
|
84
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
85
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
86
|
+
|
|
87
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
+
if not self.api_key:
|
|
89
|
+
raise RuntimeError("Moonshot API key not found")
|
|
90
|
+
|
|
91
|
+
model = options.get("model", self.model)
|
|
92
|
+
|
|
93
|
+
model_config = self._get_model_config("moonshot", model)
|
|
94
|
+
tokens_param = model_config["tokens_param"]
|
|
95
|
+
supports_temperature = model_config["supports_temperature"]
|
|
96
|
+
|
|
97
|
+
self._validate_model_capabilities(
|
|
98
|
+
"moonshot",
|
|
99
|
+
model,
|
|
100
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
104
|
+
opts = self._clamp_temperature(opts)
|
|
105
|
+
|
|
106
|
+
data: dict[str, Any] = {
|
|
107
|
+
"model": model,
|
|
108
|
+
"messages": messages,
|
|
109
|
+
}
|
|
110
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
111
|
+
|
|
112
|
+
if supports_temperature and "temperature" in opts:
|
|
113
|
+
data["temperature"] = opts["temperature"]
|
|
114
|
+
|
|
115
|
+
# Native JSON mode support
|
|
116
|
+
if options.get("json_mode"):
|
|
117
|
+
json_schema = options.get("json_schema")
|
|
118
|
+
if json_schema:
|
|
119
|
+
data["response_format"] = {
|
|
120
|
+
"type": "json_schema",
|
|
121
|
+
"json_schema": {
|
|
122
|
+
"name": "extraction",
|
|
123
|
+
"strict": True,
|
|
124
|
+
"schema": json_schema,
|
|
125
|
+
},
|
|
126
|
+
}
|
|
127
|
+
else:
|
|
128
|
+
data["response_format"] = {"type": "json_object"}
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
response = requests.post(
|
|
132
|
+
f"{self.base_url}/chat/completions",
|
|
133
|
+
headers=self.headers,
|
|
134
|
+
json=data,
|
|
135
|
+
timeout=120,
|
|
136
|
+
)
|
|
137
|
+
response.raise_for_status()
|
|
138
|
+
resp = response.json()
|
|
139
|
+
except requests.exceptions.HTTPError as e:
|
|
140
|
+
error_msg = f"Moonshot API request failed: {e!s}"
|
|
141
|
+
raise RuntimeError(error_msg) from e
|
|
142
|
+
except requests.exceptions.RequestException as e:
|
|
143
|
+
raise RuntimeError(f"Moonshot API request failed: {e!s}") from e
|
|
144
|
+
|
|
145
|
+
usage = resp.get("usage", {})
|
|
146
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
147
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
148
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
149
|
+
|
|
150
|
+
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
151
|
+
|
|
152
|
+
meta = {
|
|
153
|
+
"prompt_tokens": prompt_tokens,
|
|
154
|
+
"completion_tokens": completion_tokens,
|
|
155
|
+
"total_tokens": total_tokens,
|
|
156
|
+
"cost": round(total_cost, 6),
|
|
157
|
+
"raw_response": resp,
|
|
158
|
+
"model_name": model,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
text = resp["choices"][0]["message"]["content"]
|
|
162
|
+
return {"text": text, "meta": meta}
|
|
163
|
+
|
|
164
|
+
# ------------------------------------------------------------------
|
|
165
|
+
# Tool use
|
|
166
|
+
# ------------------------------------------------------------------
|
|
167
|
+
|
|
168
|
+
def generate_messages_with_tools(
|
|
169
|
+
self,
|
|
170
|
+
messages: list[dict[str, Any]],
|
|
171
|
+
tools: list[dict[str, Any]],
|
|
172
|
+
options: dict[str, Any],
|
|
173
|
+
) -> dict[str, Any]:
|
|
174
|
+
"""Generate a response that may include tool calls."""
|
|
175
|
+
if not self.api_key:
|
|
176
|
+
raise RuntimeError("Moonshot API key not found")
|
|
177
|
+
|
|
178
|
+
model = options.get("model", self.model)
|
|
179
|
+
model_config = self._get_model_config("moonshot", model)
|
|
180
|
+
tokens_param = model_config["tokens_param"]
|
|
181
|
+
supports_temperature = model_config["supports_temperature"]
|
|
182
|
+
|
|
183
|
+
self._validate_model_capabilities("moonshot", model, using_tool_use=True)
|
|
184
|
+
|
|
185
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
186
|
+
opts = self._clamp_temperature(opts)
|
|
187
|
+
|
|
188
|
+
data: dict[str, Any] = {
|
|
189
|
+
"model": model,
|
|
190
|
+
"messages": messages,
|
|
191
|
+
"tools": tools,
|
|
192
|
+
}
|
|
193
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
194
|
+
|
|
195
|
+
if supports_temperature and "temperature" in opts:
|
|
196
|
+
data["temperature"] = opts["temperature"]
|
|
197
|
+
|
|
198
|
+
if "tool_choice" in options:
|
|
199
|
+
data["tool_choice"] = options["tool_choice"]
|
|
200
|
+
|
|
201
|
+
data = self._sanitize_tool_choice(data)
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
response = requests.post(
|
|
205
|
+
f"{self.base_url}/chat/completions",
|
|
206
|
+
headers=self.headers,
|
|
207
|
+
json=data,
|
|
208
|
+
timeout=120,
|
|
209
|
+
)
|
|
210
|
+
response.raise_for_status()
|
|
211
|
+
resp = response.json()
|
|
212
|
+
except requests.exceptions.HTTPError as e:
|
|
213
|
+
error_msg = f"Moonshot API request failed: {e!s}"
|
|
214
|
+
raise RuntimeError(error_msg) from e
|
|
215
|
+
except requests.exceptions.RequestException as e:
|
|
216
|
+
raise RuntimeError(f"Moonshot API request failed: {e!s}") from e
|
|
217
|
+
|
|
218
|
+
usage = resp.get("usage", {})
|
|
219
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
220
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
221
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
222
|
+
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
223
|
+
|
|
224
|
+
meta = {
|
|
225
|
+
"prompt_tokens": prompt_tokens,
|
|
226
|
+
"completion_tokens": completion_tokens,
|
|
227
|
+
"total_tokens": total_tokens,
|
|
228
|
+
"cost": round(total_cost, 6),
|
|
229
|
+
"raw_response": resp,
|
|
230
|
+
"model_name": model,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
choice = resp["choices"][0]
|
|
234
|
+
text = choice["message"].get("content") or ""
|
|
235
|
+
stop_reason = choice.get("finish_reason")
|
|
236
|
+
|
|
237
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
238
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
239
|
+
try:
|
|
240
|
+
args = json.loads(tc["function"]["arguments"])
|
|
241
|
+
except (json.JSONDecodeError, TypeError):
|
|
242
|
+
args = {}
|
|
243
|
+
tool_calls_out.append(
|
|
244
|
+
{
|
|
245
|
+
"id": tc["id"],
|
|
246
|
+
"name": tc["function"]["name"],
|
|
247
|
+
"arguments": args,
|
|
248
|
+
}
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return {
|
|
252
|
+
"text": text,
|
|
253
|
+
"meta": meta,
|
|
254
|
+
"tool_calls": tool_calls_out,
|
|
255
|
+
"stop_reason": stop_reason,
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
# ------------------------------------------------------------------
|
|
259
|
+
# Streaming
|
|
260
|
+
# ------------------------------------------------------------------
|
|
261
|
+
|
|
262
|
+
def generate_messages_stream(
|
|
263
|
+
self,
|
|
264
|
+
messages: list[dict[str, Any]],
|
|
265
|
+
options: dict[str, Any],
|
|
266
|
+
) -> Iterator[dict[str, Any]]:
|
|
267
|
+
"""Yield response chunks via Moonshot streaming API."""
|
|
268
|
+
if not self.api_key:
|
|
269
|
+
raise RuntimeError("Moonshot API key not found")
|
|
270
|
+
|
|
271
|
+
model = options.get("model", self.model)
|
|
272
|
+
model_config = self._get_model_config("moonshot", model)
|
|
273
|
+
tokens_param = model_config["tokens_param"]
|
|
274
|
+
supports_temperature = model_config["supports_temperature"]
|
|
275
|
+
|
|
276
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
277
|
+
opts = self._clamp_temperature(opts)
|
|
278
|
+
|
|
279
|
+
data: dict[str, Any] = {
|
|
280
|
+
"model": model,
|
|
281
|
+
"messages": messages,
|
|
282
|
+
"stream": True,
|
|
283
|
+
"stream_options": {"include_usage": True},
|
|
284
|
+
}
|
|
285
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
286
|
+
|
|
287
|
+
if supports_temperature and "temperature" in opts:
|
|
288
|
+
data["temperature"] = opts["temperature"]
|
|
289
|
+
|
|
290
|
+
response = requests.post(
|
|
291
|
+
f"{self.base_url}/chat/completions",
|
|
292
|
+
headers=self.headers,
|
|
293
|
+
json=data,
|
|
294
|
+
stream=True,
|
|
295
|
+
timeout=120,
|
|
296
|
+
)
|
|
297
|
+
response.raise_for_status()
|
|
298
|
+
|
|
299
|
+
full_text = ""
|
|
300
|
+
prompt_tokens = 0
|
|
301
|
+
completion_tokens = 0
|
|
302
|
+
|
|
303
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
304
|
+
if not line or not line.startswith("data: "):
|
|
305
|
+
continue
|
|
306
|
+
payload = line[len("data: ") :]
|
|
307
|
+
if payload.strip() == "[DONE]":
|
|
308
|
+
break
|
|
309
|
+
try:
|
|
310
|
+
chunk = json.loads(payload)
|
|
311
|
+
except json.JSONDecodeError:
|
|
312
|
+
continue
|
|
313
|
+
|
|
314
|
+
usage = chunk.get("usage")
|
|
315
|
+
if usage:
|
|
316
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
317
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
318
|
+
|
|
319
|
+
choices = chunk.get("choices", [])
|
|
320
|
+
if choices:
|
|
321
|
+
delta = choices[0].get("delta", {})
|
|
322
|
+
content = delta.get("content", "")
|
|
323
|
+
if content:
|
|
324
|
+
full_text += content
|
|
325
|
+
yield {"type": "delta", "text": content}
|
|
326
|
+
|
|
327
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
328
|
+
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
329
|
+
|
|
330
|
+
yield {
|
|
331
|
+
"type": "done",
|
|
332
|
+
"text": full_text,
|
|
333
|
+
"meta": {
|
|
334
|
+
"prompt_tokens": prompt_tokens,
|
|
335
|
+
"completion_tokens": completion_tokens,
|
|
336
|
+
"total_tokens": total_tokens,
|
|
337
|
+
"cost": round(total_cost, 6),
|
|
338
|
+
"raw_response": {},
|
|
339
|
+
"model_name": model,
|
|
340
|
+
},
|
|
341
|
+
}
|