prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -1,46 +1,60 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import logging
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
3
7
|
import google.generativeai as genai
|
|
4
|
-
|
|
8
|
+
|
|
9
|
+
from ..cost_mixin import CostMixin
|
|
5
10
|
from ..driver import Driver
|
|
6
11
|
|
|
7
12
|
logger = logging.getLogger(__name__)
|
|
8
13
|
|
|
9
14
|
|
|
10
|
-
class GoogleDriver(Driver):
|
|
15
|
+
class GoogleDriver(CostMixin, Driver):
|
|
11
16
|
"""Driver for Google's Generative AI API (Gemini)."""
|
|
12
17
|
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
supports_vision = True
|
|
21
|
+
supports_tool_use = True
|
|
22
|
+
supports_streaming = True
|
|
23
|
+
|
|
13
24
|
# Based on current Gemini pricing (as of 2025)
|
|
14
25
|
# Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
|
|
26
|
+
_PRICING_UNIT = 1_000_000
|
|
15
27
|
MODEL_PRICING = {
|
|
16
28
|
"gemini-1.5-pro": {
|
|
17
29
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
18
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
30
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
19
31
|
},
|
|
20
32
|
"gemini-1.5-pro-vision": {
|
|
21
33
|
"prompt": 0.00025, # $0.25/1M chars input
|
|
22
|
-
"completion": 0.0005 # $0.50/1M chars output
|
|
34
|
+
"completion": 0.0005, # $0.50/1M chars output
|
|
23
35
|
},
|
|
24
36
|
"gemini-2.5-pro": {
|
|
25
37
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
26
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
38
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
27
39
|
},
|
|
28
40
|
"gemini-2.5-flash": {
|
|
29
41
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
30
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
42
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
31
43
|
},
|
|
32
44
|
"gemini-2.5-flash-lite": {
|
|
33
45
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
34
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
46
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
35
47
|
},
|
|
36
|
-
|
|
48
|
+
"gemini-2.0-flash": {
|
|
37
49
|
"prompt": 0.0004, # $0.40/1M chars input
|
|
38
|
-
"completion": 0.0008 # $0.80/1M chars output
|
|
50
|
+
"completion": 0.0008, # $0.80/1M chars output
|
|
39
51
|
},
|
|
40
52
|
"gemini-2.0-flash-lite": {
|
|
41
53
|
"prompt": 0.0002, # $0.20/1M chars input
|
|
42
|
-
"completion": 0.0004 # $0.40/1M chars output
|
|
54
|
+
"completion": 0.0004, # $0.40/1M chars output
|
|
43
55
|
},
|
|
56
|
+
"gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
|
|
57
|
+
"gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
|
|
44
58
|
}
|
|
45
59
|
|
|
46
60
|
def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
|
|
@@ -55,13 +69,14 @@ class GoogleDriver(Driver):
|
|
|
55
69
|
raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
|
|
56
70
|
|
|
57
71
|
self.model = model
|
|
72
|
+
# Warn if model is not in pricing table but allow it (might be new)
|
|
58
73
|
if model not in self.MODEL_PRICING:
|
|
59
|
-
|
|
74
|
+
logger.warning(f"Model {model} not found in pricing table. Cost calculations will be 0.")
|
|
60
75
|
|
|
61
76
|
# Configure google.generativeai
|
|
62
77
|
genai.configure(api_key=self.api_key)
|
|
63
|
-
self.options:
|
|
64
|
-
|
|
78
|
+
self.options: dict[str, Any] = {}
|
|
79
|
+
|
|
65
80
|
# Validate connection and model availability
|
|
66
81
|
self._validate_connection()
|
|
67
82
|
|
|
@@ -75,48 +90,159 @@ class GoogleDriver(Driver):
|
|
|
75
90
|
logger.warning(f"Could not validate connection to Google API: {e}")
|
|
76
91
|
raise
|
|
77
92
|
|
|
78
|
-
def
|
|
79
|
-
"""
|
|
93
|
+
def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
|
|
94
|
+
"""Calculate cost from character counts.
|
|
80
95
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
96
|
+
Live rates use token-based pricing (estimate ~4 chars/token).
|
|
97
|
+
Hardcoded MODEL_PRICING uses per-1M-character rates.
|
|
98
|
+
"""
|
|
99
|
+
from ..model_rates import get_model_rates
|
|
100
|
+
|
|
101
|
+
live_rates = get_model_rates("google", self.model)
|
|
102
|
+
if live_rates:
|
|
103
|
+
est_prompt_tokens = prompt_chars / 4
|
|
104
|
+
est_completion_tokens = completion_chars / 4
|
|
105
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
106
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
107
|
+
else:
|
|
108
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
109
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
110
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
111
|
+
return round(prompt_cost + completion_cost, 6)
|
|
112
|
+
|
|
113
|
+
def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
114
|
+
"""Extract token counts from response, falling back to character estimation."""
|
|
115
|
+
usage = getattr(response, "usage_metadata", None)
|
|
116
|
+
if usage:
|
|
117
|
+
prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
|
|
118
|
+
completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
|
|
119
|
+
total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
|
|
120
|
+
cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
|
|
121
|
+
else:
|
|
122
|
+
# Fallback: estimate from character counts
|
|
123
|
+
total_prompt_chars = 0
|
|
124
|
+
for msg in messages:
|
|
125
|
+
c = msg.get("content", "")
|
|
126
|
+
if isinstance(c, str):
|
|
127
|
+
total_prompt_chars += len(c)
|
|
128
|
+
elif isinstance(c, list):
|
|
129
|
+
for part in c:
|
|
130
|
+
if isinstance(part, str):
|
|
131
|
+
total_prompt_chars += len(part)
|
|
132
|
+
elif isinstance(part, dict) and "text" in part:
|
|
133
|
+
total_prompt_chars += len(part["text"])
|
|
134
|
+
completion_chars = len(response.text) if response.text else 0
|
|
135
|
+
prompt_tokens = total_prompt_chars // 4
|
|
136
|
+
completion_tokens = completion_chars // 4
|
|
137
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
138
|
+
cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
|
|
139
|
+
|
|
140
|
+
return {
|
|
141
|
+
"prompt_tokens": prompt_tokens,
|
|
142
|
+
"completion_tokens": completion_tokens,
|
|
143
|
+
"total_tokens": total_tokens,
|
|
144
|
+
"cost": round(cost, 6),
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
supports_messages = True
|
|
148
|
+
|
|
149
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
150
|
+
from .vision_helpers import _prepare_google_vision_messages
|
|
84
151
|
|
|
85
|
-
|
|
86
|
-
|
|
152
|
+
return _prepare_google_vision_messages(messages)
|
|
153
|
+
|
|
154
|
+
def _build_generation_args(
|
|
155
|
+
self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
|
|
156
|
+
) -> tuple[Any, dict[str, Any]]:
|
|
157
|
+
"""Parse messages and options into (gen_input, kwargs) for generate_content.
|
|
158
|
+
|
|
159
|
+
Returns the content input and a dict of keyword arguments
|
|
160
|
+
(generation_config, safety_settings, model kwargs including system_instruction).
|
|
87
161
|
"""
|
|
88
162
|
merged_options = self.options.copy()
|
|
89
163
|
if options:
|
|
90
164
|
merged_options.update(options)
|
|
91
165
|
|
|
166
|
+
generation_config = merged_options.get("generation_config", {})
|
|
167
|
+
safety_settings = merged_options.get("safety_settings", {})
|
|
168
|
+
|
|
169
|
+
if "temperature" in merged_options and "temperature" not in generation_config:
|
|
170
|
+
generation_config["temperature"] = merged_options["temperature"]
|
|
171
|
+
if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
|
|
172
|
+
generation_config["max_output_tokens"] = merged_options["max_tokens"]
|
|
173
|
+
if "top_p" in merged_options and "top_p" not in generation_config:
|
|
174
|
+
generation_config["top_p"] = merged_options["top_p"]
|
|
175
|
+
if "top_k" in merged_options and "top_k" not in generation_config:
|
|
176
|
+
generation_config["top_k"] = merged_options["top_k"]
|
|
177
|
+
|
|
178
|
+
# Native JSON mode support
|
|
179
|
+
if merged_options.get("json_mode"):
|
|
180
|
+
generation_config["response_mime_type"] = "application/json"
|
|
181
|
+
json_schema = merged_options.get("json_schema")
|
|
182
|
+
if json_schema:
|
|
183
|
+
generation_config["response_schema"] = json_schema
|
|
184
|
+
|
|
185
|
+
# Convert messages to Gemini format
|
|
186
|
+
system_instruction = None
|
|
187
|
+
contents: list[dict[str, Any]] = []
|
|
188
|
+
for msg in messages:
|
|
189
|
+
role = msg.get("role", "user")
|
|
190
|
+
content = msg.get("content", "")
|
|
191
|
+
if role == "system":
|
|
192
|
+
system_instruction = content if isinstance(content, str) else str(content)
|
|
193
|
+
else:
|
|
194
|
+
gemini_role = "model" if role == "assistant" else "user"
|
|
195
|
+
if msg.get("_vision_parts"):
|
|
196
|
+
contents.append({"role": gemini_role, "parts": content})
|
|
197
|
+
else:
|
|
198
|
+
contents.append({"role": gemini_role, "parts": [content]})
|
|
199
|
+
|
|
200
|
+
# For a single message, unwrap only if it has exactly one string part
|
|
201
|
+
if len(contents) == 1:
|
|
202
|
+
parts = contents[0]["parts"]
|
|
203
|
+
if len(parts) == 1 and isinstance(parts[0], str):
|
|
204
|
+
gen_input = parts[0]
|
|
205
|
+
else:
|
|
206
|
+
gen_input = contents
|
|
207
|
+
else:
|
|
208
|
+
gen_input = contents
|
|
209
|
+
|
|
210
|
+
model_kwargs: dict[str, Any] = {}
|
|
211
|
+
if system_instruction:
|
|
212
|
+
model_kwargs["system_instruction"] = system_instruction
|
|
213
|
+
|
|
214
|
+
gen_kwargs: dict[str, Any] = {
|
|
215
|
+
"generation_config": generation_config if generation_config else None,
|
|
216
|
+
"safety_settings": safety_settings if safety_settings else None,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
return gen_input, gen_kwargs, model_kwargs
|
|
220
|
+
|
|
221
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
222
|
+
messages = [{"role": "user", "content": prompt}]
|
|
223
|
+
return self._do_generate(messages, options)
|
|
224
|
+
|
|
225
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
226
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
227
|
+
|
|
228
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
229
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
|
|
230
|
+
|
|
92
231
|
try:
|
|
93
232
|
logger.debug(f"Initializing {self.model} for generation")
|
|
94
|
-
model = genai.GenerativeModel(self.model)
|
|
233
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
234
|
+
|
|
235
|
+
logger.debug(f"Generating with model {self.model}")
|
|
236
|
+
response = model.generate_content(gen_input, **gen_kwargs)
|
|
95
237
|
|
|
96
|
-
# Generate response
|
|
97
|
-
logger.debug(f"Generating with prompt: {prompt}")
|
|
98
|
-
response = model.generate_content(prompt)
|
|
99
|
-
|
|
100
238
|
if not response.text:
|
|
101
239
|
raise ValueError("Empty response from model")
|
|
102
240
|
|
|
103
|
-
|
|
104
|
-
# Note: Using character count as proxy since Google charges per character
|
|
105
|
-
prompt_chars = len(prompt)
|
|
106
|
-
completion_chars = len(response.text)
|
|
107
|
-
|
|
108
|
-
# Calculate costs
|
|
109
|
-
model_pricing = self.MODEL_PRICING[self.model]
|
|
110
|
-
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
111
|
-
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
112
|
-
total_cost = prompt_cost + completion_cost
|
|
241
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
113
242
|
|
|
114
243
|
meta = {
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
"total_chars": prompt_chars + completion_chars,
|
|
118
|
-
"cost": total_cost,
|
|
119
|
-
"raw_response": response.prompt_feedback,
|
|
244
|
+
**usage_meta,
|
|
245
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
120
246
|
"model_name": self.model,
|
|
121
247
|
}
|
|
122
248
|
|
|
@@ -124,4 +250,131 @@ class GoogleDriver(Driver):
|
|
|
124
250
|
|
|
125
251
|
except Exception as e:
|
|
126
252
|
logger.error(f"Google API request failed: {e}")
|
|
127
|
-
raise RuntimeError(f"Google API request failed: {e}")
|
|
253
|
+
raise RuntimeError(f"Google API request failed: {e}") from e
|
|
254
|
+
|
|
255
|
+
# ------------------------------------------------------------------
|
|
256
|
+
# Tool use
|
|
257
|
+
# ------------------------------------------------------------------
|
|
258
|
+
|
|
259
|
+
def generate_messages_with_tools(
|
|
260
|
+
self,
|
|
261
|
+
messages: list[dict[str, Any]],
|
|
262
|
+
tools: list[dict[str, Any]],
|
|
263
|
+
options: dict[str, Any],
|
|
264
|
+
) -> dict[str, Any]:
|
|
265
|
+
"""Generate a response that may include tool/function calls."""
|
|
266
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
267
|
+
self._prepare_messages(messages), options
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Convert tools from OpenAI format to Gemini function declarations
|
|
271
|
+
function_declarations = []
|
|
272
|
+
for t in tools:
|
|
273
|
+
if "type" in t and t["type"] == "function":
|
|
274
|
+
fn = t["function"]
|
|
275
|
+
decl = {
|
|
276
|
+
"name": fn["name"],
|
|
277
|
+
"description": fn.get("description", ""),
|
|
278
|
+
}
|
|
279
|
+
params = fn.get("parameters")
|
|
280
|
+
if params:
|
|
281
|
+
decl["parameters"] = params
|
|
282
|
+
function_declarations.append(decl)
|
|
283
|
+
elif "name" in t:
|
|
284
|
+
# Already in a generic format
|
|
285
|
+
decl = {"name": t["name"], "description": t.get("description", "")}
|
|
286
|
+
params = t.get("parameters") or t.get("input_schema")
|
|
287
|
+
if params:
|
|
288
|
+
decl["parameters"] = params
|
|
289
|
+
function_declarations.append(decl)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
293
|
+
|
|
294
|
+
gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
|
|
295
|
+
response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
|
|
296
|
+
|
|
297
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
298
|
+
meta = {
|
|
299
|
+
**usage_meta,
|
|
300
|
+
"raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
|
|
301
|
+
"model_name": self.model,
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
text = ""
|
|
305
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
306
|
+
stop_reason = "stop"
|
|
307
|
+
|
|
308
|
+
for candidate in response.candidates:
|
|
309
|
+
for part in candidate.content.parts:
|
|
310
|
+
if hasattr(part, "text") and part.text:
|
|
311
|
+
text += part.text
|
|
312
|
+
if hasattr(part, "function_call") and part.function_call.name:
|
|
313
|
+
fc = part.function_call
|
|
314
|
+
tool_calls_out.append({
|
|
315
|
+
"id": str(uuid.uuid4()),
|
|
316
|
+
"name": fc.name,
|
|
317
|
+
"arguments": dict(fc.args) if fc.args else {},
|
|
318
|
+
})
|
|
319
|
+
|
|
320
|
+
finish_reason = getattr(candidate, "finish_reason", None)
|
|
321
|
+
if finish_reason is not None:
|
|
322
|
+
# Map Gemini finish reasons to standard stop reasons
|
|
323
|
+
reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
|
|
324
|
+
stop_reason = reason_map.get(finish_reason, "stop")
|
|
325
|
+
|
|
326
|
+
if tool_calls_out:
|
|
327
|
+
stop_reason = "tool_use"
|
|
328
|
+
|
|
329
|
+
return {
|
|
330
|
+
"text": text,
|
|
331
|
+
"meta": meta,
|
|
332
|
+
"tool_calls": tool_calls_out,
|
|
333
|
+
"stop_reason": stop_reason,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.error(f"Google API tool call request failed: {e}")
|
|
338
|
+
raise RuntimeError(f"Google API tool call request failed: {e}") from e
|
|
339
|
+
|
|
340
|
+
# ------------------------------------------------------------------
|
|
341
|
+
# Streaming
|
|
342
|
+
# ------------------------------------------------------------------
|
|
343
|
+
|
|
344
|
+
def generate_messages_stream(
|
|
345
|
+
self,
|
|
346
|
+
messages: list[dict[str, Any]],
|
|
347
|
+
options: dict[str, Any],
|
|
348
|
+
) -> Iterator[dict[str, Any]]:
|
|
349
|
+
"""Yield response chunks via Gemini streaming API."""
|
|
350
|
+
gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
|
|
351
|
+
self._prepare_messages(messages), options
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
model = genai.GenerativeModel(self.model, **model_kwargs)
|
|
356
|
+
response = model.generate_content(gen_input, stream=True, **gen_kwargs)
|
|
357
|
+
|
|
358
|
+
full_text = ""
|
|
359
|
+
for chunk in response:
|
|
360
|
+
chunk_text = getattr(chunk, "text", None) or ""
|
|
361
|
+
if chunk_text:
|
|
362
|
+
full_text += chunk_text
|
|
363
|
+
yield {"type": "delta", "text": chunk_text}
|
|
364
|
+
|
|
365
|
+
# After iteration completes, resolve() has been called on the response
|
|
366
|
+
usage_meta = self._extract_usage_metadata(response, messages)
|
|
367
|
+
|
|
368
|
+
yield {
|
|
369
|
+
"type": "done",
|
|
370
|
+
"text": full_text,
|
|
371
|
+
"meta": {
|
|
372
|
+
**usage_meta,
|
|
373
|
+
"raw_response": {},
|
|
374
|
+
"model_name": self.model,
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.error(f"Google API streaming request failed: {e}")
|
|
380
|
+
raise RuntimeError(f"Google API streaming request failed: {e}") from e
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -1,15 +1,22 @@
|
|
|
1
1
|
"""xAI Grok driver.
|
|
2
2
|
Requires the `requests` package. Uses GROK_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
import requests
|
|
7
9
|
|
|
10
|
+
from ..cost_mixin import CostMixin
|
|
8
11
|
from ..driver import Driver
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class GrokDriver(Driver):
|
|
14
|
+
class GrokDriver(CostMixin, Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
supports_vision = True
|
|
17
|
+
|
|
12
18
|
# Pricing per 1M tokens based on xAI's documentation
|
|
19
|
+
_PRICING_UNIT = 1_000_000
|
|
13
20
|
MODEL_PRICING = {
|
|
14
21
|
"grok-code-fast-1": {
|
|
15
22
|
"prompt": 0.20,
|
|
@@ -72,19 +79,21 @@ class GrokDriver(Driver):
|
|
|
72
79
|
self.model = model
|
|
73
80
|
self.api_base = "https://api.x.ai/v1"
|
|
74
81
|
|
|
75
|
-
|
|
76
|
-
"""Generate completion using Grok API.
|
|
82
|
+
supports_messages = True
|
|
77
83
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
84
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
85
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
86
|
+
|
|
87
|
+
return _prepare_openai_vision_messages(messages)
|
|
88
|
+
|
|
89
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
90
|
+
messages = [{"role": "user", "content": prompt}]
|
|
91
|
+
return self._do_generate(messages, options)
|
|
92
|
+
|
|
93
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
94
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
95
|
+
|
|
96
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
97
|
if not self.api_key:
|
|
89
98
|
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
90
99
|
|
|
@@ -101,7 +110,7 @@ class GrokDriver(Driver):
|
|
|
101
110
|
# Base request payload
|
|
102
111
|
payload = {
|
|
103
112
|
"model": model,
|
|
104
|
-
"messages":
|
|
113
|
+
"messages": messages,
|
|
105
114
|
}
|
|
106
115
|
|
|
107
116
|
# Add token limit with correct parameter name
|
|
@@ -111,33 +120,27 @@ class GrokDriver(Driver):
|
|
|
111
120
|
if supports_temperature and "temperature" in opts:
|
|
112
121
|
payload["temperature"] = opts["temperature"]
|
|
113
122
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
|
|
123
|
+
# Native JSON mode support
|
|
124
|
+
if options.get("json_mode"):
|
|
125
|
+
payload["response_format"] = {"type": "json_object"}
|
|
126
|
+
|
|
127
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
118
128
|
|
|
119
129
|
try:
|
|
120
|
-
response = requests.post(
|
|
121
|
-
f"{self.api_base}/chat/completions",
|
|
122
|
-
headers=headers,
|
|
123
|
-
json=payload
|
|
124
|
-
)
|
|
130
|
+
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
|
|
125
131
|
response.raise_for_status()
|
|
126
132
|
resp = response.json()
|
|
127
133
|
except requests.exceptions.RequestException as e:
|
|
128
|
-
raise RuntimeError(f"Grok API request failed: {
|
|
134
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
129
135
|
|
|
130
136
|
# Extract usage info
|
|
131
137
|
usage = resp.get("usage", {})
|
|
132
138
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
133
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
|
139
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
134
140
|
total_tokens = usage.get("total_tokens", 0)
|
|
135
141
|
|
|
136
|
-
# Calculate cost
|
|
137
|
-
|
|
138
|
-
prompt_cost = (prompt_tokens / 1000000) * model_pricing["prompt"]
|
|
139
|
-
completion_cost = (completion_tokens / 1000000) * model_pricing["completion"]
|
|
140
|
-
total_cost = prompt_cost + completion_cost
|
|
142
|
+
# Calculate cost via shared mixin
|
|
143
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
141
144
|
|
|
142
145
|
# Standardized meta object
|
|
143
146
|
meta = {
|
|
@@ -150,4 +153,4 @@ class GrokDriver(Driver):
|
|
|
150
153
|
}
|
|
151
154
|
|
|
152
155
|
text = resp["choices"][0]["message"]["content"]
|
|
153
|
-
return {"text": text, "meta": meta}
|
|
156
|
+
return {"text": text, "meta": meta}
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -1,18 +1,23 @@
|
|
|
1
1
|
"""Groq driver for prompture.
|
|
2
2
|
Requires the `groq` package. Uses GROQ_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
try:
|
|
8
9
|
import groq
|
|
9
10
|
except Exception:
|
|
10
11
|
groq = None
|
|
11
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
12
14
|
from ..driver import Driver
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class GroqDriver(Driver):
|
|
17
|
+
class GroqDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
supports_vision = True
|
|
20
|
+
|
|
16
21
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
17
22
|
# Each model entry defines token parameters and temperature support
|
|
18
23
|
MODEL_PRICING = {
|
|
@@ -32,7 +37,7 @@ class GroqDriver(Driver):
|
|
|
32
37
|
|
|
33
38
|
def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
|
|
34
39
|
"""Initialize Groq driver.
|
|
35
|
-
|
|
40
|
+
|
|
36
41
|
Args:
|
|
37
42
|
api_key: Groq API key (defaults to GROQ_API_KEY env var)
|
|
38
43
|
model: Model to use (defaults to llama2-70b-4096)
|
|
@@ -44,20 +49,21 @@ class GroqDriver(Driver):
|
|
|
44
49
|
else:
|
|
45
50
|
self.client = None
|
|
46
51
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
52
|
+
supports_messages = True
|
|
53
|
+
|
|
54
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
55
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
56
|
+
|
|
57
|
+
return _prepare_openai_vision_messages(messages)
|
|
58
|
+
|
|
59
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
60
|
+
messages = [{"role": "user", "content": prompt}]
|
|
61
|
+
return self._do_generate(messages, options)
|
|
62
|
+
|
|
63
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
64
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
65
|
+
|
|
66
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
67
|
if self.client is None:
|
|
62
68
|
raise RuntimeError("groq package is not installed")
|
|
63
69
|
|
|
@@ -74,7 +80,7 @@ class GroqDriver(Driver):
|
|
|
74
80
|
# Base kwargs for API call
|
|
75
81
|
kwargs = {
|
|
76
82
|
"model": model,
|
|
77
|
-
"messages":
|
|
83
|
+
"messages": messages,
|
|
78
84
|
}
|
|
79
85
|
|
|
80
86
|
# Set token limit with correct parameter name
|
|
@@ -84,23 +90,24 @@ class GroqDriver(Driver):
|
|
|
84
90
|
if supports_temperature and "temperature" in opts:
|
|
85
91
|
kwargs["temperature"] = opts["temperature"]
|
|
86
92
|
|
|
93
|
+
# Native JSON mode support
|
|
94
|
+
if options.get("json_mode"):
|
|
95
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
96
|
+
|
|
87
97
|
try:
|
|
88
98
|
resp = self.client.chat.completions.create(**kwargs)
|
|
89
|
-
except Exception
|
|
99
|
+
except Exception:
|
|
90
100
|
# Re-raise any Groq API errors
|
|
91
101
|
raise
|
|
92
102
|
|
|
93
103
|
# Extract usage statistics
|
|
94
104
|
usage = getattr(resp, "usage", None)
|
|
95
105
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
96
|
-
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
106
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
97
107
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
98
108
|
|
|
99
|
-
# Calculate
|
|
100
|
-
|
|
101
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
102
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
103
|
-
total_cost = prompt_cost + completion_cost
|
|
109
|
+
# Calculate cost via shared mixin
|
|
110
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
104
111
|
|
|
105
112
|
# Standard metadata object
|
|
106
113
|
meta = {
|
|
@@ -114,4 +121,4 @@ class GroqDriver(Driver):
|
|
|
114
121
|
|
|
115
122
|
# Extract generated text
|
|
116
123
|
text = resp.choices[0].message.content
|
|
117
|
-
return {"text": text, "meta": meta}
|
|
124
|
+
return {"text": text, "meta": meta}
|