prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. prompture/__init__.py +132 -3
  2. prompture/_version.py +2 -2
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/async_agent.py +880 -0
  6. prompture/async_conversation.py +208 -17
  7. prompture/async_core.py +16 -0
  8. prompture/async_driver.py +63 -0
  9. prompture/async_groups.py +551 -0
  10. prompture/conversation.py +222 -18
  11. prompture/core.py +46 -12
  12. prompture/cost_mixin.py +37 -0
  13. prompture/discovery.py +132 -44
  14. prompture/driver.py +77 -0
  15. prompture/drivers/__init__.py +5 -1
  16. prompture/drivers/async_azure_driver.py +11 -5
  17. prompture/drivers/async_claude_driver.py +184 -9
  18. prompture/drivers/async_google_driver.py +222 -28
  19. prompture/drivers/async_grok_driver.py +11 -5
  20. prompture/drivers/async_groq_driver.py +11 -5
  21. prompture/drivers/async_lmstudio_driver.py +74 -5
  22. prompture/drivers/async_ollama_driver.py +13 -3
  23. prompture/drivers/async_openai_driver.py +162 -5
  24. prompture/drivers/async_openrouter_driver.py +11 -5
  25. prompture/drivers/async_registry.py +5 -1
  26. prompture/drivers/azure_driver.py +10 -4
  27. prompture/drivers/claude_driver.py +17 -1
  28. prompture/drivers/google_driver.py +227 -33
  29. prompture/drivers/grok_driver.py +11 -5
  30. prompture/drivers/groq_driver.py +11 -5
  31. prompture/drivers/lmstudio_driver.py +73 -8
  32. prompture/drivers/ollama_driver.py +16 -5
  33. prompture/drivers/openai_driver.py +26 -11
  34. prompture/drivers/openrouter_driver.py +11 -5
  35. prompture/drivers/vision_helpers.py +153 -0
  36. prompture/group_types.py +147 -0
  37. prompture/groups.py +530 -0
  38. prompture/image.py +180 -0
  39. prompture/ledger.py +252 -0
  40. prompture/model_rates.py +112 -2
  41. prompture/persistence.py +254 -0
  42. prompture/persona.py +482 -0
  43. prompture/serialization.py +218 -0
  44. prompture/settings.py +1 -0
  45. prompture-0.0.40.dev1.dist-info/METADATA +369 -0
  46. prompture-0.0.40.dev1.dist-info/RECORD +78 -0
  47. prompture-0.0.35.dist-info/METADATA +0 -464
  48. prompture-0.0.35.dist-info/RECORD +0 -66
  49. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  50. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  51. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  52. {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import json
5
6
  import os
7
+ from collections.abc import AsyncIterator
6
8
  from typing import Any
7
9
 
8
10
  try:
@@ -18,6 +20,9 @@ from .openai_driver import OpenAIDriver
18
20
  class AsyncOpenAIDriver(CostMixin, AsyncDriver):
19
21
  supports_json_mode = True
20
22
  supports_json_schema = True
23
+ supports_tool_use = True
24
+ supports_streaming = True
25
+ supports_vision = True
21
26
 
22
27
  MODEL_PRICING = OpenAIDriver.MODEL_PRICING
23
28
 
@@ -31,12 +36,17 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
31
36
 
32
37
  supports_messages = True
33
38
 
39
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
40
+ from .vision_helpers import _prepare_openai_vision_messages
41
+
42
+ return _prepare_openai_vision_messages(messages)
43
+
34
44
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
45
  messages = [{"role": "user", "content": prompt}]
36
46
  return await self._do_generate(messages, options)
37
47
 
38
48
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
- return await self._do_generate(messages, options)
49
+ return await self._do_generate(self._prepare_messages(messages), options)
40
50
 
41
51
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
52
  if self.client is None:
@@ -44,9 +54,16 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
44
54
 
45
55
  model = options.get("model", self.model)
46
56
 
47
- model_info = self.MODEL_PRICING.get(model, {})
48
- tokens_param = model_info.get("tokens_param", "max_tokens")
49
- supports_temperature = model_info.get("supports_temperature", True)
57
+ model_config = self._get_model_config("openai", model)
58
+ tokens_param = model_config["tokens_param"]
59
+ supports_temperature = model_config["supports_temperature"]
60
+
61
+ # Validate capabilities against models.dev metadata
62
+ self._validate_model_capabilities(
63
+ "openai",
64
+ model,
65
+ using_json_schema=bool(options.get("json_schema")),
66
+ )
50
67
 
51
68
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
52
69
 
@@ -87,10 +104,150 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
87
104
  "prompt_tokens": prompt_tokens,
88
105
  "completion_tokens": completion_tokens,
89
106
  "total_tokens": total_tokens,
90
- "cost": total_cost,
107
+ "cost": round(total_cost, 6),
91
108
  "raw_response": resp.model_dump(),
92
109
  "model_name": model,
93
110
  }
94
111
 
95
112
  text = resp.choices[0].message.content
96
113
  return {"text": text, "meta": meta}
114
+
115
+ # ------------------------------------------------------------------
116
+ # Tool use
117
+ # ------------------------------------------------------------------
118
+
119
+ async def generate_messages_with_tools(
120
+ self,
121
+ messages: list[dict[str, Any]],
122
+ tools: list[dict[str, Any]],
123
+ options: dict[str, Any],
124
+ ) -> dict[str, Any]:
125
+ """Generate a response that may include tool calls."""
126
+ if self.client is None:
127
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
128
+
129
+ model = options.get("model", self.model)
130
+ model_config = self._get_model_config("openai", model)
131
+ tokens_param = model_config["tokens_param"]
132
+ supports_temperature = model_config["supports_temperature"]
133
+
134
+ self._validate_model_capabilities("openai", model, using_tool_use=True)
135
+
136
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
137
+
138
+ kwargs: dict[str, Any] = {
139
+ "model": model,
140
+ "messages": messages,
141
+ "tools": tools,
142
+ }
143
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
144
+
145
+ if supports_temperature and "temperature" in opts:
146
+ kwargs["temperature"] = opts["temperature"]
147
+
148
+ resp = await self.client.chat.completions.create(**kwargs)
149
+
150
+ usage = getattr(resp, "usage", None)
151
+ prompt_tokens = getattr(usage, "prompt_tokens", 0)
152
+ completion_tokens = getattr(usage, "completion_tokens", 0)
153
+ total_tokens = getattr(usage, "total_tokens", 0)
154
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
155
+
156
+ meta = {
157
+ "prompt_tokens": prompt_tokens,
158
+ "completion_tokens": completion_tokens,
159
+ "total_tokens": total_tokens,
160
+ "cost": round(total_cost, 6),
161
+ "raw_response": resp.model_dump(),
162
+ "model_name": model,
163
+ }
164
+
165
+ choice = resp.choices[0]
166
+ text = choice.message.content or ""
167
+ stop_reason = choice.finish_reason
168
+
169
+ tool_calls_out: list[dict[str, Any]] = []
170
+ if choice.message.tool_calls:
171
+ for tc in choice.message.tool_calls:
172
+ try:
173
+ args = json.loads(tc.function.arguments)
174
+ except (json.JSONDecodeError, TypeError):
175
+ args = {}
176
+ tool_calls_out.append({
177
+ "id": tc.id,
178
+ "name": tc.function.name,
179
+ "arguments": args,
180
+ })
181
+
182
+ return {
183
+ "text": text,
184
+ "meta": meta,
185
+ "tool_calls": tool_calls_out,
186
+ "stop_reason": stop_reason,
187
+ }
188
+
189
+ # ------------------------------------------------------------------
190
+ # Streaming
191
+ # ------------------------------------------------------------------
192
+
193
+ async def generate_messages_stream(
194
+ self,
195
+ messages: list[dict[str, Any]],
196
+ options: dict[str, Any],
197
+ ) -> AsyncIterator[dict[str, Any]]:
198
+ """Yield response chunks via OpenAI streaming API."""
199
+ if self.client is None:
200
+ raise RuntimeError("openai package (>=1.0.0) is not installed")
201
+
202
+ model = options.get("model", self.model)
203
+ model_config = self._get_model_config("openai", model)
204
+ tokens_param = model_config["tokens_param"]
205
+ supports_temperature = model_config["supports_temperature"]
206
+
207
+ opts = {"temperature": 1.0, "max_tokens": 512, **options}
208
+
209
+ kwargs: dict[str, Any] = {
210
+ "model": model,
211
+ "messages": messages,
212
+ "stream": True,
213
+ "stream_options": {"include_usage": True},
214
+ }
215
+ kwargs[tokens_param] = opts.get("max_tokens", 512)
216
+
217
+ if supports_temperature and "temperature" in opts:
218
+ kwargs["temperature"] = opts["temperature"]
219
+
220
+ stream = await self.client.chat.completions.create(**kwargs)
221
+
222
+ full_text = ""
223
+ prompt_tokens = 0
224
+ completion_tokens = 0
225
+
226
+ async for chunk in stream:
227
+ # Usage comes in the final chunk
228
+ if getattr(chunk, "usage", None):
229
+ prompt_tokens = chunk.usage.prompt_tokens or 0
230
+ completion_tokens = chunk.usage.completion_tokens or 0
231
+
232
+ if chunk.choices:
233
+ delta = chunk.choices[0].delta
234
+ content = getattr(delta, "content", None) or ""
235
+ if content:
236
+ full_text += content
237
+ yield {"type": "delta", "text": content}
238
+
239
+ total_tokens = prompt_tokens + completion_tokens
240
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
241
+
242
+ yield {
243
+ "type": "done",
244
+ "text": full_text,
245
+ "meta": {
246
+ "prompt_tokens": prompt_tokens,
247
+ "completion_tokens": completion_tokens,
248
+ "total_tokens": total_tokens,
249
+ "cost": round(total_cost, 6),
250
+ "raw_response": {},
251
+ "model_name": model,
252
+ },
253
+ }
@@ -14,6 +14,7 @@ from .openrouter_driver import OpenRouterDriver
14
14
 
15
15
  class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
16
16
  supports_json_mode = True
17
+ supports_vision = True
17
18
 
18
19
  MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
19
20
 
@@ -31,19 +32,24 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
31
32
 
32
33
  supports_messages = True
33
34
 
35
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ from .vision_helpers import _prepare_openai_vision_messages
37
+
38
+ return _prepare_openai_vision_messages(messages)
39
+
34
40
  async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
35
41
  messages = [{"role": "user", "content": prompt}]
36
42
  return await self._do_generate(messages, options)
37
43
 
38
44
  async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
39
- return await self._do_generate(messages, options)
45
+ return await self._do_generate(self._prepare_messages(messages), options)
40
46
 
41
47
  async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
42
48
  model = options.get("model", self.model)
43
49
 
44
- model_info = self.MODEL_PRICING.get(model, {})
45
- tokens_param = model_info.get("tokens_param", "max_tokens")
46
- supports_temperature = model_info.get("supports_temperature", True)
50
+ model_config = self._get_model_config("openrouter", model)
51
+ tokens_param = model_config["tokens_param"]
52
+ supports_temperature = model_config["supports_temperature"]
47
53
 
48
54
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
49
55
 
@@ -87,7 +93,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
87
93
  "prompt_tokens": prompt_tokens,
88
94
  "completion_tokens": completion_tokens,
89
95
  "total_tokens": total_tokens,
90
- "cost": total_cost,
96
+ "cost": round(total_cost, 6),
91
97
  "raw_response": resp,
92
98
  "model_name": model,
93
99
  }
@@ -49,7 +49,11 @@ register_async_driver(
49
49
  )
50
50
  register_async_driver(
51
51
  "lmstudio",
52
- lambda model=None: AsyncLMStudioDriver(endpoint=settings.lmstudio_endpoint, model=model or settings.lmstudio_model),
52
+ lambda model=None: AsyncLMStudioDriver(
53
+ endpoint=settings.lmstudio_endpoint,
54
+ model=model or settings.lmstudio_model,
55
+ api_key=settings.lmstudio_api_key,
56
+ ),
53
57
  overwrite=True,
54
58
  )
55
59
  register_async_driver(
@@ -17,6 +17,7 @@ from ..driver import Driver
17
17
  class AzureDriver(CostMixin, Driver):
18
18
  supports_json_mode = True
19
19
  supports_json_schema = True
20
+ supports_vision = True
20
21
 
21
22
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
22
23
  MODEL_PRICING = {
@@ -90,21 +91,26 @@ class AzureDriver(CostMixin, Driver):
90
91
 
91
92
  supports_messages = True
92
93
 
94
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
95
+ from .vision_helpers import _prepare_openai_vision_messages
96
+
97
+ return _prepare_openai_vision_messages(messages)
98
+
93
99
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
94
100
  messages = [{"role": "user", "content": prompt}]
95
101
  return self._do_generate(messages, options)
96
102
 
97
103
  def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
- return self._do_generate(messages, options)
104
+ return self._do_generate(self._prepare_messages(messages), options)
99
105
 
100
106
  def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
101
107
  if self.client is None:
102
108
  raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
103
109
 
104
110
  model = options.get("model", self.model)
105
- model_info = self.MODEL_PRICING.get(model, {})
106
- tokens_param = model_info.get("tokens_param", "max_tokens")
107
- supports_temperature = model_info.get("supports_temperature", True)
111
+ model_config = self._get_model_config("azure", model)
112
+ tokens_param = model_config["tokens_param"]
113
+ supports_temperature = model_config["supports_temperature"]
108
114
 
109
115
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
110
116
 
@@ -21,6 +21,7 @@ class ClaudeDriver(CostMixin, Driver):
21
21
  supports_json_schema = True
22
22
  supports_tool_use = True
23
23
  supports_streaming = True
24
+ supports_vision = True
24
25
 
25
26
  # Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
26
27
  MODEL_PRICING = {
@@ -57,12 +58,17 @@ class ClaudeDriver(CostMixin, Driver):
57
58
 
58
59
  supports_messages = True
59
60
 
61
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
62
+ from .vision_helpers import _prepare_claude_vision_messages
63
+
64
+ return _prepare_claude_vision_messages(messages)
65
+
60
66
  def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
61
67
  messages = [{"role": "user", "content": prompt}]
62
68
  return self._do_generate(messages, options)
63
69
 
64
70
  def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
65
- return self._do_generate(messages, options)
71
+ return self._do_generate(self._prepare_messages(messages), options)
66
72
 
67
73
  def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
68
74
  if anthropic is None:
@@ -71,6 +77,13 @@ class ClaudeDriver(CostMixin, Driver):
71
77
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
72
78
  model = options.get("model", self.model)
73
79
 
80
+ # Validate capabilities against models.dev metadata
81
+ self._validate_model_capabilities(
82
+ "claude",
83
+ model,
84
+ using_json_schema=bool(options.get("json_schema")),
85
+ )
86
+
74
87
  client = anthropic.Anthropic(api_key=self.api_key)
75
88
 
76
89
  # Anthropic requires system messages as a top-level parameter
@@ -171,6 +184,9 @@ class ClaudeDriver(CostMixin, Driver):
171
184
 
172
185
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
173
186
  model = options.get("model", self.model)
187
+
188
+ self._validate_model_capabilities("claude", model, using_tool_use=True)
189
+
174
190
  client = anthropic.Anthropic(api_key=self.api_key)
175
191
 
176
192
  system_content, api_messages = self._extract_system_and_messages(messages)
@@ -1,5 +1,7 @@
1
1
  import logging
2
2
  import os
3
+ import uuid
4
+ from collections.abc import Iterator
3
5
  from typing import Any, Optional
4
6
 
5
7
  import google.generativeai as genai
@@ -15,6 +17,9 @@ class GoogleDriver(CostMixin, Driver):
15
17
 
16
18
  supports_json_mode = True
17
19
  supports_json_schema = True
20
+ supports_vision = True
21
+ supports_tool_use = True
22
+ supports_streaming = True
18
23
 
19
24
  # Based on current Gemini pricing (as of 2025)
20
25
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
@@ -105,25 +110,62 @@ class GoogleDriver(CostMixin, Driver):
105
110
  completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
106
111
  return round(prompt_cost + completion_cost, 6)
107
112
 
113
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
114
+ """Extract token counts from response, falling back to character estimation."""
115
+ usage = getattr(response, "usage_metadata", None)
116
+ if usage:
117
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
118
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
119
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
120
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
121
+ else:
122
+ # Fallback: estimate from character counts
123
+ total_prompt_chars = 0
124
+ for msg in messages:
125
+ c = msg.get("content", "")
126
+ if isinstance(c, str):
127
+ total_prompt_chars += len(c)
128
+ elif isinstance(c, list):
129
+ for part in c:
130
+ if isinstance(part, str):
131
+ total_prompt_chars += len(part)
132
+ elif isinstance(part, dict) and "text" in part:
133
+ total_prompt_chars += len(part["text"])
134
+ completion_chars = len(response.text) if response.text else 0
135
+ prompt_tokens = total_prompt_chars // 4
136
+ completion_tokens = completion_chars // 4
137
+ total_tokens = prompt_tokens + completion_tokens
138
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
139
+
140
+ return {
141
+ "prompt_tokens": prompt_tokens,
142
+ "completion_tokens": completion_tokens,
143
+ "total_tokens": total_tokens,
144
+ "cost": round(cost, 6),
145
+ }
146
+
108
147
  supports_messages = True
109
148
 
110
- def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
111
- messages = [{"role": "user", "content": prompt}]
112
- return self._do_generate(messages, options)
149
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
150
+ from .vision_helpers import _prepare_google_vision_messages
113
151
 
114
- def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
115
- return self._do_generate(messages, options)
152
+ return _prepare_google_vision_messages(messages)
116
153
 
117
- def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
154
+ def _build_generation_args(
155
+ self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
156
+ ) -> tuple[Any, dict[str, Any]]:
157
+ """Parse messages and options into (gen_input, kwargs) for generate_content.
158
+
159
+ Returns the content input and a dict of keyword arguments
160
+ (generation_config, safety_settings, model kwargs including system_instruction).
161
+ """
118
162
  merged_options = self.options.copy()
119
163
  if options:
120
164
  merged_options.update(options)
121
165
 
122
- # Extract specific options for Google's API
123
166
  generation_config = merged_options.get("generation_config", {})
124
167
  safety_settings = merged_options.get("safety_settings", {})
125
168
 
126
- # Map common options to generation_config if not present
127
169
  if "temperature" in merged_options and "temperature" not in generation_config:
128
170
  generation_config["temperature"] = merged_options["temperature"]
129
171
  if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
@@ -147,44 +189,66 @@ class GoogleDriver(CostMixin, Driver):
147
189
  role = msg.get("role", "user")
148
190
  content = msg.get("content", "")
149
191
  if role == "system":
150
- system_instruction = content
192
+ system_instruction = content if isinstance(content, str) else str(content)
151
193
  else:
152
- # Gemini uses "model" for assistant role
153
194
  gemini_role = "model" if role == "assistant" else "user"
154
- contents.append({"role": gemini_role, "parts": [content]})
195
+ if msg.get("_vision_parts"):
196
+ contents.append({"role": gemini_role, "parts": content})
197
+ else:
198
+ contents.append({"role": gemini_role, "parts": [content]})
199
+
200
+ # For a single message, unwrap only if it has exactly one string part
201
+ if len(contents) == 1:
202
+ parts = contents[0]["parts"]
203
+ if len(parts) == 1 and isinstance(parts[0], str):
204
+ gen_input = parts[0]
205
+ else:
206
+ gen_input = contents
207
+ else:
208
+ gen_input = contents
209
+
210
+ model_kwargs: dict[str, Any] = {}
211
+ if system_instruction:
212
+ model_kwargs["system_instruction"] = system_instruction
213
+
214
+ gen_kwargs: dict[str, Any] = {
215
+ "generation_config": generation_config if generation_config else None,
216
+ "safety_settings": safety_settings if safety_settings else None,
217
+ }
218
+
219
+ return gen_input, gen_kwargs, model_kwargs
220
+
221
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
222
+ messages = [{"role": "user", "content": prompt}]
223
+ return self._do_generate(messages, options)
224
+
225
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
226
+ return self._do_generate(self._prepare_messages(messages), options)
227
+
228
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
229
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
230
+
231
+ # Validate capabilities against models.dev metadata
232
+ self._validate_model_capabilities(
233
+ "google",
234
+ self.model,
235
+ using_json_schema=bool((options or {}).get("json_schema")),
236
+ )
155
237
 
156
238
  try:
157
239
  logger.debug(f"Initializing {self.model} for generation")
158
- model_kwargs: dict[str, Any] = {}
159
- if system_instruction:
160
- model_kwargs["system_instruction"] = system_instruction
161
240
  model = genai.GenerativeModel(self.model, **model_kwargs)
162
241
 
163
- # Generate response
164
- logger.debug(f"Generating with {len(contents)} content parts")
165
- # If single user message, pass content directly for backward compatibility
166
- gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
167
- response = model.generate_content(
168
- gen_input,
169
- generation_config=generation_config if generation_config else None,
170
- safety_settings=safety_settings if safety_settings else None,
171
- )
242
+ logger.debug(f"Generating with model {self.model}")
243
+ response = model.generate_content(gen_input, **gen_kwargs)
172
244
 
173
245
  if not response.text:
174
246
  raise ValueError("Empty response from model")
175
247
 
176
- # Calculate token usage and cost
177
- total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
178
- completion_chars = len(response.text)
179
-
180
- # Google uses character-based cost estimation
181
- total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
248
+ usage_meta = self._extract_usage_metadata(response, messages)
182
249
 
183
250
  meta = {
184
- "prompt_chars": total_prompt_chars,
185
- "completion_chars": completion_chars,
186
- "total_chars": total_prompt_chars + completion_chars,
187
- "cost": total_cost,
251
+ **usage_meta,
188
252
  "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
189
253
  "model_name": self.model,
190
254
  }
@@ -194,3 +258,133 @@ class GoogleDriver(CostMixin, Driver):
194
258
  except Exception as e:
195
259
  logger.error(f"Google API request failed: {e}")
196
260
  raise RuntimeError(f"Google API request failed: {e}") from e
261
+
262
+ # ------------------------------------------------------------------
263
+ # Tool use
264
+ # ------------------------------------------------------------------
265
+
266
+ def generate_messages_with_tools(
267
+ self,
268
+ messages: list[dict[str, Any]],
269
+ tools: list[dict[str, Any]],
270
+ options: dict[str, Any],
271
+ ) -> dict[str, Any]:
272
+ """Generate a response that may include tool/function calls."""
273
+ model = options.get("model", self.model)
274
+ self._validate_model_capabilities("google", model, using_tool_use=True)
275
+
276
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
277
+ self._prepare_messages(messages), options
278
+ )
279
+
280
+ # Convert tools from OpenAI format to Gemini function declarations
281
+ function_declarations = []
282
+ for t in tools:
283
+ if "type" in t and t["type"] == "function":
284
+ fn = t["function"]
285
+ decl = {
286
+ "name": fn["name"],
287
+ "description": fn.get("description", ""),
288
+ }
289
+ params = fn.get("parameters")
290
+ if params:
291
+ decl["parameters"] = params
292
+ function_declarations.append(decl)
293
+ elif "name" in t:
294
+ # Already in a generic format
295
+ decl = {"name": t["name"], "description": t.get("description", "")}
296
+ params = t.get("parameters") or t.get("input_schema")
297
+ if params:
298
+ decl["parameters"] = params
299
+ function_declarations.append(decl)
300
+
301
+ try:
302
+ model = genai.GenerativeModel(self.model, **model_kwargs)
303
+
304
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
305
+ response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
306
+
307
+ usage_meta = self._extract_usage_metadata(response, messages)
308
+ meta = {
309
+ **usage_meta,
310
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
311
+ "model_name": self.model,
312
+ }
313
+
314
+ text = ""
315
+ tool_calls_out: list[dict[str, Any]] = []
316
+ stop_reason = "stop"
317
+
318
+ for candidate in response.candidates:
319
+ for part in candidate.content.parts:
320
+ if hasattr(part, "text") and part.text:
321
+ text += part.text
322
+ if hasattr(part, "function_call") and part.function_call.name:
323
+ fc = part.function_call
324
+ tool_calls_out.append({
325
+ "id": str(uuid.uuid4()),
326
+ "name": fc.name,
327
+ "arguments": dict(fc.args) if fc.args else {},
328
+ })
329
+
330
+ finish_reason = getattr(candidate, "finish_reason", None)
331
+ if finish_reason is not None:
332
+ # Map Gemini finish reasons to standard stop reasons
333
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
334
+ stop_reason = reason_map.get(finish_reason, "stop")
335
+
336
+ if tool_calls_out:
337
+ stop_reason = "tool_use"
338
+
339
+ return {
340
+ "text": text,
341
+ "meta": meta,
342
+ "tool_calls": tool_calls_out,
343
+ "stop_reason": stop_reason,
344
+ }
345
+
346
+ except Exception as e:
347
+ logger.error(f"Google API tool call request failed: {e}")
348
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
349
+
350
+ # ------------------------------------------------------------------
351
+ # Streaming
352
+ # ------------------------------------------------------------------
353
+
354
+ def generate_messages_stream(
355
+ self,
356
+ messages: list[dict[str, Any]],
357
+ options: dict[str, Any],
358
+ ) -> Iterator[dict[str, Any]]:
359
+ """Yield response chunks via Gemini streaming API."""
360
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
361
+ self._prepare_messages(messages), options
362
+ )
363
+
364
+ try:
365
+ model = genai.GenerativeModel(self.model, **model_kwargs)
366
+ response = model.generate_content(gen_input, stream=True, **gen_kwargs)
367
+
368
+ full_text = ""
369
+ for chunk in response:
370
+ chunk_text = getattr(chunk, "text", None) or ""
371
+ if chunk_text:
372
+ full_text += chunk_text
373
+ yield {"type": "delta", "text": chunk_text}
374
+
375
+ # After iteration completes, resolve() has been called on the response
376
+ usage_meta = self._extract_usage_metadata(response, messages)
377
+
378
+ yield {
379
+ "type": "done",
380
+ "text": full_text,
381
+ "meta": {
382
+ **usage_meta,
383
+ "raw_response": {},
384
+ "model_name": self.model,
385
+ },
386
+ }
387
+
388
+ except Exception as e:
389
+ logger.error(f"Google API streaming request failed: {e}")
390
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e