prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +132 -3
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +208 -17
- prompture/async_core.py +16 -0
- prompture/async_driver.py +63 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +222 -18
- prompture/core.py +46 -12
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +132 -44
- prompture/driver.py +77 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +11 -5
- prompture/drivers/async_claude_driver.py +184 -9
- prompture/drivers/async_google_driver.py +222 -28
- prompture/drivers/async_grok_driver.py +11 -5
- prompture/drivers/async_groq_driver.py +11 -5
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +162 -5
- prompture/drivers/async_openrouter_driver.py +11 -5
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +17 -1
- prompture/drivers/google_driver.py +227 -33
- prompture/drivers/grok_driver.py +11 -5
- prompture/drivers/groq_driver.py +11 -5
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +26 -11
- prompture/drivers/openrouter_driver.py +11 -5
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.40.dev1.dist-info/METADATA +369 -0
- prompture-0.0.40.dev1.dist-info/RECORD +78 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
6
8
|
from typing import Any
|
|
7
9
|
|
|
8
10
|
try:
|
|
@@ -18,6 +20,9 @@ from .openai_driver import OpenAIDriver
|
|
|
18
20
|
class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
19
21
|
supports_json_mode = True
|
|
20
22
|
supports_json_schema = True
|
|
23
|
+
supports_tool_use = True
|
|
24
|
+
supports_streaming = True
|
|
25
|
+
supports_vision = True
|
|
21
26
|
|
|
22
27
|
MODEL_PRICING = OpenAIDriver.MODEL_PRICING
|
|
23
28
|
|
|
@@ -31,12 +36,17 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
31
36
|
|
|
32
37
|
supports_messages = True
|
|
33
38
|
|
|
39
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
40
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
41
|
+
|
|
42
|
+
return _prepare_openai_vision_messages(messages)
|
|
43
|
+
|
|
34
44
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
45
|
messages = [{"role": "user", "content": prompt}]
|
|
36
46
|
return await self._do_generate(messages, options)
|
|
37
47
|
|
|
38
48
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
-
return await self._do_generate(messages, options)
|
|
49
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
40
50
|
|
|
41
51
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
52
|
if self.client is None:
|
|
@@ -44,9 +54,16 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
44
54
|
|
|
45
55
|
model = options.get("model", self.model)
|
|
46
56
|
|
|
47
|
-
|
|
48
|
-
tokens_param =
|
|
49
|
-
supports_temperature =
|
|
57
|
+
model_config = self._get_model_config("openai", model)
|
|
58
|
+
tokens_param = model_config["tokens_param"]
|
|
59
|
+
supports_temperature = model_config["supports_temperature"]
|
|
60
|
+
|
|
61
|
+
# Validate capabilities against models.dev metadata
|
|
62
|
+
self._validate_model_capabilities(
|
|
63
|
+
"openai",
|
|
64
|
+
model,
|
|
65
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
66
|
+
)
|
|
50
67
|
|
|
51
68
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
52
69
|
|
|
@@ -87,10 +104,150 @@ class AsyncOpenAIDriver(CostMixin, AsyncDriver):
|
|
|
87
104
|
"prompt_tokens": prompt_tokens,
|
|
88
105
|
"completion_tokens": completion_tokens,
|
|
89
106
|
"total_tokens": total_tokens,
|
|
90
|
-
"cost": total_cost,
|
|
107
|
+
"cost": round(total_cost, 6),
|
|
91
108
|
"raw_response": resp.model_dump(),
|
|
92
109
|
"model_name": model,
|
|
93
110
|
}
|
|
94
111
|
|
|
95
112
|
text = resp.choices[0].message.content
|
|
96
113
|
return {"text": text, "meta": meta}
|
|
114
|
+
|
|
115
|
+
# ------------------------------------------------------------------
|
|
116
|
+
# Tool use
|
|
117
|
+
# ------------------------------------------------------------------
|
|
118
|
+
|
|
119
|
+
async def generate_messages_with_tools(
|
|
120
|
+
self,
|
|
121
|
+
messages: list[dict[str, Any]],
|
|
122
|
+
tools: list[dict[str, Any]],
|
|
123
|
+
options: dict[str, Any],
|
|
124
|
+
) -> dict[str, Any]:
|
|
125
|
+
"""Generate a response that may include tool calls."""
|
|
126
|
+
if self.client is None:
|
|
127
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
128
|
+
|
|
129
|
+
model = options.get("model", self.model)
|
|
130
|
+
model_config = self._get_model_config("openai", model)
|
|
131
|
+
tokens_param = model_config["tokens_param"]
|
|
132
|
+
supports_temperature = model_config["supports_temperature"]
|
|
133
|
+
|
|
134
|
+
self._validate_model_capabilities("openai", model, using_tool_use=True)
|
|
135
|
+
|
|
136
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
137
|
+
|
|
138
|
+
kwargs: dict[str, Any] = {
|
|
139
|
+
"model": model,
|
|
140
|
+
"messages": messages,
|
|
141
|
+
"tools": tools,
|
|
142
|
+
}
|
|
143
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
144
|
+
|
|
145
|
+
if supports_temperature and "temperature" in opts:
|
|
146
|
+
kwargs["temperature"] = opts["temperature"]
|
|
147
|
+
|
|
148
|
+
resp = await self.client.chat.completions.create(**kwargs)
|
|
149
|
+
|
|
150
|
+
usage = getattr(resp, "usage", None)
|
|
151
|
+
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
152
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
153
|
+
total_tokens = getattr(usage, "total_tokens", 0)
|
|
154
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
155
|
+
|
|
156
|
+
meta = {
|
|
157
|
+
"prompt_tokens": prompt_tokens,
|
|
158
|
+
"completion_tokens": completion_tokens,
|
|
159
|
+
"total_tokens": total_tokens,
|
|
160
|
+
"cost": round(total_cost, 6),
|
|
161
|
+
"raw_response": resp.model_dump(),
|
|
162
|
+
"model_name": model,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
choice = resp.choices[0]
|
|
166
|
+
text = choice.message.content or ""
|
|
167
|
+
stop_reason = choice.finish_reason
|
|
168
|
+
|
|
169
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
170
|
+
if choice.message.tool_calls:
|
|
171
|
+
for tc in choice.message.tool_calls:
|
|
172
|
+
try:
|
|
173
|
+
args = json.loads(tc.function.arguments)
|
|
174
|
+
except (json.JSONDecodeError, TypeError):
|
|
175
|
+
args = {}
|
|
176
|
+
tool_calls_out.append({
|
|
177
|
+
"id": tc.id,
|
|
178
|
+
"name": tc.function.name,
|
|
179
|
+
"arguments": args,
|
|
180
|
+
})
|
|
181
|
+
|
|
182
|
+
return {
|
|
183
|
+
"text": text,
|
|
184
|
+
"meta": meta,
|
|
185
|
+
"tool_calls": tool_calls_out,
|
|
186
|
+
"stop_reason": stop_reason,
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
# ------------------------------------------------------------------
|
|
190
|
+
# Streaming
|
|
191
|
+
# ------------------------------------------------------------------
|
|
192
|
+
|
|
193
|
+
async def generate_messages_stream(
|
|
194
|
+
self,
|
|
195
|
+
messages: list[dict[str, Any]],
|
|
196
|
+
options: dict[str, Any],
|
|
197
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
198
|
+
"""Yield response chunks via OpenAI streaming API."""
|
|
199
|
+
if self.client is None:
|
|
200
|
+
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
201
|
+
|
|
202
|
+
model = options.get("model", self.model)
|
|
203
|
+
model_config = self._get_model_config("openai", model)
|
|
204
|
+
tokens_param = model_config["tokens_param"]
|
|
205
|
+
supports_temperature = model_config["supports_temperature"]
|
|
206
|
+
|
|
207
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
208
|
+
|
|
209
|
+
kwargs: dict[str, Any] = {
|
|
210
|
+
"model": model,
|
|
211
|
+
"messages": messages,
|
|
212
|
+
"stream": True,
|
|
213
|
+
"stream_options": {"include_usage": True},
|
|
214
|
+
}
|
|
215
|
+
kwargs[tokens_param] = opts.get("max_tokens", 512)
|
|
216
|
+
|
|
217
|
+
if supports_temperature and "temperature" in opts:
|
|
218
|
+
kwargs["temperature"] = opts["temperature"]
|
|
219
|
+
|
|
220
|
+
stream = await self.client.chat.completions.create(**kwargs)
|
|
221
|
+
|
|
222
|
+
full_text = ""
|
|
223
|
+
prompt_tokens = 0
|
|
224
|
+
completion_tokens = 0
|
|
225
|
+
|
|
226
|
+
async for chunk in stream:
|
|
227
|
+
# Usage comes in the final chunk
|
|
228
|
+
if getattr(chunk, "usage", None):
|
|
229
|
+
prompt_tokens = chunk.usage.prompt_tokens or 0
|
|
230
|
+
completion_tokens = chunk.usage.completion_tokens or 0
|
|
231
|
+
|
|
232
|
+
if chunk.choices:
|
|
233
|
+
delta = chunk.choices[0].delta
|
|
234
|
+
content = getattr(delta, "content", None) or ""
|
|
235
|
+
if content:
|
|
236
|
+
full_text += content
|
|
237
|
+
yield {"type": "delta", "text": content}
|
|
238
|
+
|
|
239
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
240
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
241
|
+
|
|
242
|
+
yield {
|
|
243
|
+
"type": "done",
|
|
244
|
+
"text": full_text,
|
|
245
|
+
"meta": {
|
|
246
|
+
"prompt_tokens": prompt_tokens,
|
|
247
|
+
"completion_tokens": completion_tokens,
|
|
248
|
+
"total_tokens": total_tokens,
|
|
249
|
+
"cost": round(total_cost, 6),
|
|
250
|
+
"raw_response": {},
|
|
251
|
+
"model_name": model,
|
|
252
|
+
},
|
|
253
|
+
}
|
|
@@ -14,6 +14,7 @@ from .openrouter_driver import OpenRouterDriver
|
|
|
14
14
|
|
|
15
15
|
class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
16
16
|
supports_json_mode = True
|
|
17
|
+
supports_vision = True
|
|
17
18
|
|
|
18
19
|
MODEL_PRICING = OpenRouterDriver.MODEL_PRICING
|
|
19
20
|
|
|
@@ -31,19 +32,24 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
31
32
|
|
|
32
33
|
supports_messages = True
|
|
33
34
|
|
|
35
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
37
|
+
|
|
38
|
+
return _prepare_openai_vision_messages(messages)
|
|
39
|
+
|
|
34
40
|
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
35
41
|
messages = [{"role": "user", "content": prompt}]
|
|
36
42
|
return await self._do_generate(messages, options)
|
|
37
43
|
|
|
38
44
|
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
-
return await self._do_generate(messages, options)
|
|
45
|
+
return await self._do_generate(self._prepare_messages(messages), options)
|
|
40
46
|
|
|
41
47
|
async def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
48
|
model = options.get("model", self.model)
|
|
43
49
|
|
|
44
|
-
|
|
45
|
-
tokens_param =
|
|
46
|
-
supports_temperature =
|
|
50
|
+
model_config = self._get_model_config("openrouter", model)
|
|
51
|
+
tokens_param = model_config["tokens_param"]
|
|
52
|
+
supports_temperature = model_config["supports_temperature"]
|
|
47
53
|
|
|
48
54
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
49
55
|
|
|
@@ -87,7 +93,7 @@ class AsyncOpenRouterDriver(CostMixin, AsyncDriver):
|
|
|
87
93
|
"prompt_tokens": prompt_tokens,
|
|
88
94
|
"completion_tokens": completion_tokens,
|
|
89
95
|
"total_tokens": total_tokens,
|
|
90
|
-
"cost": total_cost,
|
|
96
|
+
"cost": round(total_cost, 6),
|
|
91
97
|
"raw_response": resp,
|
|
92
98
|
"model_name": model,
|
|
93
99
|
}
|
|
@@ -49,7 +49,11 @@ register_async_driver(
|
|
|
49
49
|
)
|
|
50
50
|
register_async_driver(
|
|
51
51
|
"lmstudio",
|
|
52
|
-
lambda model=None: AsyncLMStudioDriver(
|
|
52
|
+
lambda model=None: AsyncLMStudioDriver(
|
|
53
|
+
endpoint=settings.lmstudio_endpoint,
|
|
54
|
+
model=model or settings.lmstudio_model,
|
|
55
|
+
api_key=settings.lmstudio_api_key,
|
|
56
|
+
),
|
|
53
57
|
overwrite=True,
|
|
54
58
|
)
|
|
55
59
|
register_async_driver(
|
|
@@ -17,6 +17,7 @@ from ..driver import Driver
|
|
|
17
17
|
class AzureDriver(CostMixin, Driver):
|
|
18
18
|
supports_json_mode = True
|
|
19
19
|
supports_json_schema = True
|
|
20
|
+
supports_vision = True
|
|
20
21
|
|
|
21
22
|
# Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
|
|
22
23
|
MODEL_PRICING = {
|
|
@@ -90,21 +91,26 @@ class AzureDriver(CostMixin, Driver):
|
|
|
90
91
|
|
|
91
92
|
supports_messages = True
|
|
92
93
|
|
|
94
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
95
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
96
|
+
|
|
97
|
+
return _prepare_openai_vision_messages(messages)
|
|
98
|
+
|
|
93
99
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
94
100
|
messages = [{"role": "user", "content": prompt}]
|
|
95
101
|
return self._do_generate(messages, options)
|
|
96
102
|
|
|
97
103
|
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
98
|
-
return self._do_generate(messages, options)
|
|
104
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
99
105
|
|
|
100
106
|
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
101
107
|
if self.client is None:
|
|
102
108
|
raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
|
|
103
109
|
|
|
104
110
|
model = options.get("model", self.model)
|
|
105
|
-
|
|
106
|
-
tokens_param =
|
|
107
|
-
supports_temperature =
|
|
111
|
+
model_config = self._get_model_config("azure", model)
|
|
112
|
+
tokens_param = model_config["tokens_param"]
|
|
113
|
+
supports_temperature = model_config["supports_temperature"]
|
|
108
114
|
|
|
109
115
|
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
110
116
|
|
|
@@ -21,6 +21,7 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
21
21
|
supports_json_schema = True
|
|
22
22
|
supports_tool_use = True
|
|
23
23
|
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
24
25
|
|
|
25
26
|
# Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
|
|
26
27
|
MODEL_PRICING = {
|
|
@@ -57,12 +58,17 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
57
58
|
|
|
58
59
|
supports_messages = True
|
|
59
60
|
|
|
61
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
62
|
+
from .vision_helpers import _prepare_claude_vision_messages
|
|
63
|
+
|
|
64
|
+
return _prepare_claude_vision_messages(messages)
|
|
65
|
+
|
|
60
66
|
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
messages = [{"role": "user", "content": prompt}]
|
|
62
68
|
return self._do_generate(messages, options)
|
|
63
69
|
|
|
64
70
|
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
65
|
-
return self._do_generate(messages, options)
|
|
71
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
66
72
|
|
|
67
73
|
def _do_generate(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
68
74
|
if anthropic is None:
|
|
@@ -71,6 +77,13 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
71
77
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
72
78
|
model = options.get("model", self.model)
|
|
73
79
|
|
|
80
|
+
# Validate capabilities against models.dev metadata
|
|
81
|
+
self._validate_model_capabilities(
|
|
82
|
+
"claude",
|
|
83
|
+
model,
|
|
84
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
85
|
+
)
|
|
86
|
+
|
|
74
87
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
75
88
|
|
|
76
89
|
# Anthropic requires system messages as a top-level parameter
|
|
@@ -171,6 +184,9 @@ class ClaudeDriver(CostMixin, Driver):
|
|
|
171
184
|
|
|
172
185
|
opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
|
|
173
186
|
model = options.get("model", self.model)
|
|
187
|
+
|
|
188
|
+
self._validate_model_capabilities("claude", model, using_tool_use=True)
|
|
189
|
+
|
|
174
190
|
client = anthropic.Anthropic(api_key=self.api_key)
|
|
175
191
|
|
|
176
192
|
system_content, api_messages = self._extract_system_and_messages(messages)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Iterator
|
|
3
5
|
from typing import Any, Optional
|
|
4
6
|
|
|
5
7
|
import google.generativeai as genai
|
|
@@ -15,6 +17,9 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
15
17
|
|
|
16
18
|
supports_json_mode = True
|
|
17
19
|
supports_json_schema = True
|
|
20
|
+
supports_vision = True
|
|
21
|
+
supports_tool_use = True
|
|
22
|
+
supports_streaming = True
|
|
18
23
|
|
|
19
24
|
# Based on current Gemini pricing (as of 2025)
|
|
20
25
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
@@ -105,25 +110,62 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
105
110
|
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
106
111
|
return round(prompt_cost + completion_cost, 6)
|
|
107
112
|
|
|
113
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
114
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
115
|
+
usage = getattr(response, "usage_metadata", None)
|
|
116
|
+
if usage:
|
|
117
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
118
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
119
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
120
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
121
|
+
else:
|
|
122
|
+
# Fallback: estimate from character counts
|
|
123
|
+
total_prompt_chars = 0
|
|
124
|
+
for msg in messages:
|
|
125
|
+
c = msg.get("content", "")
|
|
126
|
+
if isinstance(c, str):
|
|
127
|
+
total_prompt_chars += len(c)
|
|
128
|
+
elif isinstance(c, list):
|
|
129
|
+
for part in c:
|
|
130
|
+
if isinstance(part, str):
|
|
131
|
+
total_prompt_chars += len(part)
|
|
132
|
+
elif isinstance(part, dict) and "text" in part:
|
|
133
|
+
total_prompt_chars += len(part["text"])
|
|
134
|
+
completion_chars = len(response.text) if response.text else 0
|
|
135
|
+
prompt_tokens = total_prompt_chars // 4
|
|
136
|
+
completion_tokens = completion_chars // 4
|
|
137
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
138
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"prompt_tokens": prompt_tokens,
|
|
142
|
+
"completion_tokens": completion_tokens,
|
|
143
|
+
"total_tokens": total_tokens,
|
|
144
|
+
"cost": round(cost, 6),
|
|
145
|
+
}
|
|
146
|
+
|
|
108
147
|
supports_messages = True
|
|
109
148
|
|
|
110
|
-
def
|
|
111
|
-
|
|
112
|
-
return self._do_generate(messages, options)
|
|
149
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
150
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
113
151
|
|
|
114
|
-
|
|
115
|
-
return self._do_generate(messages, options)
|
|
152
|
+
return _prepare_google_vision_messages(messages)
|
|
116
153
|
|
|
117
|
-
def
|
|
154
|
+
def _build_generation_args(
|
|
155
|
+
self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
|
|
156
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
157
|
+
"""Parse messages and options into (gen_input, kwargs) for generate_content.
|
|
158
|
+
|
|
159
|
+
Returns the content input and a dict of keyword arguments
|
|
160
|
+
(generation_config, safety_settings, model kwargs including system_instruction).
|
|
161
|
+
"""
|
|
118
162
|
merged_options = self.options.copy()
|
|
119
163
|
if options:
|
|
120
164
|
merged_options.update(options)
|
|
121
165
|
|
|
122
|
-
# Extract specific options for Google's API
|
|
123
166
|
generation_config = merged_options.get("generation_config", {})
|
|
124
167
|
safety_settings = merged_options.get("safety_settings", {})
|
|
125
168
|
|
|
126
|
-
# Map common options to generation_config if not present
|
|
127
169
|
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
128
170
|
generation_config["temperature"] = merged_options["temperature"]
|
|
129
171
|
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
@@ -147,44 +189,66 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
147
189
|
role = msg.get("role", "user")
|
|
148
190
|
content = msg.get("content", "")
|
|
149
191
|
if role == "system":
|
|
150
|
-
system_instruction = content
|
|
192
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
151
193
|
else:
|
|
152
|
-
# Gemini uses "model" for assistant role
|
|
153
194
|
gemini_role = "model" if role == "assistant" else "user"
|
|
154
|
-
|
|
195
|
+
if msg.get("_vision_parts"):
|
|
196
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
197
|
+
else:
|
|
198
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
199
|
+
|
|
200
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
201
|
+
if len(contents) == 1:
|
|
202
|
+
parts = contents[0]["parts"]
|
|
203
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
204
|
+
gen_input = parts[0]
|
|
205
|
+
else:
|
|
206
|
+
gen_input = contents
|
|
207
|
+
else:
|
|
208
|
+
gen_input = contents
|
|
209
|
+
|
|
210
|
+
model_kwargs: dict[str, Any] = {}
|
|
211
|
+
if system_instruction:
|
|
212
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
213
|
+
|
|
214
|
+
gen_kwargs: dict[str, Any] = {
|
|
215
|
+
"generation_config": generation_config if generation_config else None,
|
|
216
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
220
|
+
|
|
221
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
222
|
+
messages = [{"role": "user", "content": prompt}]
|
|
223
|
+
return self._do_generate(messages, options)
|
|
224
|
+
|
|
225
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
226
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
227
|
+
|
|
228
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
230
|
+
|
|
231
|
+
# Validate capabilities against models.dev metadata
|
|
232
|
+
self._validate_model_capabilities(
|
|
233
|
+
"google",
|
|
234
|
+
self.model,
|
|
235
|
+
using_json_schema=bool((options or {}).get("json_schema")),
|
|
236
|
+
)
|
|
155
237
|
|
|
156
238
|
try:
|
|
157
239
|
logger.debug(f"Initializing {self.model} for generation")
|
|
158
|
-
model_kwargs: dict[str, Any] = {}
|
|
159
|
-
if system_instruction:
|
|
160
|
-
model_kwargs["system_instruction"] = system_instruction
|
|
161
240
|
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
162
241
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# If single user message, pass content directly for backward compatibility
|
|
166
|
-
gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
|
|
167
|
-
response = model.generate_content(
|
|
168
|
-
gen_input,
|
|
169
|
-
generation_config=generation_config if generation_config else None,
|
|
170
|
-
safety_settings=safety_settings if safety_settings else None,
|
|
171
|
-
)
|
|
242
|
+
logger.debug(f"Generating with model {self.model}")
|
|
243
|
+
response = model.generate_content(gen_input, **gen_kwargs)
|
|
172
244
|
|
|
173
245
|
if not response.text:
|
|
174
246
|
raise ValueError("Empty response from model")
|
|
175
247
|
|
|
176
|
-
|
|
177
|
-
total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
|
|
178
|
-
completion_chars = len(response.text)
|
|
179
|
-
|
|
180
|
-
# Google uses character-based cost estimation
|
|
181
|
-
total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
248
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
182
249
|
|
|
183
250
|
meta = {
|
|
184
|
-
|
|
185
|
-
"completion_chars": completion_chars,
|
|
186
|
-
"total_chars": total_prompt_chars + completion_chars,
|
|
187
|
-
"cost": total_cost,
|
|
251
|
+
**usage_meta,
|
|
188
252
|
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
189
253
|
"model_name": self.model,
|
|
190
254
|
}
|
|
@@ -194,3 +258,133 @@ class GoogleDriver(CostMixin, Driver):
|
|
|
194
258
|
except Exception as e:
|
|
195
259
|
logger.error(f"Google API request failed: {e}")
|
|
196
260
|
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
261
|
+
|
|
262
|
+
# ------------------------------------------------------------------
|
|
263
|
+
# Tool use
|
|
264
|
+
# ------------------------------------------------------------------
|
|
265
|
+
|
|
266
|
+
def generate_messages_with_tools(
|
|
267
|
+
self,
|
|
268
|
+
messages: list[dict[str, Any]],
|
|
269
|
+
tools: list[dict[str, Any]],
|
|
270
|
+
options: dict[str, Any],
|
|
271
|
+
) -> dict[str, Any]:
|
|
272
|
+
"""Generate a response that may include tool/function calls."""
|
|
273
|
+
model = options.get("model", self.model)
|
|
274
|
+
self._validate_model_capabilities("google", model, using_tool_use=True)
|
|
275
|
+
|
|
276
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
277
|
+
self._prepare_messages(messages), options
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
281
|
+
function_declarations = []
|
|
282
|
+
for t in tools:
|
|
283
|
+
if "type" in t and t["type"] == "function":
|
|
284
|
+
fn = t["function"]
|
|
285
|
+
decl = {
|
|
286
|
+
"name": fn["name"],
|
|
287
|
+
"description": fn.get("description", ""),
|
|
288
|
+
}
|
|
289
|
+
params = fn.get("parameters")
|
|
290
|
+
if params:
|
|
291
|
+
decl["parameters"] = params
|
|
292
|
+
function_declarations.append(decl)
|
|
293
|
+
elif "name" in t:
|
|
294
|
+
# Already in a generic format
|
|
295
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
296
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
297
|
+
if params:
|
|
298
|
+
decl["parameters"] = params
|
|
299
|
+
function_declarations.append(decl)
|
|
300
|
+
|
|
301
|
+
try:
|
|
302
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
303
|
+
|
|
304
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
305
|
+
response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
306
|
+
|
|
307
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
308
|
+
meta = {
|
|
309
|
+
**usage_meta,
|
|
310
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
311
|
+
"model_name": self.model,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
text = ""
|
|
315
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
316
|
+
stop_reason = "stop"
|
|
317
|
+
|
|
318
|
+
for candidate in response.candidates:
|
|
319
|
+
for part in candidate.content.parts:
|
|
320
|
+
if hasattr(part, "text") and part.text:
|
|
321
|
+
text += part.text
|
|
322
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
323
|
+
fc = part.function_call
|
|
324
|
+
tool_calls_out.append({
|
|
325
|
+
"id": str(uuid.uuid4()),
|
|
326
|
+
"name": fc.name,
|
|
327
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
328
|
+
})
|
|
329
|
+
|
|
330
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
331
|
+
if finish_reason is not None:
|
|
332
|
+
# Map Gemini finish reasons to standard stop reasons
|
|
333
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
334
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
335
|
+
|
|
336
|
+
if tool_calls_out:
|
|
337
|
+
stop_reason = "tool_use"
|
|
338
|
+
|
|
339
|
+
return {
|
|
340
|
+
"text": text,
|
|
341
|
+
"meta": meta,
|
|
342
|
+
"tool_calls": tool_calls_out,
|
|
343
|
+
"stop_reason": stop_reason,
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
except Exception as e:
|
|
347
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
348
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
349
|
+
|
|
350
|
+
# ------------------------------------------------------------------
|
|
351
|
+
# Streaming
|
|
352
|
+
# ------------------------------------------------------------------
|
|
353
|
+
|
|
354
|
+
def generate_messages_stream(
|
|
355
|
+
self,
|
|
356
|
+
messages: list[dict[str, Any]],
|
|
357
|
+
options: dict[str, Any],
|
|
358
|
+
) -> Iterator[dict[str, Any]]:
|
|
359
|
+
"""Yield response chunks via Gemini streaming API."""
|
|
360
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
361
|
+
self._prepare_messages(messages), options
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
try:
|
|
365
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
366
|
+
response = model.generate_content(gen_input, stream=True, **gen_kwargs)
|
|
367
|
+
|
|
368
|
+
full_text = ""
|
|
369
|
+
for chunk in response:
|
|
370
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
371
|
+
if chunk_text:
|
|
372
|
+
full_text += chunk_text
|
|
373
|
+
yield {"type": "delta", "text": chunk_text}
|
|
374
|
+
|
|
375
|
+
# After iteration completes, resolve() has been called on the response
|
|
376
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
377
|
+
|
|
378
|
+
yield {
|
|
379
|
+
"type": "done",
|
|
380
|
+
"text": full_text,
|
|
381
|
+
"meta": {
|
|
382
|
+
**usage_meta,
|
|
383
|
+
"raw_response": {},
|
|
384
|
+
"model_name": self.model,
|
|
385
|
+
},
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
except Exception as e:
|
|
389
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
390
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|