prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. prompture/__init__.py +133 -49
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +50 -35
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +171 -73
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +86 -34
  32. prompture/drivers/google_driver.py +87 -51
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/session.py +117 -0
  47. prompture/settings.py +14 -1
  48. prompture/tools.py +172 -265
  49. prompture/validator.py +3 -3
  50. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
  51. prompture-0.0.34.dist-info/RECORD +55 -0
  52. prompture-0.0.33.dev1.dist-info/RECORD +0 -29
  53. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
  54. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
  55. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
  56. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,23 @@
1
1
  """Driver for Azure OpenAI Service (migrated to openai>=1.0.0).
2
2
  Requires the `openai` package.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  try:
7
9
  from openai import AzureOpenAI
8
10
  except Exception:
9
11
  AzureOpenAI = None
10
12
 
13
+ from ..cost_mixin import CostMixin
11
14
  from ..driver import Driver
12
15
 
13
16
 
14
- class AzureDriver(Driver):
17
+ class AzureDriver(CostMixin, Driver):
18
+ supports_json_mode = True
19
+ supports_json_schema = True
20
+
15
21
  # Pricing per 1K tokens (adjust if your Azure pricing differs from OpenAI defaults)
16
22
  MODEL_PRICING = {
17
23
  "gpt-5-mini": {
@@ -82,7 +88,16 @@ class AzureDriver(Driver):
82
88
  else:
83
89
  self.client = None
84
90
 
85
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
91
+ supports_messages = True
92
+
93
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
94
+ messages = [{"role": "user", "content": prompt}]
95
+ return self._do_generate(messages, options)
96
+
97
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
98
+ return self._do_generate(messages, options)
99
+
100
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
86
101
  if self.client is None:
87
102
  raise RuntimeError("openai package (>=1.0.0) with AzureOpenAI not installed")
88
103
 
@@ -96,13 +111,28 @@ class AzureDriver(Driver):
96
111
  # Build request kwargs
97
112
  kwargs = {
98
113
  "model": self.deployment_id, # for Azure, use deployment name
99
- "messages": [{"role": "user", "content": prompt}],
114
+ "messages": messages,
100
115
  }
101
116
  kwargs[tokens_param] = opts.get("max_tokens", 512)
102
117
 
103
118
  if supports_temperature and "temperature" in opts:
104
119
  kwargs["temperature"] = opts["temperature"]
105
120
 
121
+ # Native JSON mode support
122
+ if options.get("json_mode"):
123
+ json_schema = options.get("json_schema")
124
+ if json_schema:
125
+ kwargs["response_format"] = {
126
+ "type": "json_schema",
127
+ "json_schema": {
128
+ "name": "extraction",
129
+ "strict": True,
130
+ "schema": json_schema,
131
+ },
132
+ }
133
+ else:
134
+ kwargs["response_format"] = {"type": "json_object"}
135
+
106
136
  resp = self.client.chat.completions.create(**kwargs)
107
137
 
108
138
  # Extract usage
@@ -111,11 +141,8 @@ class AzureDriver(Driver):
111
141
  completion_tokens = getattr(usage, "completion_tokens", 0)
112
142
  total_tokens = getattr(usage, "total_tokens", 0)
113
143
 
114
- # Calculate cost
115
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
116
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
117
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
118
- total_cost = prompt_cost + completion_cost
144
+ # Calculate cost via shared mixin
145
+ total_cost = self._calculate_cost("azure", model, prompt_tokens, completion_tokens)
119
146
 
120
147
  # Standardized meta object
121
148
  meta = {
@@ -1,75 +1,128 @@
1
1
  """Driver for Anthropic's Claude models. Requires the `anthropic` library.
2
2
  Use with API key in CLAUDE_API_KEY env var or provide directly.
3
3
  """
4
+
5
+ import json
4
6
  import os
5
- from typing import Any, Dict
7
+ from typing import Any
8
+
6
9
  try:
7
10
  import anthropic
8
11
  except Exception:
9
12
  anthropic = None
10
13
 
14
+ from ..cost_mixin import CostMixin
11
15
  from ..driver import Driver
12
16
 
13
- class ClaudeDriver(Driver):
17
+
18
+ class ClaudeDriver(CostMixin, Driver):
19
+ supports_json_mode = True
20
+ supports_json_schema = True
21
+
14
22
  # Claude pricing per 1000 tokens (prices should be kept current with Anthropic's pricing)
15
23
  MODEL_PRICING = {
16
24
  # Claude Opus 4.1
17
25
  "claude-opus-4-1-20250805": {
18
- "prompt": 0.015, # $15 per 1M prompt tokens
19
- "completion": 0.075, # $75 per 1M completion tokens
26
+ "prompt": 0.015, # $15 per 1M prompt tokens
27
+ "completion": 0.075, # $75 per 1M completion tokens
20
28
  },
21
29
  # Claude Opus 4.0
22
30
  "claude-opus-4-20250514": {
23
- "prompt": 0.015, # $15 per 1M prompt tokens
24
- "completion": 0.075, # $75 per 1M completion tokens
31
+ "prompt": 0.015, # $15 per 1M prompt tokens
32
+ "completion": 0.075, # $75 per 1M completion tokens
25
33
  },
26
34
  # Claude Sonnet 4.0
27
35
  "claude-sonnet-4-20250514": {
28
- "prompt": 0.003, # $3 per 1M prompt tokens
29
- "completion": 0.015, # $15 per 1M completion tokens
36
+ "prompt": 0.003, # $3 per 1M prompt tokens
37
+ "completion": 0.015, # $15 per 1M completion tokens
30
38
  },
31
39
  # Claude Sonnet 3.7
32
40
  "claude-3-7-sonnet-20250219": {
33
- "prompt": 0.003, # $3 per 1M prompt tokens
34
- "completion": 0.015, # $15 per 1M completion tokens
41
+ "prompt": 0.003, # $3 per 1M prompt tokens
42
+ "completion": 0.015, # $15 per 1M completion tokens
35
43
  },
36
44
  # Claude Haiku 3.5
37
45
  "claude-3-5-haiku-20241022": {
38
- "prompt": 0.0008, # $0.80 per 1M prompt tokens
39
- "completion": 0.004, # $4 per 1M completion tokens
40
- }
46
+ "prompt": 0.0008, # $0.80 per 1M prompt tokens
47
+ "completion": 0.004, # $4 per 1M completion tokens
48
+ },
41
49
  }
42
50
 
43
51
  def __init__(self, api_key: str | None = None, model: str = "claude-3-5-haiku-20241022"):
44
52
  self.api_key = api_key or os.getenv("CLAUDE_API_KEY")
45
53
  self.model = model or os.getenv("CLAUDE_MODEL_NAME", "claude-3-5-haiku-20241022")
46
54
 
47
- def generate(self, prompt: str, options: Dict[str,Any]) -> Dict[str,Any]:
55
+ supports_messages = True
56
+
57
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
58
+ messages = [{"role": "user", "content": prompt}]
59
+ return self._do_generate(messages, options)
60
+
61
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
62
+ return self._do_generate(messages, options)
63
+
64
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
65
  if anthropic is None:
49
66
  raise RuntimeError("anthropic package not installed")
50
-
67
+
51
68
  opts = {**{"temperature": 0.0, "max_tokens": 512}, **options}
52
69
  model = options.get("model", self.model)
53
-
70
+
54
71
  client = anthropic.Anthropic(api_key=self.api_key)
55
- resp = client.messages.create(
56
- model=model,
57
- messages=[{"role": "user", "content": prompt}],
58
- temperature=opts["temperature"],
59
- max_tokens=opts["max_tokens"]
60
- )
61
-
72
+
73
+ # Anthropic requires system messages as a top-level parameter
74
+ system_content = None
75
+ api_messages = []
76
+ for msg in messages:
77
+ if msg.get("role") == "system":
78
+ system_content = msg.get("content", "")
79
+ else:
80
+ api_messages.append(msg)
81
+
82
+ # Build common kwargs
83
+ common_kwargs: dict[str, Any] = {
84
+ "model": model,
85
+ "messages": api_messages,
86
+ "temperature": opts["temperature"],
87
+ "max_tokens": opts["max_tokens"],
88
+ }
89
+ if system_content:
90
+ common_kwargs["system"] = system_content
91
+
92
+ # Native JSON mode: use tool-use for schema enforcement
93
+ if options.get("json_mode"):
94
+ json_schema = options.get("json_schema")
95
+ if json_schema:
96
+ tool_def = {
97
+ "name": "extract_json",
98
+ "description": "Extract structured data matching the schema",
99
+ "input_schema": json_schema,
100
+ }
101
+ resp = client.messages.create(
102
+ **common_kwargs,
103
+ tools=[tool_def],
104
+ tool_choice={"type": "tool", "name": "extract_json"},
105
+ )
106
+ text = ""
107
+ for block in resp.content:
108
+ if block.type == "tool_use":
109
+ text = json.dumps(block.input)
110
+ break
111
+ else:
112
+ resp = client.messages.create(**common_kwargs)
113
+ text = resp.content[0].text
114
+ else:
115
+ resp = client.messages.create(**common_kwargs)
116
+ text = resp.content[0].text
117
+
62
118
  # Extract token usage from Claude response
63
119
  prompt_tokens = resp.usage.input_tokens
64
120
  completion_tokens = resp.usage.output_tokens
65
121
  total_tokens = prompt_tokens + completion_tokens
66
-
67
- # Calculate cost based on model pricing
68
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
69
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
70
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
71
- total_cost = prompt_cost + completion_cost
72
-
122
+
123
+ # Calculate cost via shared mixin
124
+ total_cost = self._calculate_cost("claude", model, prompt_tokens, completion_tokens)
125
+
73
126
  # Create standardized meta object
74
127
  meta = {
75
128
  "prompt_tokens": prompt_tokens,
@@ -77,8 +130,7 @@ class ClaudeDriver(Driver):
77
130
  "total_tokens": total_tokens,
78
131
  "cost": round(total_cost, 6), # Round to 6 decimal places
79
132
  "raw_response": dict(resp),
80
- "model_name": model
133
+ "model_name": model,
81
134
  }
82
-
83
- text = resp.content[0].text
84
- return {"text": text, "meta": meta}
135
+
136
+ return {"text": text, "meta": meta}
@@ -1,60 +1,55 @@
1
- import os
2
1
  import logging
3
- import google.generativeai as genai
4
- from typing import Any, Dict
5
- from ..driver import Driver
6
-
7
2
  import os
8
- import logging
3
+ from typing import Any, Optional
4
+
9
5
  import google.generativeai as genai
10
- from typing import Any, Dict
6
+
7
+ from ..cost_mixin import CostMixin
11
8
  from ..driver import Driver
12
9
 
13
10
  logger = logging.getLogger(__name__)
14
11
 
15
12
 
16
- class GoogleDriver(Driver):
13
+ class GoogleDriver(CostMixin, Driver):
17
14
  """Driver for Google's Generative AI API (Gemini)."""
18
15
 
16
+ supports_json_mode = True
17
+ supports_json_schema = True
18
+
19
19
  # Based on current Gemini pricing (as of 2025)
20
20
  # Source: https://cloud.google.com/vertex-ai/pricing#gemini_models
21
+ _PRICING_UNIT = 1_000_000
21
22
  MODEL_PRICING = {
22
23
  "gemini-1.5-pro": {
23
24
  "prompt": 0.00025, # $0.25/1M chars input
24
- "completion": 0.0005 # $0.50/1M chars output
25
+ "completion": 0.0005, # $0.50/1M chars output
25
26
  },
26
27
  "gemini-1.5-pro-vision": {
27
28
  "prompt": 0.00025, # $0.25/1M chars input
28
- "completion": 0.0005 # $0.50/1M chars output
29
+ "completion": 0.0005, # $0.50/1M chars output
29
30
  },
30
31
  "gemini-2.5-pro": {
31
32
  "prompt": 0.0004, # $0.40/1M chars input
32
- "completion": 0.0008 # $0.80/1M chars output
33
+ "completion": 0.0008, # $0.80/1M chars output
33
34
  },
34
35
  "gemini-2.5-flash": {
35
36
  "prompt": 0.0004, # $0.40/1M chars input
36
- "completion": 0.0008 # $0.80/1M chars output
37
+ "completion": 0.0008, # $0.80/1M chars output
37
38
  },
38
39
  "gemini-2.5-flash-lite": {
39
40
  "prompt": 0.0002, # $0.20/1M chars input
40
- "completion": 0.0004 # $0.40/1M chars output
41
+ "completion": 0.0004, # $0.40/1M chars output
41
42
  },
42
- "gemini-2.0-flash": {
43
+ "gemini-2.0-flash": {
43
44
  "prompt": 0.0004, # $0.40/1M chars input
44
- "completion": 0.0008 # $0.80/1M chars output
45
+ "completion": 0.0008, # $0.80/1M chars output
45
46
  },
46
47
  "gemini-2.0-flash-lite": {
47
48
  "prompt": 0.0002, # $0.20/1M chars input
48
- "completion": 0.0004 # $0.40/1M chars output
49
- },
50
- "gemini-1.5-flash": {
51
- "prompt": 0.00001875,
52
- "completion": 0.000075
49
+ "completion": 0.0004, # $0.40/1M chars output
53
50
  },
54
- "gemini-1.5-flash-8b": {
55
- "prompt": 0.00001,
56
- "completion": 0.00004
57
- }
51
+ "gemini-1.5-flash": {"prompt": 0.00001875, "completion": 0.000075},
52
+ "gemini-1.5-flash-8b": {"prompt": 0.00001, "completion": 0.00004},
58
53
  }
59
54
 
60
55
  def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-pro"):
@@ -75,8 +70,8 @@ class GoogleDriver(Driver):
75
70
 
76
71
  # Configure google.generativeai
77
72
  genai.configure(api_key=self.api_key)
78
- self.options: Dict[str, Any] = {}
79
-
73
+ self.options: dict[str, Any] = {}
74
+
80
75
  # Validate connection and model availability
81
76
  self._validate_connection()
82
77
 
@@ -90,16 +85,36 @@ class GoogleDriver(Driver):
90
85
  logger.warning(f"Could not validate connection to Google API: {e}")
91
86
  raise
92
87
 
93
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
94
- """Generate text using Google's Generative AI.
88
+ def _calculate_cost_chars(self, prompt_chars: int, completion_chars: int) -> float:
89
+ """Calculate cost from character counts.
95
90
 
96
- Args:
97
- prompt: The input prompt
98
- options: Additional options to pass to the model
99
-
100
- Returns:
101
- Dict containing generated text and metadata
91
+ Live rates use token-based pricing (estimate ~4 chars/token).
92
+ Hardcoded MODEL_PRICING uses per-1M-character rates.
102
93
  """
94
+ from ..model_rates import get_model_rates
95
+
96
+ live_rates = get_model_rates("google", self.model)
97
+ if live_rates:
98
+ est_prompt_tokens = prompt_chars / 4
99
+ est_completion_tokens = completion_chars / 4
100
+ prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
101
+ completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
102
+ else:
103
+ model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
104
+ prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
105
+ completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
106
+ return round(prompt_cost + completion_cost, 6)
107
+
108
+ supports_messages = True
109
+
110
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
111
+ messages = [{"role": "user", "content": prompt}]
112
+ return self._do_generate(messages, options)
113
+
114
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
115
+ return self._do_generate(messages, options)
116
+
117
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
103
118
  merged_options = self.options.copy()
104
119
  if options:
105
120
  merged_options.update(options)
@@ -107,7 +122,7 @@ class GoogleDriver(Driver):
107
122
  # Extract specific options for Google's API
108
123
  generation_config = merged_options.get("generation_config", {})
109
124
  safety_settings = merged_options.get("safety_settings", {})
110
-
125
+
111
126
  # Map common options to generation_config if not present
112
127
  if "temperature" in merged_options and "temperature" not in generation_config:
113
128
  generation_config["temperature"] = merged_options["temperature"]
@@ -118,36 +133,57 @@ class GoogleDriver(Driver):
118
133
  if "top_k" in merged_options and "top_k" not in generation_config:
119
134
  generation_config["top_k"] = merged_options["top_k"]
120
135
 
136
+ # Native JSON mode support
137
+ if merged_options.get("json_mode"):
138
+ generation_config["response_mime_type"] = "application/json"
139
+ json_schema = merged_options.get("json_schema")
140
+ if json_schema:
141
+ generation_config["response_schema"] = json_schema
142
+
143
+ # Convert messages to Gemini format
144
+ system_instruction = None
145
+ contents: list[dict[str, Any]] = []
146
+ for msg in messages:
147
+ role = msg.get("role", "user")
148
+ content = msg.get("content", "")
149
+ if role == "system":
150
+ system_instruction = content
151
+ else:
152
+ # Gemini uses "model" for assistant role
153
+ gemini_role = "model" if role == "assistant" else "user"
154
+ contents.append({"role": gemini_role, "parts": [content]})
155
+
121
156
  try:
122
157
  logger.debug(f"Initializing {self.model} for generation")
123
- model = genai.GenerativeModel(self.model)
158
+ model_kwargs: dict[str, Any] = {}
159
+ if system_instruction:
160
+ model_kwargs["system_instruction"] = system_instruction
161
+ model = genai.GenerativeModel(self.model, **model_kwargs)
124
162
 
125
163
  # Generate response
126
- logger.debug(f"Generating with prompt: {prompt}")
164
+ logger.debug(f"Generating with {len(contents)} content parts")
165
+ # If single user message, pass content directly for backward compatibility
166
+ gen_input: Any = contents if len(contents) != 1 else contents[0]["parts"][0]
127
167
  response = model.generate_content(
128
- prompt,
168
+ gen_input,
129
169
  generation_config=generation_config if generation_config else None,
130
- safety_settings=safety_settings if safety_settings else None
170
+ safety_settings=safety_settings if safety_settings else None,
131
171
  )
132
-
172
+
133
173
  if not response.text:
134
174
  raise ValueError("Empty response from model")
135
175
 
136
176
  # Calculate token usage and cost
137
- # Note: Using character count as proxy since Google charges per character
138
- prompt_chars = len(prompt)
177
+ total_prompt_chars = sum(len(msg.get("content", "")) for msg in messages)
139
178
  completion_chars = len(response.text)
140
-
141
- # Calculate costs
142
- model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
143
- prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
144
- completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
145
- total_cost = prompt_cost + completion_cost
179
+
180
+ # Google uses character-based cost estimation
181
+ total_cost = self._calculate_cost_chars(total_prompt_chars, completion_chars)
146
182
 
147
183
  meta = {
148
- "prompt_chars": prompt_chars,
184
+ "prompt_chars": total_prompt_chars,
149
185
  "completion_chars": completion_chars,
150
- "total_chars": prompt_chars + completion_chars,
186
+ "total_chars": total_prompt_chars + completion_chars,
151
187
  "cost": total_cost,
152
188
  "raw_response": response.prompt_feedback if hasattr(response, "prompt_feedback") else None,
153
189
  "model_name": self.model,
@@ -157,4 +193,4 @@ class GoogleDriver(Driver):
157
193
 
158
194
  except Exception as e:
159
195
  logger.error(f"Google API request failed: {e}")
160
- raise RuntimeError(f"Google API request failed: {e}")
196
+ raise RuntimeError(f"Google API request failed: {e}") from e
@@ -1,15 +1,21 @@
1
1
  """xAI Grok driver.
2
2
  Requires the `requests` package. Uses GROK_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  import requests
7
9
 
10
+ from ..cost_mixin import CostMixin
8
11
  from ..driver import Driver
9
12
 
10
13
 
11
- class GrokDriver(Driver):
14
+ class GrokDriver(CostMixin, Driver):
15
+ supports_json_mode = True
16
+
12
17
  # Pricing per 1M tokens based on xAI's documentation
18
+ _PRICING_UNIT = 1_000_000
13
19
  MODEL_PRICING = {
14
20
  "grok-code-fast-1": {
15
21
  "prompt": 0.20,
@@ -72,19 +78,16 @@ class GrokDriver(Driver):
72
78
  self.model = model
73
79
  self.api_base = "https://api.x.ai/v1"
74
80
 
75
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
76
- """Generate completion using Grok API.
81
+ supports_messages = True
77
82
 
78
- Args:
79
- prompt: Input prompt
80
- options: Generation options
81
-
82
- Returns:
83
- Dict containing generated text and metadata
84
-
85
- Raises:
86
- RuntimeError: If API key is missing or request fails
87
- """
83
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
84
+ messages = [{"role": "user", "content": prompt}]
85
+ return self._do_generate(messages, options)
86
+
87
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
+ return self._do_generate(messages, options)
89
+
90
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
88
91
  if not self.api_key:
89
92
  raise RuntimeError("GROK_API_KEY environment variable is required")
90
93
 
@@ -101,7 +104,7 @@ class GrokDriver(Driver):
101
104
  # Base request payload
102
105
  payload = {
103
106
  "model": model,
104
- "messages": [{"role": "user", "content": prompt}],
107
+ "messages": messages,
105
108
  }
106
109
 
107
110
  # Add token limit with correct parameter name
@@ -111,33 +114,27 @@ class GrokDriver(Driver):
111
114
  if supports_temperature and "temperature" in opts:
112
115
  payload["temperature"] = opts["temperature"]
113
116
 
114
- headers = {
115
- "Authorization": f"Bearer {self.api_key}",
116
- "Content-Type": "application/json"
117
- }
117
+ # Native JSON mode support
118
+ if options.get("json_mode"):
119
+ payload["response_format"] = {"type": "json_object"}
120
+
121
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
118
122
 
119
123
  try:
120
- response = requests.post(
121
- f"{self.api_base}/chat/completions",
122
- headers=headers,
123
- json=payload
124
- )
124
+ response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
125
125
  response.raise_for_status()
126
126
  resp = response.json()
127
127
  except requests.exceptions.RequestException as e:
128
- raise RuntimeError(f"Grok API request failed: {str(e)}")
128
+ raise RuntimeError(f"Grok API request failed: {e!s}") from e
129
129
 
130
130
  # Extract usage info
131
131
  usage = resp.get("usage", {})
132
132
  prompt_tokens = usage.get("prompt_tokens", 0)
133
- completion_tokens = usage.get("completion_tokens", 0)
133
+ completion_tokens = usage.get("completion_tokens", 0)
134
134
  total_tokens = usage.get("total_tokens", 0)
135
135
 
136
- # Calculate cost
137
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
138
- prompt_cost = (prompt_tokens / 1000000) * model_pricing["prompt"]
139
- completion_cost = (completion_tokens / 1000000) * model_pricing["completion"]
140
- total_cost = prompt_cost + completion_cost
136
+ # Calculate cost via shared mixin
137
+ total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
141
138
 
142
139
  # Standardized meta object
143
140
  meta = {
@@ -150,4 +147,4 @@ class GrokDriver(Driver):
150
147
  }
151
148
 
152
149
  text = resp["choices"][0]["message"]["content"]
153
- return {"text": text, "meta": meta}
150
+ return {"text": text, "meta": meta}