prompture 0.0.38.dev1__tar.gz → 0.0.38.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. {prompture-0.0.38.dev1/prompture.egg-info → prompture-0.0.38.dev2}/PKG-INFO +1 -1
  2. prompture-0.0.38.dev2/docs/source/_templates/footer.html +16 -0
  3. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/conf.py +1 -1
  4. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/_version.py +2 -2
  5. prompture-0.0.38.dev2/prompture/drivers/async_google_driver.py +316 -0
  6. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/google_driver.py +207 -43
  7. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2/prompture.egg-info}/PKG-INFO +1 -1
  8. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/SOURCES.txt +1 -0
  9. prompture-0.0.38.dev1/prompture/drivers/async_google_driver.py +0 -152
  10. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-driver/SKILL.md +0 -0
  11. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-driver/references/driver-template.md +0 -0
  12. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-example/SKILL.md +0 -0
  13. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-field/SKILL.md +0 -0
  14. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/add-test/SKILL.md +0 -0
  15. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/run-tests/SKILL.md +0 -0
  16. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/scaffold-extraction/SKILL.md +0 -0
  17. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.claude/skills/update-pricing/SKILL.md +0 -0
  18. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.env.copy +0 -0
  19. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/FUNDING.yml +0 -0
  20. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/scripts/update_docs_version.py +0 -0
  21. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/scripts/update_wrapper_version.py +0 -0
  22. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/workflows/dev.yml +0 -0
  23. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/workflows/documentation.yml +0 -0
  24. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/.github/workflows/publish.yml +0 -0
  25. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/CLAUDE.md +0 -0
  26. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/LICENSE +0 -0
  27. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/MANIFEST.in +0 -0
  28. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/README.md +0 -0
  29. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/ROADMAP.md +0 -0
  30. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/_static/custom.css +0 -0
  31. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/core.rst +0 -0
  32. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/drivers.rst +0 -0
  33. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/field_definitions.rst +0 -0
  34. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/index.rst +0 -0
  35. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/runner.rst +0 -0
  36. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/tools.rst +0 -0
  37. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/api/validator.rst +0 -0
  38. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/contributing.rst +0 -0
  39. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/examples.rst +0 -0
  40. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/field_definitions_reference.rst +0 -0
  41. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/index.rst +0 -0
  42. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/installation.rst +0 -0
  43. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/quickstart.rst +0 -0
  44. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/docs/source/toon_input_guide.rst +0 -0
  45. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/README.md +0 -0
  46. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/README.md +0 -0
  47. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/llm_to_json/__init__.py +0 -0
  48. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/pyproject.toml +0 -0
  49. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_json/test.py +0 -0
  50. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/README.md +0 -0
  51. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/llm_to_toon/__init__.py +0 -0
  52. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/pyproject.toml +0 -0
  53. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/packages/llm_to_toon/test.py +0 -0
  54. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/__init__.py +0 -0
  55. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/agent.py +0 -0
  56. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/agent_types.py +0 -0
  57. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/aio/__init__.py +0 -0
  58. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_agent.py +0 -0
  59. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_conversation.py +0 -0
  60. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_core.py +0 -0
  61. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_driver.py +0 -0
  62. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/async_groups.py +0 -0
  63. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/cache.py +0 -0
  64. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/callbacks.py +0 -0
  65. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/cli.py +0 -0
  66. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/conversation.py +0 -0
  67. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/core.py +0 -0
  68. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/cost_mixin.py +0 -0
  69. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/discovery.py +0 -0
  70. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/driver.py +0 -0
  71. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/__init__.py +0 -0
  72. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/airllm_driver.py +0 -0
  73. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_airllm_driver.py +0 -0
  74. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_azure_driver.py +0 -0
  75. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_claude_driver.py +0 -0
  76. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_grok_driver.py +0 -0
  77. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_groq_driver.py +0 -0
  78. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_hugging_driver.py +0 -0
  79. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_lmstudio_driver.py +0 -0
  80. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_local_http_driver.py +0 -0
  81. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_ollama_driver.py +0 -0
  82. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_openai_driver.py +0 -0
  83. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_openrouter_driver.py +0 -0
  84. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/async_registry.py +0 -0
  85. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/azure_driver.py +0 -0
  86. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/claude_driver.py +0 -0
  87. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/grok_driver.py +0 -0
  88. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/groq_driver.py +0 -0
  89. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/hugging_driver.py +0 -0
  90. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/lmstudio_driver.py +0 -0
  91. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/local_http_driver.py +0 -0
  92. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/ollama_driver.py +0 -0
  93. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/openai_driver.py +0 -0
  94. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/openrouter_driver.py +0 -0
  95. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/registry.py +0 -0
  96. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/drivers/vision_helpers.py +0 -0
  97. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/field_definitions.py +0 -0
  98. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/group_types.py +0 -0
  99. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/groups.py +0 -0
  100. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/image.py +0 -0
  101. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/logging.py +0 -0
  102. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/model_rates.py +0 -0
  103. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/persistence.py +0 -0
  104. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/persona.py +0 -0
  105. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/runner.py +0 -0
  106. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/__init__.py +0 -0
  107. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/generator.py +0 -0
  108. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/Dockerfile.j2 +0 -0
  109. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/README.md.j2 +0 -0
  110. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/config.py.j2 +0 -0
  111. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/env.example.j2 +0 -0
  112. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/main.py.j2 +0 -0
  113. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/models.py.j2 +0 -0
  114. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/scaffold/templates/requirements.txt.j2 +0 -0
  115. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/serialization.py +0 -0
  116. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/server.py +0 -0
  117. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/session.py +0 -0
  118. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/settings.py +0 -0
  119. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/tools.py +0 -0
  120. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/tools_schema.py +0 -0
  121. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture/validator.py +0 -0
  122. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/dependency_links.txt +0 -0
  123. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/entry_points.txt +0 -0
  124. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/requires.txt +0 -0
  125. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/prompture.egg-info/top_level.txt +0 -0
  126. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/pyproject.toml +0 -0
  127. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/requirements.txt +0 -0
  128. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/setup.cfg +0 -0
  129. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/test.py +0 -0
  130. {prompture-0.0.38.dev1 → prompture-0.0.38.dev2}/test_version_diagnosis.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: prompture
3
- Version: 0.0.38.dev1
3
+ Version: 0.0.38.dev2
4
4
  Summary: Ask LLMs to return structured JSON and run cross-model tests. API-first.
5
5
  Author-email: Juan Denis <juan@vene.co>
6
6
  License-Expression: MIT
@@ -0,0 +1,16 @@
1
+ {%- extends "!footer.html" %}
2
+
3
+ {% block extrafooter %}
4
+ <script>
5
+ document.addEventListener("DOMContentLoaded", function() {
6
+ var footerCopy = document.querySelector("footer .copyright");
7
+ if (footerCopy) {
8
+ footerCopy.innerHTML = footerCopy.innerHTML.replace(
9
+ "Juan Denis",
10
+ '<a href="https://juandenis.com">Juan Denis</a>'
11
+ );
12
+ }
13
+ });
14
+ </script>
15
+ {{ super() }}
16
+ {% endblock %}
@@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath("../../"))
14
14
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
15
15
 
16
16
  project = "Prompture"
17
- copyright = '2026, <a href="https://juandenis.com">Juan Denis</a>'
17
+ copyright = '2026, Juan Denis'
18
18
  author = "Juan Denis"
19
19
 
20
20
  # Read version dynamically: VERSION file > setuptools_scm > fallback
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.38.dev1'
32
- __version_tuple__ = version_tuple = (0, 0, 38, 'dev1')
31
+ __version__ = version = '0.0.38.dev2'
32
+ __version_tuple__ = version_tuple = (0, 0, 38, 'dev2')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,316 @@
1
+ """Async Google Generative AI (Gemini) driver."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from collections.abc import AsyncIterator
9
+ from typing import Any
10
+
11
+ import google.generativeai as genai
12
+
13
+ from ..async_driver import AsyncDriver
14
+ from ..cost_mixin import CostMixin
15
+ from .google_driver import GoogleDriver
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class AsyncGoogleDriver(CostMixin, AsyncDriver):
21
+ """Async driver for Google's Generative AI API (Gemini)."""
22
+
23
+ supports_json_mode = True
24
+ supports_json_schema = True
25
+ supports_vision = True
26
+ supports_tool_use = True
27
+ supports_streaming = True
28
+
29
+ MODEL_PRICING = GoogleDriver.MODEL_PRICING
30
+ _PRICING_UNIT = 1_000_000
31
+
32
+ def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
33
+ self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
34
+ if not self.api_key:
35
+ raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
36
+ self.model = model
37
+ genai.configure(api_key=self.api_key)
38
+ self.options: dict[str, Any] = {}
39
+
40
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
41
+ """Calculate cost from character counts (same logic as sync GoogleDriver)."""
42
+ from ..model_rates import get_model_rates
43
+
44
+ live_rates = get_model_rates("google", self.model)
45
+ if live_rates:
46
+ est_prompt_tokens = prompt_chars / 4
47
+ est_completion_tokens = completion_chars / 4
48
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
49
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
50
+ else:
51
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
52
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
53
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
54
+ return round(prompt_cost + completion_cost, 6)
55
+
56
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
57
+ """Extract token counts from response, falling back to character estimation."""
58
+ usage = getattr(response, "usage_metadata", None)
59
+ if usage:
60
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
61
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
62
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
63
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
64
+ else:
65
+ # Fallback: estimate from character counts
66
+ total_prompt_chars = 0
67
+ for msg in messages:
68
+ c = msg.get("content", "")
69
+ if isinstance(c, str):
70
+ total_prompt_chars += len(c)
71
+ elif isinstance(c, list):
72
+ for part in c:
73
+ if isinstance(part, str):
74
+ total_prompt_chars += len(part)
75
+ elif isinstance(part, dict) and "text" in part:
76
+ total_prompt_chars += len(part["text"])
77
+ completion_chars = len(response.text) if response.text else 0
78
+ prompt_tokens = total_prompt_chars // 4
79
+ completion_tokens = completion_chars // 4
80
+ total_tokens = prompt_tokens + completion_tokens
81
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
82
+
83
+ return {
84
+ "prompt_tokens": prompt_tokens,
85
+ "completion_tokens": completion_tokens,
86
+ "total_tokens": total_tokens,
87
+ "cost": round(cost, 6),
88
+ }
89
+
90
+ supports_messages = True
91
+
92
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
93
+ from .vision_helpers import _prepare_google_vision_messages
94
+
95
+ return _prepare_google_vision_messages(messages)
96
+
97
+ def _build_generation_args(
98
+ self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
99
+ ) -> tuple[Any, dict[str, Any], dict[str, Any]]:
100
+ """Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
101
+ merged_options = self.options.copy()
102
+ if options:
103
+ merged_options.update(options)
104
+
105
+ generation_config = merged_options.get("generation_config", {})
106
+ safety_settings = merged_options.get("safety_settings", {})
107
+
108
+ if "temperature" in merged_options and "temperature" not in generation_config:
109
+ generation_config["temperature"] = merged_options["temperature"]
110
+ if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
111
+ generation_config["max_output_tokens"] = merged_options["max_tokens"]
112
+ if "top_p" in merged_options and "top_p" not in generation_config:
113
+ generation_config["top_p"] = merged_options["top_p"]
114
+ if "top_k" in merged_options and "top_k" not in generation_config:
115
+ generation_config["top_k"] = merged_options["top_k"]
116
+
117
+ # Native JSON mode support
118
+ if merged_options.get("json_mode"):
119
+ generation_config["response_mime_type"] = "application/json"
120
+ json_schema = merged_options.get("json_schema")
121
+ if json_schema:
122
+ generation_config["response_schema"] = json_schema
123
+
124
+ # Convert messages to Gemini format
125
+ system_instruction = None
126
+ contents: list[dict[str, Any]] = []
127
+ for msg in messages:
128
+ role = msg.get("role", "user")
129
+ content = msg.get("content", "")
130
+ if role == "system":
131
+ system_instruction = content if isinstance(content, str) else str(content)
132
+ else:
133
+ gemini_role = "model" if role == "assistant" else "user"
134
+ if msg.get("_vision_parts"):
135
+ contents.append({"role": gemini_role, "parts": content})
136
+ else:
137
+ contents.append({"role": gemini_role, "parts": [content]})
138
+
139
+ # For a single message, unwrap only if it has exactly one string part
140
+ if len(contents) == 1:
141
+ parts = contents[0]["parts"]
142
+ if len(parts) == 1 and isinstance(parts[0], str):
143
+ gen_input = parts[0]
144
+ else:
145
+ gen_input = contents
146
+ else:
147
+ gen_input = contents
148
+
149
+ model_kwargs: dict[str, Any] = {}
150
+ if system_instruction:
151
+ model_kwargs["system_instruction"] = system_instruction
152
+
153
+ gen_kwargs: dict[str, Any] = {
154
+ "generation_config": generation_config if generation_config else None,
155
+ "safety_settings": safety_settings if safety_settings else None,
156
+ }
157
+
158
+ return gen_input, gen_kwargs, model_kwargs
159
+
160
+ async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
161
+ messages = [{"role": "user", "content": prompt}]
162
+ return await self._do_generate(messages, options)
163
+
164
+ async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
165
+ return await self._do_generate(self._prepare_messages(messages), options)
166
+
167
+ async def _do_generate(
168
+ self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
169
+ ) -> dict[str, Any]:
170
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
171
+
172
+ try:
173
+ model = genai.GenerativeModel(self.model, **model_kwargs)
174
+ response = await model.generate_content_async(gen_input, **gen_kwargs)
175
+
176
+ if not response.text:
177
+ raise ValueError("Empty response from model")
178
+
179
+ usage_meta = self._extract_usage_metadata(response, messages)
180
+
181
+ meta = {
182
+ **usage_meta,
183
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
184
+ "model_name": self.model,
185
+ }
186
+
187
+ return {"text": response.text, "meta": meta}
188
+
189
+ except Exception as e:
190
+ logger.error(f"Google API request failed: {e}")
191
+ raise RuntimeError(f"Google API request failed: {e}") from e
192
+
193
+ # ------------------------------------------------------------------
194
+ # Tool use
195
+ # ------------------------------------------------------------------
196
+
197
+ async def generate_messages_with_tools(
198
+ self,
199
+ messages: list[dict[str, Any]],
200
+ tools: list[dict[str, Any]],
201
+ options: dict[str, Any],
202
+ ) -> dict[str, Any]:
203
+ """Generate a response that may include tool/function calls (async)."""
204
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
205
+ self._prepare_messages(messages), options
206
+ )
207
+
208
+ # Convert tools from OpenAI format to Gemini function declarations
209
+ function_declarations = []
210
+ for t in tools:
211
+ if "type" in t and t["type"] == "function":
212
+ fn = t["function"]
213
+ decl = {
214
+ "name": fn["name"],
215
+ "description": fn.get("description", ""),
216
+ }
217
+ params = fn.get("parameters")
218
+ if params:
219
+ decl["parameters"] = params
220
+ function_declarations.append(decl)
221
+ elif "name" in t:
222
+ decl = {"name": t["name"], "description": t.get("description", "")}
223
+ params = t.get("parameters") or t.get("input_schema")
224
+ if params:
225
+ decl["parameters"] = params
226
+ function_declarations.append(decl)
227
+
228
+ try:
229
+ model = genai.GenerativeModel(self.model, **model_kwargs)
230
+
231
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
232
+ response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
233
+
234
+ usage_meta = self._extract_usage_metadata(response, messages)
235
+ meta = {
236
+ **usage_meta,
237
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
238
+ "model_name": self.model,
239
+ }
240
+
241
+ text = ""
242
+ tool_calls_out: list[dict[str, Any]] = []
243
+ stop_reason = "stop"
244
+
245
+ for candidate in response.candidates:
246
+ for part in candidate.content.parts:
247
+ if hasattr(part, "text") and part.text:
248
+ text += part.text
249
+ if hasattr(part, "function_call") and part.function_call.name:
250
+ fc = part.function_call
251
+ tool_calls_out.append({
252
+ "id": str(uuid.uuid4()),
253
+ "name": fc.name,
254
+ "arguments": dict(fc.args) if fc.args else {},
255
+ })
256
+
257
+ finish_reason = getattr(candidate, "finish_reason", None)
258
+ if finish_reason is not None:
259
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
260
+ stop_reason = reason_map.get(finish_reason, "stop")
261
+
262
+ if tool_calls_out:
263
+ stop_reason = "tool_use"
264
+
265
+ return {
266
+ "text": text,
267
+ "meta": meta,
268
+ "tool_calls": tool_calls_out,
269
+ "stop_reason": stop_reason,
270
+ }
271
+
272
+ except Exception as e:
273
+ logger.error(f"Google API tool call request failed: {e}")
274
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
275
+
276
+ # ------------------------------------------------------------------
277
+ # Streaming
278
+ # ------------------------------------------------------------------
279
+
280
+ async def generate_messages_stream(
281
+ self,
282
+ messages: list[dict[str, Any]],
283
+ options: dict[str, Any],
284
+ ) -> AsyncIterator[dict[str, Any]]:
285
+ """Yield response chunks via Gemini async streaming API."""
286
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
287
+ self._prepare_messages(messages), options
288
+ )
289
+
290
+ try:
291
+ model = genai.GenerativeModel(self.model, **model_kwargs)
292
+ response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
293
+
294
+ full_text = ""
295
+ async for chunk in response:
296
+ chunk_text = getattr(chunk, "text", None) or ""
297
+ if chunk_text:
298
+ full_text += chunk_text
299
+ yield {"type": "delta", "text": chunk_text}
300
+
301
+ # After iteration completes, usage_metadata should be available
302
+ usage_meta = self._extract_usage_metadata(response, messages)
303
+
304
+ yield {
305
+ "type": "done",
306
+ "text": full_text,
307
+ "meta": {
308
+ **usage_meta,
309
+ "raw_response": {},
310
+ "model_name": self.model,
311
+ },
312
+ }
313
+
314
+ except Exception as e:
315
+ logger.error(f"Google API streaming request failed: {e}")
316
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e
@@ -1,5 +1,7 @@
1
1
  import logging
2
2
  import os
3
+ import uuid
4
+ from collections.abc import Iterator
3
5
  from typing import Any, Optional
4
6
 
5
7
  import google.generativeai as genai
@@ -16,6 +18,8 @@ class GoogleDriver(CostMixin, Driver):
16
18
  supports_json_mode = True
17
19
  supports_json_schema = True
18
20
  supports_vision = True
21
+ supports_tool_use = True
22
+ supports_streaming = True
19
23
 
20
24
  # Based on current Gemini pricing (as of 2025)
21
25
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
@@ -106,6 +110,40 @@ class GoogleDriver(CostMixin, Driver):
106
110
  completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
107
111
  return round(prompt_cost + completion_cost, 6)
108
112
 
113
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
114
+ """Extract token counts from response, falling back to character estimation."""
115
+ usage = getattr(response, "usage_metadata", None)
116
+ if usage:
117
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
118
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
119
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
120
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
121
+ else:
122
+ # Fallback: estimate from character counts
123
+ total_prompt_chars = 0
124
+ for msg in messages:
125
+ c = msg.get("content", "")
126
+ if isinstance(c, str):
127
+ total_prompt_chars += len(c)
128
+ elif isinstance(c, list):
129
+ for part in c:
130
+ if isinstance(part, str):
131
+ total_prompt_chars += len(part)
132
+ elif isinstance(part, dict) and "text" in part:
133
+ total_prompt_chars += len(part["text"])
134
+ completion_chars = len(response.text) if response.text else 0
135
+ prompt_tokens = total_prompt_chars // 4
136
+ completion_tokens = completion_chars // 4
137
+ total_tokens = prompt_tokens + completion_tokens
138
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
139
+
140
+ return {
141
+ "prompt_tokens": prompt_tokens,
142
+ "completion_tokens": completion_tokens,
143
+ "total_tokens": total_tokens,
144
+ "cost": round(cost, 6),
145
+ }
146
+
109
147
  supports_messages = True
110
148
 
111
149
  def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -113,23 +151,21 @@ class GoogleDriver(CostMixin, Driver):
113
151
 
114
152
  return _prepare_google_vision_messages(messages)
115
153
 
116
- def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
117
- messages = [{"role": "user", "content": prompt}]
118
- return self._do_generate(messages, options)
119
-
120
- def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
121
- return self._do_generate(self._prepare_messages(messages), options)
154
+ def _build_generation_args(
155
+ self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
156
+ ) -> tuple[Any, dict[str, Any]]:
157
+ """Parse messages and options into (gen_input, kwargs) for generate_content.
122
158
 
123
- def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
159
+ Returns the content input and a dict of keyword arguments
160
+ (generation_config, safety_settings, model kwargs including system_instruction).
161
+ """
124
162
  merged_options = self.options.copy()
125
163
  if options:
126
164
  merged_options.update(options)
127
165
 
128
- # Extract specific options for Google's API
129
166
  generation_config = merged_options.get("generation_config", {})
130
167
  safety_settings = merged_options.get("safety_settings", {})
131
168
 
132
- # Map common options to generation_config if not present
133
169
  if "temperature" in merged_options and "temperature" not in generation_config:
134
170
  generation_config["temperature"] = merged_options["temperature"]
135
171
  if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
@@ -155,56 +191,57 @@ class GoogleDriver(CostMixin, Driver):
155
191
  if role == "system":
156
192
  system_instruction = content if isinstance(content, str) else str(content)
157
193
  else:
158
- # Gemini uses "model" for assistant role
159
194
  gemini_role = "model" if role == "assistant" else "user"
160
195
  if msg.get("_vision_parts"):
161
- # Already converted to Gemini parts by _prepare_messages
162
196
  contents.append({"role": gemini_role, "parts": content})
163
197
  else:
164
198
  contents.append({"role": gemini_role, "parts": [content]})
165
199
 
200
+ # For a single message, unwrap only if it has exactly one string part
201
+ if len(contents) == 1:
202
+ parts = contents[0]["parts"]
203
+ if len(parts) == 1 and isinstance(parts[0], str):
204
+ gen_input = parts[0]
205
+ else:
206
+ gen_input = contents
207
+ else:
208
+ gen_input = contents
209
+
210
+ model_kwargs: dict[str, Any] = {}
211
+ if system_instruction:
212
+ model_kwargs["system_instruction"] = system_instruction
213
+
214
+ gen_kwargs: dict[str, Any] = {
215
+ "generation_config": generation_config if generation_config else None,
216
+ "safety_settings": safety_settings if safety_settings else None,
217
+ }
218
+
219
+ return gen_input, gen_kwargs, model_kwargs
220
+
221
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
222
+ messages = [{"role": "user", "content": prompt}]
223
+ return self._do_generate(messages, options)
224
+
225
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
226
+ return self._do_generate(self._prepare_messages(messages), options)
227
+
228
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
229
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
230
+
166
231
  try:
167
232
  logger.debug(f"Initializing {self.model} for generation")
168
- model_kwargs: dict[str, Any] = {}
169
- if system_instruction:
170
- model_kwargs["system_instruction"] = system_instruction
171
233
  model = genai.GenerativeModel(self.model, **model_kwargs)
172
234
 
173
- # Generate response
174
- logger.debug(f"Generating with {len(contents)} content parts")
175
- # If single user message, pass content directly for backward compatibility
176
- gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
177
- response = model.generate_content(
178
- gen_input,
179
- generation_config=generation_config if generation_config else None,
180
- safety_settings=safety_settings if safety_settings else None,
181
- )
235
+ logger.debug(f"Generating with model {self.model}")
236
+ response = model.generate_content(gen_input, **gen_kwargs)
182
237
 
183
238
  if not response.text:
184
239
  raise ValueError("Empty response from model")
185
240
 
186
- # Calculate token usage and cost
187
- total_prompt_chars = 0
188
- for msg in messages:
189
- c = msg.get("content", "")
190
- if isinstance(c, str):
191
- total_prompt_chars += len(c)
192
- elif isinstance(c, list):
193
- for part in c:
194
- if isinstance(part, str):
195
- total_prompt_chars += len(part)
196
- elif isinstance(part, dict) and "text" in part:
197
- total_prompt_chars += len(part["text"])
198
- completion_chars = len(response.text)
199
-
200
- # Google uses character-based cost estimation
201
- total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
241
+ usage_meta = self._extract_usage_metadata(response, messages)
202
242
 
203
243
  meta = {
204
- "prompt_chars": total_prompt_chars,
205
- "completion_chars": completion_chars,
206
- "total_chars": total_prompt_chars + completion_chars,
207
- "cost": total_cost,
244
+ **usage_meta,
208
245
  "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
209
246
  "model_name": self.model,
210
247
  }
@@ -214,3 +251,130 @@ class GoogleDriver(CostMixin, Driver):
214
251
  except Exception as e:
215
252
  logger.error(f"Google API request failed: {e}")
216
253
  raise RuntimeError(f"Google API request failed: {e}") from e
254
+
255
+ # ------------------------------------------------------------------
256
+ # Tool use
257
+ # ------------------------------------------------------------------
258
+
259
+ def generate_messages_with_tools(
260
+ self,
261
+ messages: list[dict[str, Any]],
262
+ tools: list[dict[str, Any]],
263
+ options: dict[str, Any],
264
+ ) -> dict[str, Any]:
265
+ """Generate a response that may include tool/function calls."""
266
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
267
+ self._prepare_messages(messages), options
268
+ )
269
+
270
+ # Convert tools from OpenAI format to Gemini function declarations
271
+ function_declarations = []
272
+ for t in tools:
273
+ if "type" in t and t["type"] == "function":
274
+ fn = t["function"]
275
+ decl = {
276
+ "name": fn["name"],
277
+ "description": fn.get("description", ""),
278
+ }
279
+ params = fn.get("parameters")
280
+ if params:
281
+ decl["parameters"] = params
282
+ function_declarations.append(decl)
283
+ elif "name" in t:
284
+ # Already in a generic format
285
+ decl = {"name": t["name"], "description": t.get("description", "")}
286
+ params = t.get("parameters") or t.get("input_schema")
287
+ if params:
288
+ decl["parameters"] = params
289
+ function_declarations.append(decl)
290
+
291
+ try:
292
+ model = genai.GenerativeModel(self.model, **model_kwargs)
293
+
294
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
295
+ response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
296
+
297
+ usage_meta = self._extract_usage_metadata(response, messages)
298
+ meta = {
299
+ **usage_meta,
300
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
301
+ "model_name": self.model,
302
+ }
303
+
304
+ text = ""
305
+ tool_calls_out: list[dict[str, Any]] = []
306
+ stop_reason = "stop"
307
+
308
+ for candidate in response.candidates:
309
+ for part in candidate.content.parts:
310
+ if hasattr(part, "text") and part.text:
311
+ text += part.text
312
+ if hasattr(part, "function_call") and part.function_call.name:
313
+ fc = part.function_call
314
+ tool_calls_out.append({
315
+ "id": str(uuid.uuid4()),
316
+ "name": fc.name,
317
+ "arguments": dict(fc.args) if fc.args else {},
318
+ })
319
+
320
+ finish_reason = getattr(candidate, "finish_reason", None)
321
+ if finish_reason is not None:
322
+ # Map Gemini finish reasons to standard stop reasons
323
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
324
+ stop_reason = reason_map.get(finish_reason, "stop")
325
+
326
+ if tool_calls_out:
327
+ stop_reason = "tool_use"
328
+
329
+ return {
330
+ "text": text,
331
+ "meta": meta,
332
+ "tool_calls": tool_calls_out,
333
+ "stop_reason": stop_reason,
334
+ }
335
+
336
+ except Exception as e:
337
+ logger.error(f"Google API tool call request failed: {e}")
338
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
339
+
340
+ # ------------------------------------------------------------------
341
+ # Streaming
342
+ # ------------------------------------------------------------------
343
+
344
+ def generate_messages_stream(
345
+ self,
346
+ messages: list[dict[str, Any]],
347
+ options: dict[str, Any],
348
+ ) -> Iterator[dict[str, Any]]:
349
+ """Yield response chunks via Gemini streaming API."""
350
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
351
+ self._prepare_messages(messages), options
352
+ )
353
+
354
+ try:
355
+ model = genai.GenerativeModel(self.model, **model_kwargs)
356
+ response = model.generate_content(gen_input, stream=True, **gen_kwargs)
357
+
358
+ full_text = ""
359
+ for chunk in response:
360
+ chunk_text = getattr(chunk, "text", None) or ""
361
+ if chunk_text:
362
+ full_text += chunk_text
363
+ yield {"type": "delta", "text": chunk_text}
364
+
365
+ # After iteration completes, resolve() has been called on the response
366
+ usage_meta = self._extract_usage_metadata(response, messages)
367
+
368
+ yield {
369
+ "type": "done",
370
+ "text": full_text,
371
+ "meta": {
372
+ **usage_meta,
373
+ "raw_response": {},
374
+ "model_name": self.model,
375
+ },
376
+ }
377
+
378
+ except Exception as e:
379
+ logger.error(f"Google API streaming request failed: {e}")
380
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e