prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -1,46 +1,60 @@
1
- import os
2
1
  import logging
2
+ import os
3
+ import uuid
4
+ from collections.abc import Iterator
5
+ from typing import Any, Optional
6
+
3
7
  import google.generativeai as genai
4
- from typing import Any, Dict
8
+
9
+ from ..cost_mixin import CostMixin
5
10
  from ..driver import Driver
6
11
 
7
12
  logger = logging.getLogger(__name__)
8
13
 
9
14
 
10
- class GoogleDriver(Driver):
15
+ class GoogleDriver(CostMixin, Driver):
11
16
  """Driver for Google's Generative AI API (Gemini)."""
12
17
 
18
+ supports_json_mode = True
19
+ supports_json_schema = True
20
+ supports_vision = True
21
+ supports_tool_use = True
22
+ supports_streaming = True
23
+
13
24
  # Based on current Gemini pricing (as of 2025)
14
25
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
26
+ _PRICING_UNIT = 1_000_000
15
27
  MODEL_PRICING = {
16
28
  "gemini-1.5-pro": {
17
29
  "prompt": 0.00025, # $0.25/1M chars input
18
- "completion": 0.0005 # $0.50/1M chars output
30
+ "completion": 0.0005, # $0.50/1M chars output
19
31
  },
20
32
  "gemini-1.5-pro-vision": {
21
33
  "prompt": 0.00025, # $0.25/1M chars input
22
- "completion": 0.0005 # $0.50/1M chars output
34
+ "completion": 0.0005, # $0.50/1M chars output
23
35
  },
24
36
  "gemini-2.5-pro": {
25
37
  "prompt": 0.0004, # $0.40/1M chars input
26
- "completion": 0.0008 # $0.80/1M chars output
38
+ "completion": 0.0008, # $0.80/1M chars output
27
39
  },
28
40
  "gemini-2.5-flash": {
29
41
  "prompt": 0.0004, # $0.40/1M chars input
30
- "completion": 0.0008 # $0.80/1M chars output
42
+ "completion": 0.0008, # $0.80/1M chars output
31
43
  },
32
44
  "gemini-2.5-flash-lite": {
33
45
  "prompt": 0.0002, # $0.20/1M chars input
34
- "completion": 0.0004 # $0.40/1M chars output
46
+ "completion": 0.0004, # $0.40/1M chars output
35
47
  },
36
- "gemini-2.0-flash": {
48
+ "gemini-2.0-flash": {
37
49
  "prompt": 0.0004, # $0.40/1M chars input
38
- "completion": 0.0008 # $0.80/1M chars output
50
+ "completion": 0.0008, # $0.80/1M chars output
39
51
  },
40
52
  "gemini-2.0-flash-lite": {
41
53
  "prompt": 0.0002, # $0.20/1M chars input
42
- "completion": 0.0004 # $0.40/1M chars output
54
+ "completion": 0.0004, # $0.40/1M chars output
43
55
  },
56
+ "gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
57
+ "gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
44
58
  }
45
59
 
46
60
  def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
@@ -55,13 +69,14 @@ class GoogleDriver(Driver):
55
69
  raise ValueError("Google API key not found. Set GOOGLE_API_KEY env var or pass api_key to constructor")
56
70
 
57
71
  self.model = model
72
+ # Warn if model is not in pricing table but allow it (might be new)
58
73
  if model not in self.MODEL_PRICING:
59
- raise ValueError(f"Unsupported model: {model}. Must be one of: {list(self.MODEL_PRICING.keys())}")
74
+ logger.warning(f"Model {model} not found in pricing table. Cost calculations will be 0.")
60
75
 
61
76
  # Configure google.generativeai
62
77
  genai.configure(api_key=self.api_key)
63
- self.options: Dict[str, Any] = {}
64
-
78
+ self.options: dict[str, Any] = {}
79
+
65
80
  # Validate connection and model availability
66
81
  self._validate_connection()
67
82
 
@@ -75,48 +90,159 @@ class GoogleDriver(Driver):
75
90
  logger.warning(f"Could not validate connection to Google API: {e}")
76
91
  raise
77
92
 
78
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
79
- """Generate text using Google's Generative AI.
93
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
94
+ """Calculate cost from character counts.
80
95
 
81
- Args:
82
- prompt: The input prompt
83
- options: Additional options to pass to the model
96
+ Live rates use token-based pricing (estimate ~4 chars/token).
97
+ Hardcoded MODEL_PRICING uses per-1M-character rates.
98
+ """
99
+ from ..model_rates import get_model_rates
100
+
101
+ live_rates = get_model_rates("google", self.model)
102
+ if live_rates:
103
+ est_prompt_tokens = prompt_chars / 4
104
+ est_completion_tokens = completion_chars / 4
105
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
106
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
107
+ else:
108
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
109
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
110
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
111
+ return round(prompt_cost + completion_cost, 6)
112
+
113
+ def _extract_usage_metadata(self, response: Any, messages: list[dict[str, Any]]) -> dict[str, Any]:
114
+ """Extract token counts from response, falling back to character estimation."""
115
+ usage = getattr(response, "usage_metadata", None)
116
+ if usage:
117
+ prompt_tokens = getattr(usage, "prompt_token_count", 0) or 0
118
+ completion_tokens = getattr(usage, "candidates_token_count", 0) or 0
119
+ total_tokens = getattr(usage, "total_token_count", 0) or (prompt_tokens + completion_tokens)
120
+ cost = self._calculate_cost("google", self.model, prompt_tokens, completion_tokens)
121
+ else:
122
+ # Fallback: estimate from character counts
123
+ total_prompt_chars = 0
124
+ for msg in messages:
125
+ c = msg.get("content", "")
126
+ if isinstance(c, str):
127
+ total_prompt_chars += len(c)
128
+ elif isinstance(c, list):
129
+ for part in c:
130
+ if isinstance(part, str):
131
+ total_prompt_chars += len(part)
132
+ elif isinstance(part, dict) and "text" in part:
133
+ total_prompt_chars += len(part["text"])
134
+ completion_chars = len(response.text) if response.text else 0
135
+ prompt_tokens = total_prompt_chars // 4
136
+ completion_tokens = completion_chars // 4
137
+ total_tokens = prompt_tokens + completion_tokens
138
+ cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
139
+
140
+ return {
141
+ "prompt_tokens": prompt_tokens,
142
+ "completion_tokens": completion_tokens,
143
+ "total_tokens": total_tokens,
144
+ "cost": round(cost, 6),
145
+ }
146
+
147
+ supports_messages = True
148
+
149
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
150
+ from .vision_helpers import _prepare_google_vision_messages
84
151
 
85
- Returns:
86
- Dict containing generated text and metadata
152
+ return _prepare_google_vision_messages(messages)
153
+
154
+ def _build_generation_args(
155
+ self, messages: list[dict[str, Any]], options: Optional[dict[str, Any]] = None
156
+ ) -> tuple[Any, dict[str, Any]]:
157
+ """Parse messages and options into (gen_input, kwargs) for generate_content.
158
+
159
+ Returns the content input and a dict of keyword arguments
160
+ (generation_config, safety_settings, model kwargs including system_instruction).
87
161
  """
88
162
  merged_options = self.options.copy()
89
163
  if options:
90
164
  merged_options.update(options)
91
165
 
166
+ generation_config = merged_options.get("generation_config", {})
167
+ safety_settings = merged_options.get("safety_settings", {})
168
+
169
+ if "temperature" in merged_options and "temperature" not in generation_config:
170
+ generation_config["temperature"] = merged_options["temperature"]
171
+ if "max_tokens" in merged_options and "max_output_tokens" not in generation_config:
172
+ generation_config["max_output_tokens"] = merged_options["max_tokens"]
173
+ if "top_p" in merged_options and "top_p" not in generation_config:
174
+ generation_config["top_p"] = merged_options["top_p"]
175
+ if "top_k" in merged_options and "top_k" not in generation_config:
176
+ generation_config["top_k"] = merged_options["top_k"]
177
+
178
+ # Native JSON mode support
179
+ if merged_options.get("json_mode"):
180
+ generation_config["response_mime_type"] = "application/json"
181
+ json_schema = merged_options.get("json_schema")
182
+ if json_schema:
183
+ generation_config["response_schema"] = json_schema
184
+
185
+ # Convert messages to Gemini format
186
+ system_instruction = None
187
+ contents: list[dict[str, Any]] = []
188
+ for msg in messages:
189
+ role = msg.get("role", "user")
190
+ content = msg.get("content", "")
191
+ if role == "system":
192
+ system_instruction = content if isinstance(content, str) else str(content)
193
+ else:
194
+ gemini_role = "model" if role == "assistant" else "user"
195
+ if msg.get("_vision_parts"):
196
+ contents.append({"role": gemini_role, "parts": content})
197
+ else:
198
+ contents.append({"role": gemini_role, "parts": [content]})
199
+
200
+ # For a single message, unwrap only if it has exactly one string part
201
+ if len(contents) == 1:
202
+ parts = contents[0]["parts"]
203
+ if len(parts) == 1 and isinstance(parts[0], str):
204
+ gen_input = parts[0]
205
+ else:
206
+ gen_input = contents
207
+ else:
208
+ gen_input = contents
209
+
210
+ model_kwargs: dict[str, Any] = {}
211
+ if system_instruction:
212
+ model_kwargs["system_instruction"] = system_instruction
213
+
214
+ gen_kwargs: dict[str, Any] = {
215
+ "generation_config": generation_config if generation_config else None,
216
+ "safety_settings": safety_settings if safety_settings else None,
217
+ }
218
+
219
+ return gen_input, gen_kwargs, model_kwargs
220
+
221
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
222
+ messages = [{"role": "user", "content": prompt}]
223
+ return self._do_generate(messages, options)
224
+
225
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
226
+ return self._do_generate(self._prepare_messages(messages), options)
227
+
228
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
229
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(messages, options)
230
+
92
231
  try:
93
232
  logger.debug(f"Initializing {self.model} for generation")
94
- model = genai.GenerativeModel(self.model)
233
+ model = genai.GenerativeModel(self.model, **model_kwargs)
234
+
235
+ logger.debug(f"Generating with model {self.model}")
236
+ response = model.generate_content(gen_input, **gen_kwargs)
95
237
 
96
- # Generate response
97
- logger.debug(f"Generating with prompt: {prompt}")
98
- response = model.generate_content(prompt)
99
-
100
238
  if not response.text:
101
239
  raise ValueError("Empty response from model")
102
240
 
103
- # Calculate token usage and cost
104
- # Note: Using character count as proxy since Google charges per character
105
- prompt_chars = len(prompt)
106
- completion_chars = len(response.text)
107
-
108
- # Calculate costs
109
- model_pricing = self.MODEL_PRICING[self.model]
110
- prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
111
- completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
112
- total_cost = prompt_cost + completion_cost
241
+ usage_meta = self._extract_usage_metadata(response, messages)
113
242
 
114
243
  meta = {
115
- "prompt_chars": prompt_chars,
116
- "completion_chars": completion_chars,
117
- "total_chars": prompt_chars + completion_chars,
118
- "cost": total_cost,
119
- "raw_response": response.prompt_feedback,
244
+ **usage_meta,
245
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
120
246
  "model_name": self.model,
121
247
  }
122
248
 
@@ -124,4 +250,131 @@ class GoogleDriver(Driver):
124
250
 
125
251
  except Exception as e:
126
252
  logger.error(f"Google API request failed: {e}")
127
- raise RuntimeError(f"Google API request failed: {e}")
253
+ raise RuntimeError(f"Google API request failed: {e}") from e
254
+
255
+ # ------------------------------------------------------------------
256
+ # Tool use
257
+ # ------------------------------------------------------------------
258
+
259
+ def generate_messages_with_tools(
260
+ self,
261
+ messages: list[dict[str, Any]],
262
+ tools: list[dict[str, Any]],
263
+ options: dict[str, Any],
264
+ ) -> dict[str, Any]:
265
+ """Generate a response that may include tool/function calls."""
266
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
267
+ self._prepare_messages(messages), options
268
+ )
269
+
270
+ # Convert tools from OpenAI format to Gemini function declarations
271
+ function_declarations = []
272
+ for t in tools:
273
+ if "type" in t and t["type"] == "function":
274
+ fn = t["function"]
275
+ decl = {
276
+ "name": fn["name"],
277
+ "description": fn.get("description", ""),
278
+ }
279
+ params = fn.get("parameters")
280
+ if params:
281
+ decl["parameters"] = params
282
+ function_declarations.append(decl)
283
+ elif "name" in t:
284
+ # Already in a generic format
285
+ decl = {"name": t["name"], "description": t.get("description", "")}
286
+ params = t.get("parameters") or t.get("input_schema")
287
+ if params:
288
+ decl["parameters"] = params
289
+ function_declarations.append(decl)
290
+
291
+ try:
292
+ model = genai.GenerativeModel(self.model, **model_kwargs)
293
+
294
+ gemini_tools = [genai.types.Tool(function_declarations=function_declarations)]
295
+ response = model.generate_content(gen_input, tools=gemini_tools, **gen_kwargs)
296
+
297
+ usage_meta = self._extract_usage_metadata(response, messages)
298
+ meta = {
299
+ **usage_meta,
300
+ "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
301
+ "model_name": self.model,
302
+ }
303
+
304
+ text = ""
305
+ tool_calls_out: list[dict[str, Any]] = []
306
+ stop_reason = "stop"
307
+
308
+ for candidate in response.candidates:
309
+ for part in candidate.content.parts:
310
+ if hasattr(part, "text") and part.text:
311
+ text += part.text
312
+ if hasattr(part, "function_call") and part.function_call.name:
313
+ fc = part.function_call
314
+ tool_calls_out.append({
315
+ "id": str(uuid.uuid4()),
316
+ "name": fc.name,
317
+ "arguments": dict(fc.args) if fc.args else {},
318
+ })
319
+
320
+ finish_reason = getattr(candidate, "finish_reason", None)
321
+ if finish_reason is not None:
322
+ # Map Gemini finish reasons to standard stop reasons
323
+ reason_map = {1: "stop", 2: "max_tokens", 3: "safety", 4: "recitation", 5: "other"}
324
+ stop_reason = reason_map.get(finish_reason, "stop")
325
+
326
+ if tool_calls_out:
327
+ stop_reason = "tool_use"
328
+
329
+ return {
330
+ "text": text,
331
+ "meta": meta,
332
+ "tool_calls": tool_calls_out,
333
+ "stop_reason": stop_reason,
334
+ }
335
+
336
+ except Exception as e:
337
+ logger.error(f"Google API tool call request failed: {e}")
338
+ raise RuntimeError(f"Google API tool call request failed: {e}") from e
339
+
340
+ # ------------------------------------------------------------------
341
+ # Streaming
342
+ # ------------------------------------------------------------------
343
+
344
+ def generate_messages_stream(
345
+ self,
346
+ messages: list[dict[str, Any]],
347
+ options: dict[str, Any],
348
+ ) -> Iterator[dict[str, Any]]:
349
+ """Yield response chunks via Gemini streaming API."""
350
+ gen_input, gen_kwargs, model_kwargs = self._build_generation_args(
351
+ self._prepare_messages(messages), options
352
+ )
353
+
354
+ try:
355
+ model = genai.GenerativeModel(self.model, **model_kwargs)
356
+ response = model.generate_content(gen_input, stream=True, **gen_kwargs)
357
+
358
+ full_text = ""
359
+ for chunk in response:
360
+ chunk_text = getattr(chunk, "text", None) or ""
361
+ if chunk_text:
362
+ full_text += chunk_text
363
+ yield {"type": "delta", "text": chunk_text}
364
+
365
+ # After iteration completes, resolve() has been called on the response
366
+ usage_meta = self._extract_usage_metadata(response, messages)
367
+
368
+ yield {
369
+ "type": "done",
370
+ "text": full_text,
371
+ "meta": {
372
+ **usage_meta,
373
+ "raw_response": {},
374
+ "model_name": self.model,
375
+ },
376
+ }
377
+
378
+ except Exception as e:
379
+ logger.error(f"Google API streaming request failed: {e}")
380
+ raise RuntimeError(f"Google API streaming request failed: {e}") from e
@@ -1,15 +1,22 @@
1
1
  """xAI Grok driver.
2
2
  Requires the `requests` package. Uses GROK_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  import requests
7
9
 
10
+ from ..cost_mixin import CostMixin
8
11
  from ..driver import Driver
9
12
 
10
13
 
11
- class GrokDriver(Driver):
14
+ class GrokDriver(CostMixin, Driver):
15
+ supports_json_mode = True
16
+ supports_vision = True
17
+
12
18
  # Pricing per 1M tokens based on xAI's documentation
19
+ _PRICING_UNIT = 1_000_000
13
20
  MODEL_PRICING = {
14
21
  "grok-code-fast-1": {
15
22
  "prompt": 0.20,
@@ -72,19 +79,21 @@ class GrokDriver(Driver):
72
79
  self.model = model
73
80
  self.api_base = "https://api.x.ai/v1"
74
81
 
75
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
76
- """Generate completion using Grok API.
82
+ supports_messages = True
77
83
 
78
- Args:
79
- prompt: Input prompt
80
- options: Generation options
81
-
82
- Returns:
83
- Dict containing generated text and metadata
84
-
85
- Raises:
86
- RuntimeError: If API key is missing or request fails
87
- """
84
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
85
+ from .vision_helpers import _prepare_openai_vision_messages
86
+
87
+ return _prepare_openai_vision_messages(messages)
88
+
89
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
90
+ messages = [{"role": "user", "content": prompt}]
91
+ return self._do_generate(messages, options)
92
+
93
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
94
+ return self._do_generate(self._prepare_messages(messages), options)
95
+
96
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
97
  if not self.api_key:
89
98
  raise RuntimeError("GROK_API_KEY environment variable is required")
90
99
 
@@ -101,7 +110,7 @@ class GrokDriver(Driver):
101
110
  # Base request payload
102
111
  payload = {
103
112
  "model": model,
104
- "messages": [{"role": "user", "content": prompt}],
113
+ "messages": messages,
105
114
  }
106
115
 
107
116
  # Add token limit with correct parameter name
@@ -111,33 +120,27 @@ class GrokDriver(Driver):
111
120
  if supports_temperature and "temperature" in opts:
112
121
  payload["temperature"] = opts["temperature"]
113
122
 
114
- headers = {
115
- "Authorization": f"Bearer {self.api_key}",
116
- "Content-Type": "application/json"
117
- }
123
+ # Native JSON mode support
124
+ if options.get("json_mode"):
125
+ payload["response_format"] = {"type": "json_object"}
126
+
127
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
118
128
 
119
129
  try:
120
- response = requests.post(
121
- f"{self.api_base}/chat/completions",
122
- headers=headers,
123
- json=payload
124
- )
130
+ response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
125
131
  response.raise_for_status()
126
132
  resp = response.json()
127
133
  except requests.exceptions.RequestException as e:
128
- raise RuntimeError(f"Grok API request failed: {str(e)}")
134
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
129
135
 
130
136
  # Extract usage info
131
137
  usage = resp.get("usage", {})
132
138
  prompt_tokens = usage.get("prompt_tokens", 0)
133
- completion_tokens = usage.get("completion_tokens", 0)
139
+ completion_tokens = usage.get("completion_tokens", 0)
134
140
  total_tokens = usage.get("total_tokens", 0)
135
141
 
136
- # Calculate cost
137
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
138
- prompt_cost = (prompt_tokens / 1000000) * model_pricing["prompt"]
139
- completion_cost = (completion_tokens / 1000000) * model_pricing["completion"]
140
- total_cost = prompt_cost + completion_cost
142
+ # Calculate cost via shared mixin
143
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
141
144
 
142
145
  # Standardized meta object
143
146
  meta = {
@@ -150,4 +153,4 @@ class GrokDriver(Driver):
150
153
  }
151
154
 
152
155
  text = resp["choices"][0]["message"]["content"]
153
- return {"text": text, "meta": meta}
156
+ return {"text": text, "meta": meta}
@@ -1,18 +1,23 @@
1
1
  """Groq driver for prompture.
2
2
  Requires the `groq` package. Uses GROQ_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
6
7
 
7
8
  try:
8
9
  import groq
9
10
  except Exception:
10
11
  groq = None
11
12
 
13
+ from ..cost_mixin import CostMixin
12
14
  from ..driver import Driver
13
15
 
14
16
 
15
- class GroqDriver(Driver):
17
+ class GroqDriver(CostMixin, Driver):
18
+ supports_json_mode = True
19
+ supports_vision = True
20
+
16
21
  # Approximate pricing per 1K tokens (to be updated with official pricing)
17
22
  # Each model entry defines token parameters and temperature support
18
23
  MODEL_PRICING = {
@@ -32,7 +37,7 @@ class GroqDriver(Driver):
32
37
 
33
38
  def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
34
39
  """Initialize Groq driver.
35
-
40
+
36
41
  Args:
37
42
  api_key: Groq API key (defaults to GROQ_API_KEY env var)
38
43
  model: Model to use (defaults to llama2-70b-4096)
@@ -44,20 +49,21 @@ class GroqDriver(Driver):
44
49
  else:
45
50
  self.client = None
46
51
 
47
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
48
- """Generate completion using Groq API.
49
-
50
- Args:
51
- prompt: Input prompt
52
- options: Generation options
53
-
54
- Returns:
55
- Dict containing generated text and metadata
56
-
57
- Raises:
58
- RuntimeError: If groq package is not installed
59
- groq.error.*: Various Groq API errors
60
- """
52
+ supports_messages = True
53
+
54
+ def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
55
+ from .vision_helpers import _prepare_openai_vision_messages
56
+
57
+ return _prepare_openai_vision_messages(messages)
58
+
59
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
60
+ messages = [{"role": "user", "content": prompt}]
61
+ return self._do_generate(messages, options)
62
+
63
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
64
+ return self._do_generate(self._prepare_messages(messages), options)
65
+
66
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
61
67
  if self.client is None:
62
68
  raise RuntimeError("groq package is not installed")
63
69
 
@@ -74,7 +80,7 @@ class GroqDriver(Driver):
74
80
  # Base kwargs for API call
75
81
  kwargs = {
76
82
  "model": model,
77
- "messages": [{"role": "user", "content": prompt}],
83
+ "messages": messages,
78
84
  }
79
85
 
80
86
  # Set token limit with correct parameter name
@@ -84,23 +90,24 @@ class GroqDriver(Driver):
84
90
  if supports_temperature and "temperature" in opts:
85
91
  kwargs["temperature"] = opts["temperature"]
86
92
 
93
+ # Native JSON mode support
94
+ if options.get("json_mode"):
95
+ kwargs["response_format"] = {"type": "json_object"}
96
+
87
97
  try:
88
98
  resp = self.client.chat.completions.create(**kwargs)
89
- except Exception as e:
99
+ except Exception:
90
100
  # Re-raise any Groq API errors
91
101
  raise
92
102
 
93
103
  # Extract usage statistics
94
104
  usage = getattr(resp, "usage", None)
95
105
  prompt_tokens = getattr(usage, "prompt_tokens", 0)
96
- completion_tokens = getattr(usage, "completion_tokens", 0)
106
+ completion_tokens = getattr(usage, "completion_tokens", 0)
97
107
  total_tokens = getattr(usage, "total_tokens", 0)
98
108
 
99
- # Calculate costs
100
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
101
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
102
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
103
- total_cost = prompt_cost + completion_cost
109
+ # Calculate cost via shared mixin
110
+ total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
104
111
 
105
112
  # Standard metadata object
106
113
  meta = {
@@ -114,4 +121,4 @@ class GroqDriver(Driver):
114
121
 
115
122
  # Extract generated text
116
123
  text = resp.choices[0].message.content
117
- return {"text": text, "meta": meta}
124
+ return {"text": text, "meta": meta}