prompture 0.0.46__py3-none-any.whl → 0.0.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/async_conversation.py +87 -2
- prompture/conversation.py +87 -2
- prompture/drivers/async_azure_driver.py +77 -0
- prompture/drivers/async_grok_driver.py +106 -2
- prompture/drivers/async_groq_driver.py +92 -2
- prompture/drivers/async_lmstudio_driver.py +10 -2
- prompture/drivers/async_moonshot_driver.py +32 -12
- prompture/drivers/async_ollama_driver.py +85 -0
- prompture/drivers/async_openrouter_driver.py +43 -17
- prompture/drivers/azure_driver.py +77 -0
- prompture/drivers/grok_driver.py +101 -2
- prompture/drivers/groq_driver.py +92 -2
- prompture/drivers/lmstudio_driver.py +11 -2
- prompture/drivers/moonshot_driver.py +32 -12
- prompture/drivers/ollama_driver.py +91 -0
- prompture/drivers/openrouter_driver.py +34 -10
- prompture/simulated_tools.py +115 -0
- prompture/tools_schema.py +22 -0
- {prompture-0.0.46.dist-info → prompture-0.0.47.dist-info}/METADATA +35 -2
- {prompture-0.0.46.dist-info → prompture-0.0.47.dist-info}/RECORD +25 -24
- {prompture-0.0.46.dist-info → prompture-0.0.47.dist-info}/WHEEL +0 -0
- {prompture-0.0.46.dist-info → prompture-0.0.47.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.46.dist-info → prompture-0.0.47.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.46.dist-info → prompture-0.0.47.dist-info}/top_level.txt +0 -0
|
@@ -98,7 +98,12 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
98
98
|
if "choices" not in response_data or not response_data["choices"]:
|
|
99
99
|
raise ValueError(f"Unexpected response format: {response_data}")
|
|
100
100
|
|
|
101
|
-
|
|
101
|
+
message = response_data["choices"][0]["message"]
|
|
102
|
+
text = message.get("content") or ""
|
|
103
|
+
reasoning_content = message.get("reasoning_content")
|
|
104
|
+
|
|
105
|
+
if not text and reasoning_content:
|
|
106
|
+
text = reasoning_content
|
|
102
107
|
|
|
103
108
|
usage = response_data.get("usage", {})
|
|
104
109
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
@@ -114,7 +119,10 @@ class AsyncLMStudioDriver(AsyncDriver):
|
|
|
114
119
|
"model_name": merged_options.get("model", self.model),
|
|
115
120
|
}
|
|
116
121
|
|
|
117
|
-
|
|
122
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
123
|
+
if reasoning_content is not None:
|
|
124
|
+
result["reasoning_content"] = reasoning_content
|
|
125
|
+
return result
|
|
118
126
|
|
|
119
127
|
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
120
128
|
|
|
@@ -138,10 +138,11 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
138
138
|
|
|
139
139
|
message = resp["choices"][0]["message"]
|
|
140
140
|
text = message.get("content") or ""
|
|
141
|
+
reasoning_content = message.get("reasoning_content")
|
|
141
142
|
|
|
142
143
|
# Reasoning models may return content in reasoning_content when content is empty
|
|
143
|
-
if not text and
|
|
144
|
-
text =
|
|
144
|
+
if not text and reasoning_content:
|
|
145
|
+
text = reasoning_content
|
|
145
146
|
|
|
146
147
|
# Structured output fallback: if we used json_schema mode and got an
|
|
147
148
|
# empty response, retry with json_object mode and schema in the prompt.
|
|
@@ -184,8 +185,9 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
184
185
|
resp = fb_resp
|
|
185
186
|
fb_message = fb_resp["choices"][0]["message"]
|
|
186
187
|
text = fb_message.get("content") or ""
|
|
187
|
-
|
|
188
|
-
|
|
188
|
+
reasoning_content = fb_message.get("reasoning_content")
|
|
189
|
+
if not text and reasoning_content:
|
|
190
|
+
text = reasoning_content
|
|
189
191
|
|
|
190
192
|
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
191
193
|
|
|
@@ -198,7 +200,10 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
198
200
|
"model_name": model,
|
|
199
201
|
}
|
|
200
202
|
|
|
201
|
-
|
|
203
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
204
|
+
if reasoning_content is not None:
|
|
205
|
+
result["reasoning_content"] = reasoning_content
|
|
206
|
+
return result
|
|
202
207
|
|
|
203
208
|
# ------------------------------------------------------------------
|
|
204
209
|
# Tool use
|
|
@@ -271,11 +276,12 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
271
276
|
}
|
|
272
277
|
|
|
273
278
|
choice = resp["choices"][0]
|
|
274
|
-
|
|
279
|
+
message = choice["message"]
|
|
280
|
+
text = message.get("content") or ""
|
|
275
281
|
stop_reason = choice.get("finish_reason")
|
|
276
282
|
|
|
277
283
|
tool_calls_out: list[dict[str, Any]] = []
|
|
278
|
-
for tc in
|
|
284
|
+
for tc in message.get("tool_calls", []):
|
|
279
285
|
try:
|
|
280
286
|
args = json.loads(tc["function"]["arguments"])
|
|
281
287
|
except (json.JSONDecodeError, TypeError):
|
|
@@ -288,13 +294,21 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
288
294
|
}
|
|
289
295
|
)
|
|
290
296
|
|
|
291
|
-
|
|
297
|
+
result: dict[str, Any] = {
|
|
292
298
|
"text": text,
|
|
293
299
|
"meta": meta,
|
|
294
300
|
"tool_calls": tool_calls_out,
|
|
295
301
|
"stop_reason": stop_reason,
|
|
296
302
|
}
|
|
297
303
|
|
|
304
|
+
# Preserve reasoning_content for reasoning models so the
|
|
305
|
+
# conversation loop can include it when sending the assistant
|
|
306
|
+
# message back (Moonshot requires it on subsequent requests).
|
|
307
|
+
if message.get("reasoning_content") is not None:
|
|
308
|
+
result["reasoning_content"] = message["reasoning_content"]
|
|
309
|
+
|
|
310
|
+
return result
|
|
311
|
+
|
|
298
312
|
# ------------------------------------------------------------------
|
|
299
313
|
# Streaming
|
|
300
314
|
# ------------------------------------------------------------------
|
|
@@ -325,6 +339,7 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
325
339
|
data["temperature"] = opts["temperature"]
|
|
326
340
|
|
|
327
341
|
full_text = ""
|
|
342
|
+
full_reasoning = ""
|
|
328
343
|
prompt_tokens = 0
|
|
329
344
|
completion_tokens = 0
|
|
330
345
|
|
|
@@ -359,9 +374,11 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
359
374
|
if choices:
|
|
360
375
|
delta = choices[0].get("delta", {})
|
|
361
376
|
content = delta.get("content") or ""
|
|
362
|
-
|
|
363
|
-
if
|
|
364
|
-
|
|
377
|
+
reasoning_chunk = delta.get("reasoning_content") or ""
|
|
378
|
+
if reasoning_chunk:
|
|
379
|
+
full_reasoning += reasoning_chunk
|
|
380
|
+
if not content and reasoning_chunk:
|
|
381
|
+
content = reasoning_chunk
|
|
365
382
|
if content:
|
|
366
383
|
full_text += content
|
|
367
384
|
yield {"type": "delta", "text": content}
|
|
@@ -369,7 +386,7 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
369
386
|
total_tokens = prompt_tokens + completion_tokens
|
|
370
387
|
total_cost = self._calculate_cost("moonshot", model, prompt_tokens, completion_tokens)
|
|
371
388
|
|
|
372
|
-
|
|
389
|
+
done_chunk: dict[str, Any] = {
|
|
373
390
|
"type": "done",
|
|
374
391
|
"text": full_text,
|
|
375
392
|
"meta": {
|
|
@@ -381,3 +398,6 @@ class AsyncMoonshotDriver(CostMixin, AsyncDriver):
|
|
|
381
398
|
"model_name": model,
|
|
382
399
|
},
|
|
383
400
|
}
|
|
401
|
+
if full_reasoning:
|
|
402
|
+
done_chunk["reasoning_content"] = full_reasoning
|
|
403
|
+
yield done_chunk
|
|
@@ -2,8 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
8
|
+
import uuid
|
|
7
9
|
from typing import Any
|
|
8
10
|
|
|
9
11
|
import httpx
|
|
@@ -16,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|
|
16
18
|
class AsyncOllamaDriver(AsyncDriver):
|
|
17
19
|
supports_json_mode = True
|
|
18
20
|
supports_json_schema = True
|
|
21
|
+
supports_tool_use = True
|
|
19
22
|
supports_vision = True
|
|
20
23
|
|
|
21
24
|
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
@@ -80,6 +83,88 @@ class AsyncOllamaDriver(AsyncDriver):
|
|
|
80
83
|
|
|
81
84
|
return {"text": response_data.get("response", ""), "meta": meta}
|
|
82
85
|
|
|
86
|
+
# ------------------------------------------------------------------
|
|
87
|
+
# Tool use
|
|
88
|
+
# ------------------------------------------------------------------
|
|
89
|
+
|
|
90
|
+
async def generate_messages_with_tools(
|
|
91
|
+
self,
|
|
92
|
+
messages: list[dict[str, Any]],
|
|
93
|
+
tools: list[dict[str, Any]],
|
|
94
|
+
options: dict[str, Any],
|
|
95
|
+
) -> dict[str, Any]:
|
|
96
|
+
"""Generate a response that may include tool calls via Ollama's /api/chat endpoint."""
|
|
97
|
+
merged_options = self.options.copy()
|
|
98
|
+
if options:
|
|
99
|
+
merged_options.update(options)
|
|
100
|
+
|
|
101
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
102
|
+
|
|
103
|
+
payload: dict[str, Any] = {
|
|
104
|
+
"model": merged_options.get("model", self.model),
|
|
105
|
+
"messages": messages,
|
|
106
|
+
"tools": tools,
|
|
107
|
+
"stream": False,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
if "temperature" in merged_options:
|
|
111
|
+
payload["temperature"] = merged_options["temperature"]
|
|
112
|
+
if "top_p" in merged_options:
|
|
113
|
+
payload["top_p"] = merged_options["top_p"]
|
|
114
|
+
if "top_k" in merged_options:
|
|
115
|
+
payload["top_k"] = merged_options["top_k"]
|
|
116
|
+
|
|
117
|
+
async with httpx.AsyncClient() as client:
|
|
118
|
+
try:
|
|
119
|
+
r = await client.post(chat_endpoint, json=payload, timeout=120)
|
|
120
|
+
r.raise_for_status()
|
|
121
|
+
response_data = r.json()
|
|
122
|
+
except httpx.HTTPStatusError as e:
|
|
123
|
+
raise RuntimeError(f"Ollama tool use request failed: {e}") from e
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise RuntimeError(f"Ollama tool use request failed: {e}") from e
|
|
126
|
+
|
|
127
|
+
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
128
|
+
completion_tokens = response_data.get("eval_count", 0)
|
|
129
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
130
|
+
|
|
131
|
+
meta = {
|
|
132
|
+
"prompt_tokens": prompt_tokens,
|
|
133
|
+
"completion_tokens": completion_tokens,
|
|
134
|
+
"total_tokens": total_tokens,
|
|
135
|
+
"cost": 0.0,
|
|
136
|
+
"raw_response": response_data,
|
|
137
|
+
"model_name": merged_options.get("model", self.model),
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
message = response_data.get("message", {})
|
|
141
|
+
text = message.get("content") or ""
|
|
142
|
+
stop_reason = response_data.get("done_reason", "stop")
|
|
143
|
+
|
|
144
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
145
|
+
for tc in message.get("tool_calls", []):
|
|
146
|
+
func = tc.get("function", {})
|
|
147
|
+
# Ollama returns arguments as a dict already (no JSON string parsing needed)
|
|
148
|
+
args = func.get("arguments", {})
|
|
149
|
+
if isinstance(args, str):
|
|
150
|
+
try:
|
|
151
|
+
args = json.loads(args)
|
|
152
|
+
except (json.JSONDecodeError, TypeError):
|
|
153
|
+
args = {}
|
|
154
|
+
tool_calls_out.append({
|
|
155
|
+
# Ollama does not return tool_call IDs — generate one locally
|
|
156
|
+
"id": f"call_{uuid.uuid4().hex[:24]}",
|
|
157
|
+
"name": func.get("name", ""),
|
|
158
|
+
"arguments": args,
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
return {
|
|
162
|
+
"text": text,
|
|
163
|
+
"meta": meta,
|
|
164
|
+
"tool_calls": tool_calls_out,
|
|
165
|
+
"stop_reason": stop_reason,
|
|
166
|
+
}
|
|
167
|
+
|
|
83
168
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
84
169
|
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
85
170
|
messages = self._prepare_messages(messages)
|
|
@@ -122,8 +122,17 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
122
122
|
"model_name": model,
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
-
|
|
126
|
-
|
|
125
|
+
message = resp["choices"][0]["message"]
|
|
126
|
+
text = message.get("content") or ""
|
|
127
|
+
reasoning_content = message.get("reasoning_content")
|
|
128
|
+
|
|
129
|
+
if not text and reasoning_content:
|
|
130
|
+
text = reasoning_content
|
|
131
|
+
|
|
132
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
133
|
+
if reasoning_content is not None:
|
|
134
|
+
result["reasoning_content"] = reasoning_content
|
|
135
|
+
return result
|
|
127
136
|
|
|
128
137
|
# ------------------------------------------------------------------
|
|
129
138
|
# Tool use
|
|
@@ -196,18 +205,23 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
196
205
|
args = json.loads(tc["function"]["arguments"])
|
|
197
206
|
except (json.JSONDecodeError, TypeError):
|
|
198
207
|
args = {}
|
|
199
|
-
tool_calls_out.append(
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
208
|
+
tool_calls_out.append(
|
|
209
|
+
{
|
|
210
|
+
"id": tc["id"],
|
|
211
|
+
"name": tc["function"]["name"],
|
|
212
|
+
"arguments": args,
|
|
213
|
+
}
|
|
214
|
+
)
|
|
204
215
|
|
|
205
|
-
|
|
216
|
+
result: dict[str, Any] = {
|
|
206
217
|
"text": text,
|
|
207
218
|
"meta": meta,
|
|
208
219
|
"tool_calls": tool_calls_out,
|
|
209
220
|
"stop_reason": stop_reason,
|
|
210
221
|
}
|
|
222
|
+
if choice["message"].get("reasoning_content") is not None:
|
|
223
|
+
result["reasoning_content"] = choice["message"]["reasoning_content"]
|
|
224
|
+
return result
|
|
211
225
|
|
|
212
226
|
# ------------------------------------------------------------------
|
|
213
227
|
# Streaming
|
|
@@ -238,21 +252,25 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
238
252
|
data["temperature"] = opts["temperature"]
|
|
239
253
|
|
|
240
254
|
full_text = ""
|
|
255
|
+
full_reasoning = ""
|
|
241
256
|
prompt_tokens = 0
|
|
242
257
|
completion_tokens = 0
|
|
243
258
|
|
|
244
|
-
async with
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
259
|
+
async with (
|
|
260
|
+
httpx.AsyncClient() as client,
|
|
261
|
+
client.stream(
|
|
262
|
+
"POST",
|
|
263
|
+
f"{self.base_url}/chat/completions",
|
|
264
|
+
headers=self.headers,
|
|
265
|
+
json=data,
|
|
266
|
+
timeout=120,
|
|
267
|
+
) as response,
|
|
268
|
+
):
|
|
251
269
|
response.raise_for_status()
|
|
252
270
|
async for line in response.aiter_lines():
|
|
253
271
|
if not line or not line.startswith("data: "):
|
|
254
272
|
continue
|
|
255
|
-
payload = line[len("data: "):]
|
|
273
|
+
payload = line[len("data: ") :]
|
|
256
274
|
if payload.strip() == "[DONE]":
|
|
257
275
|
break
|
|
258
276
|
try:
|
|
@@ -270,6 +288,11 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
270
288
|
if choices:
|
|
271
289
|
delta = choices[0].get("delta", {})
|
|
272
290
|
content = delta.get("content", "")
|
|
291
|
+
reasoning_chunk = delta.get("reasoning_content") or ""
|
|
292
|
+
if reasoning_chunk:
|
|
293
|
+
full_reasoning += reasoning_chunk
|
|
294
|
+
if not content and reasoning_chunk:
|
|
295
|
+
content = reasoning_chunk
|
|
273
296
|
if content:
|
|
274
297
|
full_text += content
|
|
275
298
|
yield {"type": "delta", "text": content}
|
|
@@ -277,7 +300,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
277
300
|
total_tokens = prompt_tokens + completion_tokens
|
|
278
301
|
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
279
302
|
|
|
280
|
-
|
|
303
|
+
done_chunk: dict[str, Any] = {
|
|
281
304
|
"type": "done",
|
|
282
305
|
"text": full_text,
|
|
283
306
|
"meta": {
|
|
@@ -289,3 +312,6 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
289
312
|
"model_name": model,
|
|
290
313
|
},
|
|
291
314
|
}
|
|
315
|
+
if full_reasoning:
|
|
316
|
+
done_chunk["reasoning_content"] = full_reasoning
|
|
317
|
+
yield done_chunk
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Requires the `openai` package.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
@@ -17,6 +18,7 @@ from ..driver import Driver
|
|
|
17
18
|
class AzureDriver(CostMixin, Driver):
|
|
18
19
|
supports_json_mode = True
|
|
19
20
|
supports_json_schema = True
|
|
21
|
+
supports_tool_use = True
|
|
20
22
|
supports_vision = True
|
|
21
23
|
|
|
22
24
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
@@ -164,3 +166,78 @@ class AzureDriver(CostMixin, Driver):
|
|
|
164
166
|
|
|
165
167
|
text = resp.choices[0].message.content
|
|
166
168
|
return {"text": text, "meta": meta}
|
|
169
|
+
|
|
170
|
+
# ------------------------------------------------------------------
|
|
171
|
+
# Tool use
|
|
172
|
+
# ------------------------------------------------------------------
|
|
173
|
+
|
|
174
|
+
def generate_messages_with_tools(
|
|
175
|
+
self,
|
|
176
|
+
messages: list[dict[str, Any]],
|
|
177
|
+
tools: list[dict[str, Any]],
|
|
178
|
+
options: dict[str, Any],
|
|
179
|
+
) -> dict[str, Any]:
|
|
180
|
+
"""Generate a response that may include tool calls."""
|
|
181
|
+
if self.client is None:
|
|
182
|
+
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
183
|
+
|
|
184
|
+
model = options.get("model", self.model)
|
|
185
|
+
model_config = self._get_model_config("azure", model)
|
|
186
|
+
tokens_param = model_config["tokens_param"]
|
|
187
|
+
supports_temperature = model_config["supports_temperature"]
|
|
188
|
+
|
|
189
|
+
self._validate_model_capabilities("azure", model, using_tool_use=True)
|
|
190
|
+
|
|
191
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
192
|
+
|
|
193
|
+
kwargs: dict[str, Any] = {
|
|
194
|
+
"model": self.deployment_id,
|
|
195
|
+
"messages": messages,
|
|
196
|
+
"tools": tools,
|
|
197
|
+
}
|
|
198
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
199
|
+
|
|
200
|
+
if supports_temperature and "temperature" in opts:
|
|
201
|
+
kwargs["temperature"] = opts["temperature"]
|
|
202
|
+
|
|
203
|
+
resp = self.client.chat.completions.create(**kwargs)
|
|
204
|
+
|
|
205
|
+
usage = getattr(resp, "usage", None)
|
|
206
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
207
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
208
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
209
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
210
|
+
|
|
211
|
+
meta = {
|
|
212
|
+
"prompt_tokens": prompt_tokens,
|
|
213
|
+
"completion_tokens": completion_tokens,
|
|
214
|
+
"total_tokens": total_tokens,
|
|
215
|
+
"cost": round(total_cost, 6),
|
|
216
|
+
"raw_response": resp.model_dump(),
|
|
217
|
+
"model_name": model,
|
|
218
|
+
"deployment_id": self.deployment_id,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
choice = resp.choices[0]
|
|
222
|
+
text = choice.message.content or ""
|
|
223
|
+
stop_reason = choice.finish_reason
|
|
224
|
+
|
|
225
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
226
|
+
if choice.message.tool_calls:
|
|
227
|
+
for tc in choice.message.tool_calls:
|
|
228
|
+
try:
|
|
229
|
+
args = json.loads(tc.function.arguments)
|
|
230
|
+
except (json.JSONDecodeError, TypeError):
|
|
231
|
+
args = {}
|
|
232
|
+
tool_calls_out.append({
|
|
233
|
+
"id": tc.id,
|
|
234
|
+
"name": tc.function.name,
|
|
235
|
+
"arguments": args,
|
|
236
|
+
})
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
"text": text,
|
|
240
|
+
"meta": meta,
|
|
241
|
+
"tool_calls": tool_calls_out,
|
|
242
|
+
"stop_reason": stop_reason,
|
|
243
|
+
}
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Requires the `requests` package. Uses GROK_API_KEY env var.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
@@ -13,6 +14,7 @@ from ..driver import Driver
|
|
|
13
14
|
|
|
14
15
|
class GrokDriver(CostMixin, Driver):
|
|
15
16
|
supports_json_mode = True
|
|
17
|
+
supports_tool_use = True
|
|
16
18
|
supports_vision = True
|
|
17
19
|
|
|
18
20
|
# Pricing per 1M tokens based on xAI's documentation
|
|
@@ -152,5 +154,102 @@ class GrokDriver(CostMixin, Driver):
|
|
|
152
154
|
"model_name": model,
|
|
153
155
|
}
|
|
154
156
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
+
message = resp["choices"][0]["message"]
|
|
158
|
+
text = message.get("content") or ""
|
|
159
|
+
reasoning_content = message.get("reasoning_content")
|
|
160
|
+
|
|
161
|
+
if not text and reasoning_content:
|
|
162
|
+
text = reasoning_content
|
|
163
|
+
|
|
164
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
165
|
+
if reasoning_content is not None:
|
|
166
|
+
result["reasoning_content"] = reasoning_content
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
# ------------------------------------------------------------------
|
|
170
|
+
# Tool use
|
|
171
|
+
# ------------------------------------------------------------------
|
|
172
|
+
|
|
173
|
+
def generate_messages_with_tools(
|
|
174
|
+
self,
|
|
175
|
+
messages: list[dict[str, Any]],
|
|
176
|
+
tools: list[dict[str, Any]],
|
|
177
|
+
options: dict[str, Any],
|
|
178
|
+
) -> dict[str, Any]:
|
|
179
|
+
"""Generate a response that may include tool calls."""
|
|
180
|
+
if not self.api_key:
|
|
181
|
+
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
182
|
+
|
|
183
|
+
model = options.get("model", self.model)
|
|
184
|
+
model_config = self._get_model_config("grok", model)
|
|
185
|
+
tokens_param = model_config["tokens_param"]
|
|
186
|
+
supports_temperature = model_config["supports_temperature"]
|
|
187
|
+
|
|
188
|
+
self._validate_model_capabilities("grok", model, using_tool_use=True)
|
|
189
|
+
|
|
190
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
191
|
+
|
|
192
|
+
payload: dict[str, Any] = {
|
|
193
|
+
"model": model,
|
|
194
|
+
"messages": messages,
|
|
195
|
+
"tools": tools,
|
|
196
|
+
}
|
|
197
|
+
payload[tokens_param] = opts.get("max_tokens", 512)
|
|
198
|
+
|
|
199
|
+
if supports_temperature and "temperature" in opts:
|
|
200
|
+
payload["temperature"] = opts["temperature"]
|
|
201
|
+
|
|
202
|
+
if "tool_choice" in options:
|
|
203
|
+
payload["tool_choice"] = options["tool_choice"]
|
|
204
|
+
|
|
205
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
|
|
209
|
+
response.raise_for_status()
|
|
210
|
+
resp = response.json()
|
|
211
|
+
except requests.exceptions.RequestException as e:
|
|
212
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
213
|
+
|
|
214
|
+
usage = resp.get("usage", {})
|
|
215
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
216
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
217
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
218
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
219
|
+
|
|
220
|
+
meta = {
|
|
221
|
+
"prompt_tokens": prompt_tokens,
|
|
222
|
+
"completion_tokens": completion_tokens,
|
|
223
|
+
"total_tokens": total_tokens,
|
|
224
|
+
"cost": round(total_cost, 6),
|
|
225
|
+
"raw_response": resp,
|
|
226
|
+
"model_name": model,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
choice = resp["choices"][0]
|
|
230
|
+
text = choice["message"].get("content") or ""
|
|
231
|
+
stop_reason = choice.get("finish_reason")
|
|
232
|
+
|
|
233
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
234
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
235
|
+
try:
|
|
236
|
+
args = json.loads(tc["function"]["arguments"])
|
|
237
|
+
except (json.JSONDecodeError, TypeError):
|
|
238
|
+
args = {}
|
|
239
|
+
tool_calls_out.append(
|
|
240
|
+
{
|
|
241
|
+
"id": tc["id"],
|
|
242
|
+
"name": tc["function"]["name"],
|
|
243
|
+
"arguments": args,
|
|
244
|
+
}
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
result: dict[str, Any] = {
|
|
248
|
+
"text": text,
|
|
249
|
+
"meta": meta,
|
|
250
|
+
"tool_calls": tool_calls_out,
|
|
251
|
+
"stop_reason": stop_reason,
|
|
252
|
+
}
|
|
253
|
+
if choice["message"].get("reasoning_content") is not None:
|
|
254
|
+
result["reasoning_content"] = choice["message"]["reasoning_content"]
|
|
255
|
+
return result
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Requires the `groq` package. Uses GROQ_API_KEY env var.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
@@ -16,6 +17,7 @@ from ..driver import Driver
|
|
|
16
17
|
|
|
17
18
|
class GroqDriver(CostMixin, Driver):
|
|
18
19
|
supports_json_mode = True
|
|
20
|
+
supports_tool_use = True
|
|
19
21
|
supports_vision = True
|
|
20
22
|
|
|
21
23
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
@@ -120,5 +122,93 @@ class GroqDriver(CostMixin, Driver):
|
|
|
120
122
|
}
|
|
121
123
|
|
|
122
124
|
# Extract generated text
|
|
123
|
-
text = resp.choices[0].message.content
|
|
124
|
-
|
|
125
|
+
text = resp.choices[0].message.content or ""
|
|
126
|
+
reasoning_content = getattr(resp.choices[0].message, "reasoning_content", None)
|
|
127
|
+
|
|
128
|
+
if not text and reasoning_content:
|
|
129
|
+
text = reasoning_content
|
|
130
|
+
|
|
131
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
132
|
+
if reasoning_content is not None:
|
|
133
|
+
result["reasoning_content"] = reasoning_content
|
|
134
|
+
return result
|
|
135
|
+
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
# Tool use
|
|
138
|
+
# ------------------------------------------------------------------
|
|
139
|
+
|
|
140
|
+
def generate_messages_with_tools(
|
|
141
|
+
self,
|
|
142
|
+
messages: list[dict[str, Any]],
|
|
143
|
+
tools: list[dict[str, Any]],
|
|
144
|
+
options: dict[str, Any],
|
|
145
|
+
) -> dict[str, Any]:
|
|
146
|
+
"""Generate a response that may include tool calls."""
|
|
147
|
+
if self.client is None:
|
|
148
|
+
raise RuntimeError("groq package is not installed")
|
|
149
|
+
|
|
150
|
+
model = options.get("model", self.model)
|
|
151
|
+
model_config = self._get_model_config("groq", model)
|
|
152
|
+
tokens_param = model_config["tokens_param"]
|
|
153
|
+
supports_temperature = model_config["supports_temperature"]
|
|
154
|
+
|
|
155
|
+
self._validate_model_capabilities("groq", model, using_tool_use=True)
|
|
156
|
+
|
|
157
|
+
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
158
|
+
|
|
159
|
+
kwargs: dict[str, Any] = {
|
|
160
|
+
"model": model,
|
|
161
|
+
"messages": messages,
|
|
162
|
+
"tools": tools,
|
|
163
|
+
}
|
|
164
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
165
|
+
|
|
166
|
+
if supports_temperature and "temperature" in opts:
|
|
167
|
+
kwargs["temperature"] = opts["temperature"]
|
|
168
|
+
|
|
169
|
+
resp = self.client.chat.completions.create(**kwargs)
|
|
170
|
+
|
|
171
|
+
usage = getattr(resp, "usage", None)
|
|
172
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
173
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
174
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
175
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
176
|
+
|
|
177
|
+
meta = {
|
|
178
|
+
"prompt_tokens": prompt_tokens,
|
|
179
|
+
"completion_tokens": completion_tokens,
|
|
180
|
+
"total_tokens": total_tokens,
|
|
181
|
+
"cost": round(total_cost, 6),
|
|
182
|
+
"raw_response": resp.model_dump(),
|
|
183
|
+
"model_name": model,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
choice = resp.choices[0]
|
|
187
|
+
text = choice.message.content or ""
|
|
188
|
+
stop_reason = choice.finish_reason
|
|
189
|
+
|
|
190
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
191
|
+
if choice.message.tool_calls:
|
|
192
|
+
for tc in choice.message.tool_calls:
|
|
193
|
+
try:
|
|
194
|
+
args = json.loads(tc.function.arguments)
|
|
195
|
+
except (json.JSONDecodeError, TypeError):
|
|
196
|
+
args = {}
|
|
197
|
+
tool_calls_out.append(
|
|
198
|
+
{
|
|
199
|
+
"id": tc.id,
|
|
200
|
+
"name": tc.function.name,
|
|
201
|
+
"arguments": args,
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
result: dict[str, Any] = {
|
|
206
|
+
"text": text,
|
|
207
|
+
"meta": meta,
|
|
208
|
+
"tool_calls": tool_calls_out,
|
|
209
|
+
"stop_reason": stop_reason,
|
|
210
|
+
}
|
|
211
|
+
reasoning_content = getattr(choice.message, "reasoning_content", None)
|
|
212
|
+
if reasoning_content is not None:
|
|
213
|
+
result["reasoning_content"] = reasoning_content
|
|
214
|
+
return result
|