prompture 0.0.37.dev3__tar.gz → 0.0.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {prompture-0.0.37.dev3/prompture.egg-info → prompture-0.0.38}/PKG-INFO +1 -1
- prompture-0.0.38/VERSION +1 -0
- prompture-0.0.38/docs/source/_templates/footer.html +16 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/conf.py +1 -1
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/_version.py +2 -2
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_azure_driver.py +1 -1
- prompture-0.0.38/prompture/drivers/async_claude_driver.py +272 -0
- prompture-0.0.38/prompture/drivers/async_google_driver.py +316 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_grok_driver.py +1 -1
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_groq_driver.py +1 -1
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_lmstudio_driver.py +16 -3
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_ollama_driver.py +6 -3
- prompture-0.0.38/prompture/drivers/async_openai_driver.py +244 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_openrouter_driver.py +1 -1
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/google_driver.py +207 -43
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/lmstudio_driver.py +16 -3
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/ollama_driver.py +9 -5
- {prompture-0.0.37.dev3 → prompture-0.0.38/prompture.egg-info}/PKG-INFO +1 -1
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture.egg-info/SOURCES.txt +2 -0
- prompture-0.0.37.dev3/prompture/drivers/async_claude_driver.py +0 -113
- prompture-0.0.37.dev3/prompture/drivers/async_google_driver.py +0 -152
- prompture-0.0.37.dev3/prompture/drivers/async_openai_driver.py +0 -102
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/add-driver/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/add-driver/references/driver-template.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/add-example/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/add-field/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/add-test/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/run-tests/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/scaffold-extraction/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.claude/skills/update-pricing/SKILL.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.env.copy +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.github/FUNDING.yml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.github/scripts/update_docs_version.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.github/scripts/update_wrapper_version.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.github/workflows/dev.yml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.github/workflows/documentation.yml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/.github/workflows/publish.yml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/CLAUDE.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/LICENSE +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/MANIFEST.in +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/README.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/ROADMAP.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/_static/custom.css +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/core.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/drivers.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/field_definitions.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/index.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/runner.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/tools.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/api/validator.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/contributing.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/examples.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/field_definitions_reference.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/index.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/installation.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/quickstart.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/docs/source/toon_input_guide.rst +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/README.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_json/README.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_json/llm_to_json/__init__.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_json/pyproject.toml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_json/test.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_toon/README.md +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_toon/llm_to_toon/__init__.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_toon/pyproject.toml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/packages/llm_to_toon/test.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/__init__.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/agent.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/agent_types.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/aio/__init__.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/async_agent.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/async_conversation.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/async_core.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/async_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/async_groups.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/cache.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/callbacks.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/cli.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/conversation.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/core.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/cost_mixin.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/discovery.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/__init__.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/airllm_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_airllm_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_hugging_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_local_http_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/async_registry.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/azure_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/claude_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/grok_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/groq_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/hugging_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/local_http_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/openai_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/openrouter_driver.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/registry.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/drivers/vision_helpers.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/field_definitions.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/group_types.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/groups.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/image.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/logging.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/model_rates.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/persistence.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/persona.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/runner.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/__init__.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/generator.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/Dockerfile.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/README.md.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/config.py.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/env.example.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/main.py.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/models.py.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/scaffold/templates/requirements.txt.j2 +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/serialization.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/server.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/session.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/settings.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/tools.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/tools_schema.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture/validator.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture.egg-info/dependency_links.txt +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture.egg-info/entry_points.txt +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture.egg-info/requires.txt +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/prompture.egg-info/top_level.txt +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/pyproject.toml +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/requirements.txt +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/setup.cfg +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/test.py +0 -0
- {prompture-0.0.37.dev3 → prompture-0.0.38}/test_version_diagnosis.py +0 -0
prompture-0.0.38/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.0.38
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{%- extends "!footer.html" %}
|
|
2
|
+
|
|
3
|
+
{% block extrafooter %}
|
|
4
|
+
<script>
|
|
5
|
+
document.addEventListener("DOMContentLoaded", function() {
|
|
6
|
+
var footerCopy = document.querySelector("footer .copyright");
|
|
7
|
+
if (footerCopy) {
|
|
8
|
+
footerCopy.innerHTML = footerCopy.innerHTML.replace(
|
|
9
|
+
"Juan Denis",
|
|
10
|
+
'<a href="https://juandenis.com">Juan Denis</a>'
|
|
11
|
+
);
|
|
12
|
+
}
|
|
13
|
+
});
|
|
14
|
+
</script>
|
|
15
|
+
{{ super() }}
|
|
16
|
+
{% endblock %}
|
|
@@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath("../../"))
|
|
|
14
14
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
15
15
|
|
|
16
16
|
project = "Prompture"
|
|
17
|
-
copyright = '2026,
|
|
17
|
+
copyright = '2026, Juan Denis'
|
|
18
18
|
author = "Juan Denis"
|
|
19
19
|
|
|
20
20
|
# Read version dynamically: VERSION file > setuptools_scm > fallback
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
31
|
+
__version__ = version = '0.0.38'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 38)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -113,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
113
113
|
"prompt_tokens": prompt_tokens,
|
|
114
114
|
"completion_tokens": completion_tokens,
|
|
115
115
|
"total_tokens": total_tokens,
|
|
116
|
-
"cost": total_cost,
|
|
116
|
+
"cost": round(total_cost, 6),
|
|
117
117
|
"raw_response": resp.model_dump(),
|
|
118
118
|
"model_name": model,
|
|
119
119
|
"deployment_id": self.deployment_id,
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""Async Anthropic Claude driver. Requires the ``anthropic`` package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import anthropic
|
|
12
|
+
except Exception:
|
|
13
|
+
anthropic = None
|
|
14
|
+
|
|
15
|
+
from ..async_driver import AsyncDriver
|
|
16
|
+
from ..cost_mixin import CostMixin
|
|
17
|
+
from .claude_driver import ClaudeDriver
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
21
|
+
supports_json_mode = True
|
|
22
|
+
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
25
|
+
supports_vision = True
|
|
26
|
+
|
|
27
|
+
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
28
|
+
|
|
29
|
+
def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
|
|
30
|
+
self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
|
|
31
|
+
self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
|
|
32
|
+
|
|
33
|
+
supports_messages = True
|
|
34
|
+
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_claude_vision_messages(messages)
|
|
39
|
+
|
|
40
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
41
|
+
messages = [{"role": "user", "content": prompt}]
|
|
42
|
+
return await self._do_generate(messages, options)
|
|
43
|
+
|
|
44
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
46
|
+
|
|
47
|
+
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
+
if anthropic is None:
|
|
49
|
+
raise RuntimeError("anthropic package not installed")
|
|
50
|
+
|
|
51
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
52
|
+
model = options.get("model", self.model)
|
|
53
|
+
|
|
54
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
55
|
+
|
|
56
|
+
# Anthropic requires system messages as a top-level parameter
|
|
57
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
58
|
+
|
|
59
|
+
# Build common kwargs
|
|
60
|
+
common_kwargs: dict[str, Any] = {
|
|
61
|
+
"model": model,
|
|
62
|
+
"messages": api_messages,
|
|
63
|
+
"temperature": opts["temperature"],
|
|
64
|
+
"max_tokens": opts["max_tokens"],
|
|
65
|
+
}
|
|
66
|
+
if system_content:
|
|
67
|
+
common_kwargs["system"] = system_content
|
|
68
|
+
|
|
69
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
70
|
+
if options.get("json_mode"):
|
|
71
|
+
json_schema = options.get("json_schema")
|
|
72
|
+
if json_schema:
|
|
73
|
+
tool_def = {
|
|
74
|
+
"name": "extract_json",
|
|
75
|
+
"description": "Extract structured data matching the schema",
|
|
76
|
+
"input_schema": json_schema,
|
|
77
|
+
}
|
|
78
|
+
resp = await client.messages.create(
|
|
79
|
+
**common_kwargs,
|
|
80
|
+
tools=[tool_def],
|
|
81
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
82
|
+
)
|
|
83
|
+
text = ""
|
|
84
|
+
for block in resp.content:
|
|
85
|
+
if block.type == "tool_use":
|
|
86
|
+
text = json.dumps(block.input)
|
|
87
|
+
break
|
|
88
|
+
else:
|
|
89
|
+
resp = await client.messages.create(**common_kwargs)
|
|
90
|
+
text = resp.content[0].text
|
|
91
|
+
else:
|
|
92
|
+
resp = await client.messages.create(**common_kwargs)
|
|
93
|
+
text = resp.content[0].text
|
|
94
|
+
|
|
95
|
+
prompt_tokens = resp.usage.input_tokens
|
|
96
|
+
completion_tokens = resp.usage.output_tokens
|
|
97
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
98
|
+
|
|
99
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
100
|
+
|
|
101
|
+
meta = {
|
|
102
|
+
"prompt_tokens": prompt_tokens,
|
|
103
|
+
"completion_tokens": completion_tokens,
|
|
104
|
+
"total_tokens": total_tokens,
|
|
105
|
+
"cost": round(total_cost, 6),
|
|
106
|
+
"raw_response": dict(resp),
|
|
107
|
+
"model_name": model,
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
return {"text": text, "meta": meta}
|
|
111
|
+
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
# Helpers
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
def _extract_system_and_messages(
|
|
117
|
+
self, messages: list[dict[str, Any]]
|
|
118
|
+
) -> tuple[str | None, list[dict[str, Any]]]:
|
|
119
|
+
"""Separate system message from conversation messages for Anthropic API."""
|
|
120
|
+
system_content = None
|
|
121
|
+
api_messages: list[dict[str, Any]] = []
|
|
122
|
+
for msg in messages:
|
|
123
|
+
if msg.get("role") == "system":
|
|
124
|
+
system_content = msg.get("content", "")
|
|
125
|
+
else:
|
|
126
|
+
api_messages.append(msg)
|
|
127
|
+
return system_content, api_messages
|
|
128
|
+
|
|
129
|
+
# ------------------------------------------------------------------
|
|
130
|
+
# Tool use
|
|
131
|
+
# ------------------------------------------------------------------
|
|
132
|
+
|
|
133
|
+
async def generate_messages_with_tools(
|
|
134
|
+
self,
|
|
135
|
+
messages: list[dict[str, Any]],
|
|
136
|
+
tools: list[dict[str, Any]],
|
|
137
|
+
options: dict[str, Any],
|
|
138
|
+
) -> dict[str, Any]:
|
|
139
|
+
"""Generate a response that may include tool calls (Anthropic)."""
|
|
140
|
+
if anthropic is None:
|
|
141
|
+
raise RuntimeError("anthropic package not installed")
|
|
142
|
+
|
|
143
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
144
|
+
model = options.get("model", self.model)
|
|
145
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
146
|
+
|
|
147
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
148
|
+
|
|
149
|
+
# Convert tools from OpenAI format to Anthropic format if needed
|
|
150
|
+
anthropic_tools = []
|
|
151
|
+
for t in tools:
|
|
152
|
+
if "type" in t and t["type"] == "function":
|
|
153
|
+
# OpenAI format -> Anthropic format
|
|
154
|
+
fn = t["function"]
|
|
155
|
+
anthropic_tools.append({
|
|
156
|
+
"name": fn["name"],
|
|
157
|
+
"description": fn.get("description", ""),
|
|
158
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
159
|
+
})
|
|
160
|
+
elif "input_schema" in t:
|
|
161
|
+
# Already Anthropic format
|
|
162
|
+
anthropic_tools.append(t)
|
|
163
|
+
else:
|
|
164
|
+
anthropic_tools.append(t)
|
|
165
|
+
|
|
166
|
+
kwargs: dict[str, Any] = {
|
|
167
|
+
"model": model,
|
|
168
|
+
"messages": api_messages,
|
|
169
|
+
"temperature": opts["temperature"],
|
|
170
|
+
"max_tokens": opts["max_tokens"],
|
|
171
|
+
"tools": anthropic_tools,
|
|
172
|
+
}
|
|
173
|
+
if system_content:
|
|
174
|
+
kwargs["system"] = system_content
|
|
175
|
+
|
|
176
|
+
resp = await client.messages.create(**kwargs)
|
|
177
|
+
|
|
178
|
+
prompt_tokens = resp.usage.input_tokens
|
|
179
|
+
completion_tokens = resp.usage.output_tokens
|
|
180
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
181
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
182
|
+
|
|
183
|
+
meta = {
|
|
184
|
+
"prompt_tokens": prompt_tokens,
|
|
185
|
+
"completion_tokens": completion_tokens,
|
|
186
|
+
"total_tokens": total_tokens,
|
|
187
|
+
"cost": round(total_cost, 6),
|
|
188
|
+
"raw_response": dict(resp),
|
|
189
|
+
"model_name": model,
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
text = ""
|
|
193
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
194
|
+
for block in resp.content:
|
|
195
|
+
if block.type == "text":
|
|
196
|
+
text += block.text
|
|
197
|
+
elif block.type == "tool_use":
|
|
198
|
+
tool_calls_out.append({
|
|
199
|
+
"id": block.id,
|
|
200
|
+
"name": block.name,
|
|
201
|
+
"arguments": block.input,
|
|
202
|
+
})
|
|
203
|
+
|
|
204
|
+
return {
|
|
205
|
+
"text": text,
|
|
206
|
+
"meta": meta,
|
|
207
|
+
"tool_calls": tool_calls_out,
|
|
208
|
+
"stop_reason": resp.stop_reason,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
# ------------------------------------------------------------------
|
|
212
|
+
# Streaming
|
|
213
|
+
# ------------------------------------------------------------------
|
|
214
|
+
|
|
215
|
+
async def generate_messages_stream(
|
|
216
|
+
self,
|
|
217
|
+
messages: list[dict[str, Any]],
|
|
218
|
+
options: dict[str, Any],
|
|
219
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
220
|
+
"""Yield response chunks via Anthropic streaming API."""
|
|
221
|
+
if anthropic is None:
|
|
222
|
+
raise RuntimeError("anthropic package not installed")
|
|
223
|
+
|
|
224
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
225
|
+
model = options.get("model", self.model)
|
|
226
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
227
|
+
|
|
228
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
229
|
+
|
|
230
|
+
kwargs: dict[str, Any] = {
|
|
231
|
+
"model": model,
|
|
232
|
+
"messages": api_messages,
|
|
233
|
+
"temperature": opts["temperature"],
|
|
234
|
+
"max_tokens": opts["max_tokens"],
|
|
235
|
+
}
|
|
236
|
+
if system_content:
|
|
237
|
+
kwargs["system"] = system_content
|
|
238
|
+
|
|
239
|
+
full_text = ""
|
|
240
|
+
prompt_tokens = 0
|
|
241
|
+
completion_tokens = 0
|
|
242
|
+
|
|
243
|
+
async with client.messages.stream(**kwargs) as stream:
|
|
244
|
+
async for event in stream:
|
|
245
|
+
if hasattr(event, "type"):
|
|
246
|
+
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
|
247
|
+
delta_text = getattr(event.delta, "text", "")
|
|
248
|
+
if delta_text:
|
|
249
|
+
full_text += delta_text
|
|
250
|
+
yield {"type": "delta", "text": delta_text}
|
|
251
|
+
elif event.type == "message_delta" and hasattr(event, "usage"):
|
|
252
|
+
completion_tokens = getattr(event.usage, "output_tokens", 0)
|
|
253
|
+
elif event.type == "message_start" and hasattr(event, "message"):
|
|
254
|
+
usage = getattr(event.message, "usage", None)
|
|
255
|
+
if usage:
|
|
256
|
+
prompt_tokens = getattr(usage, "input_tokens", 0)
|
|
257
|
+
|
|
258
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
259
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
260
|
+
|
|
261
|
+
yield {
|
|
262
|
+
"type": "done",
|
|
263
|
+
"text": full_text,
|
|
264
|
+
"meta": {
|
|
265
|
+
"prompt_tokens": prompt_tokens,
|
|
266
|
+
"completion_tokens": completion_tokens,
|
|
267
|
+
"total_tokens": total_tokens,
|
|
268
|
+
"cost": round(total_cost, 6),
|
|
269
|
+
"raw_response": {},
|
|
270
|
+
"model_name": model,
|
|
271
|
+
},
|
|
272
|
+
}
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Async Google Generative AI (Gemini) driver."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import google.generativeai as genai
|
|
12
|
+
|
|
13
|
+
from ..async_driver import AsyncDriver
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
15
|
+
from .google_driver import GoogleDriver
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
21
|
+
"""Async driver for Google's Generative AI API (Gemini)."""
|
|
22
|
+
|
|
23
|
+
supports_json_mode = True
|
|
24
|
+
supports_json_schema = True
|
|
25
|
+
supports_vision = True
|
|
26
|
+
supports_tool_use = True
|
|
27
|
+
supports_streaming = True
|
|
28
|
+
|
|
29
|
+
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
30
|
+
_PRICING_UNIT = 1_000_000
|
|
31
|
+
|
|
32
|
+
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
33
|
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
34
|
+
if not self.api_key:
|
|
35
|
+
raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
|
|
36
|
+
self.model = model
|
|
37
|
+
genai.configure(api_key=self.api_key)
|
|
38
|
+
self.options: dict[str, Any] = {}
|
|
39
|
+
|
|
40
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
41
|
+
"""Calculate cost from character counts (same logic as sync GoogleDriver)."""
|
|
42
|
+
from ..model_rates import get_model_rates
|
|
43
|
+
|
|
44
|
+
live_rates = get_model_rates("google", self.model)
|
|
45
|
+
if live_rates:
|
|
46
|
+
est_prompt_tokens = prompt_chars / 4
|
|
47
|
+
est_completion_tokens = completion_chars / 4
|
|
48
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
49
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
50
|
+
else:
|
|
51
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
52
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
53
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
54
|
+
return round(prompt_cost + completion_cost, 6)
|
|
55
|
+
|
|
56
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
57
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
58
|
+
usage = getattr(response, "usage_metadata", None)
|
|
59
|
+
if usage:
|
|
60
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
61
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
62
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
63
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
64
|
+
else:
|
|
65
|
+
# Fallback: estimate from character counts
|
|
66
|
+
total_prompt_chars = 0
|
|
67
|
+
for msg in messages:
|
|
68
|
+
c = msg.get("content", "")
|
|
69
|
+
if isinstance(c, str):
|
|
70
|
+
total_prompt_chars += len(c)
|
|
71
|
+
elif isinstance(c, list):
|
|
72
|
+
for part in c:
|
|
73
|
+
if isinstance(part, str):
|
|
74
|
+
total_prompt_chars += len(part)
|
|
75
|
+
elif isinstance(part, dict) and "text" in part:
|
|
76
|
+
total_prompt_chars += len(part["text"])
|
|
77
|
+
completion_chars = len(response.text) if response.text else 0
|
|
78
|
+
prompt_tokens = total_prompt_chars // 4
|
|
79
|
+
completion_tokens = completion_chars // 4
|
|
80
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
81
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"prompt_tokens": prompt_tokens,
|
|
85
|
+
"completion_tokens": completion_tokens,
|
|
86
|
+
"total_tokens": total_tokens,
|
|
87
|
+
"cost": round(cost, 6),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
supports_messages = True
|
|
91
|
+
|
|
92
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
93
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
94
|
+
|
|
95
|
+
return _prepare_google_vision_messages(messages)
|
|
96
|
+
|
|
97
|
+
def _build_generation_args(
|
|
98
|
+
self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
|
|
99
|
+
) -> tuple[Any, dict[str, Any], dict[str, Any]]:
|
|
100
|
+
"""Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
|
|
101
|
+
merged_options = self.options.copy()
|
|
102
|
+
if options:
|
|
103
|
+
merged_options.update(options)
|
|
104
|
+
|
|
105
|
+
generation_config = merged_options.get("generation_config", {})
|
|
106
|
+
safety_settings = merged_options.get("safety_settings", {})
|
|
107
|
+
|
|
108
|
+
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
109
|
+
generation_config["temperature"] = merged_options["temperature"]
|
|
110
|
+
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
111
|
+
generation_config["max_output_tokens"] = merged_options["max_tokens"]
|
|
112
|
+
if "top_p" in merged_options and "top_p" not in generation_config:
|
|
113
|
+
generation_config["top_p"] = merged_options["top_p"]
|
|
114
|
+
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
115
|
+
generation_config["top_k"] = merged_options["top_k"]
|
|
116
|
+
|
|
117
|
+
# Native JSON mode support
|
|
118
|
+
if merged_options.get("json_mode"):
|
|
119
|
+
generation_config["response_mime_type"] = "application/json"
|
|
120
|
+
json_schema = merged_options.get("json_schema")
|
|
121
|
+
if json_schema:
|
|
122
|
+
generation_config["response_schema"] = json_schema
|
|
123
|
+
|
|
124
|
+
# Convert messages to Gemini format
|
|
125
|
+
system_instruction = None
|
|
126
|
+
contents: list[dict[str, Any]] = []
|
|
127
|
+
for msg in messages:
|
|
128
|
+
role = msg.get("role", "user")
|
|
129
|
+
content = msg.get("content", "")
|
|
130
|
+
if role == "system":
|
|
131
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
132
|
+
else:
|
|
133
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
134
|
+
if msg.get("_vision_parts"):
|
|
135
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
136
|
+
else:
|
|
137
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
138
|
+
|
|
139
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
140
|
+
if len(contents) == 1:
|
|
141
|
+
parts = contents[0]["parts"]
|
|
142
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
143
|
+
gen_input = parts[0]
|
|
144
|
+
else:
|
|
145
|
+
gen_input = contents
|
|
146
|
+
else:
|
|
147
|
+
gen_input = contents
|
|
148
|
+
|
|
149
|
+
model_kwargs: dict[str, Any] = {}
|
|
150
|
+
if system_instruction:
|
|
151
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
152
|
+
|
|
153
|
+
gen_kwargs: dict[str, Any] = {
|
|
154
|
+
"generation_config": generation_config if generation_config else None,
|
|
155
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
159
|
+
|
|
160
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
161
|
+
messages = [{"role": "user", "content": prompt}]
|
|
162
|
+
return await self._do_generate(messages, options)
|
|
163
|
+
|
|
164
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
166
|
+
|
|
167
|
+
async def _do_generate(
|
|
168
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
174
|
+
response = await model.generate_content_async(gen_input, **gen_kwargs)
|
|
175
|
+
|
|
176
|
+
if not response.text:
|
|
177
|
+
raise ValueError("Empty response from model")
|
|
178
|
+
|
|
179
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
180
|
+
|
|
181
|
+
meta = {
|
|
182
|
+
**usage_meta,
|
|
183
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
184
|
+
"model_name": self.model,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
return {"text": response.text, "meta": meta}
|
|
188
|
+
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.error(f"Google API request failed: {e}")
|
|
191
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
192
|
+
|
|
193
|
+
# ------------------------------------------------------------------
|
|
194
|
+
# Tool use
|
|
195
|
+
# ------------------------------------------------------------------
|
|
196
|
+
|
|
197
|
+
async def generate_messages_with_tools(
|
|
198
|
+
self,
|
|
199
|
+
messages: list[dict[str, Any]],
|
|
200
|
+
tools: list[dict[str, Any]],
|
|
201
|
+
options: dict[str, Any],
|
|
202
|
+
) -> dict[str, Any]:
|
|
203
|
+
"""Generate a response that may include tool/function calls (async)."""
|
|
204
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
205
|
+
self._prepare_messages(messages), options
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
209
|
+
function_declarations = []
|
|
210
|
+
for t in tools:
|
|
211
|
+
if "type" in t and t["type"] == "function":
|
|
212
|
+
fn = t["function"]
|
|
213
|
+
decl = {
|
|
214
|
+
"name": fn["name"],
|
|
215
|
+
"description": fn.get("description", ""),
|
|
216
|
+
}
|
|
217
|
+
params = fn.get("parameters")
|
|
218
|
+
if params:
|
|
219
|
+
decl["parameters"] = params
|
|
220
|
+
function_declarations.append(decl)
|
|
221
|
+
elif "name" in t:
|
|
222
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
223
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
224
|
+
if params:
|
|
225
|
+
decl["parameters"] = params
|
|
226
|
+
function_declarations.append(decl)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
230
|
+
|
|
231
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
232
|
+
response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
233
|
+
|
|
234
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
235
|
+
meta = {
|
|
236
|
+
**usage_meta,
|
|
237
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
238
|
+
"model_name": self.model,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
text = ""
|
|
242
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
243
|
+
stop_reason = "stop"
|
|
244
|
+
|
|
245
|
+
for candidate in response.candidates:
|
|
246
|
+
for part in candidate.content.parts:
|
|
247
|
+
if hasattr(part, "text") and part.text:
|
|
248
|
+
text += part.text
|
|
249
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
250
|
+
fc = part.function_call
|
|
251
|
+
tool_calls_out.append({
|
|
252
|
+
"id": str(uuid.uuid4()),
|
|
253
|
+
"name": fc.name,
|
|
254
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
255
|
+
})
|
|
256
|
+
|
|
257
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
258
|
+
if finish_reason is not None:
|
|
259
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
260
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
261
|
+
|
|
262
|
+
if tool_calls_out:
|
|
263
|
+
stop_reason = "tool_use"
|
|
264
|
+
|
|
265
|
+
return {
|
|
266
|
+
"text": text,
|
|
267
|
+
"meta": meta,
|
|
268
|
+
"tool_calls": tool_calls_out,
|
|
269
|
+
"stop_reason": stop_reason,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
274
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
275
|
+
|
|
276
|
+
# ------------------------------------------------------------------
|
|
277
|
+
# Streaming
|
|
278
|
+
# ------------------------------------------------------------------
|
|
279
|
+
|
|
280
|
+
async def generate_messages_stream(
|
|
281
|
+
self,
|
|
282
|
+
messages: list[dict[str, Any]],
|
|
283
|
+
options: dict[str, Any],
|
|
284
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
285
|
+
"""Yield response chunks via Gemini async streaming API."""
|
|
286
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
287
|
+
self._prepare_messages(messages), options
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
292
|
+
response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
|
|
293
|
+
|
|
294
|
+
full_text = ""
|
|
295
|
+
async for chunk in response:
|
|
296
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
297
|
+
if chunk_text:
|
|
298
|
+
full_text += chunk_text
|
|
299
|
+
yield {"type": "delta", "text": chunk_text}
|
|
300
|
+
|
|
301
|
+
# After iteration completes, usage_metadata should be available
|
|
302
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
303
|
+
|
|
304
|
+
yield {
|
|
305
|
+
"type": "done",
|
|
306
|
+
"text": full_text,
|
|
307
|
+
"meta": {
|
|
308
|
+
**usage_meta,
|
|
309
|
+
"raw_response": {},
|
|
310
|
+
"model_name": self.model,
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
316
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -88,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
88
88
|
"prompt_tokens": prompt_tokens,
|
|
89
89
|
"completion_tokens": completion_tokens,
|
|
90
90
|
"total_tokens": total_tokens,
|
|
91
|
-
"cost": total_cost,
|
|
91
|
+
"cost": round(total_cost, 6),
|
|
92
92
|
"raw_response": resp,
|
|
93
93
|
"model_name": model,
|
|
94
94
|
}
|
|
@@ -81,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
81
81
|
"prompt_tokens": prompt_tokens,
|
|
82
82
|
"completion_tokens": completion_tokens,
|
|
83
83
|
"total_tokens": total_tokens,
|
|
84
|
-
"cost": total_cost,
|
|
84
|
+
"cost": round(total_cost, 6),
|
|
85
85
|
"raw_response": resp.model_dump(),
|
|
86
86
|
"model_name": model,
|
|
87
87
|
}
|