prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. prompture/__init__.py +146 -23
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +607 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +169 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +55 -0
  9. prompture/cli.py +63 -4
  10. prompture/conversation.py +631 -0
  11. prompture/core.py +876 -263
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +164 -0
  14. prompture/driver.py +168 -5
  15. prompture/drivers/__init__.py +173 -69
  16. prompture/drivers/airllm_driver.py +109 -0
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +251 -34
  32. prompture/drivers/google_driver.py +107 -38
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +157 -23
  39. prompture/drivers/openai_driver.py +178 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/scaffold/__init__.py +1 -0
  47. prompture/scaffold/generator.py +84 -0
  48. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  49. prompture/scaffold/templates/README.md.j2 +41 -0
  50. prompture/scaffold/templates/config.py.j2 +21 -0
  51. prompture/scaffold/templates/env.example.j2 +8 -0
  52. prompture/scaffold/templates/main.py.j2 +86 -0
  53. prompture/scaffold/templates/models.py.j2 +40 -0
  54. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  55. prompture/server.py +183 -0
  56. prompture/session.py +117 -0
  57. prompture/settings.py +18 -1
  58. prompture/tools.py +219 -267
  59. prompture/tools_schema.py +254 -0
  60. prompture/validator.py +3 -3
  61. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
  62. prompture-0.0.35.dist-info/RECORD +66 -0
  63. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
  64. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  65. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
  66. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
  67. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,23 @@
1
1
  """Driver for Azure OpenAI Service (migrated to openai>=1.0.0).
2
2
  Requires the `openai` package.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  try:
7
9
  from openai import AzureOpenAI
8
10
  except Exception:
9
11
  AzureOpenAI = None
10
12
 
13
+ from ..cost_mixin import CostMixin
11
14
  from ..driver import Driver
12
15
 
13
16
 
14
- class AzureDriver(Driver):
17
+ class AzureDriver(CostMixin, Driver):
18
+ supports_json_mode = True
19
+ supports_json_schema = True
20
+
15
21
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
16
22
  MODEL_PRICING = {
17
23
  "gpt-5-mini": {
@@ -82,7 +88,16 @@ class AzureDriver(Driver):
82
88
  else:
83
89
  self.client = None
84
90
 
85
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
91
+ supports_messages = True
92
+
93
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
94
+ messages = [{"role": "user", "content": prompt}]
95
+ return self._do_generate(messages, options)
96
+
97
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
+ return self._do_generate(messages, options)
99
+
100
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
86
101
  if self.client is None:
87
102
  raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
88
103
 
@@ -96,13 +111,28 @@ class AzureDriver(Driver):
96
111
  # Build request kwargs
97
112
  kwargs = {
98
113
  "model": self.deployment_id, # for Azure, use deployment name
99
- "messages": [{"role": "user", "content": prompt}],
114
+ "messages": messages,
100
115
  }
101
116
  kwargs[tokens_param] = opts.get("max_tokens", 512)
102
117
 
103
118
  if supports_temperature and "temperature" in opts:
104
119
  kwargs["temperature"] = opts["temperature"]
105
120
 
121
+ # Native JSON mode support
122
+ if options.get("json_mode"):
123
+ json_schema = options.get("json_schema")
124
+ if json_schema:
125
+ kwargs["response_format"] = {
126
+ "type": "json_schema",
127
+ "json_schema": {
128
+ "name": "extraction",
129
+ "strict": True,
130
+ "schema": json_schema,
131
+ },
132
+ }
133
+ else:
134
+ kwargs["response_format"] = {"type": "json_object"}
135
+
106
136
  resp = self.client.chat.completions.create(**kwargs)
107
137
 
108
138
  # Extract usage
@@ -111,11 +141,8 @@ class AzureDriver(Driver):
111
141
  completion_tokens = getattr(usage, "completion_tokens", 0)
112
142
  total_tokens = getattr(usage, "total_tokens", 0)
113
143
 
114
- # Calculate cost
115
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
116
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
117
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
118
- total_cost = prompt_cost + completion_cost
144
+ # Calculate cost via shared mixin
145
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
119
146
 
120
147
  # Standardized meta object
121
148
  meta = {
@@ -1,75 +1,131 @@
1
1
  """Driver for Anthropic's Claude models. Requires the `anthropic` library.
2
2
  Use with API key in CLAUDE_API_KEY env var or provide directly.
3
3
  """
4
+
5
+ import json
4
6
  import os
5
- from typing import Any, Dict
7
+ from collections.abc import Iterator
8
+ from typing import Any
9
+
6
10
  try:
7
11
  import anthropic
8
12
  except Exception:
9
13
  anthropic = None
10
14
 
15
+ from ..cost_mixin import CostMixin
11
16
  from ..driver import Driver
12
17
 
13
- class ClaudeDriver(Driver):
18
+
19
+ class ClaudeDriver(CostMixin, Driver):
20
+ supports_json_mode = True
21
+ supports_json_schema = True
22
+ supports_tool_use = True
23
+ supports_streaming = True
24
+
14
25
  # Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
15
26
  MODEL_PRICING = {
16
27
  # Claude Opus 4.1
17
28
  "claude-opus-4-1-20250805": {
18
- "prompt": 0.015, # $15 per 1M prompt tokens
19
- "completion": 0.075, # $75 per 1M completion tokens
29
+ "prompt": 0.015, # $15 per 1M prompt tokens
30
+ "completion": 0.075, # $75 per 1M completion tokens
20
31
  },
21
32
  # Claude Opus 4.0
22
33
  "claude-opus-4-20250514": {
23
- "prompt": 0.015, # $15 per 1M prompt tokens
24
- "completion": 0.075, # $75 per 1M completion tokens
34
+ "prompt": 0.015, # $15 per 1M prompt tokens
35
+ "completion": 0.075, # $75 per 1M completion tokens
25
36
  },
26
37
  # Claude Sonnet 4.0
27
38
  "claude-sonnet-4-20250514": {
28
- "prompt": 0.003, # $3 per 1M prompt tokens
29
- "completion": 0.015, # $15 per 1M completion tokens
39
+ "prompt": 0.003, # $3 per 1M prompt tokens
40
+ "completion": 0.015, # $15 per 1M completion tokens
30
41
  },
31
42
  # Claude Sonnet 3.7
32
43
  "claude-3-7-sonnet-20250219": {
33
- "prompt": 0.003, # $3 per 1M prompt tokens
34
- "completion": 0.015, # $15 per 1M completion tokens
44
+ "prompt": 0.003, # $3 per 1M prompt tokens
45
+ "completion": 0.015, # $15 per 1M completion tokens
35
46
  },
36
47
  # Claude Haiku 3.5
37
48
  "claude-3-5-haiku-20241022": {
38
- "prompt": 0.0008, # $0.80 per 1M prompt tokens
39
- "completion": 0.004, # $4 per 1M completion tokens
40
- }
49
+ "prompt": 0.0008, # $0.80 per 1M prompt tokens
50
+ "completion": 0.004, # $4 per 1M completion tokens
51
+ },
41
52
  }
42
53
 
43
54
  def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
44
55
  self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
45
56
  self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
46
57
 
47
- def generate(self, prompt: str, options: Dict[str,Any]) -> Dict[str,Any]:
58
+ supports_messages = True
59
+
60
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
61
+ messages = [{"role": "user", "content": prompt}]
62
+ return self._do_generate(messages, options)
63
+
64
+ def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
65
+ return self._do_generate(messages, options)
66
+
67
+ def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
48
68
  if anthropic is None:
49
69
  raise RuntimeError("anthropic package not installed")
50
-
70
+
51
71
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
52
72
  model = options.get("model", self.model)
53
-
73
+
54
74
  client = anthropic.Anthropic(api_key=self.api_key)
55
- resp = client.messages.create(
56
- model=model,
57
- messages=[{"role": "user", "content": prompt}],
58
- temperature=opts["temperature"],
59
- max_tokens=opts["max_tokens"]
60
- )
61
-
75
+
76
+ # Anthropic requires system messages as a top-level parameter
77
+ system_content = None
78
+ api_messages = []
79
+ for msg in messages:
80
+ if msg.get("role") == "system":
81
+ system_content = msg.get("content", "")
82
+ else:
83
+ api_messages.append(msg)
84
+
85
+ # Build common kwargs
86
+ common_kwargs: dict[str, Any] = {
87
+ "model": model,
88
+ "messages": api_messages,
89
+ "temperature": opts["temperature"],
90
+ "max_tokens": opts["max_tokens"],
91
+ }
92
+ if system_content:
93
+ common_kwargs["system"] = system_content
94
+
95
+ # Native JSON mode: use tool-use for schema enforcement
96
+ if options.get("json_mode"):
97
+ json_schema = options.get("json_schema")
98
+ if json_schema:
99
+ tool_def = {
100
+ "name": "extract_json",
101
+ "description": "Extract structured data matching the schema",
102
+ "input_schema": json_schema,
103
+ }
104
+ resp = client.messages.create(
105
+ **common_kwargs,
106
+ tools=[tool_def],
107
+ tool_choice={"type": "tool", "name": "extract_json"},
108
+ )
109
+ text = ""
110
+ for block in resp.content:
111
+ if block.type == "tool_use":
112
+ text = json.dumps(block.input)
113
+ break
114
+ else:
115
+ resp = client.messages.create(**common_kwargs)
116
+ text = resp.content[0].text
117
+ else:
118
+ resp = client.messages.create(**common_kwargs)
119
+ text = resp.content[0].text
120
+
62
121
  # Extract token usage from Claude response
63
122
  prompt_tokens = resp.usage.input_tokens
64
123
  completion_tokens = resp.usage.output_tokens
65
124
  total_tokens = prompt_tokens + completion_tokens
66
-
67
- # Calculate cost based on model pricing
68
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
69
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
70
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
71
- total_cost = prompt_cost + completion_cost
72
-
125
+
126
+ # Calculate cost via shared mixin
127
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
128
+
73
129
  # Create standardized meta object
74
130
  meta = {
75
131
  "prompt_tokens": prompt_tokens,
@@ -77,8 +133,169 @@ class ClaudeDriver(Driver):
77
133
  "total_tokens": total_tokens,
78
134
  "cost": round(total_cost, 6), # Round to 6 decimal places
79
135
  "raw_response": dict(resp),
80
- "model_name": model
136
+ "model_name": model,
137
+ }
138
+
139
+ return {"text": text, "meta": meta}
140
+
141
+ # ------------------------------------------------------------------
142
+ # Helpers
143
+ # ------------------------------------------------------------------
144
+
145
+ def _extract_system_and_messages(
146
+ self, messages: list[dict[str, Any]]
147
+ ) -> tuple[str | None, list[dict[str, Any]]]:
148
+ """Separate system message from conversation messages for Anthropic API."""
149
+ system_content = None
150
+ api_messages: list[dict[str, Any]] = []
151
+ for msg in messages:
152
+ if msg.get("role") == "system":
153
+ system_content = msg.get("content", "")
154
+ else:
155
+ api_messages.append(msg)
156
+ return system_content, api_messages
157
+
158
+ # ------------------------------------------------------------------
159
+ # Tool use
160
+ # ------------------------------------------------------------------
161
+
162
+ def generate_messages_with_tools(
163
+ self,
164
+ messages: list[dict[str, Any]],
165
+ tools: list[dict[str, Any]],
166
+ options: dict[str, Any],
167
+ ) -> dict[str, Any]:
168
+ """Generate a response that may include tool calls (Anthropic)."""
169
+ if anthropic is None:
170
+ raise RuntimeError("anthropic package not installed")
171
+
172
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
173
+ model = options.get("model", self.model)
174
+ client = anthropic.Anthropic(api_key=self.api_key)
175
+
176
+ system_content, api_messages = self._extract_system_and_messages(messages)
177
+
178
+ # Convert tools from OpenAI format to Anthropic format if needed
179
+ anthropic_tools = []
180
+ for t in tools:
181
+ if "type" in t and t["type"] == "function":
182
+ # OpenAI format -> Anthropic format
183
+ fn = t["function"]
184
+ anthropic_tools.append({
185
+ "name": fn["name"],
186
+ "description": fn.get("description", ""),
187
+ "input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
188
+ })
189
+ elif "input_schema" in t:
190
+ # Already Anthropic format
191
+ anthropic_tools.append(t)
192
+ else:
193
+ anthropic_tools.append(t)
194
+
195
+ kwargs: dict[str, Any] = {
196
+ "model": model,
197
+ "messages": api_messages,
198
+ "temperature": opts["temperature"],
199
+ "max_tokens": opts["max_tokens"],
200
+ "tools": anthropic_tools,
201
+ }
202
+ if system_content:
203
+ kwargs["system"] = system_content
204
+
205
+ resp = client.messages.create(**kwargs)
206
+
207
+ prompt_tokens = resp.usage.input_tokens
208
+ completion_tokens = resp.usage.output_tokens
209
+ total_tokens = prompt_tokens + completion_tokens
210
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
211
+
212
+ meta = {
213
+ "prompt_tokens": prompt_tokens,
214
+ "completion_tokens": completion_tokens,
215
+ "total_tokens": total_tokens,
216
+ "cost": round(total_cost, 6),
217
+ "raw_response": dict(resp),
218
+ "model_name": model,
219
+ }
220
+
221
+ text = ""
222
+ tool_calls_out: list[dict[str, Any]] = []
223
+ for block in resp.content:
224
+ if block.type == "text":
225
+ text += block.text
226
+ elif block.type == "tool_use":
227
+ tool_calls_out.append({
228
+ "id": block.id,
229
+ "name": block.name,
230
+ "arguments": block.input,
231
+ })
232
+
233
+ return {
234
+ "text": text,
235
+ "meta": meta,
236
+ "tool_calls": tool_calls_out,
237
+ "stop_reason": resp.stop_reason,
238
+ }
239
+
240
+ # ------------------------------------------------------------------
241
+ # Streaming
242
+ # ------------------------------------------------------------------
243
+
244
+ def generate_messages_stream(
245
+ self,
246
+ messages: list[dict[str, Any]],
247
+ options: dict[str, Any],
248
+ ) -> Iterator[dict[str, Any]]:
249
+ """Yield response chunks via Anthropic streaming API."""
250
+ if anthropic is None:
251
+ raise RuntimeError("anthropic package not installed")
252
+
253
+ opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
254
+ model = options.get("model", self.model)
255
+ client = anthropic.Anthropic(api_key=self.api_key)
256
+
257
+ system_content, api_messages = self._extract_system_and_messages(messages)
258
+
259
+ kwargs: dict[str, Any] = {
260
+ "model": model,
261
+ "messages": api_messages,
262
+ "temperature": opts["temperature"],
263
+ "max_tokens": opts["max_tokens"],
264
+ }
265
+ if system_content:
266
+ kwargs["system"] = system_content
267
+
268
+ full_text = ""
269
+ prompt_tokens = 0
270
+ completion_tokens = 0
271
+
272
+ with client.messages.stream(**kwargs) as stream:
273
+ for event in stream:
274
+ if hasattr(event, "type"):
275
+ if event.type == "content_block_delta" and hasattr(event, "delta"):
276
+ delta_text = getattr(event.delta, "text", "")
277
+ if delta_text:
278
+ full_text += delta_text
279
+ yield {"type": "delta", "text": delta_text}
280
+ elif event.type == "message_delta" and hasattr(event, "usage"):
281
+ completion_tokens = getattr(event.usage, "output_tokens", 0)
282
+ elif event.type == "message_start" and hasattr(event, "message"):
283
+ usage = getattr(event.message, "usage", None)
284
+ if usage:
285
+ prompt_tokens = getattr(usage, "input_tokens", 0)
286
+
287
+ total_tokens = prompt_tokens + completion_tokens
288
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
289
+
290
+ yield {
291
+ "type": "done",
292
+ "text": full_text,
293
+ "meta": {
294
+ "prompt_tokens": prompt_tokens,
295
+ "completion_tokens": completion_tokens,
296
+ "total_tokens": total_tokens,
297
+ "cost": round(total_cost, 6),
298
+ "raw_response": {},
299
+ "model_name": model,
300
+ },
81
301
  }
82
-
83
- text = resp.content[0].text
84
- return {"text": text, "meta": meta}
@@ -1,46 +1,55 @@
1
- import os
2
1
  import logging
2
+ import os
3
+ from typing import Any, Optional
4
+
3
5
  import google.generativeai as genai
4
- from typing import Any, Dict
6
+
7
+ from ..cost_mixin import CostMixin
5
8
  from ..driver import Driver
6
9
 
7
10
  logger = logging.getLogger(__name__)
8
11
 
9
12
 
10
- class GoogleDriver(Driver):
13
+ class GoogleDriver(CostMixin, Driver):
11
14
  """Driver for Google's Generative AI API (Gemini)."""
12
15
 
16
+ supports_json_mode = True
17
+ supports_json_schema = True
18
+
13
19
  # Based on current Gemini pricing (as of 2025)
14
20
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
21
+ _PRICING_UNIT = 1_000_000
15
22
  MODEL_PRICING = {
16
23
  "gemini-1.5-pro": {
17
24
  "prompt": 0.00025, # $0.25/1M chars input
18
- "completion": 0.0005 # $0.50/1M chars output
25
+ "completion": 0.0005, # $0.50/1M chars output
19
26
  },
20
27
  "gemini-1.5-pro-vision": {
21
28
  "prompt": 0.00025, # $0.25/1M chars input
22
- "completion": 0.0005 # $0.50/1M chars output
29
+ "completion": 0.0005, # $0.50/1M chars output
23
30
  },
24
31
  "gemini-2.5-pro": {
25
32
  "prompt": 0.0004, # $0.40/1M chars input
26
- "completion": 0.0008 # $0.80/1M chars output
33
+ "completion": 0.0008, # $0.80/1M chars output
27
34
  },
28
35
  "gemini-2.5-flash": {
29
36
  "prompt": 0.0004, # $0.40/1M chars input
30
- "completion": 0.0008 # $0.80/1M chars output
37
+ "completion": 0.0008, # $0.80/1M chars output
31
38
  },
32
39
  "gemini-2.5-flash-lite": {
33
40
  "prompt": 0.0002, # $0.20/1M chars input
34
- "completion": 0.0004 # $0.40/1M chars output
41
+ "completion": 0.0004, # $0.40/1M chars output
35
42
  },
36
- "gemini-2.0-flash": {
43
+ "gemini-2.0-flash": {
37
44
  "prompt": 0.0004, # $0.40/1M chars input
38
- "completion": 0.0008 # $0.80/1M chars output
45
+ "completion": 0.0008, # $0.80/1M chars output
39
46
  },
40
47
  "gemini-2.0-flash-lite": {
41
48
  "prompt": 0.0002, # $0.20/1M chars input
42
- "completion": 0.0004 # $0.40/1M chars output
49
+ "completion": 0.0004, # $0.40/1M chars output
43
50
  },
51
+ "gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
52
+ "gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
44
53
  }
45
54
 
46
55
  def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
@@ -55,13 +64,14 @@ class GoogleDriver(Driver):
55
64
  raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
56
65
 
57
66
  self.model = model
67
+ # Warn if model is not in pricing table but allow it (might be new)
58
68
  if model not in self.MODEL_PRICING:
59
- raise ValueError(f"Unsupported model: {model}. Must be one of: {list(self.MODEL_PRICING.keys())}")
69
+ logger.warning(f"Model {model} not found in pricing table. Cost calculations will be 0.")
60
70
 
61
71
  # Configure google.generativeai
62
72
  genai.configure(api_key=self.api_key)
63
- self.options: Dict[str, Any] = {}
64
-
73
+ self.options: dict[str, Any] = {}
74
+
65
75
  # Validate connection and model availability
66
76
  self._validate_connection()
67
77
 
@@ -75,48 +85,107 @@ class GoogleDriver(Driver):
75
85
  logger.warning(f"Could not validate connection to Google API: {e}")
76
86
  raise
77
87
 
78
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
79
- """Generate text using Google's Generative AI.
80
-
81
- Args:
82
- prompt: The input prompt
83
- options: Additional options to pass to the model
88
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
89
+ """Calculate cost from character counts.
84
90
 
85
- Returns:
86
- Dict containing generated text and metadata
91
+ Live rates use token-based pricing (estimate ~4 chars/token).
92
+ Hardcoded MODEL_PRICING uses per-1M-character rates.
87
93
  """
94
+ from ..model_rates import get_model_rates
95
+
96
+ live_rates = get_model_rates("google", self.model)
97
+ if live_rates:
98
+ est_prompt_tokens = prompt_chars / 4
99
+ est_completion_tokens = completion_chars / 4
100
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
101
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
102
+ else:
103
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
104
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
105
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
106
+ return round(prompt_cost + completion_cost, 6)
107
+
108
+ supports_messages = True
109
+
110
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
111
+ messages = [{"role": "user", "content": prompt}]
112
+ return self._do_generate(messages, options)
113
+
114
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
115
+ return self._do_generate(messages, options)
116
+
117
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
88
118
  merged_options = self.options.copy()
89
119
  if options:
90
120
  merged_options.update(options)
91
121
 
122
+ # Extract specific options for Google's API
123
+ generation_config = merged_options.get("generation_config", {})
124
+ safety_settings = merged_options.get("safety_settings", {})
125
+
126
+ # Map common options to generation_config if not present
127
+ if "temperature" in merged_options and "temperature" not in generation_config:
128
+ generation_config["temperature"] = merged_options["temperature"]
129
+ if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
130
+ generation_config["max_output_tokens"] = merged_options["max_tokens"]
131
+ if "top_p" in merged_options and "top_p" not in generation_config:
132
+ generation_config["top_p"] = merged_options["top_p"]
133
+ if "top_k" in merged_options and "top_k" not in generation_config:
134
+ generation_config["top_k"] = merged_options["top_k"]
135
+
136
+ # Native JSON mode support
137
+ if merged_options.get("json_mode"):
138
+ generation_config["response_mime_type"] = "application/json"
139
+ json_schema = merged_options.get("json_schema")
140
+ if json_schema:
141
+ generation_config["response_schema"] = json_schema
142
+
143
+ # Convert messages to Gemini format
144
+ system_instruction = None
145
+ contents: list[dict[str, Any]] = []
146
+ for msg in messages:
147
+ role = msg.get("role", "user")
148
+ content = msg.get("content", "")
149
+ if role == "system":
150
+ system_instruction = content
151
+ else:
152
+ # Gemini uses "model" for assistant role
153
+ gemini_role = "model" if role == "assistant" else "user"
154
+ contents.append({"role": gemini_role, "parts": [content]})
155
+
92
156
  try:
93
157
  logger.debug(f"Initializing {self.model} for generation")
94
- model = genai.GenerativeModel(self.model)
158
+ model_kwargs: dict[str, Any] = {}
159
+ if system_instruction:
160
+ model_kwargs["system_instruction"] = system_instruction
161
+ model = genai.GenerativeModel(self.model, **model_kwargs)
95
162
 
96
163
  # Generate response
97
- logger.debug(f"Generating with prompt: {prompt}")
98
- response = model.generate_content(prompt)
99
-
164
+ logger.debug(f"Generating with {len(contents)} content parts")
165
+ # If single user message, pass content directly for backward compatibility
166
+ gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
167
+ response = model.generate_content(
168
+ gen_input,
169
+ generation_config=generation_config if generation_config else None,
170
+ safety_settings=safety_settings if safety_settings else None,
171
+ )
172
+
100
173
  if not response.text:
101
174
  raise ValueError("Empty response from model")
102
175
 
103
176
  # Calculate token usage and cost
104
- # Note: Using character count as proxy since Google charges per character
105
- prompt_chars = len(prompt)
177
+ total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
106
178
  completion_chars = len(response.text)
107
-
108
- # Calculate costs
109
- model_pricing = self.MODEL_PRICING[self.model]
110
- prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
111
- completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
112
- total_cost = prompt_cost + completion_cost
179
+
180
+ # Google uses character-based cost estimation
181
+ total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
113
182
 
114
183
  meta = {
115
- "prompt_chars": prompt_chars,
184
+ "prompt_chars": total_prompt_chars,
116
185
  "completion_chars": completion_chars,
117
- "total_chars": prompt_chars + completion_chars,
186
+ "total_chars": total_prompt_chars + completion_chars,
118
187
  "cost": total_cost,
119
- "raw_response": response.prompt_feedback,
188
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
120
189
  "model_name": self.model,
121
190
  }
122
191
 
@@ -124,4 +193,4 @@ class GoogleDriver(Driver):
124
193
 
125
194
  except Exception as e:
126
195
  logger.error(f"Google API request failed: {e}")
127
- raise RuntimeError(f"Google API request failed: {e}")
196
+ raise RuntimeError(f"Google API request failed: {e}") from e