prompture 0.0.38.dev1__tar.gz → 0.0.38.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {prompture-0.0.38.dev1/prompture.egg-info → prompture-0.0.38.dev2}/PKG-INFO +1 -1
- prompture-0.0.38.dev2/docs/source/_templates/footer.html +16 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/conf.py +1 -1
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/_version.py +2 -2
- prompture-0.0.38.dev2/prompture/drivers/async_google_driver.py +316 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/google_driver.py +207 -43
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2/prompture.egg-info}/PKG-INFO +1 -1
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/SOURCES.txt +1 -0
- prompture-0.0.38.dev1/prompture/drivers/async_google_driver.py +0 -152
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-driver/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-driver/references/driver-template.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-example/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-field/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-test/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/run-tests/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/scaffold-extraction/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/update-pricing/SKILL.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.env.copy +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/FUNDING.yml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/scripts/update_docs_version.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/scripts/update_wrapper_version.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/workflows/dev.yml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/workflows/documentation.yml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/workflows/publish.yml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/CLAUDE.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/LICENSE +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/MANIFEST.in +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/README.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/ROADMAP.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/_static/custom.css +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/core.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/drivers.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/field_definitions.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/index.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/runner.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/tools.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/validator.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/contributing.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/examples.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/field_definitions_reference.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/index.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/installation.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/quickstart.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/toon_input_guide.rst +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/README.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/README.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/llm_to_json/__init__.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/pyproject.toml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/test.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/README.md +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/llm_to_toon/__init__.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/pyproject.toml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/test.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/__init__.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/agent.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/agent_types.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/aio/__init__.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_agent.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_conversation.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_core.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_groups.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/cache.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/callbacks.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/cli.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/conversation.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/core.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/cost_mixin.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/discovery.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/__init__.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/airllm_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_airllm_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_azure_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_claude_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_grok_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_groq_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_hugging_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_lmstudio_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_local_http_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_ollama_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_openai_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_openrouter_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_registry.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/azure_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/claude_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/grok_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/groq_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/hugging_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/lmstudio_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/local_http_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/ollama_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/openai_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/openrouter_driver.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/registry.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/vision_helpers.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/field_definitions.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/group_types.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/groups.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/image.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/logging.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/model_rates.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/persistence.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/persona.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/runner.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/__init__.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/generator.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/Dockerfile.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/README.md.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/config.py.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/env.example.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/main.py.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/models.py.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/requirements.txt.j2 +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/serialization.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/server.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/session.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/settings.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/tools.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/tools_schema.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/validator.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/dependency_links.txt +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/entry_points.txt +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/requires.txt +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/top_level.txt +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/pyproject.toml +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/requirements.txt +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/setup.cfg +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/test.py +0 -0
- {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/test_version_diagnosis.py +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
{%- extends "!footer.html" %}
|
|
2
|
+
|
|
3
|
+
{% block extrafooter %}
|
|
4
|
+
<script>
|
|
5
|
+
document.addEventListener("DOMContentLoaded", function() {
|
|
6
|
+
var footerCopy = document.querySelector("footer .copyright");
|
|
7
|
+
if (footerCopy) {
|
|
8
|
+
footerCopy.innerHTML = footerCopy.innerHTML.replace(
|
|
9
|
+
"Juan Denis",
|
|
10
|
+
'<a href="https://juandenis.com">Juan Denis</a>'
|
|
11
|
+
);
|
|
12
|
+
}
|
|
13
|
+
});
|
|
14
|
+
</script>
|
|
15
|
+
{{ super() }}
|
|
16
|
+
{% endblock %}
|
|
@@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath("../../"))
|
|
|
14
14
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
|
15
15
|
|
|
16
16
|
project = "Prompture"
|
|
17
|
-
copyright = '2026,
|
|
17
|
+
copyright = '2026, Juan Denis'
|
|
18
18
|
author = "Juan Denis"
|
|
19
19
|
|
|
20
20
|
# Read version dynamically: VERSION file > setuptools_scm > fallback
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.38.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0, 38, '
|
|
31
|
+
__version__ = version = '0.0.38.dev2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 38, 'dev2')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Async Google Generative AI (Gemini) driver."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import google.generativeai as genai
|
|
12
|
+
|
|
13
|
+
from ..async_driver import AsyncDriver
|
|
14
|
+
from ..cost_mixin import CostMixin
|
|
15
|
+
from .google_driver import GoogleDriver
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
21
|
+
"""Async driver for Google's Generative AI API (Gemini)."""
|
|
22
|
+
|
|
23
|
+
supports_json_mode = True
|
|
24
|
+
supports_json_schema = True
|
|
25
|
+
supports_vision = True
|
|
26
|
+
supports_tool_use = True
|
|
27
|
+
supports_streaming = True
|
|
28
|
+
|
|
29
|
+
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
30
|
+
_PRICING_UNIT = 1_000_000
|
|
31
|
+
|
|
32
|
+
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
33
|
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
34
|
+
if not self.api_key:
|
|
35
|
+
raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
|
|
36
|
+
self.model = model
|
|
37
|
+
genai.configure(api_key=self.api_key)
|
|
38
|
+
self.options: dict[str, Any] = {}
|
|
39
|
+
|
|
40
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
41
|
+
"""Calculate cost from character counts (same logic as sync GoogleDriver)."""
|
|
42
|
+
from ..model_rates import get_model_rates
|
|
43
|
+
|
|
44
|
+
live_rates = get_model_rates("google", self.model)
|
|
45
|
+
if live_rates:
|
|
46
|
+
est_prompt_tokens = prompt_chars / 4
|
|
47
|
+
est_completion_tokens = completion_chars / 4
|
|
48
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
49
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
50
|
+
else:
|
|
51
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
52
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
53
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
54
|
+
return round(prompt_cost + completion_cost, 6)
|
|
55
|
+
|
|
56
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
57
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
58
|
+
usage = getattr(response, "usage_metadata", None)
|
|
59
|
+
if usage:
|
|
60
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
61
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
62
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
63
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
64
|
+
else:
|
|
65
|
+
# Fallback: estimate from character counts
|
|
66
|
+
total_prompt_chars = 0
|
|
67
|
+
for msg in messages:
|
|
68
|
+
c = msg.get("content", "")
|
|
69
|
+
if isinstance(c, str):
|
|
70
|
+
total_prompt_chars += len(c)
|
|
71
|
+
elif isinstance(c, list):
|
|
72
|
+
for part in c:
|
|
73
|
+
if isinstance(part, str):
|
|
74
|
+
total_prompt_chars += len(part)
|
|
75
|
+
elif isinstance(part, dict) and "text" in part:
|
|
76
|
+
total_prompt_chars += len(part["text"])
|
|
77
|
+
completion_chars = len(response.text) if response.text else 0
|
|
78
|
+
prompt_tokens = total_prompt_chars // 4
|
|
79
|
+
completion_tokens = completion_chars // 4
|
|
80
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
81
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"prompt_tokens": prompt_tokens,
|
|
85
|
+
"completion_tokens": completion_tokens,
|
|
86
|
+
"total_tokens": total_tokens,
|
|
87
|
+
"cost": round(cost, 6),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
supports_messages = True
|
|
91
|
+
|
|
92
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
93
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
94
|
+
|
|
95
|
+
return _prepare_google_vision_messages(messages)
|
|
96
|
+
|
|
97
|
+
def _build_generation_args(
|
|
98
|
+
self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
|
|
99
|
+
) -> tuple[Any, dict[str, Any], dict[str, Any]]:
|
|
100
|
+
"""Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
|
|
101
|
+
merged_options = self.options.copy()
|
|
102
|
+
if options:
|
|
103
|
+
merged_options.update(options)
|
|
104
|
+
|
|
105
|
+
generation_config = merged_options.get("generation_config", {})
|
|
106
|
+
safety_settings = merged_options.get("safety_settings", {})
|
|
107
|
+
|
|
108
|
+
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
109
|
+
generation_config["temperature"] = merged_options["temperature"]
|
|
110
|
+
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
111
|
+
generation_config["max_output_tokens"] = merged_options["max_tokens"]
|
|
112
|
+
if "top_p" in merged_options and "top_p" not in generation_config:
|
|
113
|
+
generation_config["top_p"] = merged_options["top_p"]
|
|
114
|
+
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
115
|
+
generation_config["top_k"] = merged_options["top_k"]
|
|
116
|
+
|
|
117
|
+
# Native JSON mode support
|
|
118
|
+
if merged_options.get("json_mode"):
|
|
119
|
+
generation_config["response_mime_type"] = "application/json"
|
|
120
|
+
json_schema = merged_options.get("json_schema")
|
|
121
|
+
if json_schema:
|
|
122
|
+
generation_config["response_schema"] = json_schema
|
|
123
|
+
|
|
124
|
+
# Convert messages to Gemini format
|
|
125
|
+
system_instruction = None
|
|
126
|
+
contents: list[dict[str, Any]] = []
|
|
127
|
+
for msg in messages:
|
|
128
|
+
role = msg.get("role", "user")
|
|
129
|
+
content = msg.get("content", "")
|
|
130
|
+
if role == "system":
|
|
131
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
132
|
+
else:
|
|
133
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
134
|
+
if msg.get("_vision_parts"):
|
|
135
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
136
|
+
else:
|
|
137
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
138
|
+
|
|
139
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
140
|
+
if len(contents) == 1:
|
|
141
|
+
parts = contents[0]["parts"]
|
|
142
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
143
|
+
gen_input = parts[0]
|
|
144
|
+
else:
|
|
145
|
+
gen_input = contents
|
|
146
|
+
else:
|
|
147
|
+
gen_input = contents
|
|
148
|
+
|
|
149
|
+
model_kwargs: dict[str, Any] = {}
|
|
150
|
+
if system_instruction:
|
|
151
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
152
|
+
|
|
153
|
+
gen_kwargs: dict[str, Any] = {
|
|
154
|
+
"generation_config": generation_config if generation_config else None,
|
|
155
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
159
|
+
|
|
160
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
161
|
+
messages = [{"role": "user", "content": prompt}]
|
|
162
|
+
return await self._do_generate(messages, options)
|
|
163
|
+
|
|
164
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
166
|
+
|
|
167
|
+
async def _do_generate(
|
|
168
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
174
|
+
response = await model.generate_content_async(gen_input, **gen_kwargs)
|
|
175
|
+
|
|
176
|
+
if not response.text:
|
|
177
|
+
raise ValueError("Empty response from model")
|
|
178
|
+
|
|
179
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
180
|
+
|
|
181
|
+
meta = {
|
|
182
|
+
**usage_meta,
|
|
183
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
184
|
+
"model_name": self.model,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
return {"text": response.text, "meta": meta}
|
|
188
|
+
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.error(f"Google API request failed: {e}")
|
|
191
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
192
|
+
|
|
193
|
+
# ------------------------------------------------------------------
|
|
194
|
+
# Tool use
|
|
195
|
+
# ------------------------------------------------------------------
|
|
196
|
+
|
|
197
|
+
async def generate_messages_with_tools(
|
|
198
|
+
self,
|
|
199
|
+
messages: list[dict[str, Any]],
|
|
200
|
+
tools: list[dict[str, Any]],
|
|
201
|
+
options: dict[str, Any],
|
|
202
|
+
) -> dict[str, Any]:
|
|
203
|
+
"""Generate a response that may include tool/function calls (async)."""
|
|
204
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
205
|
+
self._prepare_messages(messages), options
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
209
|
+
function_declarations = []
|
|
210
|
+
for t in tools:
|
|
211
|
+
if "type" in t and t["type"] == "function":
|
|
212
|
+
fn = t["function"]
|
|
213
|
+
decl = {
|
|
214
|
+
"name": fn["name"],
|
|
215
|
+
"description": fn.get("description", ""),
|
|
216
|
+
}
|
|
217
|
+
params = fn.get("parameters")
|
|
218
|
+
if params:
|
|
219
|
+
decl["parameters"] = params
|
|
220
|
+
function_declarations.append(decl)
|
|
221
|
+
elif "name" in t:
|
|
222
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
223
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
224
|
+
if params:
|
|
225
|
+
decl["parameters"] = params
|
|
226
|
+
function_declarations.append(decl)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
230
|
+
|
|
231
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
232
|
+
response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
233
|
+
|
|
234
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
235
|
+
meta = {
|
|
236
|
+
**usage_meta,
|
|
237
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
238
|
+
"model_name": self.model,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
text = ""
|
|
242
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
243
|
+
stop_reason = "stop"
|
|
244
|
+
|
|
245
|
+
for candidate in response.candidates:
|
|
246
|
+
for part in candidate.content.parts:
|
|
247
|
+
if hasattr(part, "text") and part.text:
|
|
248
|
+
text += part.text
|
|
249
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
250
|
+
fc = part.function_call
|
|
251
|
+
tool_calls_out.append({
|
|
252
|
+
"id": str(uuid.uuid4()),
|
|
253
|
+
"name": fc.name,
|
|
254
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
255
|
+
})
|
|
256
|
+
|
|
257
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
258
|
+
if finish_reason is not None:
|
|
259
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
260
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
261
|
+
|
|
262
|
+
if tool_calls_out:
|
|
263
|
+
stop_reason = "tool_use"
|
|
264
|
+
|
|
265
|
+
return {
|
|
266
|
+
"text": text,
|
|
267
|
+
"meta": meta,
|
|
268
|
+
"tool_calls": tool_calls_out,
|
|
269
|
+
"stop_reason": stop_reason,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
274
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
275
|
+
|
|
276
|
+
# ------------------------------------------------------------------
|
|
277
|
+
# Streaming
|
|
278
|
+
# ------------------------------------------------------------------
|
|
279
|
+
|
|
280
|
+
async def generate_messages_stream(
|
|
281
|
+
self,
|
|
282
|
+
messages: list[dict[str, Any]],
|
|
283
|
+
options: dict[str, Any],
|
|
284
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
285
|
+
"""Yield response chunks via Gemini async streaming API."""
|
|
286
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
287
|
+
self._prepare_messages(messages), options
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
292
|
+
response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
|
|
293
|
+
|
|
294
|
+
full_text = ""
|
|
295
|
+
async for chunk in response:
|
|
296
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
297
|
+
if chunk_text:
|
|
298
|
+
full_text += chunk_text
|
|
299
|
+
yield {"type": "delta", "text": chunk_text}
|
|
300
|
+
|
|
301
|
+
# After iteration completes, usage_metadata should be available
|
|
302
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
303
|
+
|
|
304
|
+
yield {
|
|
305
|
+
"type": "done",
|
|
306
|
+
"text": full_text,
|
|
307
|
+
"meta": {
|
|
308
|
+
**usage_meta,
|
|
309
|
+
"raw_response": {},
|
|
310
|
+
"model_name": self.model,
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
316
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Iterator
|
|
3
5
|
from typing import Any, Optional
|
|
4
6
|
|
|
5
7
|
import google.generativeai as genai
|
|
@@ -16,6 +18,8 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
16
18
|
supports_json_mode = True
|
|
17
19
|
supports_json_schema = True
|
|
18
20
|
supports_vision = True
|
|
21
|
+
supports_tool_use = True
|
|
22
|
+
supports_streaming = True
|
|
19
23
|
|
|
20
24
|
# Based on current Gemini pricing (as of 2025)
|
|
21
25
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
@@ -106,6 +110,40 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
106
110
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
107
111
|
return round(prompt_cost + completion_cost, 6)
|
|
108
112
|
|
|
113
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
114
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
115
|
+
usage = getattr(response, "usage_metadata", None)
|
|
116
|
+
if usage:
|
|
117
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
118
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
119
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
120
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
121
|
+
else:
|
|
122
|
+
# Fallback: estimate from character counts
|
|
123
|
+
total_prompt_chars = 0
|
|
124
|
+
for msg in messages:
|
|
125
|
+
c = msg.get("content", "")
|
|
126
|
+
if isinstance(c, str):
|
|
127
|
+
total_prompt_chars += len(c)
|
|
128
|
+
elif isinstance(c, list):
|
|
129
|
+
for part in c:
|
|
130
|
+
if isinstance(part, str):
|
|
131
|
+
total_prompt_chars += len(part)
|
|
132
|
+
elif isinstance(part, dict) and "text" in part:
|
|
133
|
+
total_prompt_chars += len(part["text"])
|
|
134
|
+
completion_chars = len(response.text) if response.text else 0
|
|
135
|
+
prompt_tokens = total_prompt_chars // 4
|
|
136
|
+
completion_tokens = completion_chars // 4
|
|
137
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
138
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"prompt_tokens": prompt_tokens,
|
|
142
|
+
"completion_tokens": completion_tokens,
|
|
143
|
+
"total_tokens": total_tokens,
|
|
144
|
+
"cost": round(cost, 6),
|
|
145
|
+
}
|
|
146
|
+
|
|
109
147
|
supports_messages = True
|
|
110
148
|
|
|
111
149
|
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
@@ -113,23 +151,21 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
113
151
|
|
|
114
152
|
return _prepare_google_vision_messages(messages)
|
|
115
153
|
|
|
116
|
-
def
|
|
117
|
-
messages
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
121
|
-
return self._do_generate(self._prepare_messages(messages), options)
|
|
154
|
+
def _build_generation_args(
|
|
155
|
+
self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
|
|
156
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
157
|
+
"""Parse messages and options into (gen_input, kwargs) for generate_content.
|
|
122
158
|
|
|
123
|
-
|
|
159
|
+
Returns the content input and a dict of keyword arguments
|
|
160
|
+
(generation_config, safety_settings, model kwargs including system_instruction).
|
|
161
|
+
"""
|
|
124
162
|
merged_options = self.options.copy()
|
|
125
163
|
if options:
|
|
126
164
|
merged_options.update(options)
|
|
127
165
|
|
|
128
|
-
# Extract specific options for Google's API
|
|
129
166
|
generation_config = merged_options.get("generation_config", {})
|
|
130
167
|
safety_settings = merged_options.get("safety_settings", {})
|
|
131
168
|
|
|
132
|
-
# Map common options to generation_config if not present
|
|
133
169
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
134
170
|
generation_config["temperature"] = merged_options["temperature"]
|
|
135
171
|
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
@@ -155,56 +191,57 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
155
191
|
if role == "system":
|
|
156
192
|
system_instruction = content if isinstance(content, str) else str(content)
|
|
157
193
|
else:
|
|
158
|
-
# Gemini uses "model" for assistant role
|
|
159
194
|
gemini_role = "model" if role == "assistant" else "user"
|
|
160
195
|
if msg.get("_vision_parts"):
|
|
161
|
-
# Already converted to Gemini parts by _prepare_messages
|
|
162
196
|
contents.append({"role": gemini_role, "parts": content})
|
|
163
197
|
else:
|
|
164
198
|
contents.append({"role": gemini_role, "parts": [content]})
|
|
165
199
|
|
|
200
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
201
|
+
if len(contents) == 1:
|
|
202
|
+
parts = contents[0]["parts"]
|
|
203
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
204
|
+
gen_input = parts[0]
|
|
205
|
+
else:
|
|
206
|
+
gen_input = contents
|
|
207
|
+
else:
|
|
208
|
+
gen_input = contents
|
|
209
|
+
|
|
210
|
+
model_kwargs: dict[str, Any] = {}
|
|
211
|
+
if system_instruction:
|
|
212
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
213
|
+
|
|
214
|
+
gen_kwargs: dict[str, Any] = {
|
|
215
|
+
"generation_config": generation_config if generation_config else None,
|
|
216
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
220
|
+
|
|
221
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
222
|
+
messages = [{"role": "user", "content": prompt}]
|
|
223
|
+
return self._do_generate(messages, options)
|
|
224
|
+
|
|
225
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
226
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
227
|
+
|
|
228
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
230
|
+
|
|
166
231
|
try:
|
|
167
232
|
logger.debug(f"Initializing {self.model} for generation")
|
|
168
|
-
model_kwargs: dict[str, Any] = {}
|
|
169
|
-
if system_instruction:
|
|
170
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
171
233
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
172
234
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
# If single user message, pass content directly for backward compatibility
|
|
176
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
177
|
-
response = model.generate_content(
|
|
178
|
-
gen_input,
|
|
179
|
-
generation_config=generation_config if generation_config else None,
|
|
180
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
181
|
-
)
|
|
235
|
+
logger.debug(f"Generating with model {self.model}")
|
|
236
|
+
response = model.generate_content(gen_input, **gen_kwargs)
|
|
182
237
|
|
|
183
238
|
if not response.text:
|
|
184
239
|
raise ValueError("Empty response from model")
|
|
185
240
|
|
|
186
|
-
|
|
187
|
-
total_prompt_chars = 0
|
|
188
|
-
for msg in messages:
|
|
189
|
-
c = msg.get("content", "")
|
|
190
|
-
if isinstance(c, str):
|
|
191
|
-
total_prompt_chars += len(c)
|
|
192
|
-
elif isinstance(c, list):
|
|
193
|
-
for part in c:
|
|
194
|
-
if isinstance(part, str):
|
|
195
|
-
total_prompt_chars += len(part)
|
|
196
|
-
elif isinstance(part, dict) and "text" in part:
|
|
197
|
-
total_prompt_chars += len(part["text"])
|
|
198
|
-
completion_chars = len(response.text)
|
|
199
|
-
|
|
200
|
-
# Google uses character-based cost estimation
|
|
201
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
241
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
202
242
|
|
|
203
243
|
meta = {
|
|
204
|
-
|
|
205
|
-
"completion_chars": completion_chars,
|
|
206
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
207
|
-
"cost": total_cost,
|
|
244
|
+
**usage_meta,
|
|
208
245
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
209
246
|
"model_name": self.model,
|
|
210
247
|
}
|
|
@@ -214,3 +251,130 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
214
251
|
except Exception as e:
|
|
215
252
|
logger.error(f"Google API request failed: {e}")
|
|
216
253
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
254
|
+
|
|
255
|
+
# ------------------------------------------------------------------
|
|
256
|
+
# Tool use
|
|
257
|
+
# ------------------------------------------------------------------
|
|
258
|
+
|
|
259
|
+
def generate_messages_with_tools(
|
|
260
|
+
self,
|
|
261
|
+
messages: list[dict[str, Any]],
|
|
262
|
+
tools: list[dict[str, Any]],
|
|
263
|
+
options: dict[str, Any],
|
|
264
|
+
) -> dict[str, Any]:
|
|
265
|
+
"""Generate a response that may include tool/function calls."""
|
|
266
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
267
|
+
self._prepare_messages(messages), options
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
271
|
+
function_declarations = []
|
|
272
|
+
for t in tools:
|
|
273
|
+
if "type" in t and t["type"] == "function":
|
|
274
|
+
fn = t["function"]
|
|
275
|
+
decl = {
|
|
276
|
+
"name": fn["name"],
|
|
277
|
+
"description": fn.get("description", ""),
|
|
278
|
+
}
|
|
279
|
+
params = fn.get("parameters")
|
|
280
|
+
if params:
|
|
281
|
+
decl["parameters"] = params
|
|
282
|
+
function_declarations.append(decl)
|
|
283
|
+
elif "name" in t:
|
|
284
|
+
# Already in a generic format
|
|
285
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
286
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
287
|
+
if params:
|
|
288
|
+
decl["parameters"] = params
|
|
289
|
+
function_declarations.append(decl)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
293
|
+
|
|
294
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
295
|
+
response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
296
|
+
|
|
297
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
298
|
+
meta = {
|
|
299
|
+
**usage_meta,
|
|
300
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
301
|
+
"model_name": self.model,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
text = ""
|
|
305
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
306
|
+
stop_reason = "stop"
|
|
307
|
+
|
|
308
|
+
for candidate in response.candidates:
|
|
309
|
+
for part in candidate.content.parts:
|
|
310
|
+
if hasattr(part, "text") and part.text:
|
|
311
|
+
text += part.text
|
|
312
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
313
|
+
fc = part.function_call
|
|
314
|
+
tool_calls_out.append({
|
|
315
|
+
"id": str(uuid.uuid4()),
|
|
316
|
+
"name": fc.name,
|
|
317
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
318
|
+
})
|
|
319
|
+
|
|
320
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
321
|
+
if finish_reason is not None:
|
|
322
|
+
# Map Gemini finish reasons to standard stop reasons
|
|
323
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
324
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
325
|
+
|
|
326
|
+
if tool_calls_out:
|
|
327
|
+
stop_reason = "tool_use"
|
|
328
|
+
|
|
329
|
+
return {
|
|
330
|
+
"text": text,
|
|
331
|
+
"meta": meta,
|
|
332
|
+
"tool_calls": tool_calls_out,
|
|
333
|
+
"stop_reason": stop_reason,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
338
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
339
|
+
|
|
340
|
+
# ------------------------------------------------------------------
|
|
341
|
+
# Streaming
|
|
342
|
+
# ------------------------------------------------------------------
|
|
343
|
+
|
|
344
|
+
def generate_messages_stream(
|
|
345
|
+
self,
|
|
346
|
+
messages: list[dict[str, Any]],
|
|
347
|
+
options: dict[str, Any],
|
|
348
|
+
) -> Iterator[dict[str, Any]]:
|
|
349
|
+
"""Yield response chunks via Gemini streaming API."""
|
|
350
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
351
|
+
self._prepare_messages(messages), options
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
356
|
+
response = model.generate_content(gen_input, stream=True, **gen_kwargs)
|
|
357
|
+
|
|
358
|
+
full_text = ""
|
|
359
|
+
for chunk in response:
|
|
360
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
361
|
+
if chunk_text:
|
|
362
|
+
full_text += chunk_text
|
|
363
|
+
yield {"type": "delta", "text": chunk_text}
|
|
364
|
+
|
|
365
|
+
# After iteration completes, resolve() has been called on the response
|
|
366
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
367
|
+
|
|
368
|
+
yield {
|
|
369
|
+
"type": "done",
|
|
370
|
+
"text": full_text,
|
|
371
|
+
"meta": {
|
|
372
|
+
**usage_meta,
|
|
373
|
+
"raw_response": {},
|
|
374
|
+
"model_name": self.model,
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
380
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|