prompture 0.0.46.dev1__py3-none-any.whl → 0.0.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/async_conversation.py +87 -2
- prompture/conversation.py +87 -2
- prompture/drivers/async_azure_driver.py +77 -0
- prompture/drivers/async_grok_driver.py +106 -2
- prompture/drivers/async_groq_driver.py +92 -2
- prompture/drivers/async_lmstudio_driver.py +10 -2
- prompture/drivers/async_moonshot_driver.py +32 -12
- prompture/drivers/async_ollama_driver.py +85 -0
- prompture/drivers/async_openrouter_driver.py +43 -17
- prompture/drivers/azure_driver.py +77 -0
- prompture/drivers/grok_driver.py +101 -2
- prompture/drivers/groq_driver.py +92 -2
- prompture/drivers/lmstudio_driver.py +11 -2
- prompture/drivers/moonshot_driver.py +32 -12
- prompture/drivers/ollama_driver.py +91 -0
- prompture/drivers/openrouter_driver.py +34 -10
- prompture/simulated_tools.py +115 -0
- prompture/tools_schema.py +22 -0
- {prompture-0.0.46.dev1.dist-info → prompture-0.0.47.dist-info}/METADATA +35 -2
- {prompture-0.0.46.dev1.dist-info → prompture-0.0.47.dist-info}/RECORD +25 -24
- {prompture-0.0.46.dev1.dist-info → prompture-0.0.47.dist-info}/WHEEL +0 -0
- {prompture-0.0.46.dev1.dist-info → prompture-0.0.47.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.46.dev1.dist-info → prompture-0.0.47.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.46.dev1.dist-info → prompture-0.0.47.dist-info}/top_level.txt +0 -0
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
31
|
+
__version__ = version = '0.0.47'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 47)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
prompture/async_conversation.py
CHANGED
|
@@ -55,6 +55,7 @@ class AsyncConversation:
|
|
|
55
55
|
callbacks: DriverCallbacks | None = None,
|
|
56
56
|
tools: ToolRegistry | None = None,
|
|
57
57
|
max_tool_rounds: int = 10,
|
|
58
|
+
simulated_tools: bool | Literal["auto"] = "auto",
|
|
58
59
|
conversation_id: str | None = None,
|
|
59
60
|
auto_save: str | Path | None = None,
|
|
60
61
|
tags: list[str] | None = None,
|
|
@@ -106,6 +107,10 @@ class AsyncConversation:
|
|
|
106
107
|
}
|
|
107
108
|
self._tools = tools or ToolRegistry()
|
|
108
109
|
self._max_tool_rounds = max_tool_rounds
|
|
110
|
+
self._simulated_tools = simulated_tools
|
|
111
|
+
|
|
112
|
+
# Reasoning content from last response
|
|
113
|
+
self._last_reasoning: str | None = None
|
|
109
114
|
|
|
110
115
|
# Persistence
|
|
111
116
|
self._conversation_id = conversation_id or str(uuid.uuid4())
|
|
@@ -119,6 +124,11 @@ class AsyncConversation:
|
|
|
119
124
|
# Public helpers
|
|
120
125
|
# ------------------------------------------------------------------
|
|
121
126
|
|
|
127
|
+
@property
|
|
128
|
+
def last_reasoning(self) -> str | None:
|
|
129
|
+
"""The reasoning/thinking content from the last LLM response, if any."""
|
|
130
|
+
return self._last_reasoning
|
|
131
|
+
|
|
122
132
|
@property
|
|
123
133
|
def messages(self) -> list[dict[str, Any]]:
|
|
124
134
|
"""Read-only view of the conversation history."""
|
|
@@ -324,8 +334,15 @@ class AsyncConversation:
|
|
|
324
334
|
If tools are registered and the driver supports tool use,
|
|
325
335
|
dispatches to the async tool execution loop.
|
|
326
336
|
"""
|
|
327
|
-
|
|
328
|
-
|
|
337
|
+
self._last_reasoning = None
|
|
338
|
+
|
|
339
|
+
# Route to appropriate tool handling
|
|
340
|
+
if self._tools:
|
|
341
|
+
use_native = getattr(self._driver, "supports_tool_use", False)
|
|
342
|
+
if self._simulated_tools is True or (self._simulated_tools == "auto" and not use_native):
|
|
343
|
+
return await self._ask_with_simulated_tools(content, options, images=images)
|
|
344
|
+
elif use_native and self._simulated_tools is not True:
|
|
345
|
+
return await self._ask_with_tools(content, options, images=images)
|
|
329
346
|
|
|
330
347
|
merged = {**self._options, **(options or {})}
|
|
331
348
|
messages = self._build_messages(content, images=images)
|
|
@@ -333,6 +350,7 @@ class AsyncConversation:
|
|
|
333
350
|
|
|
334
351
|
text = resp.get("text", "")
|
|
335
352
|
meta = resp.get("meta", {})
|
|
353
|
+
self._last_reasoning = resp.get("reasoning_content")
|
|
336
354
|
|
|
337
355
|
user_content = self._build_content_with_images(content, images)
|
|
338
356
|
self._messages.append({"role": "user", "content": user_content})
|
|
@@ -365,6 +383,7 @@ class AsyncConversation:
|
|
|
365
383
|
text = resp.get("text", "")
|
|
366
384
|
|
|
367
385
|
if not tool_calls:
|
|
386
|
+
self._last_reasoning = resp.get("reasoning_content")
|
|
368
387
|
self._messages.append({"role": "assistant", "content": text})
|
|
369
388
|
return text
|
|
370
389
|
|
|
@@ -377,6 +396,11 @@ class AsyncConversation:
|
|
|
377
396
|
}
|
|
378
397
|
for tc in tool_calls
|
|
379
398
|
]
|
|
399
|
+
# Preserve reasoning_content for providers that require it
|
|
400
|
+
# on subsequent requests (e.g. Moonshot reasoning models).
|
|
401
|
+
if resp.get("reasoning_content") is not None:
|
|
402
|
+
assistant_msg["reasoning_content"] = resp["reasoning_content"]
|
|
403
|
+
|
|
380
404
|
self._messages.append(assistant_msg)
|
|
381
405
|
msgs.append(assistant_msg)
|
|
382
406
|
|
|
@@ -397,6 +421,63 @@ class AsyncConversation:
|
|
|
397
421
|
|
|
398
422
|
raise RuntimeError(f"Tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
399
423
|
|
|
424
|
+
async def _ask_with_simulated_tools(
|
|
425
|
+
self,
|
|
426
|
+
content: str,
|
|
427
|
+
options: dict[str, Any] | None = None,
|
|
428
|
+
images: list[ImageInput] | None = None,
|
|
429
|
+
) -> str:
|
|
430
|
+
"""Async prompt-based tool calling for drivers without native tool use."""
|
|
431
|
+
from .simulated_tools import build_tool_prompt, format_tool_result, parse_simulated_response
|
|
432
|
+
|
|
433
|
+
merged = {**self._options, **(options or {})}
|
|
434
|
+
tool_prompt = build_tool_prompt(self._tools)
|
|
435
|
+
|
|
436
|
+
# Augment system prompt with tool descriptions
|
|
437
|
+
augmented_system = tool_prompt
|
|
438
|
+
if self._system_prompt:
|
|
439
|
+
augmented_system = f"{self._system_prompt}\n\n{tool_prompt}"
|
|
440
|
+
|
|
441
|
+
# Record user message in history
|
|
442
|
+
user_content = self._build_content_with_images(content, images)
|
|
443
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
444
|
+
|
|
445
|
+
for _round in range(self._max_tool_rounds):
|
|
446
|
+
# Build messages with the augmented system prompt
|
|
447
|
+
msgs: list[dict[str, Any]] = []
|
|
448
|
+
msgs.append({"role": "system", "content": augmented_system})
|
|
449
|
+
msgs.extend(self._messages)
|
|
450
|
+
|
|
451
|
+
resp = await self._driver.generate_messages_with_hooks(msgs, merged)
|
|
452
|
+
text = resp.get("text", "")
|
|
453
|
+
meta = resp.get("meta", {})
|
|
454
|
+
self._accumulate_usage(meta)
|
|
455
|
+
|
|
456
|
+
parsed = parse_simulated_response(text, self._tools)
|
|
457
|
+
|
|
458
|
+
if parsed["type"] == "final_answer":
|
|
459
|
+
answer = parsed["content"]
|
|
460
|
+
self._messages.append({"role": "assistant", "content": answer})
|
|
461
|
+
return answer
|
|
462
|
+
|
|
463
|
+
# Tool call
|
|
464
|
+
tool_name = parsed["name"]
|
|
465
|
+
tool_args = parsed["arguments"]
|
|
466
|
+
|
|
467
|
+
# Record assistant's tool call as an assistant message
|
|
468
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
469
|
+
|
|
470
|
+
try:
|
|
471
|
+
result = self._tools.execute(tool_name, tool_args)
|
|
472
|
+
result_msg = format_tool_result(tool_name, result)
|
|
473
|
+
except Exception as exc:
|
|
474
|
+
result_msg = format_tool_result(tool_name, f"Error: {exc}")
|
|
475
|
+
|
|
476
|
+
# Record tool result as a user message
|
|
477
|
+
self._messages.append({"role": "user", "content": result_msg})
|
|
478
|
+
|
|
479
|
+
raise RuntimeError(f"Simulated tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
480
|
+
|
|
400
481
|
def _build_messages_raw(self) -> list[dict[str, Any]]:
|
|
401
482
|
"""Build messages array from system prompt + full history (including tool messages)."""
|
|
402
483
|
msgs: list[dict[str, Any]] = []
|
|
@@ -457,6 +538,8 @@ class AsyncConversation:
|
|
|
457
538
|
images: list[ImageInput] | None = None,
|
|
458
539
|
) -> dict[str, Any]:
|
|
459
540
|
"""Send a message with schema enforcement and get structured JSON back (async)."""
|
|
541
|
+
self._last_reasoning = None
|
|
542
|
+
|
|
460
543
|
merged = {**self._options, **(options or {})}
|
|
461
544
|
|
|
462
545
|
schema_string = json.dumps(json_schema, indent=2)
|
|
@@ -494,6 +577,7 @@ class AsyncConversation:
|
|
|
494
577
|
|
|
495
578
|
text = resp.get("text", "")
|
|
496
579
|
meta = resp.get("meta", {})
|
|
580
|
+
self._last_reasoning = resp.get("reasoning_content")
|
|
497
581
|
|
|
498
582
|
user_content = self._build_content_with_images(content, images)
|
|
499
583
|
self._messages.append({"role": "user", "content": user_content})
|
|
@@ -528,6 +612,7 @@ class AsyncConversation:
|
|
|
528
612
|
"json_object": json_obj,
|
|
529
613
|
"usage": usage,
|
|
530
614
|
"output_format": output_format,
|
|
615
|
+
"reasoning": self._last_reasoning,
|
|
531
616
|
}
|
|
532
617
|
|
|
533
618
|
if output_format == "toon":
|
prompture/conversation.py
CHANGED
|
@@ -56,6 +56,7 @@ class Conversation:
|
|
|
56
56
|
callbacks: DriverCallbacks | None = None,
|
|
57
57
|
tools: ToolRegistry | None = None,
|
|
58
58
|
max_tool_rounds: int = 10,
|
|
59
|
+
simulated_tools: bool | Literal["auto"] = "auto",
|
|
59
60
|
conversation_id: str | None = None,
|
|
60
61
|
auto_save: str | Path | None = None,
|
|
61
62
|
tags: list[str] | None = None,
|
|
@@ -109,6 +110,10 @@ class Conversation:
|
|
|
109
110
|
}
|
|
110
111
|
self._tools = tools or ToolRegistry()
|
|
111
112
|
self._max_tool_rounds = max_tool_rounds
|
|
113
|
+
self._simulated_tools = simulated_tools
|
|
114
|
+
|
|
115
|
+
# Reasoning content from last response
|
|
116
|
+
self._last_reasoning: str | None = None
|
|
112
117
|
|
|
113
118
|
# Persistence
|
|
114
119
|
self._conversation_id = conversation_id or str(uuid.uuid4())
|
|
@@ -122,6 +127,11 @@ class Conversation:
|
|
|
122
127
|
# Public helpers
|
|
123
128
|
# ------------------------------------------------------------------
|
|
124
129
|
|
|
130
|
+
@property
|
|
131
|
+
def last_reasoning(self) -> str | None:
|
|
132
|
+
"""The reasoning/thinking content from the last LLM response, if any."""
|
|
133
|
+
return self._last_reasoning
|
|
134
|
+
|
|
125
135
|
@property
|
|
126
136
|
def messages(self) -> list[dict[str, Any]]:
|
|
127
137
|
"""Read-only view of the conversation history."""
|
|
@@ -338,8 +348,15 @@ class Conversation:
|
|
|
338
348
|
images: Optional list of images to include (bytes, path, URL,
|
|
339
349
|
base64 string, or :class:`ImageContent`).
|
|
340
350
|
"""
|
|
341
|
-
|
|
342
|
-
|
|
351
|
+
self._last_reasoning = None
|
|
352
|
+
|
|
353
|
+
# Route to appropriate tool handling
|
|
354
|
+
if self._tools:
|
|
355
|
+
use_native = getattr(self._driver, "supports_tool_use", False)
|
|
356
|
+
if self._simulated_tools is True or (self._simulated_tools == "auto" and not use_native):
|
|
357
|
+
return self._ask_with_simulated_tools(content, options, images=images)
|
|
358
|
+
elif use_native and self._simulated_tools is not True:
|
|
359
|
+
return self._ask_with_tools(content, options, images=images)
|
|
343
360
|
|
|
344
361
|
merged = {**self._options, **(options or {})}
|
|
345
362
|
messages = self._build_messages(content, images=images)
|
|
@@ -347,6 +364,7 @@ class Conversation:
|
|
|
347
364
|
|
|
348
365
|
text = resp.get("text", "")
|
|
349
366
|
meta = resp.get("meta", {})
|
|
367
|
+
self._last_reasoning = resp.get("reasoning_content")
|
|
350
368
|
|
|
351
369
|
# Record in history — store content with images for context
|
|
352
370
|
user_content = self._build_content_with_images(content, images)
|
|
@@ -382,6 +400,7 @@ class Conversation:
|
|
|
382
400
|
|
|
383
401
|
if not tool_calls:
|
|
384
402
|
# No tool calls -> final response
|
|
403
|
+
self._last_reasoning = resp.get("reasoning_content")
|
|
385
404
|
self._messages.append({"role": "assistant", "content": text})
|
|
386
405
|
return text
|
|
387
406
|
|
|
@@ -395,6 +414,11 @@ class Conversation:
|
|
|
395
414
|
}
|
|
396
415
|
for tc in tool_calls
|
|
397
416
|
]
|
|
417
|
+
# Preserve reasoning_content for providers that require it
|
|
418
|
+
# on subsequent requests (e.g. Moonshot reasoning models).
|
|
419
|
+
if resp.get("reasoning_content") is not None:
|
|
420
|
+
assistant_msg["reasoning_content"] = resp["reasoning_content"]
|
|
421
|
+
|
|
398
422
|
self._messages.append(assistant_msg)
|
|
399
423
|
msgs.append(assistant_msg)
|
|
400
424
|
|
|
@@ -416,6 +440,63 @@ class Conversation:
|
|
|
416
440
|
|
|
417
441
|
raise RuntimeError(f"Tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
418
442
|
|
|
443
|
+
def _ask_with_simulated_tools(
|
|
444
|
+
self,
|
|
445
|
+
content: str,
|
|
446
|
+
options: dict[str, Any] | None = None,
|
|
447
|
+
images: list[ImageInput] | None = None,
|
|
448
|
+
) -> str:
|
|
449
|
+
"""Prompt-based tool calling for drivers without native tool use."""
|
|
450
|
+
from .simulated_tools import build_tool_prompt, format_tool_result, parse_simulated_response
|
|
451
|
+
|
|
452
|
+
merged = {**self._options, **(options or {})}
|
|
453
|
+
tool_prompt = build_tool_prompt(self._tools)
|
|
454
|
+
|
|
455
|
+
# Augment system prompt with tool descriptions
|
|
456
|
+
augmented_system = tool_prompt
|
|
457
|
+
if self._system_prompt:
|
|
458
|
+
augmented_system = f"{self._system_prompt}\n\n{tool_prompt}"
|
|
459
|
+
|
|
460
|
+
# Record user message in history
|
|
461
|
+
user_content = self._build_content_with_images(content, images)
|
|
462
|
+
self._messages.append({"role": "user", "content": user_content})
|
|
463
|
+
|
|
464
|
+
for _round in range(self._max_tool_rounds):
|
|
465
|
+
# Build messages with the augmented system prompt
|
|
466
|
+
msgs: list[dict[str, Any]] = []
|
|
467
|
+
msgs.append({"role": "system", "content": augmented_system})
|
|
468
|
+
msgs.extend(self._messages)
|
|
469
|
+
|
|
470
|
+
resp = self._driver.generate_messages_with_hooks(msgs, merged)
|
|
471
|
+
text = resp.get("text", "")
|
|
472
|
+
meta = resp.get("meta", {})
|
|
473
|
+
self._accumulate_usage(meta)
|
|
474
|
+
|
|
475
|
+
parsed = parse_simulated_response(text, self._tools)
|
|
476
|
+
|
|
477
|
+
if parsed["type"] == "final_answer":
|
|
478
|
+
answer = parsed["content"]
|
|
479
|
+
self._messages.append({"role": "assistant", "content": answer})
|
|
480
|
+
return answer
|
|
481
|
+
|
|
482
|
+
# Tool call
|
|
483
|
+
tool_name = parsed["name"]
|
|
484
|
+
tool_args = parsed["arguments"]
|
|
485
|
+
|
|
486
|
+
# Record assistant's tool call as an assistant message
|
|
487
|
+
self._messages.append({"role": "assistant", "content": text})
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
result = self._tools.execute(tool_name, tool_args)
|
|
491
|
+
result_msg = format_tool_result(tool_name, result)
|
|
492
|
+
except Exception as exc:
|
|
493
|
+
result_msg = format_tool_result(tool_name, f"Error: {exc}")
|
|
494
|
+
|
|
495
|
+
# Record tool result as a user message (all drivers understand user/assistant)
|
|
496
|
+
self._messages.append({"role": "user", "content": result_msg})
|
|
497
|
+
|
|
498
|
+
raise RuntimeError(f"Simulated tool execution loop exceeded {self._max_tool_rounds} rounds")
|
|
499
|
+
|
|
419
500
|
def _build_messages_raw(self) -> list[dict[str, Any]]:
|
|
420
501
|
"""Build messages array from system prompt + full history (including tool messages)."""
|
|
421
502
|
msgs: list[dict[str, Any]] = []
|
|
@@ -484,6 +565,8 @@ class Conversation:
|
|
|
484
565
|
context clean for subsequent turns.
|
|
485
566
|
"""
|
|
486
567
|
|
|
568
|
+
self._last_reasoning = None
|
|
569
|
+
|
|
487
570
|
merged = {**self._options, **(options or {})}
|
|
488
571
|
|
|
489
572
|
# Build the full prompt with schema instructions inline (handled by ask_for_json)
|
|
@@ -525,6 +608,7 @@ class Conversation:
|
|
|
525
608
|
|
|
526
609
|
text = resp.get("text", "")
|
|
527
610
|
meta = resp.get("meta", {})
|
|
611
|
+
self._last_reasoning = resp.get("reasoning_content")
|
|
528
612
|
|
|
529
613
|
# Store original content (without schema boilerplate) for cleaner context
|
|
530
614
|
# Include images in history so subsequent turns can reference them
|
|
@@ -563,6 +647,7 @@ class Conversation:
|
|
|
563
647
|
"json_object": json_obj,
|
|
564
648
|
"usage": usage,
|
|
565
649
|
"output_format": output_format,
|
|
650
|
+
"reasoning": self._last_reasoning,
|
|
566
651
|
}
|
|
567
652
|
|
|
568
653
|
if output_format == "toon":
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
@@ -18,6 +19,7 @@ from .azure_driver import AzureDriver
|
|
|
18
19
|
class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
19
20
|
supports_json_mode = True
|
|
20
21
|
supports_json_schema = True
|
|
22
|
+
supports_tool_use = True
|
|
21
23
|
supports_vision = True
|
|
22
24
|
|
|
23
25
|
MODEL_PRICING = AzureDriver.MODEL_PRICING
|
|
@@ -122,3 +124,78 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
122
124
|
|
|
123
125
|
text = resp.choices[0].message.content
|
|
124
126
|
return {"text": text, "meta": meta}
|
|
127
|
+
|
|
128
|
+
# ------------------------------------------------------------------
|
|
129
|
+
# Tool use
|
|
130
|
+
# ------------------------------------------------------------------
|
|
131
|
+
|
|
132
|
+
async def generate_messages_with_tools(
|
|
133
|
+
self,
|
|
134
|
+
messages: list[dict[str, Any]],
|
|
135
|
+
tools: list[dict[str, Any]],
|
|
136
|
+
options: dict[str, Any],
|
|
137
|
+
) -> dict[str, Any]:
|
|
138
|
+
"""Generate a response that may include tool calls."""
|
|
139
|
+
if self.client is None:
|
|
140
|
+
raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
|
|
141
|
+
|
|
142
|
+
model = options.get("model", self.model)
|
|
143
|
+
model_config = self._get_model_config("azure", model)
|
|
144
|
+
tokens_param = model_config["tokens_param"]
|
|
145
|
+
supports_temperature = model_config["supports_temperature"]
|
|
146
|
+
|
|
147
|
+
self._validate_model_capabilities("azure", model, using_tool_use=True)
|
|
148
|
+
|
|
149
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
150
|
+
|
|
151
|
+
kwargs: dict[str, Any] = {
|
|
152
|
+
"model": self.deployment_id,
|
|
153
|
+
"messages": messages,
|
|
154
|
+
"tools": tools,
|
|
155
|
+
}
|
|
156
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
157
|
+
|
|
158
|
+
if supports_temperature and "temperature" in opts:
|
|
159
|
+
kwargs["temperature"] = opts["temperature"]
|
|
160
|
+
|
|
161
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
162
|
+
|
|
163
|
+
usage = getattr(resp, "usage", None)
|
|
164
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
165
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
166
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
167
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
168
|
+
|
|
169
|
+
meta = {
|
|
170
|
+
"prompt_tokens": prompt_tokens,
|
|
171
|
+
"completion_tokens": completion_tokens,
|
|
172
|
+
"total_tokens": total_tokens,
|
|
173
|
+
"cost": round(total_cost, 6),
|
|
174
|
+
"raw_response": resp.model_dump(),
|
|
175
|
+
"model_name": model,
|
|
176
|
+
"deployment_id": self.deployment_id,
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
choice = resp.choices[0]
|
|
180
|
+
text = choice.message.content or ""
|
|
181
|
+
stop_reason = choice.finish_reason
|
|
182
|
+
|
|
183
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
184
|
+
if choice.message.tool_calls:
|
|
185
|
+
for tc in choice.message.tool_calls:
|
|
186
|
+
try:
|
|
187
|
+
args = json.loads(tc.function.arguments)
|
|
188
|
+
except (json.JSONDecodeError, TypeError):
|
|
189
|
+
args = {}
|
|
190
|
+
tool_calls_out.append({
|
|
191
|
+
"id": tc.id,
|
|
192
|
+
"name": tc.function.name,
|
|
193
|
+
"arguments": args,
|
|
194
|
+
})
|
|
195
|
+
|
|
196
|
+
return {
|
|
197
|
+
"text": text,
|
|
198
|
+
"meta": meta,
|
|
199
|
+
"tool_calls": tool_calls_out,
|
|
200
|
+
"stop_reason": stop_reason,
|
|
201
|
+
}
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
@@ -14,6 +15,7 @@ from .grok_driver import GrokDriver
|
|
|
14
15
|
|
|
15
16
|
class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
16
17
|
supports_json_mode = True
|
|
18
|
+
supports_tool_use = True
|
|
17
19
|
supports_vision = True
|
|
18
20
|
|
|
19
21
|
MODEL_PRICING = GrokDriver.MODEL_PRICING
|
|
@@ -93,5 +95,107 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
93
95
|
"model_name": model,
|
|
94
96
|
}
|
|
95
97
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
+
message = resp["choices"][0]["message"]
|
|
99
|
+
text = message.get("content") or ""
|
|
100
|
+
reasoning_content = message.get("reasoning_content")
|
|
101
|
+
|
|
102
|
+
if not text and reasoning_content:
|
|
103
|
+
text = reasoning_content
|
|
104
|
+
|
|
105
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
106
|
+
if reasoning_content is not None:
|
|
107
|
+
result["reasoning_content"] = reasoning_content
|
|
108
|
+
return result
|
|
109
|
+
|
|
110
|
+
# ------------------------------------------------------------------
|
|
111
|
+
# Tool use
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
async def generate_messages_with_tools(
|
|
115
|
+
self,
|
|
116
|
+
messages: list[dict[str, Any]],
|
|
117
|
+
tools: list[dict[str, Any]],
|
|
118
|
+
options: dict[str, Any],
|
|
119
|
+
) -> dict[str, Any]:
|
|
120
|
+
"""Generate a response that may include tool calls."""
|
|
121
|
+
if not self.api_key:
|
|
122
|
+
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
123
|
+
|
|
124
|
+
model = options.get("model", self.model)
|
|
125
|
+
model_config = self._get_model_config("grok", model)
|
|
126
|
+
tokens_param = model_config["tokens_param"]
|
|
127
|
+
supports_temperature = model_config["supports_temperature"]
|
|
128
|
+
|
|
129
|
+
self._validate_model_capabilities("grok", model, using_tool_use=True)
|
|
130
|
+
|
|
131
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
132
|
+
|
|
133
|
+
payload: dict[str, Any] = {
|
|
134
|
+
"model": model,
|
|
135
|
+
"messages": messages,
|
|
136
|
+
"tools": tools,
|
|
137
|
+
}
|
|
138
|
+
payload[tokens_param] = opts.get("max_tokens", 512)
|
|
139
|
+
|
|
140
|
+
if supports_temperature and "temperature" in opts:
|
|
141
|
+
payload["temperature"] = opts["temperature"]
|
|
142
|
+
|
|
143
|
+
if "tool_choice" in options:
|
|
144
|
+
payload["tool_choice"] = options["tool_choice"]
|
|
145
|
+
|
|
146
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
147
|
+
|
|
148
|
+
async with httpx.AsyncClient() as client:
|
|
149
|
+
try:
|
|
150
|
+
response = await client.post(
|
|
151
|
+
f"{self.api_base}/chat/completions", headers=headers, json=payload, timeout=120
|
|
152
|
+
)
|
|
153
|
+
response.raise_for_status()
|
|
154
|
+
resp = response.json()
|
|
155
|
+
except httpx.HTTPStatusError as e:
|
|
156
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
157
|
+
except Exception as e:
|
|
158
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
159
|
+
|
|
160
|
+
usage = resp.get("usage", {})
|
|
161
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
162
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
163
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
164
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
165
|
+
|
|
166
|
+
meta = {
|
|
167
|
+
"prompt_tokens": prompt_tokens,
|
|
168
|
+
"completion_tokens": completion_tokens,
|
|
169
|
+
"total_tokens": total_tokens,
|
|
170
|
+
"cost": round(total_cost, 6),
|
|
171
|
+
"raw_response": resp,
|
|
172
|
+
"model_name": model,
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
choice = resp["choices"][0]
|
|
176
|
+
text = choice["message"].get("content") or ""
|
|
177
|
+
stop_reason = choice.get("finish_reason")
|
|
178
|
+
|
|
179
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
180
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
181
|
+
try:
|
|
182
|
+
args = json.loads(tc["function"]["arguments"])
|
|
183
|
+
except (json.JSONDecodeError, TypeError):
|
|
184
|
+
args = {}
|
|
185
|
+
tool_calls_out.append(
|
|
186
|
+
{
|
|
187
|
+
"id": tc["id"],
|
|
188
|
+
"name": tc["function"]["name"],
|
|
189
|
+
"arguments": args,
|
|
190
|
+
}
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
result: dict[str, Any] = {
|
|
194
|
+
"text": text,
|
|
195
|
+
"meta": meta,
|
|
196
|
+
"tool_calls": tool_calls_out,
|
|
197
|
+
"stop_reason": stop_reason,
|
|
198
|
+
}
|
|
199
|
+
if choice["message"].get("reasoning_content") is not None:
|
|
200
|
+
result["reasoning_content"] = choice["message"]["reasoning_content"]
|
|
201
|
+
return result
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
from typing import Any
|
|
7
8
|
|
|
@@ -17,6 +18,7 @@ from .groq_driver import GroqDriver
|
|
|
17
18
|
|
|
18
19
|
class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
19
20
|
supports_json_mode = True
|
|
21
|
+
supports_tool_use = True
|
|
20
22
|
supports_vision = True
|
|
21
23
|
|
|
22
24
|
MODEL_PRICING = GroqDriver.MODEL_PRICING
|
|
@@ -86,5 +88,93 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
86
88
|
"model_name": model,
|
|
87
89
|
}
|
|
88
90
|
|
|
89
|
-
text = resp.choices[0].message.content
|
|
90
|
-
|
|
91
|
+
text = resp.choices[0].message.content or ""
|
|
92
|
+
reasoning_content = getattr(resp.choices[0].message, "reasoning_content", None)
|
|
93
|
+
|
|
94
|
+
if not text and reasoning_content:
|
|
95
|
+
text = reasoning_content
|
|
96
|
+
|
|
97
|
+
result: dict[str, Any] = {"text": text, "meta": meta}
|
|
98
|
+
if reasoning_content is not None:
|
|
99
|
+
result["reasoning_content"] = reasoning_content
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
# ------------------------------------------------------------------
|
|
103
|
+
# Tool use
|
|
104
|
+
# ------------------------------------------------------------------
|
|
105
|
+
|
|
106
|
+
async def generate_messages_with_tools(
|
|
107
|
+
self,
|
|
108
|
+
messages: list[dict[str, Any]],
|
|
109
|
+
tools: list[dict[str, Any]],
|
|
110
|
+
options: dict[str, Any],
|
|
111
|
+
) -> dict[str, Any]:
|
|
112
|
+
"""Generate a response that may include tool calls."""
|
|
113
|
+
if self.client is None:
|
|
114
|
+
raise RuntimeError("groq package is not installed")
|
|
115
|
+
|
|
116
|
+
model = options.get("model", self.model)
|
|
117
|
+
model_config = self._get_model_config("groq", model)
|
|
118
|
+
tokens_param = model_config["tokens_param"]
|
|
119
|
+
supports_temperature = model_config["supports_temperature"]
|
|
120
|
+
|
|
121
|
+
self._validate_model_capabilities("groq", model, using_tool_use=True)
|
|
122
|
+
|
|
123
|
+
opts = {"temperature": 0.7, "max_tokens": 512, **options}
|
|
124
|
+
|
|
125
|
+
kwargs: dict[str, Any] = {
|
|
126
|
+
"model": model,
|
|
127
|
+
"messages": messages,
|
|
128
|
+
"tools": tools,
|
|
129
|
+
}
|
|
130
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
131
|
+
|
|
132
|
+
if supports_temperature and "temperature" in opts:
|
|
133
|
+
kwargs["temperature"] = opts["temperature"]
|
|
134
|
+
|
|
135
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
136
|
+
|
|
137
|
+
usage = getattr(resp, "usage", None)
|
|
138
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
139
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
140
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
141
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
142
|
+
|
|
143
|
+
meta = {
|
|
144
|
+
"prompt_tokens": prompt_tokens,
|
|
145
|
+
"completion_tokens": completion_tokens,
|
|
146
|
+
"total_tokens": total_tokens,
|
|
147
|
+
"cost": round(total_cost, 6),
|
|
148
|
+
"raw_response": resp.model_dump(),
|
|
149
|
+
"model_name": model,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
choice = resp.choices[0]
|
|
153
|
+
text = choice.message.content or ""
|
|
154
|
+
stop_reason = choice.finish_reason
|
|
155
|
+
|
|
156
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
157
|
+
if choice.message.tool_calls:
|
|
158
|
+
for tc in choice.message.tool_calls:
|
|
159
|
+
try:
|
|
160
|
+
args = json.loads(tc.function.arguments)
|
|
161
|
+
except (json.JSONDecodeError, TypeError):
|
|
162
|
+
args = {}
|
|
163
|
+
tool_calls_out.append(
|
|
164
|
+
{
|
|
165
|
+
"id": tc.id,
|
|
166
|
+
"name": tc.function.name,
|
|
167
|
+
"arguments": args,
|
|
168
|
+
}
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
result: dict[str, Any] = {
|
|
172
|
+
"text": text,
|
|
173
|
+
"meta": meta,
|
|
174
|
+
"tool_calls": tool_calls_out,
|
|
175
|
+
"stop_reason": stop_reason,
|
|
176
|
+
}
|
|
177
|
+
reasoning_content = getattr(choice.message, "reasoning_content", None)
|
|
178
|
+
if reasoning_content is not None:
|
|
179
|
+
result["reasoning_content"] = reasoning_content
|
|
180
|
+
return result
|