prompture 0.0.38__py3-none-any.whl → 0.0.38.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/drivers/async_azure_driver.py +1 -1
- prompture/drivers/async_claude_driver.py +8 -167
- prompture/drivers/async_google_driver.py +39 -203
- prompture/drivers/async_grok_driver.py +1 -1
- prompture/drivers/async_groq_driver.py +1 -1
- prompture/drivers/async_openai_driver.py +1 -143
- prompture/drivers/async_openrouter_driver.py +1 -1
- prompture/drivers/google_driver.py +43 -207
- {prompture-0.0.38.dist-info → prompture-0.0.38.dev1.dist-info}/METADATA +1 -1
- {prompture-0.0.38.dist-info → prompture-0.0.38.dev1.dist-info}/RECORD +15 -15
- {prompture-0.0.38.dist-info → prompture-0.0.38.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.38.dist-info → prompture-0.0.38.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.38.dist-info → prompture-0.0.38.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.38.dist-info → prompture-0.0.38.dev1.dist-info}/top_level.txt +0 -0
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.38'
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0, 38)
|
|
31
|
+
__version__ = version = '0.0.38.dev1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 38, 'dev1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -113,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
113
113
|
"prompt_tokens": prompt_tokens,
|
|
114
114
|
"completion_tokens": completion_tokens,
|
|
115
115
|
"total_tokens": total_tokens,
|
|
116
|
-
"cost":
|
|
116
|
+
"cost": total_cost,
|
|
117
117
|
"raw_response": resp.model_dump(),
|
|
118
118
|
"model_name": model,
|
|
119
119
|
"deployment_id": self.deployment_id,
|
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
-
from collections.abc import AsyncIterator
|
|
8
7
|
from typing import Any
|
|
9
8
|
|
|
10
9
|
try:
|
|
@@ -20,8 +19,6 @@ from .claude_driver import ClaudeDriver
|
|
|
20
19
|
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
21
20
|
supports_json_mode = True
|
|
22
21
|
supports_json_schema = True
|
|
23
|
-
supports_tool_use = True
|
|
24
|
-
supports_streaming = True
|
|
25
22
|
supports_vision = True
|
|
26
23
|
|
|
27
24
|
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
@@ -54,7 +51,13 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
54
51
|
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
55
52
|
|
|
56
53
|
# Anthropic requires system messages as a top-level parameter
|
|
57
|
-
system_content
|
|
54
|
+
system_content = None
|
|
55
|
+
api_messages = []
|
|
56
|
+
for msg in messages:
|
|
57
|
+
if msg.get("role") == "system":
|
|
58
|
+
system_content = msg.get("content", "")
|
|
59
|
+
else:
|
|
60
|
+
api_messages.append(msg)
|
|
58
61
|
|
|
59
62
|
# Build common kwargs
|
|
60
63
|
common_kwargs: dict[str, Any] = {
|
|
@@ -102,171 +105,9 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
102
105
|
"prompt_tokens": prompt_tokens,
|
|
103
106
|
"completion_tokens": completion_tokens,
|
|
104
107
|
"total_tokens": total_tokens,
|
|
105
|
-
"cost":
|
|
108
|
+
"cost": total_cost,
|
|
106
109
|
"raw_response": dict(resp),
|
|
107
110
|
"model_name": model,
|
|
108
111
|
}
|
|
109
112
|
|
|
110
113
|
return {"text": text, "meta": meta}
|
|
111
|
-
|
|
112
|
-
# ------------------------------------------------------------------
|
|
113
|
-
# Helpers
|
|
114
|
-
# ------------------------------------------------------------------
|
|
115
|
-
|
|
116
|
-
def _extract_system_and_messages(
|
|
117
|
-
self, messages: list[dict[str, Any]]
|
|
118
|
-
) -> tuple[str | None, list[dict[str, Any]]]:
|
|
119
|
-
"""Separate system message from conversation messages for Anthropic API."""
|
|
120
|
-
system_content = None
|
|
121
|
-
api_messages: list[dict[str, Any]] = []
|
|
122
|
-
for msg in messages:
|
|
123
|
-
if msg.get("role") == "system":
|
|
124
|
-
system_content = msg.get("content", "")
|
|
125
|
-
else:
|
|
126
|
-
api_messages.append(msg)
|
|
127
|
-
return system_content, api_messages
|
|
128
|
-
|
|
129
|
-
# ------------------------------------------------------------------
|
|
130
|
-
# Tool use
|
|
131
|
-
# ------------------------------------------------------------------
|
|
132
|
-
|
|
133
|
-
async def generate_messages_with_tools(
|
|
134
|
-
self,
|
|
135
|
-
messages: list[dict[str, Any]],
|
|
136
|
-
tools: list[dict[str, Any]],
|
|
137
|
-
options: dict[str, Any],
|
|
138
|
-
) -> dict[str, Any]:
|
|
139
|
-
"""Generate a response that may include tool calls (Anthropic)."""
|
|
140
|
-
if anthropic is None:
|
|
141
|
-
raise RuntimeError("anthropic package not installed")
|
|
142
|
-
|
|
143
|
-
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
144
|
-
model = options.get("model", self.model)
|
|
145
|
-
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
146
|
-
|
|
147
|
-
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
148
|
-
|
|
149
|
-
# Convert tools from OpenAI format to Anthropic format if needed
|
|
150
|
-
anthropic_tools = []
|
|
151
|
-
for t in tools:
|
|
152
|
-
if "type" in t and t["type"] == "function":
|
|
153
|
-
# OpenAI format -> Anthropic format
|
|
154
|
-
fn = t["function"]
|
|
155
|
-
anthropic_tools.append({
|
|
156
|
-
"name": fn["name"],
|
|
157
|
-
"description": fn.get("description", ""),
|
|
158
|
-
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
159
|
-
})
|
|
160
|
-
elif "input_schema" in t:
|
|
161
|
-
# Already Anthropic format
|
|
162
|
-
anthropic_tools.append(t)
|
|
163
|
-
else:
|
|
164
|
-
anthropic_tools.append(t)
|
|
165
|
-
|
|
166
|
-
kwargs: dict[str, Any] = {
|
|
167
|
-
"model": model,
|
|
168
|
-
"messages": api_messages,
|
|
169
|
-
"temperature": opts["temperature"],
|
|
170
|
-
"max_tokens": opts["max_tokens"],
|
|
171
|
-
"tools": anthropic_tools,
|
|
172
|
-
}
|
|
173
|
-
if system_content:
|
|
174
|
-
kwargs["system"] = system_content
|
|
175
|
-
|
|
176
|
-
resp = await client.messages.create(**kwargs)
|
|
177
|
-
|
|
178
|
-
prompt_tokens = resp.usage.input_tokens
|
|
179
|
-
completion_tokens = resp.usage.output_tokens
|
|
180
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
181
|
-
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
182
|
-
|
|
183
|
-
meta = {
|
|
184
|
-
"prompt_tokens": prompt_tokens,
|
|
185
|
-
"completion_tokens": completion_tokens,
|
|
186
|
-
"total_tokens": total_tokens,
|
|
187
|
-
"cost": round(total_cost, 6),
|
|
188
|
-
"raw_response": dict(resp),
|
|
189
|
-
"model_name": model,
|
|
190
|
-
}
|
|
191
|
-
|
|
192
|
-
text = ""
|
|
193
|
-
tool_calls_out: list[dict[str, Any]] = []
|
|
194
|
-
for block in resp.content:
|
|
195
|
-
if block.type == "text":
|
|
196
|
-
text += block.text
|
|
197
|
-
elif block.type == "tool_use":
|
|
198
|
-
tool_calls_out.append({
|
|
199
|
-
"id": block.id,
|
|
200
|
-
"name": block.name,
|
|
201
|
-
"arguments": block.input,
|
|
202
|
-
})
|
|
203
|
-
|
|
204
|
-
return {
|
|
205
|
-
"text": text,
|
|
206
|
-
"meta": meta,
|
|
207
|
-
"tool_calls": tool_calls_out,
|
|
208
|
-
"stop_reason": resp.stop_reason,
|
|
209
|
-
}
|
|
210
|
-
|
|
211
|
-
# ------------------------------------------------------------------
|
|
212
|
-
# Streaming
|
|
213
|
-
# ------------------------------------------------------------------
|
|
214
|
-
|
|
215
|
-
async def generate_messages_stream(
|
|
216
|
-
self,
|
|
217
|
-
messages: list[dict[str, Any]],
|
|
218
|
-
options: dict[str, Any],
|
|
219
|
-
) -> AsyncIterator[dict[str, Any]]:
|
|
220
|
-
"""Yield response chunks via Anthropic streaming API."""
|
|
221
|
-
if anthropic is None:
|
|
222
|
-
raise RuntimeError("anthropic package not installed")
|
|
223
|
-
|
|
224
|
-
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
225
|
-
model = options.get("model", self.model)
|
|
226
|
-
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
227
|
-
|
|
228
|
-
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
229
|
-
|
|
230
|
-
kwargs: dict[str, Any] = {
|
|
231
|
-
"model": model,
|
|
232
|
-
"messages": api_messages,
|
|
233
|
-
"temperature": opts["temperature"],
|
|
234
|
-
"max_tokens": opts["max_tokens"],
|
|
235
|
-
}
|
|
236
|
-
if system_content:
|
|
237
|
-
kwargs["system"] = system_content
|
|
238
|
-
|
|
239
|
-
full_text = ""
|
|
240
|
-
prompt_tokens = 0
|
|
241
|
-
completion_tokens = 0
|
|
242
|
-
|
|
243
|
-
async with client.messages.stream(**kwargs) as stream:
|
|
244
|
-
async for event in stream:
|
|
245
|
-
if hasattr(event, "type"):
|
|
246
|
-
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
|
247
|
-
delta_text = getattr(event.delta, "text", "")
|
|
248
|
-
if delta_text:
|
|
249
|
-
full_text += delta_text
|
|
250
|
-
yield {"type": "delta", "text": delta_text}
|
|
251
|
-
elif event.type == "message_delta" and hasattr(event, "usage"):
|
|
252
|
-
completion_tokens = getattr(event.usage, "output_tokens", 0)
|
|
253
|
-
elif event.type == "message_start" and hasattr(event, "message"):
|
|
254
|
-
usage = getattr(event.message, "usage", None)
|
|
255
|
-
if usage:
|
|
256
|
-
prompt_tokens = getattr(usage, "input_tokens", 0)
|
|
257
|
-
|
|
258
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
259
|
-
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
260
|
-
|
|
261
|
-
yield {
|
|
262
|
-
"type": "done",
|
|
263
|
-
"text": full_text,
|
|
264
|
-
"meta": {
|
|
265
|
-
"prompt_tokens": prompt_tokens,
|
|
266
|
-
"completion_tokens": completion_tokens,
|
|
267
|
-
"total_tokens": total_tokens,
|
|
268
|
-
"cost": round(total_cost, 6),
|
|
269
|
-
"raw_response": {},
|
|
270
|
-
"model_name": model,
|
|
271
|
-
},
|
|
272
|
-
}
|
|
@@ -4,8 +4,6 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
-
import uuid
|
|
8
|
-
from collections.abc import AsyncIterator
|
|
9
7
|
from typing import Any
|
|
10
8
|
|
|
11
9
|
import google.generativeai as genai
|
|
@@ -23,8 +21,6 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
23
21
|
supports_json_mode = True
|
|
24
22
|
supports_json_schema = True
|
|
25
23
|
supports_vision = True
|
|
26
|
-
supports_tool_use = True
|
|
27
|
-
supports_streaming = True
|
|
28
24
|
|
|
29
25
|
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
30
26
|
_PRICING_UNIT = 1_000_000
|
|
@@ -53,40 +49,6 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
53
49
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
54
50
|
return round(prompt_cost + completion_cost, 6)
|
|
55
51
|
|
|
56
|
-
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
57
|
-
"""Extract token counts from response, falling back to character estimation."""
|
|
58
|
-
usage = getattr(response, "usage_metadata", None)
|
|
59
|
-
if usage:
|
|
60
|
-
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
61
|
-
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
62
|
-
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
63
|
-
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
64
|
-
else:
|
|
65
|
-
# Fallback: estimate from character counts
|
|
66
|
-
total_prompt_chars = 0
|
|
67
|
-
for msg in messages:
|
|
68
|
-
c = msg.get("content", "")
|
|
69
|
-
if isinstance(c, str):
|
|
70
|
-
total_prompt_chars += len(c)
|
|
71
|
-
elif isinstance(c, list):
|
|
72
|
-
for part in c:
|
|
73
|
-
if isinstance(part, str):
|
|
74
|
-
total_prompt_chars += len(part)
|
|
75
|
-
elif isinstance(part, dict) and "text" in part:
|
|
76
|
-
total_prompt_chars += len(part["text"])
|
|
77
|
-
completion_chars = len(response.text) if response.text else 0
|
|
78
|
-
prompt_tokens = total_prompt_chars // 4
|
|
79
|
-
completion_tokens = completion_chars // 4
|
|
80
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
81
|
-
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
82
|
-
|
|
83
|
-
return {
|
|
84
|
-
"prompt_tokens": prompt_tokens,
|
|
85
|
-
"completion_tokens": completion_tokens,
|
|
86
|
-
"total_tokens": total_tokens,
|
|
87
|
-
"cost": round(cost, 6),
|
|
88
|
-
}
|
|
89
|
-
|
|
90
52
|
supports_messages = True
|
|
91
53
|
|
|
92
54
|
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
@@ -94,10 +56,16 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
94
56
|
|
|
95
57
|
return _prepare_google_vision_messages(messages)
|
|
96
58
|
|
|
97
|
-
def
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
59
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
60
|
+
messages = [{"role": "user", "content": prompt}]
|
|
61
|
+
return await self._do_generate(messages, options)
|
|
62
|
+
|
|
63
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
64
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
65
|
+
|
|
66
|
+
async def _do_generate(
|
|
67
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
68
|
+
) -> dict[str, Any]:
|
|
101
69
|
merged_options = self.options.copy()
|
|
102
70
|
if options:
|
|
103
71
|
merged_options.update(options)
|
|
@@ -132,54 +100,47 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
132
100
|
else:
|
|
133
101
|
gemini_role = "model" if role == "assistant" else "user"
|
|
134
102
|
if msg.get("_vision_parts"):
|
|
103
|
+
# Already converted to Gemini parts by _prepare_messages
|
|
135
104
|
contents.append({"role": gemini_role, "parts": content})
|
|
136
105
|
else:
|
|
137
106
|
contents.append({"role": gemini_role, "parts": [content]})
|
|
138
107
|
|
|
139
|
-
# For a single message, unwrap only if it has exactly one string part
|
|
140
|
-
if len(contents) == 1:
|
|
141
|
-
parts = contents[0]["parts"]
|
|
142
|
-
if len(parts) == 1 and isinstance(parts[0], str):
|
|
143
|
-
gen_input = parts[0]
|
|
144
|
-
else:
|
|
145
|
-
gen_input = contents
|
|
146
|
-
else:
|
|
147
|
-
gen_input = contents
|
|
148
|
-
|
|
149
|
-
model_kwargs: dict[str, Any] = {}
|
|
150
|
-
if system_instruction:
|
|
151
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
152
|
-
|
|
153
|
-
gen_kwargs: dict[str, Any] = {
|
|
154
|
-
"generation_config": generation_config if generation_config else None,
|
|
155
|
-
"safety_settings": safety_settings if safety_settings else None,
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
return gen_input, gen_kwargs, model_kwargs
|
|
159
|
-
|
|
160
|
-
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
161
|
-
messages = [{"role": "user", "content": prompt}]
|
|
162
|
-
return await self._do_generate(messages, options)
|
|
163
|
-
|
|
164
|
-
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
-
return await self._do_generate(self._prepare_messages(messages), options)
|
|
166
|
-
|
|
167
|
-
async def _do_generate(
|
|
168
|
-
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
169
|
-
) -> dict[str, Any]:
|
|
170
|
-
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
171
|
-
|
|
172
108
|
try:
|
|
109
|
+
model_kwargs: dict[str, Any] = {}
|
|
110
|
+
if system_instruction:
|
|
111
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
173
112
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
174
|
-
|
|
113
|
+
|
|
114
|
+
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
115
|
+
response = await model.generate_content_async(
|
|
116
|
+
gen_input,
|
|
117
|
+
generation_config=generation_config if generation_config else None,
|
|
118
|
+
safety_settings=safety_settings if safety_settings else None,
|
|
119
|
+
)
|
|
175
120
|
|
|
176
121
|
if not response.text:
|
|
177
122
|
raise ValueError("Empty response from model")
|
|
178
123
|
|
|
179
|
-
|
|
124
|
+
total_prompt_chars = 0
|
|
125
|
+
for msg in messages:
|
|
126
|
+
c = msg.get("content", "")
|
|
127
|
+
if isinstance(c, str):
|
|
128
|
+
total_prompt_chars += len(c)
|
|
129
|
+
elif isinstance(c, list):
|
|
130
|
+
for part in c:
|
|
131
|
+
if isinstance(part, str):
|
|
132
|
+
total_prompt_chars += len(part)
|
|
133
|
+
elif isinstance(part, dict) and "text" in part:
|
|
134
|
+
total_prompt_chars += len(part["text"])
|
|
135
|
+
completion_chars = len(response.text)
|
|
136
|
+
|
|
137
|
+
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
180
138
|
|
|
181
139
|
meta = {
|
|
182
|
-
|
|
140
|
+
"prompt_chars": total_prompt_chars,
|
|
141
|
+
"completion_chars": completion_chars,
|
|
142
|
+
"total_chars": total_prompt_chars + completion_chars,
|
|
143
|
+
"cost": total_cost,
|
|
183
144
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
184
145
|
"model_name": self.model,
|
|
185
146
|
}
|
|
@@ -189,128 +150,3 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
189
150
|
except Exception as e:
|
|
190
151
|
logger.error(f"Google API request failed: {e}")
|
|
191
152
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
192
|
-
|
|
193
|
-
# ------------------------------------------------------------------
|
|
194
|
-
# Tool use
|
|
195
|
-
# ------------------------------------------------------------------
|
|
196
|
-
|
|
197
|
-
async def generate_messages_with_tools(
|
|
198
|
-
self,
|
|
199
|
-
messages: list[dict[str, Any]],
|
|
200
|
-
tools: list[dict[str, Any]],
|
|
201
|
-
options: dict[str, Any],
|
|
202
|
-
) -> dict[str, Any]:
|
|
203
|
-
"""Generate a response that may include tool/function calls (async)."""
|
|
204
|
-
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
205
|
-
self._prepare_messages(messages), options
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# Convert tools from OpenAI format to Gemini function declarations
|
|
209
|
-
function_declarations = []
|
|
210
|
-
for t in tools:
|
|
211
|
-
if "type" in t and t["type"] == "function":
|
|
212
|
-
fn = t["function"]
|
|
213
|
-
decl = {
|
|
214
|
-
"name": fn["name"],
|
|
215
|
-
"description": fn.get("description", ""),
|
|
216
|
-
}
|
|
217
|
-
params = fn.get("parameters")
|
|
218
|
-
if params:
|
|
219
|
-
decl["parameters"] = params
|
|
220
|
-
function_declarations.append(decl)
|
|
221
|
-
elif "name" in t:
|
|
222
|
-
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
223
|
-
params = t.get("parameters") or t.get("input_schema")
|
|
224
|
-
if params:
|
|
225
|
-
decl["parameters"] = params
|
|
226
|
-
function_declarations.append(decl)
|
|
227
|
-
|
|
228
|
-
try:
|
|
229
|
-
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
230
|
-
|
|
231
|
-
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
232
|
-
response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
233
|
-
|
|
234
|
-
usage_meta = self._extract_usage_metadata(response, messages)
|
|
235
|
-
meta = {
|
|
236
|
-
**usage_meta,
|
|
237
|
-
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
238
|
-
"model_name": self.model,
|
|
239
|
-
}
|
|
240
|
-
|
|
241
|
-
text = ""
|
|
242
|
-
tool_calls_out: list[dict[str, Any]] = []
|
|
243
|
-
stop_reason = "stop"
|
|
244
|
-
|
|
245
|
-
for candidate in response.candidates:
|
|
246
|
-
for part in candidate.content.parts:
|
|
247
|
-
if hasattr(part, "text") and part.text:
|
|
248
|
-
text += part.text
|
|
249
|
-
if hasattr(part, "function_call") and part.function_call.name:
|
|
250
|
-
fc = part.function_call
|
|
251
|
-
tool_calls_out.append({
|
|
252
|
-
"id": str(uuid.uuid4()),
|
|
253
|
-
"name": fc.name,
|
|
254
|
-
"arguments": dict(fc.args) if fc.args else {},
|
|
255
|
-
})
|
|
256
|
-
|
|
257
|
-
finish_reason = getattr(candidate, "finish_reason", None)
|
|
258
|
-
if finish_reason is not None:
|
|
259
|
-
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
260
|
-
stop_reason = reason_map.get(finish_reason, "stop")
|
|
261
|
-
|
|
262
|
-
if tool_calls_out:
|
|
263
|
-
stop_reason = "tool_use"
|
|
264
|
-
|
|
265
|
-
return {
|
|
266
|
-
"text": text,
|
|
267
|
-
"meta": meta,
|
|
268
|
-
"tool_calls": tool_calls_out,
|
|
269
|
-
"stop_reason": stop_reason,
|
|
270
|
-
}
|
|
271
|
-
|
|
272
|
-
except Exception as e:
|
|
273
|
-
logger.error(f"Google API tool call request failed: {e}")
|
|
274
|
-
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
275
|
-
|
|
276
|
-
# ------------------------------------------------------------------
|
|
277
|
-
# Streaming
|
|
278
|
-
# ------------------------------------------------------------------
|
|
279
|
-
|
|
280
|
-
async def generate_messages_stream(
|
|
281
|
-
self,
|
|
282
|
-
messages: list[dict[str, Any]],
|
|
283
|
-
options: dict[str, Any],
|
|
284
|
-
) -> AsyncIterator[dict[str, Any]]:
|
|
285
|
-
"""Yield response chunks via Gemini async streaming API."""
|
|
286
|
-
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
287
|
-
self._prepare_messages(messages), options
|
|
288
|
-
)
|
|
289
|
-
|
|
290
|
-
try:
|
|
291
|
-
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
292
|
-
response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
|
|
293
|
-
|
|
294
|
-
full_text = ""
|
|
295
|
-
async for chunk in response:
|
|
296
|
-
chunk_text = getattr(chunk, "text", None) or ""
|
|
297
|
-
if chunk_text:
|
|
298
|
-
full_text += chunk_text
|
|
299
|
-
yield {"type": "delta", "text": chunk_text}
|
|
300
|
-
|
|
301
|
-
# After iteration completes, usage_metadata should be available
|
|
302
|
-
usage_meta = self._extract_usage_metadata(response, messages)
|
|
303
|
-
|
|
304
|
-
yield {
|
|
305
|
-
"type": "done",
|
|
306
|
-
"text": full_text,
|
|
307
|
-
"meta": {
|
|
308
|
-
**usage_meta,
|
|
309
|
-
"raw_response": {},
|
|
310
|
-
"model_name": self.model,
|
|
311
|
-
},
|
|
312
|
-
}
|
|
313
|
-
|
|
314
|
-
except Exception as e:
|
|
315
|
-
logger.error(f"Google API streaming request failed: {e}")
|
|
316
|
-
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -88,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
88
88
|
"prompt_tokens": prompt_tokens,
|
|
89
89
|
"completion_tokens": completion_tokens,
|
|
90
90
|
"total_tokens": total_tokens,
|
|
91
|
-
"cost":
|
|
91
|
+
"cost": total_cost,
|
|
92
92
|
"raw_response": resp,
|
|
93
93
|
"model_name": model,
|
|
94
94
|
}
|
|
@@ -81,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
81
81
|
"prompt_tokens": prompt_tokens,
|
|
82
82
|
"completion_tokens": completion_tokens,
|
|
83
83
|
"total_tokens": total_tokens,
|
|
84
|
-
"cost":
|
|
84
|
+
"cost": total_cost,
|
|
85
85
|
"raw_response": resp.model_dump(),
|
|
86
86
|
"model_name": model,
|
|
87
87
|
}
|
|
@@ -2,9 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import json
|
|
6
5
|
import os
|
|
7
|
-
from collections.abc import AsyncIterator
|
|
8
6
|
from typing import Any
|
|
9
7
|
|
|
10
8
|
try:
|
|
@@ -20,8 +18,6 @@ from .openai_driver import OpenAIDriver
|
|
|
20
18
|
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
21
19
|
supports_json_mode = True
|
|
22
20
|
supports_json_schema = True
|
|
23
|
-
supports_tool_use = True
|
|
24
|
-
supports_streaming = True
|
|
25
21
|
supports_vision = True
|
|
26
22
|
|
|
27
23
|
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
@@ -97,148 +93,10 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
97
93
|
"prompt_tokens": prompt_tokens,
|
|
98
94
|
"completion_tokens": completion_tokens,
|
|
99
95
|
"total_tokens": total_tokens,
|
|
100
|
-
"cost":
|
|
96
|
+
"cost": total_cost,
|
|
101
97
|
"raw_response": resp.model_dump(),
|
|
102
98
|
"model_name": model,
|
|
103
99
|
}
|
|
104
100
|
|
|
105
101
|
text = resp.choices[0].message.content
|
|
106
102
|
return {"text": text, "meta": meta}
|
|
107
|
-
|
|
108
|
-
# ------------------------------------------------------------------
|
|
109
|
-
# Tool use
|
|
110
|
-
# ------------------------------------------------------------------
|
|
111
|
-
|
|
112
|
-
async def generate_messages_with_tools(
|
|
113
|
-
self,
|
|
114
|
-
messages: list[dict[str, Any]],
|
|
115
|
-
tools: list[dict[str, Any]],
|
|
116
|
-
options: dict[str, Any],
|
|
117
|
-
) -> dict[str, Any]:
|
|
118
|
-
"""Generate a response that may include tool calls."""
|
|
119
|
-
if self.client is None:
|
|
120
|
-
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
121
|
-
|
|
122
|
-
model = options.get("model", self.model)
|
|
123
|
-
model_info = self.MODEL_PRICING.get(model, {})
|
|
124
|
-
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
125
|
-
supports_temperature = model_info.get("supports_temperature", True)
|
|
126
|
-
|
|
127
|
-
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
128
|
-
|
|
129
|
-
kwargs: dict[str, Any] = {
|
|
130
|
-
"model": model,
|
|
131
|
-
"messages": messages,
|
|
132
|
-
"tools": tools,
|
|
133
|
-
}
|
|
134
|
-
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
135
|
-
|
|
136
|
-
if supports_temperature and "temperature" in opts:
|
|
137
|
-
kwargs["temperature"] = opts["temperature"]
|
|
138
|
-
|
|
139
|
-
resp = await self.client.chat.completions.create(**kwargs)
|
|
140
|
-
|
|
141
|
-
usage = getattr(resp, "usage", None)
|
|
142
|
-
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
143
|
-
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
144
|
-
total_tokens = getattr(usage, "total_tokens", 0)
|
|
145
|
-
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
146
|
-
|
|
147
|
-
meta = {
|
|
148
|
-
"prompt_tokens": prompt_tokens,
|
|
149
|
-
"completion_tokens": completion_tokens,
|
|
150
|
-
"total_tokens": total_tokens,
|
|
151
|
-
"cost": round(total_cost, 6),
|
|
152
|
-
"raw_response": resp.model_dump(),
|
|
153
|
-
"model_name": model,
|
|
154
|
-
}
|
|
155
|
-
|
|
156
|
-
choice = resp.choices[0]
|
|
157
|
-
text = choice.message.content or ""
|
|
158
|
-
stop_reason = choice.finish_reason
|
|
159
|
-
|
|
160
|
-
tool_calls_out: list[dict[str, Any]] = []
|
|
161
|
-
if choice.message.tool_calls:
|
|
162
|
-
for tc in choice.message.tool_calls:
|
|
163
|
-
try:
|
|
164
|
-
args = json.loads(tc.function.arguments)
|
|
165
|
-
except (json.JSONDecodeError, TypeError):
|
|
166
|
-
args = {}
|
|
167
|
-
tool_calls_out.append({
|
|
168
|
-
"id": tc.id,
|
|
169
|
-
"name": tc.function.name,
|
|
170
|
-
"arguments": args,
|
|
171
|
-
})
|
|
172
|
-
|
|
173
|
-
return {
|
|
174
|
-
"text": text,
|
|
175
|
-
"meta": meta,
|
|
176
|
-
"tool_calls": tool_calls_out,
|
|
177
|
-
"stop_reason": stop_reason,
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
# ------------------------------------------------------------------
|
|
181
|
-
# Streaming
|
|
182
|
-
# ------------------------------------------------------------------
|
|
183
|
-
|
|
184
|
-
async def generate_messages_stream(
|
|
185
|
-
self,
|
|
186
|
-
messages: list[dict[str, Any]],
|
|
187
|
-
options: dict[str, Any],
|
|
188
|
-
) -> AsyncIterator[dict[str, Any]]:
|
|
189
|
-
"""Yield response chunks via OpenAI streaming API."""
|
|
190
|
-
if self.client is None:
|
|
191
|
-
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
192
|
-
|
|
193
|
-
model = options.get("model", self.model)
|
|
194
|
-
model_info = self.MODEL_PRICING.get(model, {})
|
|
195
|
-
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
196
|
-
supports_temperature = model_info.get("supports_temperature", True)
|
|
197
|
-
|
|
198
|
-
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
199
|
-
|
|
200
|
-
kwargs: dict[str, Any] = {
|
|
201
|
-
"model": model,
|
|
202
|
-
"messages": messages,
|
|
203
|
-
"stream": True,
|
|
204
|
-
"stream_options": {"include_usage": True},
|
|
205
|
-
}
|
|
206
|
-
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
207
|
-
|
|
208
|
-
if supports_temperature and "temperature" in opts:
|
|
209
|
-
kwargs["temperature"] = opts["temperature"]
|
|
210
|
-
|
|
211
|
-
stream = await self.client.chat.completions.create(**kwargs)
|
|
212
|
-
|
|
213
|
-
full_text = ""
|
|
214
|
-
prompt_tokens = 0
|
|
215
|
-
completion_tokens = 0
|
|
216
|
-
|
|
217
|
-
async for chunk in stream:
|
|
218
|
-
# Usage comes in the final chunk
|
|
219
|
-
if getattr(chunk, "usage", None):
|
|
220
|
-
prompt_tokens = chunk.usage.prompt_tokens or 0
|
|
221
|
-
completion_tokens = chunk.usage.completion_tokens or 0
|
|
222
|
-
|
|
223
|
-
if chunk.choices:
|
|
224
|
-
delta = chunk.choices[0].delta
|
|
225
|
-
content = getattr(delta, "content", None) or ""
|
|
226
|
-
if content:
|
|
227
|
-
full_text += content
|
|
228
|
-
yield {"type": "delta", "text": content}
|
|
229
|
-
|
|
230
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
231
|
-
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
232
|
-
|
|
233
|
-
yield {
|
|
234
|
-
"type": "done",
|
|
235
|
-
"text": full_text,
|
|
236
|
-
"meta": {
|
|
237
|
-
"prompt_tokens": prompt_tokens,
|
|
238
|
-
"completion_tokens": completion_tokens,
|
|
239
|
-
"total_tokens": total_tokens,
|
|
240
|
-
"cost": round(total_cost, 6),
|
|
241
|
-
"raw_response": {},
|
|
242
|
-
"model_name": model,
|
|
243
|
-
},
|
|
244
|
-
}
|
|
@@ -93,7 +93,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
93
93
|
"prompt_tokens": prompt_tokens,
|
|
94
94
|
"completion_tokens": completion_tokens,
|
|
95
95
|
"total_tokens": total_tokens,
|
|
96
|
-
"cost":
|
|
96
|
+
"cost": total_cost,
|
|
97
97
|
"raw_response": resp,
|
|
98
98
|
"model_name": model,
|
|
99
99
|
}
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
import uuid
|
|
4
|
-
from collections.abc import Iterator
|
|
5
3
|
from typing import Any, Optional
|
|
6
4
|
|
|
7
5
|
import google.generativeai as genai
|
|
@@ -18,8 +16,6 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
18
16
|
supports_json_mode = True
|
|
19
17
|
supports_json_schema = True
|
|
20
18
|
supports_vision = True
|
|
21
|
-
supports_tool_use = True
|
|
22
|
-
supports_streaming = True
|
|
23
19
|
|
|
24
20
|
# Based on current Gemini pricing (as of 2025)
|
|
25
21
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
@@ -110,40 +106,6 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
110
106
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
111
107
|
return round(prompt_cost + completion_cost, 6)
|
|
112
108
|
|
|
113
|
-
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
114
|
-
"""Extract token counts from response, falling back to character estimation."""
|
|
115
|
-
usage = getattr(response, "usage_metadata", None)
|
|
116
|
-
if usage:
|
|
117
|
-
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
118
|
-
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
119
|
-
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
120
|
-
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
121
|
-
else:
|
|
122
|
-
# Fallback: estimate from character counts
|
|
123
|
-
total_prompt_chars = 0
|
|
124
|
-
for msg in messages:
|
|
125
|
-
c = msg.get("content", "")
|
|
126
|
-
if isinstance(c, str):
|
|
127
|
-
total_prompt_chars += len(c)
|
|
128
|
-
elif isinstance(c, list):
|
|
129
|
-
for part in c:
|
|
130
|
-
if isinstance(part, str):
|
|
131
|
-
total_prompt_chars += len(part)
|
|
132
|
-
elif isinstance(part, dict) and "text" in part:
|
|
133
|
-
total_prompt_chars += len(part["text"])
|
|
134
|
-
completion_chars = len(response.text) if response.text else 0
|
|
135
|
-
prompt_tokens = total_prompt_chars // 4
|
|
136
|
-
completion_tokens = completion_chars // 4
|
|
137
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
138
|
-
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
139
|
-
|
|
140
|
-
return {
|
|
141
|
-
"prompt_tokens": prompt_tokens,
|
|
142
|
-
"completion_tokens": completion_tokens,
|
|
143
|
-
"total_tokens": total_tokens,
|
|
144
|
-
"cost": round(cost, 6),
|
|
145
|
-
}
|
|
146
|
-
|
|
147
109
|
supports_messages = True
|
|
148
110
|
|
|
149
111
|
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
@@ -151,21 +113,23 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
151
113
|
|
|
152
114
|
return _prepare_google_vision_messages(messages)
|
|
153
115
|
|
|
154
|
-
def
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
"""Parse messages and options into (gen_input, kwargs) for generate_content.
|
|
116
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
117
|
+
messages = [{"role": "user", "content": prompt}]
|
|
118
|
+
return self._do_generate(messages, options)
|
|
158
119
|
|
|
159
|
-
|
|
160
|
-
(
|
|
161
|
-
|
|
120
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
121
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
122
|
+
|
|
123
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
162
124
|
merged_options = self.options.copy()
|
|
163
125
|
if options:
|
|
164
126
|
merged_options.update(options)
|
|
165
127
|
|
|
128
|
+
# Extract specific options for Google's API
|
|
166
129
|
generation_config = merged_options.get("generation_config", {})
|
|
167
130
|
safety_settings = merged_options.get("safety_settings", {})
|
|
168
131
|
|
|
132
|
+
# Map common options to generation_config if not present
|
|
169
133
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
170
134
|
generation_config["temperature"] = merged_options["temperature"]
|
|
171
135
|
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
@@ -191,57 +155,56 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
191
155
|
if role == "system":
|
|
192
156
|
system_instruction = content if isinstance(content, str) else str(content)
|
|
193
157
|
else:
|
|
158
|
+
# Gemini uses "model" for assistant role
|
|
194
159
|
gemini_role = "model" if role == "assistant" else "user"
|
|
195
160
|
if msg.get("_vision_parts"):
|
|
161
|
+
# Already converted to Gemini parts by _prepare_messages
|
|
196
162
|
contents.append({"role": gemini_role, "parts": content})
|
|
197
163
|
else:
|
|
198
164
|
contents.append({"role": gemini_role, "parts": [content]})
|
|
199
165
|
|
|
200
|
-
# For a single message, unwrap only if it has exactly one string part
|
|
201
|
-
if len(contents) == 1:
|
|
202
|
-
parts = contents[0]["parts"]
|
|
203
|
-
if len(parts) == 1 and isinstance(parts[0], str):
|
|
204
|
-
gen_input = parts[0]
|
|
205
|
-
else:
|
|
206
|
-
gen_input = contents
|
|
207
|
-
else:
|
|
208
|
-
gen_input = contents
|
|
209
|
-
|
|
210
|
-
model_kwargs: dict[str, Any] = {}
|
|
211
|
-
if system_instruction:
|
|
212
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
213
|
-
|
|
214
|
-
gen_kwargs: dict[str, Any] = {
|
|
215
|
-
"generation_config": generation_config if generation_config else None,
|
|
216
|
-
"safety_settings": safety_settings if safety_settings else None,
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
return gen_input, gen_kwargs, model_kwargs
|
|
220
|
-
|
|
221
|
-
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
222
|
-
messages = [{"role": "user", "content": prompt}]
|
|
223
|
-
return self._do_generate(messages, options)
|
|
224
|
-
|
|
225
|
-
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
226
|
-
return self._do_generate(self._prepare_messages(messages), options)
|
|
227
|
-
|
|
228
|
-
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
|
-
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
230
|
-
|
|
231
166
|
try:
|
|
232
167
|
logger.debug(f"Initializing {self.model} for generation")
|
|
168
|
+
model_kwargs: dict[str, Any] = {}
|
|
169
|
+
if system_instruction:
|
|
170
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
233
171
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
234
172
|
|
|
235
|
-
|
|
236
|
-
|
|
173
|
+
# Generate response
|
|
174
|
+
logger.debug(f"Generating with {len(contents)} content parts")
|
|
175
|
+
# If single user message, pass content directly for backward compatibility
|
|
176
|
+
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
177
|
+
response = model.generate_content(
|
|
178
|
+
gen_input,
|
|
179
|
+
generation_config=generation_config if generation_config else None,
|
|
180
|
+
safety_settings=safety_settings if safety_settings else None,
|
|
181
|
+
)
|
|
237
182
|
|
|
238
183
|
if not response.text:
|
|
239
184
|
raise ValueError("Empty response from model")
|
|
240
185
|
|
|
241
|
-
|
|
186
|
+
# Calculate token usage and cost
|
|
187
|
+
total_prompt_chars = 0
|
|
188
|
+
for msg in messages:
|
|
189
|
+
c = msg.get("content", "")
|
|
190
|
+
if isinstance(c, str):
|
|
191
|
+
total_prompt_chars += len(c)
|
|
192
|
+
elif isinstance(c, list):
|
|
193
|
+
for part in c:
|
|
194
|
+
if isinstance(part, str):
|
|
195
|
+
total_prompt_chars += len(part)
|
|
196
|
+
elif isinstance(part, dict) and "text" in part:
|
|
197
|
+
total_prompt_chars += len(part["text"])
|
|
198
|
+
completion_chars = len(response.text)
|
|
199
|
+
|
|
200
|
+
# Google uses character-based cost estimation
|
|
201
|
+
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
242
202
|
|
|
243
203
|
meta = {
|
|
244
|
-
|
|
204
|
+
"prompt_chars": total_prompt_chars,
|
|
205
|
+
"completion_chars": completion_chars,
|
|
206
|
+
"total_chars": total_prompt_chars + completion_chars,
|
|
207
|
+
"cost": total_cost,
|
|
245
208
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
246
209
|
"model_name": self.model,
|
|
247
210
|
}
|
|
@@ -251,130 +214,3 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
251
214
|
except Exception as e:
|
|
252
215
|
logger.error(f"Google API request failed: {e}")
|
|
253
216
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
254
|
-
|
|
255
|
-
# ------------------------------------------------------------------
|
|
256
|
-
# Tool use
|
|
257
|
-
# ------------------------------------------------------------------
|
|
258
|
-
|
|
259
|
-
def generate_messages_with_tools(
|
|
260
|
-
self,
|
|
261
|
-
messages: list[dict[str, Any]],
|
|
262
|
-
tools: list[dict[str, Any]],
|
|
263
|
-
options: dict[str, Any],
|
|
264
|
-
) -> dict[str, Any]:
|
|
265
|
-
"""Generate a response that may include tool/function calls."""
|
|
266
|
-
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
267
|
-
self._prepare_messages(messages), options
|
|
268
|
-
)
|
|
269
|
-
|
|
270
|
-
# Convert tools from OpenAI format to Gemini function declarations
|
|
271
|
-
function_declarations = []
|
|
272
|
-
for t in tools:
|
|
273
|
-
if "type" in t and t["type"] == "function":
|
|
274
|
-
fn = t["function"]
|
|
275
|
-
decl = {
|
|
276
|
-
"name": fn["name"],
|
|
277
|
-
"description": fn.get("description", ""),
|
|
278
|
-
}
|
|
279
|
-
params = fn.get("parameters")
|
|
280
|
-
if params:
|
|
281
|
-
decl["parameters"] = params
|
|
282
|
-
function_declarations.append(decl)
|
|
283
|
-
elif "name" in t:
|
|
284
|
-
# Already in a generic format
|
|
285
|
-
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
286
|
-
params = t.get("parameters") or t.get("input_schema")
|
|
287
|
-
if params:
|
|
288
|
-
decl["parameters"] = params
|
|
289
|
-
function_declarations.append(decl)
|
|
290
|
-
|
|
291
|
-
try:
|
|
292
|
-
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
293
|
-
|
|
294
|
-
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
295
|
-
response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
296
|
-
|
|
297
|
-
usage_meta = self._extract_usage_metadata(response, messages)
|
|
298
|
-
meta = {
|
|
299
|
-
**usage_meta,
|
|
300
|
-
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
301
|
-
"model_name": self.model,
|
|
302
|
-
}
|
|
303
|
-
|
|
304
|
-
text = ""
|
|
305
|
-
tool_calls_out: list[dict[str, Any]] = []
|
|
306
|
-
stop_reason = "stop"
|
|
307
|
-
|
|
308
|
-
for candidate in response.candidates:
|
|
309
|
-
for part in candidate.content.parts:
|
|
310
|
-
if hasattr(part, "text") and part.text:
|
|
311
|
-
text += part.text
|
|
312
|
-
if hasattr(part, "function_call") and part.function_call.name:
|
|
313
|
-
fc = part.function_call
|
|
314
|
-
tool_calls_out.append({
|
|
315
|
-
"id": str(uuid.uuid4()),
|
|
316
|
-
"name": fc.name,
|
|
317
|
-
"arguments": dict(fc.args) if fc.args else {},
|
|
318
|
-
})
|
|
319
|
-
|
|
320
|
-
finish_reason = getattr(candidate, "finish_reason", None)
|
|
321
|
-
if finish_reason is not None:
|
|
322
|
-
# Map Gemini finish reasons to standard stop reasons
|
|
323
|
-
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
324
|
-
stop_reason = reason_map.get(finish_reason, "stop")
|
|
325
|
-
|
|
326
|
-
if tool_calls_out:
|
|
327
|
-
stop_reason = "tool_use"
|
|
328
|
-
|
|
329
|
-
return {
|
|
330
|
-
"text": text,
|
|
331
|
-
"meta": meta,
|
|
332
|
-
"tool_calls": tool_calls_out,
|
|
333
|
-
"stop_reason": stop_reason,
|
|
334
|
-
}
|
|
335
|
-
|
|
336
|
-
except Exception as e:
|
|
337
|
-
logger.error(f"Google API tool call request failed: {e}")
|
|
338
|
-
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
339
|
-
|
|
340
|
-
# ------------------------------------------------------------------
|
|
341
|
-
# Streaming
|
|
342
|
-
# ------------------------------------------------------------------
|
|
343
|
-
|
|
344
|
-
def generate_messages_stream(
|
|
345
|
-
self,
|
|
346
|
-
messages: list[dict[str, Any]],
|
|
347
|
-
options: dict[str, Any],
|
|
348
|
-
) -> Iterator[dict[str, Any]]:
|
|
349
|
-
"""Yield response chunks via Gemini streaming API."""
|
|
350
|
-
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
351
|
-
self._prepare_messages(messages), options
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
try:
|
|
355
|
-
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
356
|
-
response = model.generate_content(gen_input, stream=True, **gen_kwargs)
|
|
357
|
-
|
|
358
|
-
full_text = ""
|
|
359
|
-
for chunk in response:
|
|
360
|
-
chunk_text = getattr(chunk, "text", None) or ""
|
|
361
|
-
if chunk_text:
|
|
362
|
-
full_text += chunk_text
|
|
363
|
-
yield {"type": "delta", "text": chunk_text}
|
|
364
|
-
|
|
365
|
-
# After iteration completes, resolve() has been called on the response
|
|
366
|
-
usage_meta = self._extract_usage_metadata(response, messages)
|
|
367
|
-
|
|
368
|
-
yield {
|
|
369
|
-
"type": "done",
|
|
370
|
-
"text": full_text,
|
|
371
|
-
"meta": {
|
|
372
|
-
**usage_meta,
|
|
373
|
-
"raw_response": {},
|
|
374
|
-
"model_name": self.model,
|
|
375
|
-
},
|
|
376
|
-
}
|
|
377
|
-
|
|
378
|
-
except Exception as e:
|
|
379
|
-
logger.error(f"Google API streaming request failed: {e}")
|
|
380
|
-
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
prompture/__init__.py,sha256=RrpHZlLPpzntUOp2tL2II2DdVxQRoCxY6JBF_b4k3s0,7213
|
|
2
|
-
prompture/_version.py,sha256=
|
|
2
|
+
prompture/_version.py,sha256=RC6_NeerdSjHaWAxl4iygvpfefFayk2zxKyKmOU7s08,719
|
|
3
3
|
prompture/agent.py,sha256=xe_yFHGDzTxaU4tmaLt5AQnzrN0I72hBGwGVrCxg2D0,34704
|
|
4
4
|
prompture/agent_types.py,sha256=Icl16PQI-ThGLMFCU43adtQA6cqETbsPn4KssKBI4xc,4664
|
|
5
5
|
prompture/async_agent.py,sha256=nOLOQCNkg0sKKTpryIiidmIcAAlA3FR2NfnZwrNBuCg,33066
|
|
@@ -35,21 +35,21 @@ prompture/aio/__init__.py,sha256=bKqTu4Jxld16aP_7SP9wU5au45UBIb041ORo4E4HzVo,181
|
|
|
35
35
|
prompture/drivers/__init__.py,sha256=VuEBZPqaQzXLl_Lvn_c5mRlJJrrlObZCLeHaR8n2eJ4,7050
|
|
36
36
|
prompture/drivers/airllm_driver.py,sha256=SaTh7e7Plvuct_TfRqQvsJsKHvvM_3iVqhBtlciM-Kw,3858
|
|
37
37
|
prompture/drivers/async_airllm_driver.py,sha256=1hIWLXfyyIg9tXaOE22tLJvFyNwHnOi1M5BIKnV8ysk,908
|
|
38
|
-
prompture/drivers/async_azure_driver.py,sha256=
|
|
39
|
-
prompture/drivers/async_claude_driver.py,sha256=
|
|
40
|
-
prompture/drivers/async_google_driver.py,sha256=
|
|
41
|
-
prompture/drivers/async_grok_driver.py,sha256=
|
|
42
|
-
prompture/drivers/async_groq_driver.py,sha256=
|
|
38
|
+
prompture/drivers/async_azure_driver.py,sha256=Rqq_5Utgr-lvxMHwlU0B5lwCTtqDhuUW212G9k8P0fQ,4463
|
|
39
|
+
prompture/drivers/async_claude_driver.py,sha256=yB5QLbXD7Uqs4j45yulj73QSJJx1-IyIo84YGA1xjkw,4092
|
|
40
|
+
prompture/drivers/async_google_driver.py,sha256=UL3WtQ2gdVYXPpq_HqzNkOifYiR7GLADr7DOOel1SjI,6634
|
|
41
|
+
prompture/drivers/async_grok_driver.py,sha256=bblcUY5c5NJ_IeuFQ-jHRapGi_WywVgH6SSWWWbUMzo,3546
|
|
42
|
+
prompture/drivers/async_groq_driver.py,sha256=gHvVe4M5VaRcyvonK9FQMLmCuL7i7HV9hwWcRgASUSg,3075
|
|
43
43
|
prompture/drivers/async_hugging_driver.py,sha256=IblxqU6TpNUiigZ0BCgNkAgzpUr2FtPHJOZnOZMnHF0,2152
|
|
44
44
|
prompture/drivers/async_lmstudio_driver.py,sha256=rPn2qVPm6UE2APzAn7ZHYTELUwr0dQMi8XHv6gAhyH8,5782
|
|
45
45
|
prompture/drivers/async_local_http_driver.py,sha256=qoigIf-w3_c2dbVdM6m1e2RMAWP4Gk4VzVs5hM3lPvQ,1609
|
|
46
46
|
prompture/drivers/async_ollama_driver.py,sha256=FaSXtFXrgeVHIe0b90Vg6rGeSTWLpPnjaThh9Ai7qQo,5042
|
|
47
|
-
prompture/drivers/async_openai_driver.py,sha256=
|
|
48
|
-
prompture/drivers/async_openrouter_driver.py,sha256=
|
|
47
|
+
prompture/drivers/async_openai_driver.py,sha256=eLdVYQ8BUErQzVr4Ek1BZ75riMbHMz3ZPm6VQSTNFxk,3572
|
|
48
|
+
prompture/drivers/async_openrouter_driver.py,sha256=VcSYOeBhbzRbzorYh_7K58yWCXB4UO0d6MmpBLf-7lQ,3783
|
|
49
49
|
prompture/drivers/async_registry.py,sha256=syervbb7THneJ-NUVSuxy4cnxGW6VuNzKv-Aqqn2ysU,4329
|
|
50
50
|
prompture/drivers/azure_driver.py,sha256=QZr7HEvgSKT9LOTCtCjuBdHl57yvrnWmeTHtmewuJQY,5727
|
|
51
51
|
prompture/drivers/claude_driver.py,sha256=8XnCBHtk6N_PzHStwxIUlcvekdPN896BqOLShmgxU9k,11536
|
|
52
|
-
prompture/drivers/google_driver.py,sha256=
|
|
52
|
+
prompture/drivers/google_driver.py,sha256=2V2mfWO8TuJTtvOKBW11WM1dicNfYFhBJrt7SsgiBbE,9432
|
|
53
53
|
prompture/drivers/grok_driver.py,sha256=AIwuzNAQyOhmVDA07ISWt2e-rsv5aYk3I5AM4HkLM7o,5294
|
|
54
54
|
prompture/drivers/groq_driver.py,sha256=9cZI21RsgYJTjnrtX2fVA0AadDL-VklhY4ugjDCutwM,4195
|
|
55
55
|
prompture/drivers/hugging_driver.py,sha256=gZir3XnM77VfYIdnu3S1pRftlZJM6G3L8bgGn5esg-Q,2346
|
|
@@ -69,9 +69,9 @@ prompture/scaffold/templates/env.example.j2,sha256=eESKr1KWgyrczO6d-nwAhQwSpf_G-
|
|
|
69
69
|
prompture/scaffold/templates/main.py.j2,sha256=TEgc5OvsZOEX0JthkSW1NI_yLwgoeVN_x97Ibg-vyWY,2632
|
|
70
70
|
prompture/scaffold/templates/models.py.j2,sha256=JrZ99GCVK6TKWapskVRSwCssGrTu5cGZ_r46fOhY2GE,858
|
|
71
71
|
prompture/scaffold/templates/requirements.txt.j2,sha256=m3S5fi1hq9KG9l_9j317rjwWww0a43WMKd8VnUWv2A4,102
|
|
72
|
-
prompture-0.0.38.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
73
|
-
prompture-0.0.38.dist-info/METADATA,sha256=
|
|
74
|
-
prompture-0.0.38.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
75
|
-
prompture-0.0.38.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
76
|
-
prompture-0.0.38.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
77
|
-
prompture-0.0.38.dist-info/RECORD,,
|
|
72
|
+
prompture-0.0.38.dev1.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
73
|
+
prompture-0.0.38.dev1.dist-info/METADATA,sha256=ZDa9mNU6SdEy4IKb7l-wVvR2Tp_bO3RZ8sHshWtq6Y8,10842
|
|
74
|
+
prompture-0.0.38.dev1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
75
|
+
prompture-0.0.38.dev1.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
76
|
+
prompture-0.0.38.dev1.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
77
|
+
prompture-0.0.38.dev1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|