prompture 0.0.38.dev1__py3-none-any.whl → 0.0.38.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/drivers/async_azure_driver.py +1 -1
- prompture/drivers/async_claude_driver.py +167 -8
- prompture/drivers/async_google_driver.py +203 -39
- prompture/drivers/async_grok_driver.py +1 -1
- prompture/drivers/async_groq_driver.py +1 -1
- prompture/drivers/async_openai_driver.py +143 -1
- prompture/drivers/async_openrouter_driver.py +1 -1
- prompture/drivers/google_driver.py +207 -43
- {prompture-0.0.38.dev1.dist-info → prompture-0.0.38.dev3.dist-info}/METADATA +1 -1
- {prompture-0.0.38.dev1.dist-info → prompture-0.0.38.dev3.dist-info}/RECORD +15 -15
- {prompture-0.0.38.dev1.dist-info → prompture-0.0.38.dev3.dist-info}/WHEEL +0 -0
- {prompture-0.0.38.dev1.dist-info → prompture-0.0.38.dev3.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.38.dev1.dist-info → prompture-0.0.38.dev3.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.38.dev1.dist-info → prompture-0.0.38.dev3.dist-info}/top_level.txt +0 -0
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.38.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0, 38, '
|
|
31
|
+
__version__ = version = '0.0.38.dev3'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 38, 'dev3')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -113,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
113
113
|
"prompt_tokens": prompt_tokens,
|
|
114
114
|
"completion_tokens": completion_tokens,
|
|
115
115
|
"total_tokens": total_tokens,
|
|
116
|
-
"cost": total_cost,
|
|
116
|
+
"cost": round(total_cost, 6),
|
|
117
117
|
"raw_response": resp.model_dump(),
|
|
118
118
|
"model_name": model,
|
|
119
119
|
"deployment_id": self.deployment_id,
|
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
7
8
|
from typing import Any
|
|
8
9
|
|
|
9
10
|
try:
|
|
@@ -19,6 +20,8 @@ from .claude_driver import ClaudeDriver
|
|
|
19
20
|
class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
20
21
|
supports_json_mode = True
|
|
21
22
|
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
22
25
|
supports_vision = True
|
|
23
26
|
|
|
24
27
|
MODEL_PRICING = ClaudeDriver.MODEL_PRICING
|
|
@@ -51,13 +54,7 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
51
54
|
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
52
55
|
|
|
53
56
|
# Anthropic requires system messages as a top-level parameter
|
|
54
|
-
system_content =
|
|
55
|
-
api_messages = []
|
|
56
|
-
for msg in messages:
|
|
57
|
-
if msg.get("role") == "system":
|
|
58
|
-
system_content = msg.get("content", "")
|
|
59
|
-
else:
|
|
60
|
-
api_messages.append(msg)
|
|
57
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
61
58
|
|
|
62
59
|
# Build common kwargs
|
|
63
60
|
common_kwargs: dict[str, Any] = {
|
|
@@ -105,9 +102,171 @@ class AsyncClaudeDriver(CostMixin, AsyncDriver):
|
|
|
105
102
|
"prompt_tokens": prompt_tokens,
|
|
106
103
|
"completion_tokens": completion_tokens,
|
|
107
104
|
"total_tokens": total_tokens,
|
|
108
|
-
"cost": total_cost,
|
|
105
|
+
"cost": round(total_cost, 6),
|
|
109
106
|
"raw_response": dict(resp),
|
|
110
107
|
"model_name": model,
|
|
111
108
|
}
|
|
112
109
|
|
|
113
110
|
return {"text": text, "meta": meta}
|
|
111
|
+
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
# Helpers
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
def _extract_system_and_messages(
|
|
117
|
+
self, messages: list[dict[str, Any]]
|
|
118
|
+
) -> tuple[str | None, list[dict[str, Any]]]:
|
|
119
|
+
"""Separate system message from conversation messages for Anthropic API."""
|
|
120
|
+
system_content = None
|
|
121
|
+
api_messages: list[dict[str, Any]] = []
|
|
122
|
+
for msg in messages:
|
|
123
|
+
if msg.get("role") == "system":
|
|
124
|
+
system_content = msg.get("content", "")
|
|
125
|
+
else:
|
|
126
|
+
api_messages.append(msg)
|
|
127
|
+
return system_content, api_messages
|
|
128
|
+
|
|
129
|
+
# ------------------------------------------------------------------
|
|
130
|
+
# Tool use
|
|
131
|
+
# ------------------------------------------------------------------
|
|
132
|
+
|
|
133
|
+
async def generate_messages_with_tools(
|
|
134
|
+
self,
|
|
135
|
+
messages: list[dict[str, Any]],
|
|
136
|
+
tools: list[dict[str, Any]],
|
|
137
|
+
options: dict[str, Any],
|
|
138
|
+
) -> dict[str, Any]:
|
|
139
|
+
"""Generate a response that may include tool calls (Anthropic)."""
|
|
140
|
+
if anthropic is None:
|
|
141
|
+
raise RuntimeError("anthropic package not installed")
|
|
142
|
+
|
|
143
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
144
|
+
model = options.get("model", self.model)
|
|
145
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
146
|
+
|
|
147
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
148
|
+
|
|
149
|
+
# Convert tools from OpenAI format to Anthropic format if needed
|
|
150
|
+
anthropic_tools = []
|
|
151
|
+
for t in tools:
|
|
152
|
+
if "type" in t and t["type"] == "function":
|
|
153
|
+
# OpenAI format -> Anthropic format
|
|
154
|
+
fn = t["function"]
|
|
155
|
+
anthropic_tools.append({
|
|
156
|
+
"name": fn["name"],
|
|
157
|
+
"description": fn.get("description", ""),
|
|
158
|
+
"input_schema": fn.get("parameters", {"type": "object", "properties": {}}),
|
|
159
|
+
})
|
|
160
|
+
elif "input_schema" in t:
|
|
161
|
+
# Already Anthropic format
|
|
162
|
+
anthropic_tools.append(t)
|
|
163
|
+
else:
|
|
164
|
+
anthropic_tools.append(t)
|
|
165
|
+
|
|
166
|
+
kwargs: dict[str, Any] = {
|
|
167
|
+
"model": model,
|
|
168
|
+
"messages": api_messages,
|
|
169
|
+
"temperature": opts["temperature"],
|
|
170
|
+
"max_tokens": opts["max_tokens"],
|
|
171
|
+
"tools": anthropic_tools,
|
|
172
|
+
}
|
|
173
|
+
if system_content:
|
|
174
|
+
kwargs["system"] = system_content
|
|
175
|
+
|
|
176
|
+
resp = await client.messages.create(**kwargs)
|
|
177
|
+
|
|
178
|
+
prompt_tokens = resp.usage.input_tokens
|
|
179
|
+
completion_tokens = resp.usage.output_tokens
|
|
180
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
181
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
182
|
+
|
|
183
|
+
meta = {
|
|
184
|
+
"prompt_tokens": prompt_tokens,
|
|
185
|
+
"completion_tokens": completion_tokens,
|
|
186
|
+
"total_tokens": total_tokens,
|
|
187
|
+
"cost": round(total_cost, 6),
|
|
188
|
+
"raw_response": dict(resp),
|
|
189
|
+
"model_name": model,
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
text = ""
|
|
193
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
194
|
+
for block in resp.content:
|
|
195
|
+
if block.type == "text":
|
|
196
|
+
text += block.text
|
|
197
|
+
elif block.type == "tool_use":
|
|
198
|
+
tool_calls_out.append({
|
|
199
|
+
"id": block.id,
|
|
200
|
+
"name": block.name,
|
|
201
|
+
"arguments": block.input,
|
|
202
|
+
})
|
|
203
|
+
|
|
204
|
+
return {
|
|
205
|
+
"text": text,
|
|
206
|
+
"meta": meta,
|
|
207
|
+
"tool_calls": tool_calls_out,
|
|
208
|
+
"stop_reason": resp.stop_reason,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
# ------------------------------------------------------------------
|
|
212
|
+
# Streaming
|
|
213
|
+
# ------------------------------------------------------------------
|
|
214
|
+
|
|
215
|
+
async def generate_messages_stream(
|
|
216
|
+
self,
|
|
217
|
+
messages: list[dict[str, Any]],
|
|
218
|
+
options: dict[str, Any],
|
|
219
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
220
|
+
"""Yield response chunks via Anthropic streaming API."""
|
|
221
|
+
if anthropic is None:
|
|
222
|
+
raise RuntimeError("anthropic package not installed")
|
|
223
|
+
|
|
224
|
+
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
225
|
+
model = options.get("model", self.model)
|
|
226
|
+
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
|
227
|
+
|
|
228
|
+
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
229
|
+
|
|
230
|
+
kwargs: dict[str, Any] = {
|
|
231
|
+
"model": model,
|
|
232
|
+
"messages": api_messages,
|
|
233
|
+
"temperature": opts["temperature"],
|
|
234
|
+
"max_tokens": opts["max_tokens"],
|
|
235
|
+
}
|
|
236
|
+
if system_content:
|
|
237
|
+
kwargs["system"] = system_content
|
|
238
|
+
|
|
239
|
+
full_text = ""
|
|
240
|
+
prompt_tokens = 0
|
|
241
|
+
completion_tokens = 0
|
|
242
|
+
|
|
243
|
+
async with client.messages.stream(**kwargs) as stream:
|
|
244
|
+
async for event in stream:
|
|
245
|
+
if hasattr(event, "type"):
|
|
246
|
+
if event.type == "content_block_delta" and hasattr(event, "delta"):
|
|
247
|
+
delta_text = getattr(event.delta, "text", "")
|
|
248
|
+
if delta_text:
|
|
249
|
+
full_text += delta_text
|
|
250
|
+
yield {"type": "delta", "text": delta_text}
|
|
251
|
+
elif event.type == "message_delta" and hasattr(event, "usage"):
|
|
252
|
+
completion_tokens = getattr(event.usage, "output_tokens", 0)
|
|
253
|
+
elif event.type == "message_start" and hasattr(event, "message"):
|
|
254
|
+
usage = getattr(event.message, "usage", None)
|
|
255
|
+
if usage:
|
|
256
|
+
prompt_tokens = getattr(usage, "input_tokens", 0)
|
|
257
|
+
|
|
258
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
259
|
+
total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
|
|
260
|
+
|
|
261
|
+
yield {
|
|
262
|
+
"type": "done",
|
|
263
|
+
"text": full_text,
|
|
264
|
+
"meta": {
|
|
265
|
+
"prompt_tokens": prompt_tokens,
|
|
266
|
+
"completion_tokens": completion_tokens,
|
|
267
|
+
"total_tokens": total_tokens,
|
|
268
|
+
"cost": round(total_cost, 6),
|
|
269
|
+
"raw_response": {},
|
|
270
|
+
"model_name": model,
|
|
271
|
+
},
|
|
272
|
+
}
|
|
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
7
9
|
from typing import Any
|
|
8
10
|
|
|
9
11
|
import google.generativeai as genai
|
|
@@ -21,6 +23,8 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
21
23
|
supports_json_mode = True
|
|
22
24
|
supports_json_schema = True
|
|
23
25
|
supports_vision = True
|
|
26
|
+
supports_tool_use = True
|
|
27
|
+
supports_streaming = True
|
|
24
28
|
|
|
25
29
|
MODEL_PRICING = GoogleDriver.MODEL_PRICING
|
|
26
30
|
_PRICING_UNIT = 1_000_000
|
|
@@ -49,6 +53,40 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
49
53
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
50
54
|
return round(prompt_cost + completion_cost, 6)
|
|
51
55
|
|
|
56
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
57
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
58
|
+
usage = getattr(response, "usage_metadata", None)
|
|
59
|
+
if usage:
|
|
60
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
61
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
62
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
63
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
64
|
+
else:
|
|
65
|
+
# Fallback: estimate from character counts
|
|
66
|
+
total_prompt_chars = 0
|
|
67
|
+
for msg in messages:
|
|
68
|
+
c = msg.get("content", "")
|
|
69
|
+
if isinstance(c, str):
|
|
70
|
+
total_prompt_chars += len(c)
|
|
71
|
+
elif isinstance(c, list):
|
|
72
|
+
for part in c:
|
|
73
|
+
if isinstance(part, str):
|
|
74
|
+
total_prompt_chars += len(part)
|
|
75
|
+
elif isinstance(part, dict) and "text" in part:
|
|
76
|
+
total_prompt_chars += len(part["text"])
|
|
77
|
+
completion_chars = len(response.text) if response.text else 0
|
|
78
|
+
prompt_tokens = total_prompt_chars // 4
|
|
79
|
+
completion_tokens = completion_chars // 4
|
|
80
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
81
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"prompt_tokens": prompt_tokens,
|
|
85
|
+
"completion_tokens": completion_tokens,
|
|
86
|
+
"total_tokens": total_tokens,
|
|
87
|
+
"cost": round(cost, 6),
|
|
88
|
+
}
|
|
89
|
+
|
|
52
90
|
supports_messages = True
|
|
53
91
|
|
|
54
92
|
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
@@ -56,16 +94,10 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
56
94
|
|
|
57
95
|
return _prepare_google_vision_messages(messages)
|
|
58
96
|
|
|
59
|
-
|
|
60
|
-
messages
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
64
|
-
return await self._do_generate(self._prepare_messages(messages), options)
|
|
65
|
-
|
|
66
|
-
async def _do_generate(
|
|
67
|
-
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
68
|
-
) -> dict[str, Any]:
|
|
97
|
+
def _build_generation_args(
|
|
98
|
+
self, messages: list[dict[str, Any]], options: dict[str, Any] | None = None
|
|
99
|
+
) -> tuple[Any, dict[str, Any], dict[str, Any]]:
|
|
100
|
+
"""Parse messages and options into (gen_input, gen_kwargs, model_kwargs)."""
|
|
69
101
|
merged_options = self.options.copy()
|
|
70
102
|
if options:
|
|
71
103
|
merged_options.update(options)
|
|
@@ -100,47 +132,54 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
100
132
|
else:
|
|
101
133
|
gemini_role = "model" if role == "assistant" else "user"
|
|
102
134
|
if msg.get("_vision_parts"):
|
|
103
|
-
# Already converted to Gemini parts by _prepare_messages
|
|
104
135
|
contents.append({"role": gemini_role, "parts": content})
|
|
105
136
|
else:
|
|
106
137
|
contents.append({"role": gemini_role, "parts": [content]})
|
|
107
138
|
|
|
139
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
140
|
+
if len(contents) == 1:
|
|
141
|
+
parts = contents[0]["parts"]
|
|
142
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
143
|
+
gen_input = parts[0]
|
|
144
|
+
else:
|
|
145
|
+
gen_input = contents
|
|
146
|
+
else:
|
|
147
|
+
gen_input = contents
|
|
148
|
+
|
|
149
|
+
model_kwargs: dict[str, Any] = {}
|
|
150
|
+
if system_instruction:
|
|
151
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
152
|
+
|
|
153
|
+
gen_kwargs: dict[str, Any] = {
|
|
154
|
+
"generation_config": generation_config if generation_config else None,
|
|
155
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
159
|
+
|
|
160
|
+
async def generate(self, prompt: str, options: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
161
|
+
messages = [{"role": "user", "content": prompt}]
|
|
162
|
+
return await self._do_generate(messages, options)
|
|
163
|
+
|
|
164
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
165
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
166
|
+
|
|
167
|
+
async def _do_generate(
|
|
168
|
+
self, messages: list[dict[str, str]], options: dict[str, Any] | None = None
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
171
|
+
|
|
108
172
|
try:
|
|
109
|
-
model_kwargs: dict[str, Any] = {}
|
|
110
|
-
if system_instruction:
|
|
111
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
112
173
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
113
|
-
|
|
114
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
115
|
-
response = await model.generate_content_async(
|
|
116
|
-
gen_input,
|
|
117
|
-
generation_config=generation_config if generation_config else None,
|
|
118
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
119
|
-
)
|
|
174
|
+
response = await model.generate_content_async(gen_input, **gen_kwargs)
|
|
120
175
|
|
|
121
176
|
if not response.text:
|
|
122
177
|
raise ValueError("Empty response from model")
|
|
123
178
|
|
|
124
|
-
|
|
125
|
-
for msg in messages:
|
|
126
|
-
c = msg.get("content", "")
|
|
127
|
-
if isinstance(c, str):
|
|
128
|
-
total_prompt_chars += len(c)
|
|
129
|
-
elif isinstance(c, list):
|
|
130
|
-
for part in c:
|
|
131
|
-
if isinstance(part, str):
|
|
132
|
-
total_prompt_chars += len(part)
|
|
133
|
-
elif isinstance(part, dict) and "text" in part:
|
|
134
|
-
total_prompt_chars += len(part["text"])
|
|
135
|
-
completion_chars = len(response.text)
|
|
136
|
-
|
|
137
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
179
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
138
180
|
|
|
139
181
|
meta = {
|
|
140
|
-
|
|
141
|
-
"completion_chars": completion_chars,
|
|
142
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
143
|
-
"cost": total_cost,
|
|
182
|
+
**usage_meta,
|
|
144
183
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
145
184
|
"model_name": self.model,
|
|
146
185
|
}
|
|
@@ -150,3 +189,128 @@ class AsyncGoogleDriver(CostMixin, AsyncDriver):
|
|
|
150
189
|
except Exception as e:
|
|
151
190
|
logger.error(f"Google API request failed: {e}")
|
|
152
191
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
192
|
+
|
|
193
|
+
# ------------------------------------------------------------------
|
|
194
|
+
# Tool use
|
|
195
|
+
# ------------------------------------------------------------------
|
|
196
|
+
|
|
197
|
+
async def generate_messages_with_tools(
|
|
198
|
+
self,
|
|
199
|
+
messages: list[dict[str, Any]],
|
|
200
|
+
tools: list[dict[str, Any]],
|
|
201
|
+
options: dict[str, Any],
|
|
202
|
+
) -> dict[str, Any]:
|
|
203
|
+
"""Generate a response that may include tool/function calls (async)."""
|
|
204
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
205
|
+
self._prepare_messages(messages), options
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
209
|
+
function_declarations = []
|
|
210
|
+
for t in tools:
|
|
211
|
+
if "type" in t and t["type"] == "function":
|
|
212
|
+
fn = t["function"]
|
|
213
|
+
decl = {
|
|
214
|
+
"name": fn["name"],
|
|
215
|
+
"description": fn.get("description", ""),
|
|
216
|
+
}
|
|
217
|
+
params = fn.get("parameters")
|
|
218
|
+
if params:
|
|
219
|
+
decl["parameters"] = params
|
|
220
|
+
function_declarations.append(decl)
|
|
221
|
+
elif "name" in t:
|
|
222
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
223
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
224
|
+
if params:
|
|
225
|
+
decl["parameters"] = params
|
|
226
|
+
function_declarations.append(decl)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
230
|
+
|
|
231
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
232
|
+
response = await model.generate_content_async(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
233
|
+
|
|
234
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
235
|
+
meta = {
|
|
236
|
+
**usage_meta,
|
|
237
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
238
|
+
"model_name": self.model,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
text = ""
|
|
242
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
243
|
+
stop_reason = "stop"
|
|
244
|
+
|
|
245
|
+
for candidate in response.candidates:
|
|
246
|
+
for part in candidate.content.parts:
|
|
247
|
+
if hasattr(part, "text") and part.text:
|
|
248
|
+
text += part.text
|
|
249
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
250
|
+
fc = part.function_call
|
|
251
|
+
tool_calls_out.append({
|
|
252
|
+
"id": str(uuid.uuid4()),
|
|
253
|
+
"name": fc.name,
|
|
254
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
255
|
+
})
|
|
256
|
+
|
|
257
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
258
|
+
if finish_reason is not None:
|
|
259
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
260
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
261
|
+
|
|
262
|
+
if tool_calls_out:
|
|
263
|
+
stop_reason = "tool_use"
|
|
264
|
+
|
|
265
|
+
return {
|
|
266
|
+
"text": text,
|
|
267
|
+
"meta": meta,
|
|
268
|
+
"tool_calls": tool_calls_out,
|
|
269
|
+
"stop_reason": stop_reason,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
274
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
275
|
+
|
|
276
|
+
# ------------------------------------------------------------------
|
|
277
|
+
# Streaming
|
|
278
|
+
# ------------------------------------------------------------------
|
|
279
|
+
|
|
280
|
+
async def generate_messages_stream(
|
|
281
|
+
self,
|
|
282
|
+
messages: list[dict[str, Any]],
|
|
283
|
+
options: dict[str, Any],
|
|
284
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
285
|
+
"""Yield response chunks via Gemini async streaming API."""
|
|
286
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
287
|
+
self._prepare_messages(messages), options
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
292
|
+
response = await model.generate_content_async(gen_input, stream=True, **gen_kwargs)
|
|
293
|
+
|
|
294
|
+
full_text = ""
|
|
295
|
+
async for chunk in response:
|
|
296
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
297
|
+
if chunk_text:
|
|
298
|
+
full_text += chunk_text
|
|
299
|
+
yield {"type": "delta", "text": chunk_text}
|
|
300
|
+
|
|
301
|
+
# After iteration completes, usage_metadata should be available
|
|
302
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
303
|
+
|
|
304
|
+
yield {
|
|
305
|
+
"type": "done",
|
|
306
|
+
"text": full_text,
|
|
307
|
+
"meta": {
|
|
308
|
+
**usage_meta,
|
|
309
|
+
"raw_response": {},
|
|
310
|
+
"model_name": self.model,
|
|
311
|
+
},
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
316
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -88,7 +88,7 @@ class AsyncGrokDriver(CostMixin, AsyncDriver):
|
|
|
88
88
|
"prompt_tokens": prompt_tokens,
|
|
89
89
|
"completion_tokens": completion_tokens,
|
|
90
90
|
"total_tokens": total_tokens,
|
|
91
|
-
"cost": total_cost,
|
|
91
|
+
"cost": round(total_cost, 6),
|
|
92
92
|
"raw_response": resp,
|
|
93
93
|
"model_name": model,
|
|
94
94
|
}
|
|
@@ -81,7 +81,7 @@ class AsyncGroqDriver(CostMixin, AsyncDriver):
|
|
|
81
81
|
"prompt_tokens": prompt_tokens,
|
|
82
82
|
"completion_tokens": completion_tokens,
|
|
83
83
|
"total_tokens": total_tokens,
|
|
84
|
-
"cost": total_cost,
|
|
84
|
+
"cost": round(total_cost, 6),
|
|
85
85
|
"raw_response": resp.model_dump(),
|
|
86
86
|
"model_name": model,
|
|
87
87
|
}
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
6
8
|
from typing import Any
|
|
7
9
|
|
|
8
10
|
try:
|
|
@@ -18,6 +20,8 @@ from .openai_driver import OpenAIDriver
|
|
|
18
20
|
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
19
21
|
supports_json_mode = True
|
|
20
22
|
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
21
25
|
supports_vision = True
|
|
22
26
|
|
|
23
27
|
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
@@ -93,10 +97,148 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
93
97
|
"prompt_tokens": prompt_tokens,
|
|
94
98
|
"completion_tokens": completion_tokens,
|
|
95
99
|
"total_tokens": total_tokens,
|
|
96
|
-
"cost": total_cost,
|
|
100
|
+
"cost": round(total_cost, 6),
|
|
97
101
|
"raw_response": resp.model_dump(),
|
|
98
102
|
"model_name": model,
|
|
99
103
|
}
|
|
100
104
|
|
|
101
105
|
text = resp.choices[0].message.content
|
|
102
106
|
return {"text": text, "meta": meta}
|
|
107
|
+
|
|
108
|
+
# ------------------------------------------------------------------
|
|
109
|
+
# Tool use
|
|
110
|
+
# ------------------------------------------------------------------
|
|
111
|
+
|
|
112
|
+
async def generate_messages_with_tools(
|
|
113
|
+
self,
|
|
114
|
+
messages: list[dict[str, Any]],
|
|
115
|
+
tools: list[dict[str, Any]],
|
|
116
|
+
options: dict[str, Any],
|
|
117
|
+
) -> dict[str, Any]:
|
|
118
|
+
"""Generate a response that may include tool calls."""
|
|
119
|
+
if self.client is None:
|
|
120
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
121
|
+
|
|
122
|
+
model = options.get("model", self.model)
|
|
123
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
124
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
125
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
126
|
+
|
|
127
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
128
|
+
|
|
129
|
+
kwargs: dict[str, Any] = {
|
|
130
|
+
"model": model,
|
|
131
|
+
"messages": messages,
|
|
132
|
+
"tools": tools,
|
|
133
|
+
}
|
|
134
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
135
|
+
|
|
136
|
+
if supports_temperature and "temperature" in opts:
|
|
137
|
+
kwargs["temperature"] = opts["temperature"]
|
|
138
|
+
|
|
139
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
140
|
+
|
|
141
|
+
usage = getattr(resp, "usage", None)
|
|
142
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
143
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
144
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
145
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
146
|
+
|
|
147
|
+
meta = {
|
|
148
|
+
"prompt_tokens": prompt_tokens,
|
|
149
|
+
"completion_tokens": completion_tokens,
|
|
150
|
+
"total_tokens": total_tokens,
|
|
151
|
+
"cost": round(total_cost, 6),
|
|
152
|
+
"raw_response": resp.model_dump(),
|
|
153
|
+
"model_name": model,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
choice = resp.choices[0]
|
|
157
|
+
text = choice.message.content or ""
|
|
158
|
+
stop_reason = choice.finish_reason
|
|
159
|
+
|
|
160
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
161
|
+
if choice.message.tool_calls:
|
|
162
|
+
for tc in choice.message.tool_calls:
|
|
163
|
+
try:
|
|
164
|
+
args = json.loads(tc.function.arguments)
|
|
165
|
+
except (json.JSONDecodeError, TypeError):
|
|
166
|
+
args = {}
|
|
167
|
+
tool_calls_out.append({
|
|
168
|
+
"id": tc.id,
|
|
169
|
+
"name": tc.function.name,
|
|
170
|
+
"arguments": args,
|
|
171
|
+
})
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
"text": text,
|
|
175
|
+
"meta": meta,
|
|
176
|
+
"tool_calls": tool_calls_out,
|
|
177
|
+
"stop_reason": stop_reason,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# ------------------------------------------------------------------
|
|
181
|
+
# Streaming
|
|
182
|
+
# ------------------------------------------------------------------
|
|
183
|
+
|
|
184
|
+
async def generate_messages_stream(
|
|
185
|
+
self,
|
|
186
|
+
messages: list[dict[str, Any]],
|
|
187
|
+
options: dict[str, Any],
|
|
188
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
189
|
+
"""Yield response chunks via OpenAI streaming API."""
|
|
190
|
+
if self.client is None:
|
|
191
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
192
|
+
|
|
193
|
+
model = options.get("model", self.model)
|
|
194
|
+
model_info = self.MODEL_PRICING.get(model, {})
|
|
195
|
+
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
196
|
+
supports_temperature = model_info.get("supports_temperature", True)
|
|
197
|
+
|
|
198
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
199
|
+
|
|
200
|
+
kwargs: dict[str, Any] = {
|
|
201
|
+
"model": model,
|
|
202
|
+
"messages": messages,
|
|
203
|
+
"stream": True,
|
|
204
|
+
"stream_options": {"include_usage": True},
|
|
205
|
+
}
|
|
206
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
207
|
+
|
|
208
|
+
if supports_temperature and "temperature" in opts:
|
|
209
|
+
kwargs["temperature"] = opts["temperature"]
|
|
210
|
+
|
|
211
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
|
212
|
+
|
|
213
|
+
full_text = ""
|
|
214
|
+
prompt_tokens = 0
|
|
215
|
+
completion_tokens = 0
|
|
216
|
+
|
|
217
|
+
async for chunk in stream:
|
|
218
|
+
# Usage comes in the final chunk
|
|
219
|
+
if getattr(chunk, "usage", None):
|
|
220
|
+
prompt_tokens = chunk.usage.prompt_tokens or 0
|
|
221
|
+
completion_tokens = chunk.usage.completion_tokens or 0
|
|
222
|
+
|
|
223
|
+
if chunk.choices:
|
|
224
|
+
delta = chunk.choices[0].delta
|
|
225
|
+
content = getattr(delta, "content", None) or ""
|
|
226
|
+
if content:
|
|
227
|
+
full_text += content
|
|
228
|
+
yield {"type": "delta", "text": content}
|
|
229
|
+
|
|
230
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
231
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
232
|
+
|
|
233
|
+
yield {
|
|
234
|
+
"type": "done",
|
|
235
|
+
"text": full_text,
|
|
236
|
+
"meta": {
|
|
237
|
+
"prompt_tokens": prompt_tokens,
|
|
238
|
+
"completion_tokens": completion_tokens,
|
|
239
|
+
"total_tokens": total_tokens,
|
|
240
|
+
"cost": round(total_cost, 6),
|
|
241
|
+
"raw_response": {},
|
|
242
|
+
"model_name": model,
|
|
243
|
+
},
|
|
244
|
+
}
|
|
@@ -93,7 +93,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
93
93
|
"prompt_tokens": prompt_tokens,
|
|
94
94
|
"completion_tokens": completion_tokens,
|
|
95
95
|
"total_tokens": total_tokens,
|
|
96
|
-
"cost": total_cost,
|
|
96
|
+
"cost": round(total_cost, 6),
|
|
97
97
|
"raw_response": resp,
|
|
98
98
|
"model_name": model,
|
|
99
99
|
}
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Iterator
|
|
3
5
|
from typing import Any, Optional
|
|
4
6
|
|
|
5
7
|
import google.generativeai as genai
|
|
@@ -16,6 +18,8 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
16
18
|
supports_json_mode = True
|
|
17
19
|
supports_json_schema = True
|
|
18
20
|
supports_vision = True
|
|
21
|
+
supports_tool_use = True
|
|
22
|
+
supports_streaming = True
|
|
19
23
|
|
|
20
24
|
# Based on current Gemini pricing (as of 2025)
|
|
21
25
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
@@ -106,6 +110,40 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
106
110
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
107
111
|
return round(prompt_cost + completion_cost, 6)
|
|
108
112
|
|
|
113
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
114
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
115
|
+
usage = getattr(response, "usage_metadata", None)
|
|
116
|
+
if usage:
|
|
117
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
118
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
119
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
120
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
121
|
+
else:
|
|
122
|
+
# Fallback: estimate from character counts
|
|
123
|
+
total_prompt_chars = 0
|
|
124
|
+
for msg in messages:
|
|
125
|
+
c = msg.get("content", "")
|
|
126
|
+
if isinstance(c, str):
|
|
127
|
+
total_prompt_chars += len(c)
|
|
128
|
+
elif isinstance(c, list):
|
|
129
|
+
for part in c:
|
|
130
|
+
if isinstance(part, str):
|
|
131
|
+
total_prompt_chars += len(part)
|
|
132
|
+
elif isinstance(part, dict) and "text" in part:
|
|
133
|
+
total_prompt_chars += len(part["text"])
|
|
134
|
+
completion_chars = len(response.text) if response.text else 0
|
|
135
|
+
prompt_tokens = total_prompt_chars // 4
|
|
136
|
+
completion_tokens = completion_chars // 4
|
|
137
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
138
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"prompt_tokens": prompt_tokens,
|
|
142
|
+
"completion_tokens": completion_tokens,
|
|
143
|
+
"total_tokens": total_tokens,
|
|
144
|
+
"cost": round(cost, 6),
|
|
145
|
+
}
|
|
146
|
+
|
|
109
147
|
supports_messages = True
|
|
110
148
|
|
|
111
149
|
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
@@ -113,23 +151,21 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
113
151
|
|
|
114
152
|
return _prepare_google_vision_messages(messages)
|
|
115
153
|
|
|
116
|
-
def
|
|
117
|
-
messages
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
121
|
-
return self._do_generate(self._prepare_messages(messages), options)
|
|
154
|
+
def _build_generation_args(
|
|
155
|
+
self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
|
|
156
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
157
|
+
"""Parse messages and options into (gen_input, kwargs) for generate_content.
|
|
122
158
|
|
|
123
|
-
|
|
159
|
+
Returns the content input and a dict of keyword arguments
|
|
160
|
+
(generation_config, safety_settings, model kwargs including system_instruction).
|
|
161
|
+
"""
|
|
124
162
|
merged_options = self.options.copy()
|
|
125
163
|
if options:
|
|
126
164
|
merged_options.update(options)
|
|
127
165
|
|
|
128
|
-
# Extract specific options for Google's API
|
|
129
166
|
generation_config = merged_options.get("generation_config", {})
|
|
130
167
|
safety_settings = merged_options.get("safety_settings", {})
|
|
131
168
|
|
|
132
|
-
# Map common options to generation_config if not present
|
|
133
169
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
134
170
|
generation_config["temperature"] = merged_options["temperature"]
|
|
135
171
|
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
@@ -155,56 +191,57 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
155
191
|
if role == "system":
|
|
156
192
|
system_instruction = content if isinstance(content, str) else str(content)
|
|
157
193
|
else:
|
|
158
|
-
# Gemini uses "model" for assistant role
|
|
159
194
|
gemini_role = "model" if role == "assistant" else "user"
|
|
160
195
|
if msg.get("_vision_parts"):
|
|
161
|
-
# Already converted to Gemini parts by _prepare_messages
|
|
162
196
|
contents.append({"role": gemini_role, "parts": content})
|
|
163
197
|
else:
|
|
164
198
|
contents.append({"role": gemini_role, "parts": [content]})
|
|
165
199
|
|
|
200
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
201
|
+
if len(contents) == 1:
|
|
202
|
+
parts = contents[0]["parts"]
|
|
203
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
204
|
+
gen_input = parts[0]
|
|
205
|
+
else:
|
|
206
|
+
gen_input = contents
|
|
207
|
+
else:
|
|
208
|
+
gen_input = contents
|
|
209
|
+
|
|
210
|
+
model_kwargs: dict[str, Any] = {}
|
|
211
|
+
if system_instruction:
|
|
212
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
213
|
+
|
|
214
|
+
gen_kwargs: dict[str, Any] = {
|
|
215
|
+
"generation_config": generation_config if generation_config else None,
|
|
216
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
220
|
+
|
|
221
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
222
|
+
messages = [{"role": "user", "content": prompt}]
|
|
223
|
+
return self._do_generate(messages, options)
|
|
224
|
+
|
|
225
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
226
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
227
|
+
|
|
228
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
230
|
+
|
|
166
231
|
try:
|
|
167
232
|
logger.debug(f"Initializing {self.model} for generation")
|
|
168
|
-
model_kwargs: dict[str, Any] = {}
|
|
169
|
-
if system_instruction:
|
|
170
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
171
233
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
172
234
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
# If single user message, pass content directly for backward compatibility
|
|
176
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
177
|
-
response = model.generate_content(
|
|
178
|
-
gen_input,
|
|
179
|
-
generation_config=generation_config if generation_config else None,
|
|
180
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
181
|
-
)
|
|
235
|
+
logger.debug(f"Generating with model {self.model}")
|
|
236
|
+
response = model.generate_content(gen_input, **gen_kwargs)
|
|
182
237
|
|
|
183
238
|
if not response.text:
|
|
184
239
|
raise ValueError("Empty response from model")
|
|
185
240
|
|
|
186
|
-
|
|
187
|
-
total_prompt_chars = 0
|
|
188
|
-
for msg in messages:
|
|
189
|
-
c = msg.get("content", "")
|
|
190
|
-
if isinstance(c, str):
|
|
191
|
-
total_prompt_chars += len(c)
|
|
192
|
-
elif isinstance(c, list):
|
|
193
|
-
for part in c:
|
|
194
|
-
if isinstance(part, str):
|
|
195
|
-
total_prompt_chars += len(part)
|
|
196
|
-
elif isinstance(part, dict) and "text" in part:
|
|
197
|
-
total_prompt_chars += len(part["text"])
|
|
198
|
-
completion_chars = len(response.text)
|
|
199
|
-
|
|
200
|
-
# Google uses character-based cost estimation
|
|
201
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
241
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
202
242
|
|
|
203
243
|
meta = {
|
|
204
|
-
|
|
205
|
-
"completion_chars": completion_chars,
|
|
206
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
207
|
-
"cost": total_cost,
|
|
244
|
+
**usage_meta,
|
|
208
245
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
209
246
|
"model_name": self.model,
|
|
210
247
|
}
|
|
@@ -214,3 +251,130 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
214
251
|
except Exception as e:
|
|
215
252
|
logger.error(f"Google API request failed: {e}")
|
|
216
253
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
254
|
+
|
|
255
|
+
# ------------------------------------------------------------------
|
|
256
|
+
# Tool use
|
|
257
|
+
# ------------------------------------------------------------------
|
|
258
|
+
|
|
259
|
+
def generate_messages_with_tools(
|
|
260
|
+
self,
|
|
261
|
+
messages: list[dict[str, Any]],
|
|
262
|
+
tools: list[dict[str, Any]],
|
|
263
|
+
options: dict[str, Any],
|
|
264
|
+
) -> dict[str, Any]:
|
|
265
|
+
"""Generate a response that may include tool/function calls."""
|
|
266
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
267
|
+
self._prepare_messages(messages), options
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
271
|
+
function_declarations = []
|
|
272
|
+
for t in tools:
|
|
273
|
+
if "type" in t and t["type"] == "function":
|
|
274
|
+
fn = t["function"]
|
|
275
|
+
decl = {
|
|
276
|
+
"name": fn["name"],
|
|
277
|
+
"description": fn.get("description", ""),
|
|
278
|
+
}
|
|
279
|
+
params = fn.get("parameters")
|
|
280
|
+
if params:
|
|
281
|
+
decl["parameters"] = params
|
|
282
|
+
function_declarations.append(decl)
|
|
283
|
+
elif "name" in t:
|
|
284
|
+
# Already in a generic format
|
|
285
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
286
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
287
|
+
if params:
|
|
288
|
+
decl["parameters"] = params
|
|
289
|
+
function_declarations.append(decl)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
293
|
+
|
|
294
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
295
|
+
response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
296
|
+
|
|
297
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
298
|
+
meta = {
|
|
299
|
+
**usage_meta,
|
|
300
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
301
|
+
"model_name": self.model,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
text = ""
|
|
305
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
306
|
+
stop_reason = "stop"
|
|
307
|
+
|
|
308
|
+
for candidate in response.candidates:
|
|
309
|
+
for part in candidate.content.parts:
|
|
310
|
+
if hasattr(part, "text") and part.text:
|
|
311
|
+
text += part.text
|
|
312
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
313
|
+
fc = part.function_call
|
|
314
|
+
tool_calls_out.append({
|
|
315
|
+
"id": str(uuid.uuid4()),
|
|
316
|
+
"name": fc.name,
|
|
317
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
318
|
+
})
|
|
319
|
+
|
|
320
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
321
|
+
if finish_reason is not None:
|
|
322
|
+
# Map Gemini finish reasons to standard stop reasons
|
|
323
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
324
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
325
|
+
|
|
326
|
+
if tool_calls_out:
|
|
327
|
+
stop_reason = "tool_use"
|
|
328
|
+
|
|
329
|
+
return {
|
|
330
|
+
"text": text,
|
|
331
|
+
"meta": meta,
|
|
332
|
+
"tool_calls": tool_calls_out,
|
|
333
|
+
"stop_reason": stop_reason,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
338
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
339
|
+
|
|
340
|
+
# ------------------------------------------------------------------
|
|
341
|
+
# Streaming
|
|
342
|
+
# ------------------------------------------------------------------
|
|
343
|
+
|
|
344
|
+
def generate_messages_stream(
|
|
345
|
+
self,
|
|
346
|
+
messages: list[dict[str, Any]],
|
|
347
|
+
options: dict[str, Any],
|
|
348
|
+
) -> Iterator[dict[str, Any]]:
|
|
349
|
+
"""Yield response chunks via Gemini streaming API."""
|
|
350
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
351
|
+
self._prepare_messages(messages), options
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
356
|
+
response = model.generate_content(gen_input, stream=True, **gen_kwargs)
|
|
357
|
+
|
|
358
|
+
full_text = ""
|
|
359
|
+
for chunk in response:
|
|
360
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
361
|
+
if chunk_text:
|
|
362
|
+
full_text += chunk_text
|
|
363
|
+
yield {"type": "delta", "text": chunk_text}
|
|
364
|
+
|
|
365
|
+
# After iteration completes, resolve() has been called on the response
|
|
366
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
367
|
+
|
|
368
|
+
yield {
|
|
369
|
+
"type": "done",
|
|
370
|
+
"text": full_text,
|
|
371
|
+
"meta": {
|
|
372
|
+
**usage_meta,
|
|
373
|
+
"raw_response": {},
|
|
374
|
+
"model_name": self.model,
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
380
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
prompture/__init__.py,sha256=RrpHZlLPpzntUOp2tL2II2DdVxQRoCxY6JBF_b4k3s0,7213
|
|
2
|
-
prompture/_version.py,sha256=
|
|
2
|
+
prompture/_version.py,sha256=e1uep7-PEqCFbKHaF3uTPcu4UaXdHJjkYrnGcuFmFZM,719
|
|
3
3
|
prompture/agent.py,sha256=xe_yFHGDzTxaU4tmaLt5AQnzrN0I72hBGwGVrCxg2D0,34704
|
|
4
4
|
prompture/agent_types.py,sha256=Icl16PQI-ThGLMFCU43adtQA6cqETbsPn4KssKBI4xc,4664
|
|
5
5
|
prompture/async_agent.py,sha256=nOLOQCNkg0sKKTpryIiidmIcAAlA3FR2NfnZwrNBuCg,33066
|
|
@@ -35,21 +35,21 @@ prompture/aio/__init__.py,sha256=bKqTu4Jxld16aP_7SP9wU5au45UBIb041ORo4E4HzVo,181
|
|
|
35
35
|
prompture/drivers/__init__.py,sha256=VuEBZPqaQzXLl_Lvn_c5mRlJJrrlObZCLeHaR8n2eJ4,7050
|
|
36
36
|
prompture/drivers/airllm_driver.py,sha256=SaTh7e7Plvuct_TfRqQvsJsKHvvM_3iVqhBtlciM-Kw,3858
|
|
37
37
|
prompture/drivers/async_airllm_driver.py,sha256=1hIWLXfyyIg9tXaOE22tLJvFyNwHnOi1M5BIKnV8ysk,908
|
|
38
|
-
prompture/drivers/async_azure_driver.py,sha256=
|
|
39
|
-
prompture/drivers/async_claude_driver.py,sha256=
|
|
40
|
-
prompture/drivers/async_google_driver.py,sha256=
|
|
41
|
-
prompture/drivers/async_grok_driver.py,sha256=
|
|
42
|
-
prompture/drivers/async_groq_driver.py,sha256=
|
|
38
|
+
prompture/drivers/async_azure_driver.py,sha256=lGZICROspP2_o2XlwIZZvrCDenSJZPNYTu7clCgRD68,4473
|
|
39
|
+
prompture/drivers/async_claude_driver.py,sha256=dbUHH2EEotxUWz8cTXVCWtf4ExtiLv3FzzNenvHSVVI,10275
|
|
40
|
+
prompture/drivers/async_google_driver.py,sha256=MIemYcE0ppSWfvVaxv4V-Tqjmy6BKO7sRG6UfZqtdV8,13349
|
|
41
|
+
prompture/drivers/async_grok_driver.py,sha256=fvqEK-mrAx4U4_0C1RePGdZ-TUmQI9Qvj-x1f_uGI5c,3556
|
|
42
|
+
prompture/drivers/async_groq_driver.py,sha256=PEAAj7QHjVqT9UtLfnFY4i__Mk-QpngmHGvbaBNEUrE,3085
|
|
43
43
|
prompture/drivers/async_hugging_driver.py,sha256=IblxqU6TpNUiigZ0BCgNkAgzpUr2FtPHJOZnOZMnHF0,2152
|
|
44
44
|
prompture/drivers/async_lmstudio_driver.py,sha256=rPn2qVPm6UE2APzAn7ZHYTELUwr0dQMi8XHv6gAhyH8,5782
|
|
45
45
|
prompture/drivers/async_local_http_driver.py,sha256=qoigIf-w3_c2dbVdM6m1e2RMAWP4Gk4VzVs5hM3lPvQ,1609
|
|
46
46
|
prompture/drivers/async_ollama_driver.py,sha256=FaSXtFXrgeVHIe0b90Vg6rGeSTWLpPnjaThh9Ai7qQo,5042
|
|
47
|
-
prompture/drivers/async_openai_driver.py,sha256=
|
|
48
|
-
prompture/drivers/async_openrouter_driver.py,sha256=
|
|
47
|
+
prompture/drivers/async_openai_driver.py,sha256=6p538rPlfAWhsTZ5HKAg8KEW1xM4WEFzXVPZsigz_P4,8704
|
|
48
|
+
prompture/drivers/async_openrouter_driver.py,sha256=qvvwJADjnEj6J9f8m0eGlfWTBEm6oXTjwrgt_Im4K7w,3793
|
|
49
49
|
prompture/drivers/async_registry.py,sha256=syervbb7THneJ-NUVSuxy4cnxGW6VuNzKv-Aqqn2ysU,4329
|
|
50
50
|
prompture/drivers/azure_driver.py,sha256=QZr7HEvgSKT9LOTCtCjuBdHl57yvrnWmeTHtmewuJQY,5727
|
|
51
51
|
prompture/drivers/claude_driver.py,sha256=8XnCBHtk6N_PzHStwxIUlcvekdPN896BqOLShmgxU9k,11536
|
|
52
|
-
prompture/drivers/google_driver.py,sha256=
|
|
52
|
+
prompture/drivers/google_driver.py,sha256=8bnAcve1xtgpUXrCdVzWpU_yAqwaeuiBWk8-PbG1cmM,15956
|
|
53
53
|
prompture/drivers/grok_driver.py,sha256=AIwuzNAQyOhmVDA07ISWt2e-rsv5aYk3I5AM4HkLM7o,5294
|
|
54
54
|
prompture/drivers/groq_driver.py,sha256=9cZI21RsgYJTjnrtX2fVA0AadDL-VklhY4ugjDCutwM,4195
|
|
55
55
|
prompture/drivers/hugging_driver.py,sha256=gZir3XnM77VfYIdnu3S1pRftlZJM6G3L8bgGn5esg-Q,2346
|
|
@@ -69,9 +69,9 @@ prompture/scaffold/templates/env.example.j2,sha256=eESKr1KWgyrczO6d-nwAhQwSpf_G-
|
|
|
69
69
|
prompture/scaffold/templates/main.py.j2,sha256=TEgc5OvsZOEX0JthkSW1NI_yLwgoeVN_x97Ibg-vyWY,2632
|
|
70
70
|
prompture/scaffold/templates/models.py.j2,sha256=JrZ99GCVK6TKWapskVRSwCssGrTu5cGZ_r46fOhY2GE,858
|
|
71
71
|
prompture/scaffold/templates/requirements.txt.j2,sha256=m3S5fi1hq9KG9l_9j317rjwWww0a43WMKd8VnUWv2A4,102
|
|
72
|
-
prompture-0.0.38.
|
|
73
|
-
prompture-0.0.38.
|
|
74
|
-
prompture-0.0.38.
|
|
75
|
-
prompture-0.0.38.
|
|
76
|
-
prompture-0.0.38.
|
|
77
|
-
prompture-0.0.38.
|
|
72
|
+
prompture-0.0.38.dev3.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
73
|
+
prompture-0.0.38.dev3.dist-info/METADATA,sha256=ejIH91dOyVKrmJ4nKEbsutiI5Gb2xMRiqKuhzgz04Kw,10842
|
|
74
|
+
prompture-0.0.38.dev3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
75
|
+
prompture-0.0.38.dev3.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
76
|
+
prompture-0.0.38.dev3.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
77
|
+
prompture-0.0.38.dev3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|