prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +133 -49
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +50 -35
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +171 -73
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +86 -34
- prompture/drivers/google_driver.py +87 -51
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +14 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
- prompture-0.0.34.dist-info/RECORD +55 -0
- prompture-0.0.33.dev1.dist-info/RECORD +0 -29
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
prompture/drivers/groq_driver.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Groq driver for prompture.
|
|
2
2
|
Requires the `groq` package. Uses GROQ_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
try:
|
|
8
9
|
import groq
|
|
9
10
|
except Exception:
|
|
10
11
|
groq = None
|
|
11
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
12
14
|
from ..driver import Driver
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class GroqDriver(Driver):
|
|
17
|
+
class GroqDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
|
|
16
20
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
17
21
|
# Each model entry defines token parameters and temperature support
|
|
18
22
|
MODEL_PRICING = {
|
|
@@ -32,7 +36,7 @@ class GroqDriver(Driver):
|
|
|
32
36
|
|
|
33
37
|
def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
|
|
34
38
|
"""Initialize Groq driver.
|
|
35
|
-
|
|
39
|
+
|
|
36
40
|
Args:
|
|
37
41
|
api_key: Groq API key (defaults to GROQ_API_KEY env var)
|
|
38
42
|
model: Model to use (defaults to llama2-70b-4096)
|
|
@@ -44,20 +48,16 @@ class GroqDriver(Driver):
|
|
|
44
48
|
else:
|
|
45
49
|
self.client = None
|
|
46
50
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
Raises:
|
|
58
|
-
RuntimeError: If groq package is not installed
|
|
59
|
-
groq.error.*: Various Groq API errors
|
|
60
|
-
"""
|
|
51
|
+
supports_messages = True
|
|
52
|
+
|
|
53
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
54
|
+
messages = [{"role": "user", "content": prompt}]
|
|
55
|
+
return self._do_generate(messages, options)
|
|
56
|
+
|
|
57
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
return self._do_generate(messages, options)
|
|
59
|
+
|
|
60
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
61
|
if self.client is None:
|
|
62
62
|
raise RuntimeError("groq package is not installed")
|
|
63
63
|
|
|
@@ -74,7 +74,7 @@ class GroqDriver(Driver):
|
|
|
74
74
|
# Base kwargs for API call
|
|
75
75
|
kwargs = {
|
|
76
76
|
"model": model,
|
|
77
|
-
"messages":
|
|
77
|
+
"messages": messages,
|
|
78
78
|
}
|
|
79
79
|
|
|
80
80
|
# Set token limit with correct parameter name
|
|
@@ -84,23 +84,24 @@ class GroqDriver(Driver):
|
|
|
84
84
|
if supports_temperature and "temperature" in opts:
|
|
85
85
|
kwargs["temperature"] = opts["temperature"]
|
|
86
86
|
|
|
87
|
+
# Native JSON mode support
|
|
88
|
+
if options.get("json_mode"):
|
|
89
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
90
|
+
|
|
87
91
|
try:
|
|
88
92
|
resp = self.client.chat.completions.create(**kwargs)
|
|
89
|
-
except Exception
|
|
93
|
+
except Exception:
|
|
90
94
|
# Re-raise any Groq API errors
|
|
91
95
|
raise
|
|
92
96
|
|
|
93
97
|
# Extract usage statistics
|
|
94
98
|
usage = getattr(resp, "usage", None)
|
|
95
99
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
96
|
-
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
100
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
97
101
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
98
102
|
|
|
99
|
-
# Calculate
|
|
100
|
-
|
|
101
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
102
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
103
|
-
total_cost = prompt_cost + completion_cost
|
|
103
|
+
# Calculate cost via shared mixin
|
|
104
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
104
105
|
|
|
105
106
|
# Standard metadata object
|
|
106
107
|
meta = {
|
|
@@ -114,4 +115,4 @@ class GroqDriver(Driver):
|
|
|
114
115
|
|
|
115
116
|
# Extract generated text
|
|
116
117
|
text = resp.choices[0].message.content
|
|
117
|
-
return {"text": text, "meta": meta}
|
|
118
|
+
return {"text": text, "meta": meta}
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
2
4
|
import requests
|
|
5
|
+
|
|
3
6
|
from ..driver import Driver
|
|
4
|
-
from typing import Any, Dict
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class HuggingFaceDriver(Driver):
|
|
8
10
|
# Hugging Face is usage-based (credits/subscription), but we set costs to 0 for now.
|
|
9
|
-
MODEL_PRICING = {
|
|
10
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
11
|
-
}
|
|
11
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
12
12
|
|
|
13
13
|
def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
|
|
14
14
|
self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
|
|
@@ -22,7 +22,7 @@ class HuggingFaceDriver(Driver):
|
|
|
22
22
|
|
|
23
23
|
self.headers = {"Authorization": f"Bearer {self.token}"}
|
|
24
24
|
|
|
25
|
-
def generate(self, prompt: str, options:
|
|
25
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
26
26
|
payload = {
|
|
27
27
|
"inputs": prompt,
|
|
28
28
|
"parameters": options, # HF allows temperature, max_new_tokens, etc. here
|
|
@@ -32,7 +32,7 @@ class HuggingFaceDriver(Driver):
|
|
|
32
32
|
r.raise_for_status()
|
|
33
33
|
response_data = r.json()
|
|
34
34
|
except Exception as e:
|
|
35
|
-
raise RuntimeError(f"HuggingFaceDriver request failed: {e}")
|
|
35
|
+
raise RuntimeError(f"HuggingFaceDriver request failed: {e}") from e
|
|
36
36
|
|
|
37
37
|
# Different HF models return slightly different response formats
|
|
38
38
|
# Text-generation models usually return [{"generated_text": "..."}]
|
|
@@ -1,26 +1,26 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import json
|
|
3
|
-
import requests
|
|
4
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
5
8
|
from ..driver import Driver
|
|
6
|
-
from typing import Any, Dict
|
|
7
9
|
|
|
8
10
|
logger = logging.getLogger(__name__)
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class LMStudioDriver(Driver):
|
|
14
|
+
supports_json_mode = True
|
|
15
|
+
|
|
12
16
|
# LM Studio is local – costs are always zero.
|
|
13
|
-
MODEL_PRICING = {
|
|
14
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
15
|
-
}
|
|
17
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
16
18
|
|
|
17
19
|
def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
|
|
18
20
|
# Allow override via env var
|
|
19
|
-
self.endpoint = endpoint or os.getenv(
|
|
20
|
-
"LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
|
|
21
|
-
)
|
|
21
|
+
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
22
22
|
self.model = model
|
|
23
|
-
self.options:
|
|
23
|
+
self.options: dict[str, Any] = {}
|
|
24
24
|
|
|
25
25
|
# Validate connection to LM Studio server
|
|
26
26
|
self._validate_connection()
|
|
@@ -38,17 +38,30 @@ class LMStudioDriver(Driver):
|
|
|
38
38
|
except requests.exceptions.RequestException as e:
|
|
39
39
|
logger.warning(f"Could not validate connection to LM Studio server: {e}")
|
|
40
40
|
|
|
41
|
-
|
|
41
|
+
supports_messages = True
|
|
42
|
+
|
|
43
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
44
|
+
messages = [{"role": "user", "content": prompt}]
|
|
45
|
+
return self._do_generate(messages, options)
|
|
46
|
+
|
|
47
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
+
return self._do_generate(messages, options)
|
|
49
|
+
|
|
50
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
42
51
|
merged_options = self.options.copy()
|
|
43
52
|
if options:
|
|
44
53
|
merged_options.update(options)
|
|
45
54
|
|
|
46
55
|
payload = {
|
|
47
56
|
"model": merged_options.get("model", self.model),
|
|
48
|
-
"messages":
|
|
57
|
+
"messages": messages,
|
|
49
58
|
"temperature": merged_options.get("temperature", 0.7),
|
|
50
59
|
}
|
|
51
60
|
|
|
61
|
+
# Native JSON mode support
|
|
62
|
+
if merged_options.get("json_mode"):
|
|
63
|
+
payload["response_format"] = {"type": "json_object"}
|
|
64
|
+
|
|
52
65
|
try:
|
|
53
66
|
logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
|
|
54
67
|
logger.debug(f"Request payload: {payload}")
|
|
@@ -70,7 +83,7 @@ class LMStudioDriver(Driver):
|
|
|
70
83
|
raise
|
|
71
84
|
except Exception as e:
|
|
72
85
|
logger.error(f"Unexpected error in LM Studio request: {e}")
|
|
73
|
-
raise RuntimeError(f"LM Studio request failed: {e}")
|
|
86
|
+
raise RuntimeError(f"LM Studio request failed: {e}") from e
|
|
74
87
|
|
|
75
88
|
# Extract text
|
|
76
89
|
text = response_data["choices"][0]["message"]["content"]
|
|
@@ -1,27 +1,27 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
2
4
|
import requests
|
|
5
|
+
|
|
3
6
|
from ..driver import Driver
|
|
4
|
-
from typing import Any, Dict
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class LocalHTTPDriver(Driver):
|
|
8
10
|
# Default: no cost; extend if your local service has pricing logic
|
|
9
|
-
MODEL_PRICING = {
|
|
10
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
11
|
-
}
|
|
11
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
12
12
|
|
|
13
13
|
def __init__(self, endpoint: str | None = None, model: str = "local-model"):
|
|
14
14
|
self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
|
|
15
15
|
self.model = model
|
|
16
16
|
|
|
17
|
-
def generate(self, prompt: str, options:
|
|
17
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
18
18
|
payload = {"prompt": prompt, "options": options}
|
|
19
19
|
try:
|
|
20
20
|
r = requests.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
|
|
21
21
|
r.raise_for_status()
|
|
22
22
|
response_data = r.json()
|
|
23
23
|
except Exception as e:
|
|
24
|
-
raise RuntimeError(f"LocalHTTPDriver request failed: {e}")
|
|
24
|
+
raise RuntimeError(f"LocalHTTPDriver request failed: {e}") from e
|
|
25
25
|
|
|
26
26
|
# If the local API already provides {"text": "...", "meta": {...}}, just return it
|
|
27
27
|
if "text" in response_data and "meta" in response_data:
|
|
@@ -1,38 +1,38 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import json
|
|
3
|
-
import requests
|
|
4
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
5
8
|
from ..driver import Driver
|
|
6
|
-
from typing import Any, Dict
|
|
7
9
|
|
|
8
10
|
logger = logging.getLogger(__name__)
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class OllamaDriver(Driver):
|
|
14
|
+
supports_json_mode = True
|
|
15
|
+
|
|
12
16
|
# Ollama is free – costs are always zero.
|
|
13
|
-
MODEL_PRICING = {
|
|
14
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
15
|
-
}
|
|
17
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
16
18
|
|
|
17
19
|
def __init__(self, endpoint: str | None = None, model: str = "llama3"):
|
|
18
20
|
# Allow override via env var
|
|
19
|
-
self.endpoint = endpoint or os.getenv(
|
|
20
|
-
"OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
|
|
21
|
-
)
|
|
21
|
+
self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
|
|
22
22
|
self.model = model
|
|
23
23
|
self.options = {} # Initialize empty options dict
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
# Validate connection to Ollama server
|
|
26
26
|
self._validate_connection()
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
def _validate_connection(self):
|
|
29
29
|
"""Validate connection to the Ollama server."""
|
|
30
30
|
try:
|
|
31
31
|
# Send a simple HEAD request to check if server is accessible
|
|
32
32
|
# Use the base API endpoint without the specific path
|
|
33
|
-
base_url = self.endpoint.split(
|
|
33
|
+
base_url = self.endpoint.split("/api/")[0]
|
|
34
34
|
health_url = f"{base_url}/api/version"
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
logger.debug(f"Validating connection to Ollama server at: {health_url}")
|
|
37
37
|
response = requests.head(health_url, timeout=5)
|
|
38
38
|
response.raise_for_status()
|
|
@@ -42,7 +42,9 @@ class OllamaDriver(Driver):
|
|
|
42
42
|
# We don't raise an error here to allow for delayed server startup
|
|
43
43
|
# The actual error will be raised when generate() is called
|
|
44
44
|
|
|
45
|
-
|
|
45
|
+
supports_messages = True
|
|
46
|
+
|
|
47
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
46
48
|
# Merge instance options with call-specific options
|
|
47
49
|
merged_options = self.options.copy()
|
|
48
50
|
if options:
|
|
@@ -54,6 +56,10 @@ class OllamaDriver(Driver):
|
|
|
54
56
|
"stream": False,
|
|
55
57
|
}
|
|
56
58
|
|
|
59
|
+
# Native JSON mode support
|
|
60
|
+
if merged_options.get("json_mode"):
|
|
61
|
+
payload["format"] = "json"
|
|
62
|
+
|
|
57
63
|
# Add any Ollama-specific options from merged_options
|
|
58
64
|
if "temperature" in merged_options:
|
|
59
65
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -65,21 +71,21 @@ class OllamaDriver(Driver):
|
|
|
65
71
|
try:
|
|
66
72
|
logger.debug(f"Sending request to Ollama endpoint: {self.endpoint}")
|
|
67
73
|
logger.debug(f"Request payload: {payload}")
|
|
68
|
-
|
|
74
|
+
|
|
69
75
|
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
70
76
|
logger.debug(f"Response status code: {r.status_code}")
|
|
71
|
-
|
|
77
|
+
|
|
72
78
|
r.raise_for_status()
|
|
73
|
-
|
|
79
|
+
|
|
74
80
|
response_text = r.text
|
|
75
81
|
logger.debug(f"Raw response text: {response_text}")
|
|
76
|
-
|
|
82
|
+
|
|
77
83
|
response_data = r.json()
|
|
78
84
|
logger.debug(f"Parsed response data: {response_data}")
|
|
79
|
-
|
|
85
|
+
|
|
80
86
|
if not isinstance(response_data, dict):
|
|
81
87
|
raise ValueError(f"Expected dict response, got {type(response_data)}")
|
|
82
|
-
|
|
88
|
+
|
|
83
89
|
except requests.exceptions.ConnectionError as e:
|
|
84
90
|
logger.error(f"Connection error to Ollama endpoint: {e}")
|
|
85
91
|
# Preserve original exception
|
|
@@ -91,11 +97,11 @@ class OllamaDriver(Driver):
|
|
|
91
97
|
except json.JSONDecodeError as e:
|
|
92
98
|
logger.error(f"Failed to decode JSON response: {e}")
|
|
93
99
|
# Re-raise JSONDecodeError with more context
|
|
94
|
-
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos)
|
|
100
|
+
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
|
|
95
101
|
except Exception as e:
|
|
96
102
|
logger.error(f"Unexpected error in Ollama request: {e}")
|
|
97
103
|
# Only wrap unknown exceptions in RuntimeError
|
|
98
|
-
raise RuntimeError(f"Ollama request failed: {e}")
|
|
104
|
+
raise RuntimeError(f"Ollama request failed: {e}") from e
|
|
99
105
|
|
|
100
106
|
# Extract token counts
|
|
101
107
|
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
@@ -113,4 +119,65 @@ class OllamaDriver(Driver):
|
|
|
113
119
|
}
|
|
114
120
|
|
|
115
121
|
# Ollama returns text in "response"
|
|
116
|
-
return {"text": response_data.get("response", ""), "meta": meta}
|
|
122
|
+
return {"text": response_data.get("response", ""), "meta": meta}
|
|
123
|
+
|
|
124
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
125
|
+
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
126
|
+
merged_options = self.options.copy()
|
|
127
|
+
if options:
|
|
128
|
+
merged_options.update(options)
|
|
129
|
+
|
|
130
|
+
# Derive the chat endpoint from the generate endpoint
|
|
131
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
132
|
+
|
|
133
|
+
payload: dict[str, Any] = {
|
|
134
|
+
"model": merged_options.get("model", self.model),
|
|
135
|
+
"messages": messages,
|
|
136
|
+
"stream": False,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# Native JSON mode support
|
|
140
|
+
if merged_options.get("json_mode"):
|
|
141
|
+
payload["format"] = "json"
|
|
142
|
+
|
|
143
|
+
if "temperature" in merged_options:
|
|
144
|
+
payload["temperature"] = merged_options["temperature"]
|
|
145
|
+
if "top_p" in merged_options:
|
|
146
|
+
payload["top_p"] = merged_options["top_p"]
|
|
147
|
+
if "top_k" in merged_options:
|
|
148
|
+
payload["top_k"] = merged_options["top_k"]
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
logger.debug(f"Sending chat request to Ollama endpoint: {chat_endpoint}")
|
|
152
|
+
r = requests.post(chat_endpoint, json=payload, timeout=120)
|
|
153
|
+
r.raise_for_status()
|
|
154
|
+
response_data = r.json()
|
|
155
|
+
|
|
156
|
+
if not isinstance(response_data, dict):
|
|
157
|
+
raise ValueError(f"Expected dict response, got {type(response_data)}")
|
|
158
|
+
except requests.exceptions.ConnectionError:
|
|
159
|
+
raise
|
|
160
|
+
except requests.exceptions.HTTPError:
|
|
161
|
+
raise
|
|
162
|
+
except json.JSONDecodeError as e:
|
|
163
|
+
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
|
|
164
|
+
except Exception as e:
|
|
165
|
+
raise RuntimeError(f"Ollama chat request failed: {e}") from e
|
|
166
|
+
|
|
167
|
+
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
168
|
+
completion_tokens = response_data.get("eval_count", 0)
|
|
169
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
170
|
+
|
|
171
|
+
meta = {
|
|
172
|
+
"prompt_tokens": prompt_tokens,
|
|
173
|
+
"completion_tokens": completion_tokens,
|
|
174
|
+
"total_tokens": total_tokens,
|
|
175
|
+
"cost": 0.0,
|
|
176
|
+
"raw_response": response_data,
|
|
177
|
+
"model_name": merged_options.get("model", self.model),
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Chat endpoint returns response in message.content
|
|
181
|
+
message = response_data.get("message", {})
|
|
182
|
+
text = message.get("content", "")
|
|
183
|
+
return {"text": text, "meta": meta}
|
|
@@ -1,17 +1,23 @@
|
|
|
1
1
|
"""Minimal OpenAI driver (migrated to openai>=1.0.0).
|
|
2
2
|
Requires the `openai` package. Uses OPENAI_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
try:
|
|
7
9
|
from openai import OpenAI
|
|
8
10
|
except Exception:
|
|
9
11
|
OpenAI = None
|
|
10
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
11
14
|
from ..driver import Driver
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
class OpenAIDriver(Driver):
|
|
17
|
+
class OpenAIDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
supports_json_schema = True
|
|
20
|
+
|
|
15
21
|
# Approximate pricing per 1K tokens (keep updated with OpenAI's official pricing)
|
|
16
22
|
# Each model entry also defines which token parameter it supports and
|
|
17
23
|
# whether it accepts temperature.
|
|
@@ -62,7 +68,16 @@ class OpenAIDriver(Driver):
|
|
|
62
68
|
else:
|
|
63
69
|
self.client = None
|
|
64
70
|
|
|
65
|
-
|
|
71
|
+
supports_messages = True
|
|
72
|
+
|
|
73
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
74
|
+
messages = [{"role": "user", "content": prompt}]
|
|
75
|
+
return self._do_generate(messages, options)
|
|
76
|
+
|
|
77
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
78
|
+
return self._do_generate(messages, options)
|
|
79
|
+
|
|
80
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
81
|
if self.client is None:
|
|
67
82
|
raise RuntimeError("openai package (>=1.0.0) is not installed")
|
|
68
83
|
|
|
@@ -79,7 +94,7 @@ class OpenAIDriver(Driver):
|
|
|
79
94
|
# Base kwargs
|
|
80
95
|
kwargs = {
|
|
81
96
|
"model": model,
|
|
82
|
-
"messages":
|
|
97
|
+
"messages": messages,
|
|
83
98
|
}
|
|
84
99
|
|
|
85
100
|
# Assign token limit with the correct parameter name
|
|
@@ -89,6 +104,21 @@ class OpenAIDriver(Driver):
|
|
|
89
104
|
if supports_temperature and "temperature" in opts:
|
|
90
105
|
kwargs["temperature"] = opts["temperature"]
|
|
91
106
|
|
|
107
|
+
# Native JSON mode support
|
|
108
|
+
if options.get("json_mode"):
|
|
109
|
+
json_schema = options.get("json_schema")
|
|
110
|
+
if json_schema:
|
|
111
|
+
kwargs["response_format"] = {
|
|
112
|
+
"type": "json_schema",
|
|
113
|
+
"json_schema": {
|
|
114
|
+
"name": "extraction",
|
|
115
|
+
"strict": True,
|
|
116
|
+
"schema": json_schema,
|
|
117
|
+
},
|
|
118
|
+
}
|
|
119
|
+
else:
|
|
120
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
121
|
+
|
|
92
122
|
resp = self.client.chat.completions.create(**kwargs)
|
|
93
123
|
|
|
94
124
|
# Extract usage info
|
|
@@ -97,11 +127,8 @@ class OpenAIDriver(Driver):
|
|
|
97
127
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
98
128
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
99
129
|
|
|
100
|
-
# Calculate cost
|
|
101
|
-
|
|
102
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
103
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
104
|
-
total_cost = prompt_cost + completion_cost
|
|
130
|
+
# Calculate cost via shared mixin
|
|
131
|
+
total_cost = self._calculate_cost("openai", model, prompt_tokens, completion_tokens)
|
|
105
132
|
|
|
106
133
|
# Standardized meta object
|
|
107
134
|
meta = {
|
|
@@ -1,14 +1,19 @@
|
|
|
1
1
|
"""OpenRouter driver implementation.
|
|
2
2
|
Requires the `requests` package. Uses OPENROUTER_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
import requests
|
|
7
9
|
|
|
10
|
+
from ..cost_mixin import CostMixin
|
|
8
11
|
from ..driver import Driver
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class OpenRouterDriver(Driver):
|
|
14
|
+
class OpenRouterDriver(CostMixin, Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
|
|
12
17
|
# Approximate pricing per 1K tokens based on OpenRouter's pricing
|
|
13
18
|
# https://openrouter.ai/docs#pricing
|
|
14
19
|
MODEL_PRICING = {
|
|
@@ -40,7 +45,7 @@ class OpenRouterDriver(Driver):
|
|
|
40
45
|
|
|
41
46
|
def __init__(self, api_key: str | None = None, model: str = "openai/gpt-3.5-turbo"):
|
|
42
47
|
"""Initialize OpenRouter driver.
|
|
43
|
-
|
|
48
|
+
|
|
44
49
|
Args:
|
|
45
50
|
api_key: OpenRouter API key. If not provided, will look for OPENROUTER_API_KEY env var
|
|
46
51
|
model: Model to use. Defaults to openai/gpt-3.5-turbo
|
|
@@ -48,10 +53,10 @@ class OpenRouterDriver(Driver):
|
|
|
48
53
|
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
49
54
|
if not self.api_key:
|
|
50
55
|
raise ValueError("OpenRouter API key not found. Set OPENROUTER_API_KEY env var.")
|
|
51
|
-
|
|
56
|
+
|
|
52
57
|
self.model = model
|
|
53
58
|
self.base_url = "https://openrouter.ai/api/v1"
|
|
54
|
-
|
|
59
|
+
|
|
55
60
|
# Required headers for OpenRouter
|
|
56
61
|
self.headers = {
|
|
57
62
|
"Authorization": f"Bearer {self.api_key}",
|
|
@@ -59,21 +64,21 @@ class OpenRouterDriver(Driver):
|
|
|
59
64
|
"Content-Type": "application/json",
|
|
60
65
|
}
|
|
61
66
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
67
|
+
supports_messages = True
|
|
68
|
+
|
|
69
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
70
|
+
messages = [{"role": "user", "content": prompt}]
|
|
71
|
+
return self._do_generate(messages, options)
|
|
72
|
+
|
|
73
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
74
|
+
return self._do_generate(messages, options)
|
|
75
|
+
|
|
76
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
72
77
|
if not self.api_key:
|
|
73
78
|
raise RuntimeError("OpenRouter API key not found")
|
|
74
79
|
|
|
75
80
|
model = options.get("model", self.model)
|
|
76
|
-
|
|
81
|
+
|
|
77
82
|
# Lookup model-specific config
|
|
78
83
|
model_info = self.MODEL_PRICING.get(model, {})
|
|
79
84
|
tokens_param = model_info.get("tokens_param", "max_tokens")
|
|
@@ -85,7 +90,7 @@ class OpenRouterDriver(Driver):
|
|
|
85
90
|
# Base request data
|
|
86
91
|
data = {
|
|
87
92
|
"model": model,
|
|
88
|
-
"messages":
|
|
93
|
+
"messages": messages,
|
|
89
94
|
}
|
|
90
95
|
|
|
91
96
|
# Add token limit with correct parameter name
|
|
@@ -95,6 +100,10 @@ class OpenRouterDriver(Driver):
|
|
|
95
100
|
if supports_temperature and "temperature" in opts:
|
|
96
101
|
data["temperature"] = opts["temperature"]
|
|
97
102
|
|
|
103
|
+
# Native JSON mode support
|
|
104
|
+
if options.get("json_mode"):
|
|
105
|
+
data["response_format"] = {"type": "json_object"}
|
|
106
|
+
|
|
98
107
|
try:
|
|
99
108
|
response = requests.post(
|
|
100
109
|
f"{self.base_url}/chat/completions",
|
|
@@ -110,11 +119,8 @@ class OpenRouterDriver(Driver):
|
|
|
110
119
|
completion_tokens = usage.get("completion_tokens", 0)
|
|
111
120
|
total_tokens = usage.get("total_tokens", 0)
|
|
112
121
|
|
|
113
|
-
# Calculate cost
|
|
114
|
-
|
|
115
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
116
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
117
|
-
total_cost = prompt_cost + completion_cost
|
|
122
|
+
# Calculate cost via shared mixin
|
|
123
|
+
total_cost = self._calculate_cost("openrouter", model, prompt_tokens, completion_tokens)
|
|
118
124
|
|
|
119
125
|
# Standardized meta object
|
|
120
126
|
meta = {
|
|
@@ -130,11 +136,11 @@ class OpenRouterDriver(Driver):
|
|
|
130
136
|
return {"text": text, "meta": meta}
|
|
131
137
|
|
|
132
138
|
except requests.exceptions.RequestException as e:
|
|
133
|
-
error_msg = f"OpenRouter API request failed: {
|
|
134
|
-
if hasattr(e.response,
|
|
139
|
+
error_msg = f"OpenRouter API request failed: {e!s}"
|
|
140
|
+
if hasattr(e.response, "json"):
|
|
135
141
|
try:
|
|
136
142
|
error_details = e.response.json()
|
|
137
143
|
error_msg = f"{error_msg} - {error_details.get('error', {}).get('message', '')}"
|
|
138
144
|
except Exception:
|
|
139
145
|
pass
|
|
140
|
-
raise RuntimeError(error_msg) from e
|
|
146
|
+
raise RuntimeError(error_msg) from e
|