prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +146 -23
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +607 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +169 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +631 -0
- prompture/core.py +876 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +164 -0
- prompture/driver.py +168 -5
- prompture/drivers/__init__.py +173 -69
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +251 -34
- prompture/drivers/google_driver.py +107 -38
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +157 -23
- prompture/drivers/openai_driver.py +178 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +18 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
- prompture-0.0.35.dist-info/RECORD +66 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,23 @@
|
|
|
1
1
|
"""Driver for Azure OpenAI Service (migrated to openai>=1.0.0).
|
|
2
2
|
Requires the `openai` package.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
try:
|
|
7
9
|
from openai import AzureOpenAI
|
|
8
10
|
except Exception:
|
|
9
11
|
AzureOpenAI = None
|
|
10
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
11
14
|
from ..driver import Driver
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
class AzureDriver(Driver):
|
|
17
|
+
class AzureDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
|
|
15
21
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
16
22
|
MODEL_PRICING = {
|
|
17
23
|
"gpt-5-mini": {
|
|
@@ -82,7 +88,16 @@ class AzureDriver(Driver):
|
|
|
82
88
|
else:
|
|
83
89
|
self.client = None
|
|
84
90
|
|
|
85
|
-
|
|
91
|
+
supports_messages = True
|
|
92
|
+
|
|
93
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
94
|
+
messages = [{"role": "user", "content": prompt}]
|
|
95
|
+
return self._do_generate(messages, options)
|
|
96
|
+
|
|
97
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
+
return self._do_generate(messages, options)
|
|
99
|
+
|
|
100
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
86
101
|
if self.client is None:
|
|
87
102
|
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
88
103
|
|
|
@@ -96,13 +111,28 @@ class AzureDriver(Driver):
|
|
|
96
111
|
# Build request kwargs
|
|
97
112
|
kwargs = {
|
|
98
113
|
"model": self.deployment_id, # for Azure, use deployment name
|
|
99
|
-
"messages":
|
|
114
|
+
"messages": messages,
|
|
100
115
|
}
|
|
101
116
|
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
102
117
|
|
|
103
118
|
if supports_temperature and "temperature" in opts:
|
|
104
119
|
kwargs["temperature"] = opts["temperature"]
|
|
105
120
|
|
|
121
|
+
# Native JSON mode support
|
|
122
|
+
if options.get("json_mode"):
|
|
123
|
+
json_schema = options.get("json_schema")
|
|
124
|
+
if json_schema:
|
|
125
|
+
kwargs["response_format"] = {
|
|
126
|
+
"type": "json_schema",
|
|
127
|
+
"json_schema": {
|
|
128
|
+
"name": "extraction",
|
|
129
|
+
"strict": True,
|
|
130
|
+
"schema": json_schema,
|
|
131
|
+
},
|
|
132
|
+
}
|
|
133
|
+
else:
|
|
134
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
135
|
+
|
|
106
136
|
resp = self.client.chat.completions.create(**kwargs)
|
|
107
137
|
|
|
108
138
|
# Extract usage
|
|
@@ -111,11 +141,8 @@ class AzureDriver(Driver):
|
|
|
111
141
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
112
142
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
113
143
|
|
|
114
|
-
# Calculate cost
|
|
115
|
-
|
|
116
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
117
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
118
|
-
total_cost = prompt_cost + completion_cost
|
|
144
|
+
# Calculate cost via shared mixin
|
|
145
|
+
total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
|
|
119
146
|
|
|
120
147
|
# Standardized meta object
|
|
121
148
|
meta = {
|
|
@@ -1,75 +1,131 @@
|
|
|
1
1
|
"""Driver for Anthropic's Claude models. Requires the `anthropic` library.
|
|
2
2
|
Use with API key in CLAUDE_API_KEY env var or provide directly.
|
|
3
3
|
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
4
6
|
import os
|
|
5
|
-
from
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
6
10
|
try:
|
|
7
11
|
import anthropic
|
|
8
12
|
except Exception:
|
|
9
13
|
anthropic = None
|
|
10
14
|
|
|
15
|
+
from ..cost_mixin import CostMixin
|
|
11
16
|
from ..driver import Driver
|
|
12
17
|
|
|
13
|
-
|
|
18
|
+
|
|
19
|
+
class ClaudeDriver(CostMixin, Driver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
supports_tool_use = True
|
|
23
|
+
supports_streaming = True
|
|
24
|
+
|
|
14
25
|
# Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
|
|
15
26
|
MODEL_PRICING = {
|
|
16
27
|
# Claude Opus 4.1
|
|
17
28
|
"claude-opus-4-1-20250805": {
|
|
18
|
-
"prompt": 0.015,
|
|
19
|
-
"completion": 0.075,
|
|
29
|
+
"prompt": 0.015, # $15 per 1M prompt tokens
|
|
30
|
+
"completion": 0.075, # $75 per 1M completion tokens
|
|
20
31
|
},
|
|
21
32
|
# Claude Opus 4.0
|
|
22
33
|
"claude-opus-4-20250514": {
|
|
23
|
-
"prompt": 0.015,
|
|
24
|
-
"completion": 0.075,
|
|
34
|
+
"prompt": 0.015, # $15 per 1M prompt tokens
|
|
35
|
+
"completion": 0.075, # $75 per 1M completion tokens
|
|
25
36
|
},
|
|
26
37
|
# Claude Sonnet 4.0
|
|
27
38
|
"claude-sonnet-4-20250514": {
|
|
28
|
-
"prompt": 0.003,
|
|
29
|
-
"completion": 0.015,
|
|
39
|
+
"prompt": 0.003, # $3 per 1M prompt tokens
|
|
40
|
+
"completion": 0.015, # $15 per 1M completion tokens
|
|
30
41
|
},
|
|
31
42
|
# Claude Sonnet 3.7
|
|
32
43
|
"claude-3-7-sonnet-20250219": {
|
|
33
|
-
"prompt": 0.003,
|
|
34
|
-
"completion": 0.015,
|
|
44
|
+
"prompt": 0.003, # $3 per 1M prompt tokens
|
|
45
|
+
"completion": 0.015, # $15 per 1M completion tokens
|
|
35
46
|
},
|
|
36
47
|
# Claude Haiku 3.5
|
|
37
48
|
"claude-3-5-haiku-20241022": {
|
|
38
|
-
"prompt": 0.0008,
|
|
39
|
-
"completion": 0.004,
|
|
40
|
-
}
|
|
49
|
+
"prompt": 0.0008, # $0.80 per 1M prompt tokens
|
|
50
|
+
"completion": 0.004, # $4 per 1M completion tokens
|
|
51
|
+
},
|
|
41
52
|
}
|
|
42
53
|
|
|
43
54
|
def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
|
|
44
55
|
self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
|
|
45
56
|
self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
|
|
46
57
|
|
|
47
|
-
|
|
58
|
+
supports_messages = True
|
|
59
|
+
|
|
60
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
61
|
+
messages = [{"role": "user", "content": prompt}]
|
|
62
|
+
return self._do_generate(messages, options)
|
|
63
|
+
|
|
64
|
+
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
65
|
+
return self._do_generate(messages, options)
|
|
66
|
+
|
|
67
|
+
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
68
|
if anthropic is None:
|
|
49
69
|
raise RuntimeError("anthropic package not installed")
|
|
50
|
-
|
|
70
|
+
|
|
51
71
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
52
72
|
model = options.get("model", self.model)
|
|
53
|
-
|
|
73
|
+
|
|
54
74
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
75
|
+
|
|
76
|
+
# Anthropic requires system messages as a top-level parameter
|
|
77
|
+
system_content = None
|
|
78
|
+
api_messages = []
|
|
79
|
+
for msg in messages:
|
|
80
|
+
if msg.get("role") == "system":
|
|
81
|
+
system_content = msg.get("content", "")
|
|
82
|
+
else:
|
|
83
|
+
api_messages.append(msg)
|
|
84
|
+
|
|
85
|
+
# Build common kwargs
|
|
86
|
+
common_kwargs: dict[str, Any] = {
|
|
87
|
+
"model": model,
|
|
88
|
+
"messages": api_messages,
|
|
89
|
+
"temperature": opts["temperature"],
|
|
90
|
+
"max_tokens": opts["max_tokens"],
|
|
91
|
+
}
|
|
92
|
+
if system_content:
|
|
93
|
+
common_kwargs["system"] = system_content
|
|
94
|
+
|
|
95
|
+
# Native JSON mode: use tool-use for schema enforcement
|
|
96
|
+
if options.get("json_mode"):
|
|
97
|
+
json_schema = options.get("json_schema")
|
|
98
|
+
if json_schema:
|
|
99
|
+
tool_def = {
|
|
100
|
+
"name": "extract_json",
|
|
101
|
+
"description": "Extract structured data matching the schema",
|
|
102
|
+
"input_schema": json_schema,
|
|
103
|
+
}
|
|
104
|
+
resp = client.messages.create(
|
|
105
|
+
**common_kwargs,
|
|
106
|
+
tools=[tool_def],
|
|
107
|
+
tool_choice={"type": "tool", "name": "extract_json"},
|
|
108
|
+
)
|
|
109
|
+
text = ""
|
|
110
|
+
for block in resp.content:
|
|
111
|
+
if block.type == "tool_use":
|
|
112
|
+
text = json.dumps(block.input)
|
|
113
|
+
break
|
|
114
|
+
else:
|
|
115
|
+
resp = client.messages.create(**common_kwargs)
|
|
116
|
+
text = resp.content[0].text
|
|
117
|
+
else:
|
|
118
|
+
resp = client.messages.create(**common_kwargs)
|
|
119
|
+
text = resp.content[0].text
|
|
120
|
+
|
|
62
121
|
# Extract token usage from Claude response
|
|
63
122
|
prompt_tokens = resp.usage.input_tokens
|
|
64
123
|
completion_tokens = resp.usage.output_tokens
|
|
65
124
|
total_tokens = prompt_tokens + completion_tokens
|
|
66
|
-
|
|
67
|
-
# Calculate cost
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
71
|
-
total_cost = prompt_cost + completion_cost
|
|
72
|
-
|
|
125
|
+
|
|
126
|
+
# Calculate cost via shared mixin
|
|
127
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
128
|
+
|
|
73
129
|
# Create standardized meta object
|
|
74
130
|
meta = {
|
|
75
131
|
"prompt_tokens": prompt_tokens,
|
|
@@ -77,8 +133,169 @@ class ClaudeDriver(Driver):
|
|
|
77
133
|
"total_tokens": total_tokens,
|
|
78
134
|
"cost": round(total_cost, 6), # Round to 6 decimal places
|
|
79
135
|
"raw_response": dict(resp),
|
|
80
|
-
"model_name": model
|
|
136
|
+
"model_name": model,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return {"text": text, "meta": meta}
|
|
140
|
+
|
|
141
|
+
# ------------------------------------------------------------------
|
|
142
|
+
# Helpers
|
|
143
|
+
# ------------------------------------------------------------------
|
|
144
|
+
|
|
145
|
+
def _extract_system_and_messages(
|
|
146
|
+
self, messages: list[dict[str, Any]]
|
|
147
|
+
) -> tuple[str | None, list[dict[str, Any]]]:
|
|
148
|
+
"""Separate system message from conversation messages for Anthropic API."""
|
|
149
|
+
system_content = None
|
|
150
|
+
api_messages: list[dict[str, Any]] = []
|
|
151
|
+
for msg in messages:
|
|
152
|
+
if msg.get("role") == "system":
|
|
153
|
+
system_content = msg.get("content", "")
|
|
154
|
+
else:
|
|
155
|
+
api_messages.append(msg)
|
|
156
|
+
return system_content, api_messages
|
|
157
|
+
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
# Tool use
|
|
160
|
+
# ------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
def generate_messages_with_tools(
|
|
163
|
+
self,
|
|
164
|
+
messages: list[dict[str, Any]],
|
|
165
|
+
tools: list[dict[str, Any]],
|
|
166
|
+
options: dict[str, Any],
|
|
167
|
+
) -> dict[str, Any]:
|
|
168
|
+
"""Generate a response that may include tool calls (Anthropic)."""
|
|
169
|
+
if anthropic is None:
|
|
170
|
+
raise RuntimeError("anthropic package not installed")
|
|
171
|
+
|
|
172
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
173
|
+
model = options.get("model", self.model)
|
|
174
|
+
client = anthropic.Anthropic(api_key=self.api_key)
|
|
175
|
+
|
|
176
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
177
|
+
|
|
178
|
+
# Convert tools from OpenAI format to Anthropic format if needed
|
|
179
|
+
anthropic_tools = []
|
|
180
|
+
for t in tools:
|
|
181
|
+
if "type" in t and t["type"] == "function":
|
|
182
|
+
# OpenAI format -> Anthropic format
|
|
183
|
+
fn = t["function"]
|
|
184
|
+
anthropic_tools.append({
|
|
185
|
+
"name": fn["name"],
|
|
186
|
+
"description": fn.get("description", ""),
|
|
187
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
188
|
+
})
|
|
189
|
+
elif "input_schema" in t:
|
|
190
|
+
# Already Anthropic format
|
|
191
|
+
anthropic_tools.append(t)
|
|
192
|
+
else:
|
|
193
|
+
anthropic_tools.append(t)
|
|
194
|
+
|
|
195
|
+
kwargs: dict[str, Any] = {
|
|
196
|
+
"model": model,
|
|
197
|
+
"messages": api_messages,
|
|
198
|
+
"temperature": opts["temperature"],
|
|
199
|
+
"max_tokens": opts["max_tokens"],
|
|
200
|
+
"tools": anthropic_tools,
|
|
201
|
+
}
|
|
202
|
+
if system_content:
|
|
203
|
+
kwargs["system"] = system_content
|
|
204
|
+
|
|
205
|
+
resp = client.messages.create(**kwargs)
|
|
206
|
+
|
|
207
|
+
prompt_tokens = resp.usage.input_tokens
|
|
208
|
+
completion_tokens = resp.usage.output_tokens
|
|
209
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
210
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
211
|
+
|
|
212
|
+
meta = {
|
|
213
|
+
"prompt_tokens": prompt_tokens,
|
|
214
|
+
"completion_tokens": completion_tokens,
|
|
215
|
+
"total_tokens": total_tokens,
|
|
216
|
+
"cost": round(total_cost, 6),
|
|
217
|
+
"raw_response": dict(resp),
|
|
218
|
+
"model_name": model,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
text = ""
|
|
222
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
223
|
+
for block in resp.content:
|
|
224
|
+
if block.type == "text":
|
|
225
|
+
text += block.text
|
|
226
|
+
elif block.type == "tool_use":
|
|
227
|
+
tool_calls_out.append({
|
|
228
|
+
"id": block.id,
|
|
229
|
+
"name": block.name,
|
|
230
|
+
"arguments": block.input,
|
|
231
|
+
})
|
|
232
|
+
|
|
233
|
+
return {
|
|
234
|
+
"text": text,
|
|
235
|
+
"meta": meta,
|
|
236
|
+
"tool_calls": tool_calls_out,
|
|
237
|
+
"stop_reason": resp.stop_reason,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
# ------------------------------------------------------------------
|
|
241
|
+
# Streaming
|
|
242
|
+
# ------------------------------------------------------------------
|
|
243
|
+
|
|
244
|
+
def generate_messages_stream(
|
|
245
|
+
self,
|
|
246
|
+
messages: list[dict[str, Any]],
|
|
247
|
+
options: dict[str, Any],
|
|
248
|
+
) -> Iterator[dict[str, Any]]:
|
|
249
|
+
"""Yield response chunks via Anthropic streaming API."""
|
|
250
|
+
if anthropic is None:
|
|
251
|
+
raise RuntimeError("anthropic package not installed")
|
|
252
|
+
|
|
253
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
254
|
+
model = options.get("model", self.model)
|
|
255
|
+
client = anthropic.Anthropic(api_key=self.api_key)
|
|
256
|
+
|
|
257
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
258
|
+
|
|
259
|
+
kwargs: dict[str, Any] = {
|
|
260
|
+
"model": model,
|
|
261
|
+
"messages": api_messages,
|
|
262
|
+
"temperature": opts["temperature"],
|
|
263
|
+
"max_tokens": opts["max_tokens"],
|
|
264
|
+
}
|
|
265
|
+
if system_content:
|
|
266
|
+
kwargs["system"] = system_content
|
|
267
|
+
|
|
268
|
+
full_text = ""
|
|
269
|
+
prompt_tokens = 0
|
|
270
|
+
completion_tokens = 0
|
|
271
|
+
|
|
272
|
+
with client.messages.stream(**kwargs) as stream:
|
|
273
|
+
for event in stream:
|
|
274
|
+
if hasattr(event, "type"):
|
|
275
|
+
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
|
276
|
+
delta_text = getattr(event.delta, "text", "")
|
|
277
|
+
if delta_text:
|
|
278
|
+
full_text += delta_text
|
|
279
|
+
yield {"type": "delta", "text": delta_text}
|
|
280
|
+
elif event.type == "message_delta" and hasattr(event, "usage"):
|
|
281
|
+
completion_tokens = getattr(event.usage, "output_tokens", 0)
|
|
282
|
+
elif event.type == "message_start" and hasattr(event, "message"):
|
|
283
|
+
usage = getattr(event.message, "usage", None)
|
|
284
|
+
if usage:
|
|
285
|
+
prompt_tokens = getattr(usage, "input_tokens", 0)
|
|
286
|
+
|
|
287
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
288
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
289
|
+
|
|
290
|
+
yield {
|
|
291
|
+
"type": "done",
|
|
292
|
+
"text": full_text,
|
|
293
|
+
"meta": {
|
|
294
|
+
"prompt_tokens": prompt_tokens,
|
|
295
|
+
"completion_tokens": completion_tokens,
|
|
296
|
+
"total_tokens": total_tokens,
|
|
297
|
+
"cost": round(total_cost, 6),
|
|
298
|
+
"raw_response": {},
|
|
299
|
+
"model_name": model,
|
|
300
|
+
},
|
|
81
301
|
}
|
|
82
|
-
|
|
83
|
-
text = resp.content[0].text
|
|
84
|
-
return {"text": text, "meta": meta}
|
|
@@ -1,46 +1,55 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
3
5
|
import google.generativeai as genai
|
|
4
|
-
|
|
6
|
+
|
|
7
|
+
from ..cost_mixin import CostMixin
|
|
5
8
|
from ..driver import Driver
|
|
6
9
|
|
|
7
10
|
logger = logging.getLogger(__name__)
|
|
8
11
|
|
|
9
12
|
|
|
10
|
-
class GoogleDriver(Driver):
|
|
13
|
+
class GoogleDriver(CostMixin, Driver):
|
|
11
14
|
"""Driver for Google's Generative AI API (Gemini)."""
|
|
12
15
|
|
|
16
|
+
supports_json_mode = True
|
|
17
|
+
supports_json_schema = True
|
|
18
|
+
|
|
13
19
|
# Based on current Gemini pricing (as of 2025)
|
|
14
20
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
21
|
+
_PRICING_UNIT = 1_000_000
|
|
15
22
|
MODEL_PRICING = {
|
|
16
23
|
"gemini-1.5-pro": {
|
|
17
24
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
18
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
25
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
19
26
|
},
|
|
20
27
|
"gemini-1.5-pro-vision": {
|
|
21
28
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
22
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
29
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
23
30
|
},
|
|
24
31
|
"gemini-2.5-pro": {
|
|
25
32
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
26
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
33
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
27
34
|
},
|
|
28
35
|
"gemini-2.5-flash": {
|
|
29
36
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
30
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
37
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
31
38
|
},
|
|
32
39
|
"gemini-2.5-flash-lite": {
|
|
33
40
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
34
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
41
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
35
42
|
},
|
|
36
|
-
|
|
43
|
+
"gemini-2.0-flash": {
|
|
37
44
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
38
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
45
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
39
46
|
},
|
|
40
47
|
"gemini-2.0-flash-lite": {
|
|
41
48
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
42
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
49
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
43
50
|
},
|
|
51
|
+
"gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
|
|
52
|
+
"gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
|
|
44
53
|
}
|
|
45
54
|
|
|
46
55
|
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
@@ -55,13 +64,14 @@ class GoogleDriver(Driver):
|
|
|
55
64
|
raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
|
|
56
65
|
|
|
57
66
|
self.model = model
|
|
67
|
+
# Warn if model is not in pricing table but allow it (might be new)
|
|
58
68
|
if model not in self.MODEL_PRICING:
|
|
59
|
-
|
|
69
|
+
logger.warning(f"Model {model} not found in pricing table. Cost calculations will be 0.")
|
|
60
70
|
|
|
61
71
|
# Configure google.generativeai
|
|
62
72
|
genai.configure(api_key=self.api_key)
|
|
63
|
-
self.options:
|
|
64
|
-
|
|
73
|
+
self.options: dict[str, Any] = {}
|
|
74
|
+
|
|
65
75
|
# Validate connection and model availability
|
|
66
76
|
self._validate_connection()
|
|
67
77
|
|
|
@@ -75,48 +85,107 @@ class GoogleDriver(Driver):
|
|
|
75
85
|
logger.warning(f"Could not validate connection to Google API: {e}")
|
|
76
86
|
raise
|
|
77
87
|
|
|
78
|
-
def
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
|
-
Args:
|
|
82
|
-
prompt: The input prompt
|
|
83
|
-
options: Additional options to pass to the model
|
|
88
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
89
|
+
"""Calculate cost from character counts.
|
|
84
90
|
|
|
85
|
-
|
|
86
|
-
|
|
91
|
+
Live rates use token-based pricing (estimate ~4 chars/token).
|
|
92
|
+
Hardcoded MODEL_PRICING uses per-1M-character rates.
|
|
87
93
|
"""
|
|
94
|
+
from ..model_rates import get_model_rates
|
|
95
|
+
|
|
96
|
+
live_rates = get_model_rates("google", self.model)
|
|
97
|
+
if live_rates:
|
|
98
|
+
est_prompt_tokens = prompt_chars / 4
|
|
99
|
+
est_completion_tokens = completion_chars / 4
|
|
100
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
101
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
102
|
+
else:
|
|
103
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
104
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
105
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
106
|
+
return round(prompt_cost + completion_cost, 6)
|
|
107
|
+
|
|
108
|
+
supports_messages = True
|
|
109
|
+
|
|
110
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
111
|
+
messages = [{"role": "user", "content": prompt}]
|
|
112
|
+
return self._do_generate(messages, options)
|
|
113
|
+
|
|
114
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
+
return self._do_generate(messages, options)
|
|
116
|
+
|
|
117
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
88
118
|
merged_options = self.options.copy()
|
|
89
119
|
if options:
|
|
90
120
|
merged_options.update(options)
|
|
91
121
|
|
|
122
|
+
# Extract specific options for Google's API
|
|
123
|
+
generation_config = merged_options.get("generation_config", {})
|
|
124
|
+
safety_settings = merged_options.get("safety_settings", {})
|
|
125
|
+
|
|
126
|
+
# Map common options to generation_config if not present
|
|
127
|
+
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
128
|
+
generation_config["temperature"] = merged_options["temperature"]
|
|
129
|
+
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
130
|
+
generation_config["max_output_tokens"] = merged_options["max_tokens"]
|
|
131
|
+
if "top_p" in merged_options and "top_p" not in generation_config:
|
|
132
|
+
generation_config["top_p"] = merged_options["top_p"]
|
|
133
|
+
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
134
|
+
generation_config["top_k"] = merged_options["top_k"]
|
|
135
|
+
|
|
136
|
+
# Native JSON mode support
|
|
137
|
+
if merged_options.get("json_mode"):
|
|
138
|
+
generation_config["response_mime_type"] = "application/json"
|
|
139
|
+
json_schema = merged_options.get("json_schema")
|
|
140
|
+
if json_schema:
|
|
141
|
+
generation_config["response_schema"] = json_schema
|
|
142
|
+
|
|
143
|
+
# Convert messages to Gemini format
|
|
144
|
+
system_instruction = None
|
|
145
|
+
contents: list[dict[str, Any]] = []
|
|
146
|
+
for msg in messages:
|
|
147
|
+
role = msg.get("role", "user")
|
|
148
|
+
content = msg.get("content", "")
|
|
149
|
+
if role == "system":
|
|
150
|
+
system_instruction = content
|
|
151
|
+
else:
|
|
152
|
+
# Gemini uses "model" for assistant role
|
|
153
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
154
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
155
|
+
|
|
92
156
|
try:
|
|
93
157
|
logger.debug(f"Initializing {self.model} for generation")
|
|
94
|
-
|
|
158
|
+
model_kwargs: dict[str, Any] = {}
|
|
159
|
+
if system_instruction:
|
|
160
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
161
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
95
162
|
|
|
96
163
|
# Generate response
|
|
97
|
-
logger.debug(f"Generating with
|
|
98
|
-
|
|
99
|
-
|
|
164
|
+
logger.debug(f"Generating with {len(contents)} content parts")
|
|
165
|
+
# If single user message, pass content directly for backward compatibility
|
|
166
|
+
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
167
|
+
response = model.generate_content(
|
|
168
|
+
gen_input,
|
|
169
|
+
generation_config=generation_config if generation_config else None,
|
|
170
|
+
safety_settings=safety_settings if safety_settings else None,
|
|
171
|
+
)
|
|
172
|
+
|
|
100
173
|
if not response.text:
|
|
101
174
|
raise ValueError("Empty response from model")
|
|
102
175
|
|
|
103
176
|
# Calculate token usage and cost
|
|
104
|
-
|
|
105
|
-
prompt_chars = len(prompt)
|
|
177
|
+
total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
|
|
106
178
|
completion_chars = len(response.text)
|
|
107
|
-
|
|
108
|
-
#
|
|
109
|
-
|
|
110
|
-
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
111
|
-
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
112
|
-
total_cost = prompt_cost + completion_cost
|
|
179
|
+
|
|
180
|
+
# Google uses character-based cost estimation
|
|
181
|
+
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
113
182
|
|
|
114
183
|
meta = {
|
|
115
|
-
"prompt_chars":
|
|
184
|
+
"prompt_chars": total_prompt_chars,
|
|
116
185
|
"completion_chars": completion_chars,
|
|
117
|
-
"total_chars":
|
|
186
|
+
"total_chars": total_prompt_chars + completion_chars,
|
|
118
187
|
"cost": total_cost,
|
|
119
|
-
"raw_response": response.prompt_feedback,
|
|
188
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
120
189
|
"model_name": self.model,
|
|
121
190
|
}
|
|
122
191
|
|
|
@@ -124,4 +193,4 @@ class GoogleDriver(Driver):
|
|
|
124
193
|
|
|
125
194
|
except Exception as e:
|
|
126
195
|
logger.error(f"Google API request failed: {e}")
|
|
127
|
-
raise RuntimeError(f"Google API request failed: {e}")
|
|
196
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|