prompture 0.0.33.dev2__py3-none-any.whl → 0.0.34.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. prompture/__init__.py +112 -54
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +41 -36
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +63 -57
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +80 -0
  30. prompture/drivers/azure_driver.py +36 -15
  31. prompture/drivers/claude_driver.py +86 -40
  32. prompture/drivers/google_driver.py +86 -58
  33. prompture/drivers/grok_driver.py +29 -38
  34. prompture/drivers/groq_driver.py +27 -32
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -15
  40. prompture/drivers/openrouter_driver.py +31 -31
  41. prompture/field_definitions.py +106 -96
  42. prompture/logging.py +80 -0
  43. prompture/model_rates.py +16 -15
  44. prompture/runner.py +49 -47
  45. prompture/session.py +117 -0
  46. prompture/settings.py +11 -1
  47. prompture/tools.py +172 -265
  48. prompture/validator.py +3 -3
  49. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/METADATA +18 -20
  50. prompture-0.0.34.dev1.dist-info/RECORD +54 -0
  51. prompture-0.0.33.dev2.dist-info/RECORD +0 -30
  52. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/WHEEL +0 -0
  53. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/entry_points.txt +0 -0
  54. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/licenses/LICENSE +0 -0
  55. {prompture-0.0.33.dev2.dist-info → prompture-0.0.34.dev1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  import os
2
+ from typing import Any
3
+
2
4
  import requests
5
+
3
6
  from ..driver import Driver
4
- from typing import Any, Dict
5
7
 
6
8
 
7
9
  class HuggingFaceDriver(Driver):
8
10
  # Hugging Face is usage-based (credits/subscription), but we set costs to 0 for now.
9
- MODEL_PRICING = {
10
- "default": {"prompt": 0.0, "completion": 0.0}
11
- }
11
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
12
12
 
13
13
  def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
14
14
  self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
@@ -22,7 +22,7 @@ class HuggingFaceDriver(Driver):
22
22
 
23
23
  self.headers = {"Authorization": f"Bearer {self.token}"}
24
24
 
25
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
25
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
26
26
  payload = {
27
27
  "inputs": prompt,
28
28
  "parameters": options, # HF allows temperature, max_new_tokens, etc. here
@@ -32,7 +32,7 @@ class HuggingFaceDriver(Driver):
32
32
  r.raise_for_status()
33
33
  response_data = r.json()
34
34
  except Exception as e:
35
- raise RuntimeError(f"HuggingFaceDriver request failed: {e}")
35
+ raise RuntimeError(f"HuggingFaceDriver request failed: {e}") from e
36
36
 
37
37
  # Different HF models return slightly different response formats
38
38
  # Text-generation models usually return [{"generated_text": "..."}]
@@ -1,26 +1,26 @@
1
- import os
2
1
  import json
3
- import requests
4
2
  import logging
3
+ import os
4
+ from typing import Any, Optional
5
+
6
+ import requests
7
+
5
8
  from ..driver import Driver
6
- from typing import Any, Dict
7
9
 
8
10
  logger = logging.getLogger(__name__)
9
11
 
10
12
 
11
13
  class LMStudioDriver(Driver):
14
+ supports_json_mode = True
15
+
12
16
  # LM Studio is local – costs are always zero.
13
- MODEL_PRICING = {
14
- "default": {"prompt": 0.0, "completion": 0.0}
15
- }
17
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
16
18
 
17
19
  def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
18
20
  # Allow override via env var
19
- self.endpoint = endpoint or os.getenv(
20
- "LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
21
- )
21
+ self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
22
22
  self.model = model
23
- self.options: Dict[str, Any] = {}
23
+ self.options: dict[str, Any] = {}
24
24
 
25
25
  # Validate connection to LM Studio server
26
26
  self._validate_connection()
@@ -38,17 +38,30 @@ class LMStudioDriver(Driver):
38
38
  except requests.exceptions.RequestException as e:
39
39
  logger.warning(f"Could not validate connection to LM Studio server: {e}")
40
40
 
41
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
41
+ supports_messages = True
42
+
43
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
44
+ messages = [{"role": "user", "content": prompt}]
45
+ return self._do_generate(messages, options)
46
+
47
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
48
+ return self._do_generate(messages, options)
49
+
50
+ def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
42
51
  merged_options = self.options.copy()
43
52
  if options:
44
53
  merged_options.update(options)
45
54
 
46
55
  payload = {
47
56
  "model": merged_options.get("model", self.model),
48
- "messages": [{"role": "user", "content": prompt}],
57
+ "messages": messages,
49
58
  "temperature": merged_options.get("temperature", 0.7),
50
59
  }
51
60
 
61
+ # Native JSON mode support
62
+ if merged_options.get("json_mode"):
63
+ payload["response_format"] = {"type": "json_object"}
64
+
52
65
  try:
53
66
  logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
54
67
  logger.debug(f"Request payload: {payload}")
@@ -70,7 +83,7 @@ class LMStudioDriver(Driver):
70
83
  raise
71
84
  except Exception as e:
72
85
  logger.error(f"Unexpected error in LM Studio request: {e}")
73
- raise RuntimeError(f"LM Studio request failed: {e}")
86
+ raise RuntimeError(f"LM Studio request failed: {e}") from e
74
87
 
75
88
  # Extract text
76
89
  text = response_data["choices"][0]["message"]["content"]
@@ -1,27 +1,27 @@
1
1
  import os
2
+ from typing import Any
3
+
2
4
  import requests
5
+
3
6
  from ..driver import Driver
4
- from typing import Any, Dict
5
7
 
6
8
 
7
9
  class LocalHTTPDriver(Driver):
8
10
  # Default: no cost; extend if your local service has pricing logic
9
- MODEL_PRICING = {
10
- "default": {"prompt": 0.0, "completion": 0.0}
11
- }
11
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
12
12
 
13
13
  def __init__(self, endpoint: str | None = None, model: str = "local-model"):
14
14
  self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
15
15
  self.model = model
16
16
 
17
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
17
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
18
18
  payload = {"prompt": prompt, "options": options}
19
19
  try:
20
20
  r = requests.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
21
21
  r.raise_for_status()
22
22
  response_data = r.json()
23
23
  except Exception as e:
24
- raise RuntimeError(f"LocalHTTPDriver request failed: {e}")
24
+ raise RuntimeError(f"LocalHTTPDriver request failed: {e}") from e
25
25
 
26
26
  # If the local API already provides {"text": "...", "meta": {...}}, just return it
27
27
  if "text" in response_data and "meta" in response_data:
@@ -1,38 +1,38 @@
1
- import os
2
1
  import json
3
- import requests
4
2
  import logging
3
+ import os
4
+ from typing import Any, Optional
5
+
6
+ import requests
7
+
5
8
  from ..driver import Driver
6
- from typing import Any, Dict
7
9
 
8
10
  logger = logging.getLogger(__name__)
9
11
 
10
12
 
11
13
  class OllamaDriver(Driver):
14
+ supports_json_mode = True
15
+
12
16
  # Ollama is free – costs are always zero.
13
- MODEL_PRICING = {
14
- "default": {"prompt": 0.0, "completion": 0.0}
15
- }
17
+ MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
16
18
 
17
19
  def __init__(self, endpoint: str | None = None, model: str = "llama3"):
18
20
  # Allow override via env var
19
- self.endpoint = endpoint or os.getenv(
20
- "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
21
- )
21
+ self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
22
22
  self.model = model
23
23
  self.options = {} # Initialize empty options dict
24
-
24
+
25
25
  # Validate connection to Ollama server
26
26
  self._validate_connection()
27
-
27
+
28
28
  def _validate_connection(self):
29
29
  """Validate connection to the Ollama server."""
30
30
  try:
31
31
  # Send a simple HEAD request to check if server is accessible
32
32
  # Use the base API endpoint without the specific path
33
- base_url = self.endpoint.split('/api/')[0]
33
+ base_url = self.endpoint.split("/api/")[0]
34
34
  health_url = f"{base_url}/api/version"
35
-
35
+
36
36
  logger.debug(f"Validating connection to Ollama server at: {health_url}")
37
37
  response = requests.head(health_url, timeout=5)
38
38
  response.raise_for_status()
@@ -42,7 +42,9 @@ class OllamaDriver(Driver):
42
42
  # We don't raise an error here to allow for delayed server startup
43
43
  # The actual error will be raised when generate() is called
44
44
 
45
- def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
45
+ supports_messages = True
46
+
47
+ def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
46
48
  # Merge instance options with call-specific options
47
49
  merged_options = self.options.copy()
48
50
  if options:
@@ -54,6 +56,10 @@ class OllamaDriver(Driver):
54
56
  "stream": False,
55
57
  }
56
58
 
59
+ # Native JSON mode support
60
+ if merged_options.get("json_mode"):
61
+ payload["format"] = "json"
62
+
57
63
  # Add any Ollama-specific options from merged_options
58
64
  if "temperature" in merged_options:
59
65
  payload["temperature"] = merged_options["temperature"]
@@ -65,21 +71,21 @@ class OllamaDriver(Driver):
65
71
  try:
66
72
  logger.debug(f"Sending request to Ollama endpoint: {self.endpoint}")
67
73
  logger.debug(f"Request payload: {payload}")
68
-
74
+
69
75
  r = requests.post(self.endpoint, json=payload, timeout=120)
70
76
  logger.debug(f"Response status code: {r.status_code}")
71
-
77
+
72
78
  r.raise_for_status()
73
-
79
+
74
80
  response_text = r.text
75
81
  logger.debug(f"Raw response text: {response_text}")
76
-
82
+
77
83
  response_data = r.json()
78
84
  logger.debug(f"Parsed response data: {response_data}")
79
-
85
+
80
86
  if not isinstance(response_data, dict):
81
87
  raise ValueError(f"Expected dict response, got {type(response_data)}")
82
-
88
+
83
89
  except requests.exceptions.ConnectionError as e:
84
90
  logger.error(f"Connection error to Ollama endpoint: {e}")
85
91
  # Preserve original exception
@@ -91,11 +97,11 @@ class OllamaDriver(Driver):
91
97
  except json.JSONDecodeError as e:
92
98
  logger.error(f"Failed to decode JSON response: {e}")
93
99
  # Re-raise JSONDecodeError with more context
94
- raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos)
100
+ raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
95
101
  except Exception as e:
96
102
  logger.error(f"Unexpected error in Ollama request: {e}")
97
103
  # Only wrap unknown exceptions in RuntimeError
98
- raise RuntimeError(f"Ollama request failed: {e}")
104
+ raise RuntimeError(f"Ollama request failed: {e}") from e
99
105
 
100
106
  # Extract token counts
101
107
  prompt_tokens = response_data.get("prompt_eval_count", 0)
@@ -113,4 +119,65 @@ class OllamaDriver(Driver):
113
119
  }
114
120
 
115
121
  # Ollama returns text in "response"
116
- return {"text": response_data.get("response", ""), "meta": meta}
122
+ return {"text": response_data.get("response", ""), "meta": meta}
123
+
124
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
125
+ """Use Ollama's /api/chat endpoint for multi-turn conversations."""
126
+ merged_options = self.options.copy()
127
+ if options:
128
+ merged_options.update(options)
129
+
130
+ # Derive the chat endpoint from the generate endpoint
131
+ chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
132
+
133
+ payload: dict[str, Any] = {
134
+ "model": merged_options.get("model", self.model),
135
+ "messages": messages,
136
+ "stream": False,
137
+ }
138
+
139
+ # Native JSON mode support
140
+ if merged_options.get("json_mode"):
141
+ payload["format"] = "json"
142
+
143
+ if "temperature" in merged_options:
144
+ payload["temperature"] = merged_options["temperature"]
145
+ if "top_p" in merged_options:
146
+ payload["top_p"] = merged_options["top_p"]
147
+ if "top_k" in merged_options:
148
+ payload["top_k"] = merged_options["top_k"]
149
+
150
+ try:
151
+ logger.debug(f"Sending chat request to Ollama endpoint: {chat_endpoint}")
152
+ r = requests.post(chat_endpoint, json=payload, timeout=120)
153
+ r.raise_for_status()
154
+ response_data = r.json()
155
+
156
+ if not isinstance(response_data, dict):
157
+ raise ValueError(f"Expected dict response, got {type(response_data)}")
158
+ except requests.exceptions.ConnectionError:
159
+ raise
160
+ except requests.exceptions.HTTPError:
161
+ raise
162
+ except json.JSONDecodeError as e:
163
+ raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
164
+ except Exception as e:
165
+ raise RuntimeError(f"Ollama chat request failed: {e}") from e
166
+
167
+ prompt_tokens = response_data.get("prompt_eval_count", 0)
168
+ completion_tokens = response_data.get("eval_count", 0)
169
+ total_tokens = prompt_tokens + completion_tokens
170
+
171
+ meta = {
172
+ "prompt_tokens": prompt_tokens,
173
+ "completion_tokens": completion_tokens,
174
+ "total_tokens": total_tokens,
175
+ "cost": 0.0,
176
+ "raw_response": response_data,
177
+ "model_name": merged_options.get("model", self.model),
178
+ }
179
+
180
+ # Chat endpoint returns response in message.content
181
+ message = response_data.get("message", {})
182
+ text = message.get("content", "")
183
+ return {"text": text, "meta": meta}
@@ -1,17 +1,23 @@
1
1
  """Minimal OpenAI driver (migrated to openai>=1.0.0).
2
2
  Requires the `openai` package. Uses OPENAI_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  try:
7
9
  from openai import OpenAI
8
10
  except Exception:
9
11
  OpenAI = None
10
12
 
13
+ from ..cost_mixin import CostMixin
11
14
  from ..driver import Driver
12
15
 
13
16
 
14
- class OpenAIDriver(Driver):
17
+ class OpenAIDriver(CostMixin, Driver):
18
+ supports_json_mode = True
19
+ supports_json_schema = True
20
+
15
21
  # Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
16
22
  # Each model entry also defines which token parameter it supports and
17
23
  # whether it accepts temperature.
@@ -62,7 +68,16 @@ class OpenAIDriver(Driver):
62
68
  else:
63
69
  self.client = None
64
70
 
65
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
71
+ supports_messages = True
72
+
73
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
74
+ messages = [{"role": "user", "content": prompt}]
75
+ return self._do_generate(messages, options)
76
+
77
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
78
+ return self._do_generate(messages, options)
79
+
80
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
66
81
  if self.client is None:
67
82
  raise RuntimeError("openai package (>=1.0.0) is not installed")
68
83
 
@@ -79,7 +94,7 @@ class OpenAIDriver(Driver):
79
94
  # Base kwargs
80
95
  kwargs = {
81
96
  "model": model,
82
- "messages": [{"role": "user", "content": prompt}],
97
+ "messages": messages,
83
98
  }
84
99
 
85
100
  # Assign token limit with the correct parameter name
@@ -89,6 +104,21 @@ class OpenAIDriver(Driver):
89
104
  if supports_temperature and "temperature" in opts:
90
105
  kwargs["temperature"] = opts["temperature"]
91
106
 
107
+ # Native JSON mode support
108
+ if options.get("json_mode"):
109
+ json_schema = options.get("json_schema")
110
+ if json_schema:
111
+ kwargs["response_format"] = {
112
+ "type": "json_schema",
113
+ "json_schema": {
114
+ "name": "extraction",
115
+ "strict": True,
116
+ "schema": json_schema,
117
+ },
118
+ }
119
+ else:
120
+ kwargs["response_format"] = {"type": "json_object"}
121
+
92
122
  resp = self.client.chat.completions.create(**kwargs)
93
123
 
94
124
  # Extract usage info
@@ -97,17 +127,8 @@ class OpenAIDriver(Driver):
97
127
  completion_tokens = getattr(usage, "completion_tokens", 0)
98
128
  total_tokens = getattr(usage, "total_tokens", 0)
99
129
 
100
- # Calculate cost try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
101
- from ..model_rates import get_model_rates
102
- live_rates = get_model_rates("openai", model)
103
- if live_rates:
104
- prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
105
- completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
106
- else:
107
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
108
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
109
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
110
- total_cost = prompt_cost + completion_cost
130
+ # Calculate cost via shared mixin
131
+ total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
111
132
 
112
133
  # Standardized meta object
113
134
  meta = {
@@ -1,14 +1,19 @@
1
1
  """OpenRouter driver implementation.
2
2
  Requires the `requests` package. Uses OPENROUTER_API_KEY env var.
3
3
  """
4
+
4
5
  import os
5
- from typing import Any, Dict
6
+ from typing import Any
7
+
6
8
  import requests
7
9
 
10
+ from ..cost_mixin import CostMixin
8
11
  from ..driver import Driver
9
12
 
10
13
 
11
- class OpenRouterDriver(Driver):
14
+ class OpenRouterDriver(CostMixin, Driver):
15
+ supports_json_mode = True
16
+
12
17
  # Approximate pricing per 1K tokens based on OpenRouter's pricing
13
18
  # https://openrouter.ai/docs#pricing
14
19
  MODEL_PRICING = {
@@ -40,7 +45,7 @@ class OpenRouterDriver(Driver):
40
45
 
41
46
  def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
42
47
  """Initialize OpenRouter driver.
43
-
48
+
44
49
  Args:
45
50
  api_key: OpenRouter API key. If not provided, will look for OPENROUTER_API_KEY env var
46
51
  model: Model to use. Defaults to openai/gpt-3.5-turbo
@@ -48,10 +53,10 @@ class OpenRouterDriver(Driver):
48
53
  self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
49
54
  if not self.api_key:
50
55
  raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
51
-
56
+
52
57
  self.model = model
53
58
  self.base_url = "https://openrouter.ai/api/v1"
54
-
59
+
55
60
  # Required headers for OpenRouter
56
61
  self.headers = {
57
62
  "Authorization": f"Bearer {self.api_key}",
@@ -59,21 +64,21 @@ class OpenRouterDriver(Driver):
59
64
  "Content-Type": "application/json",
60
65
  }
61
66
 
62
- def generate(self, prompt: str, options: Dict[str, Any]) -> Dict[str, Any]:
63
- """Generate completion using OpenRouter API.
64
-
65
- Args:
66
- prompt: The prompt text
67
- options: Generation options
68
-
69
- Returns:
70
- Dict containing generated text and metadata
71
- """
67
+ supports_messages = True
68
+
69
+ def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
70
+ messages = [{"role": "user", "content": prompt}]
71
+ return self._do_generate(messages, options)
72
+
73
+ def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
74
+ return self._do_generate(messages, options)
75
+
76
+ def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
72
77
  if not self.api_key:
73
78
  raise RuntimeError("OpenRouter API key not found")
74
79
 
75
80
  model = options.get("model", self.model)
76
-
81
+
77
82
  # Lookup model-specific config
78
83
  model_info = self.MODEL_PRICING.get(model, {})
79
84
  tokens_param = model_info.get("tokens_param", "max_tokens")
@@ -85,7 +90,7 @@ class OpenRouterDriver(Driver):
85
90
  # Base request data
86
91
  data = {
87
92
  "model": model,
88
- "messages": [{"role": "user", "content": prompt}],
93
+ "messages": messages,
89
94
  }
90
95
 
91
96
  # Add token limit with correct parameter name
@@ -95,6 +100,10 @@ class OpenRouterDriver(Driver):
95
100
  if supports_temperature and "temperature" in opts:
96
101
  data["temperature"] = opts["temperature"]
97
102
 
103
+ # Native JSON mode support
104
+ if options.get("json_mode"):
105
+ data["response_format"] = {"type": "json_object"}
106
+
98
107
  try:
99
108
  response = requests.post(
100
109
  f"{self.base_url}/chat/completions",
@@ -110,17 +119,8 @@ class OpenRouterDriver(Driver):
110
119
  completion_tokens = usage.get("completion_tokens", 0)
111
120
  total_tokens = usage.get("total_tokens", 0)
112
121
 
113
- # Calculate cost try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
114
- from ..model_rates import get_model_rates
115
- live_rates = get_model_rates("openrouter", model)
116
- if live_rates:
117
- prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
118
- completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
119
- else:
120
- model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
121
- prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
122
- completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
123
- total_cost = prompt_cost + completion_cost
122
+ # Calculate cost via shared mixin
123
+ total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
124
124
 
125
125
  # Standardized meta object
126
126
  meta = {
@@ -136,11 +136,11 @@ class OpenRouterDriver(Driver):
136
136
  return {"text": text, "meta": meta}
137
137
 
138
138
  except requests.exceptions.RequestException as e:
139
- error_msg = f"OpenRouter API request failed: {str(e)}"
140
- if hasattr(e.response, 'json'):
139
+ error_msg = f"OpenRouter API request failed: {e!s}"
140
+ if hasattr(e.response, "json"):
141
141
  try:
142
142
  error_details = e.response.json()
143
143
  error_msg = f"{error_msg} - {error_details.get('error', {}).get('message', '')}"
144
144
  except Exception:
145
145
  pass
146
- raise RuntimeError(error_msg) from e
146
+ raise RuntimeError(error_msg) from e