prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +146 -23
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +607 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +169 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +631 -0
- prompture/core.py +876 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +164 -0
- prompture/driver.py +168 -5
- prompture/drivers/__init__.py +173 -69
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +251 -34
- prompture/drivers/google_driver.py +107 -38
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +157 -23
- prompture/drivers/openai_driver.py +178 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +18 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
- prompture-0.0.35.dist-info/RECORD +66 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
prompture/drivers/grok_driver.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
"""xAI Grok driver.
|
|
2
2
|
Requires the `requests` package. Uses GROK_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
6
8
|
import requests
|
|
7
9
|
|
|
10
|
+
from ..cost_mixin import CostMixin
|
|
8
11
|
from ..driver import Driver
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class GrokDriver(Driver):
|
|
14
|
+
class GrokDriver(CostMixin, Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
|
|
12
17
|
# Pricing per 1M tokens based on xAI's documentation
|
|
18
|
+
_PRICING_UNIT = 1_000_000
|
|
13
19
|
MODEL_PRICING = {
|
|
14
20
|
"grok-code-fast-1": {
|
|
15
21
|
"prompt": 0.20,
|
|
@@ -72,19 +78,16 @@ class GrokDriver(Driver):
|
|
|
72
78
|
self.model = model
|
|
73
79
|
self.api_base = "https://api.x.ai/v1"
|
|
74
80
|
|
|
75
|
-
|
|
76
|
-
"""Generate completion using Grok API.
|
|
81
|
+
supports_messages = True
|
|
77
82
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
RuntimeError: If API key is missing or request fails
|
|
87
|
-
"""
|
|
83
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
84
|
+
messages = [{"role": "user", "content": prompt}]
|
|
85
|
+
return self._do_generate(messages, options)
|
|
86
|
+
|
|
87
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
|
+
return self._do_generate(messages, options)
|
|
89
|
+
|
|
90
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
88
91
|
if not self.api_key:
|
|
89
92
|
raise RuntimeError("GROK_API_KEY environment variable is required")
|
|
90
93
|
|
|
@@ -101,7 +104,7 @@ class GrokDriver(Driver):
|
|
|
101
104
|
# Base request payload
|
|
102
105
|
payload = {
|
|
103
106
|
"model": model,
|
|
104
|
-
"messages":
|
|
107
|
+
"messages": messages,
|
|
105
108
|
}
|
|
106
109
|
|
|
107
110
|
# Add token limit with correct parameter name
|
|
@@ -111,33 +114,27 @@ class GrokDriver(Driver):
|
|
|
111
114
|
if supports_temperature and "temperature" in opts:
|
|
112
115
|
payload["temperature"] = opts["temperature"]
|
|
113
116
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
"
|
|
117
|
-
|
|
117
|
+
# Native JSON mode support
|
|
118
|
+
if options.get("json_mode"):
|
|
119
|
+
payload["response_format"] = {"type": "json_object"}
|
|
120
|
+
|
|
121
|
+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
118
122
|
|
|
119
123
|
try:
|
|
120
|
-
response = requests.post(
|
|
121
|
-
f"{self.api_base}/chat/completions",
|
|
122
|
-
headers=headers,
|
|
123
|
-
json=payload
|
|
124
|
-
)
|
|
124
|
+
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
|
|
125
125
|
response.raise_for_status()
|
|
126
126
|
resp = response.json()
|
|
127
127
|
except requests.exceptions.RequestException as e:
|
|
128
|
-
raise RuntimeError(f"Grok API request failed: {
|
|
128
|
+
raise RuntimeError(f"Grok API request failed: {e!s}") from e
|
|
129
129
|
|
|
130
130
|
# Extract usage info
|
|
131
131
|
usage = resp.get("usage", {})
|
|
132
132
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
133
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
|
133
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
134
134
|
total_tokens = usage.get("total_tokens", 0)
|
|
135
135
|
|
|
136
|
-
# Calculate cost
|
|
137
|
-
|
|
138
|
-
prompt_cost = (prompt_tokens / 1000000) * model_pricing["prompt"]
|
|
139
|
-
completion_cost = (completion_tokens / 1000000) * model_pricing["completion"]
|
|
140
|
-
total_cost = prompt_cost + completion_cost
|
|
136
|
+
# Calculate cost via shared mixin
|
|
137
|
+
total_cost = self._calculate_cost("grok", model, prompt_tokens, completion_tokens)
|
|
141
138
|
|
|
142
139
|
# Standardized meta object
|
|
143
140
|
meta = {
|
|
@@ -150,4 +147,4 @@ class GrokDriver(Driver):
|
|
|
150
147
|
}
|
|
151
148
|
|
|
152
149
|
text = resp["choices"][0]["message"]["content"]
|
|
153
|
-
return {"text": text, "meta": meta}
|
|
150
|
+
return {"text": text, "meta": meta}
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Groq driver for prompture.
|
|
2
2
|
Requires the `groq` package. Uses GROQ_API_KEY env var.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import os
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
try:
|
|
8
9
|
import groq
|
|
9
10
|
except Exception:
|
|
10
11
|
groq = None
|
|
11
12
|
|
|
13
|
+
from ..cost_mixin import CostMixin
|
|
12
14
|
from ..driver import Driver
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class GroqDriver(Driver):
|
|
17
|
+
class GroqDriver(CostMixin, Driver):
|
|
18
|
+
supports_json_mode = True
|
|
19
|
+
|
|
16
20
|
# Approximate pricing per 1K tokens (to be updated with official pricing)
|
|
17
21
|
# Each model entry defines token parameters and temperature support
|
|
18
22
|
MODEL_PRICING = {
|
|
@@ -32,7 +36,7 @@ class GroqDriver(Driver):
|
|
|
32
36
|
|
|
33
37
|
def __init__(self, api_key: str | None = None, model: str = "llama2-70b-4096"):
|
|
34
38
|
"""Initialize Groq driver.
|
|
35
|
-
|
|
39
|
+
|
|
36
40
|
Args:
|
|
37
41
|
api_key: Groq API key (defaults to GROQ_API_KEY env var)
|
|
38
42
|
model: Model to use (defaults to llama2-70b-4096)
|
|
@@ -44,20 +48,16 @@ class GroqDriver(Driver):
|
|
|
44
48
|
else:
|
|
45
49
|
self.client = None
|
|
46
50
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
Raises:
|
|
58
|
-
RuntimeError: If groq package is not installed
|
|
59
|
-
groq.error.*: Various Groq API errors
|
|
60
|
-
"""
|
|
51
|
+
supports_messages = True
|
|
52
|
+
|
|
53
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
54
|
+
messages = [{"role": "user", "content": prompt}]
|
|
55
|
+
return self._do_generate(messages, options)
|
|
56
|
+
|
|
57
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
58
|
+
return self._do_generate(messages, options)
|
|
59
|
+
|
|
60
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
61
61
|
if self.client is None:
|
|
62
62
|
raise RuntimeError("groq package is not installed")
|
|
63
63
|
|
|
@@ -74,7 +74,7 @@ class GroqDriver(Driver):
|
|
|
74
74
|
# Base kwargs for API call
|
|
75
75
|
kwargs = {
|
|
76
76
|
"model": model,
|
|
77
|
-
"messages":
|
|
77
|
+
"messages": messages,
|
|
78
78
|
}
|
|
79
79
|
|
|
80
80
|
# Set token limit with correct parameter name
|
|
@@ -84,23 +84,24 @@ class GroqDriver(Driver):
|
|
|
84
84
|
if supports_temperature and "temperature" in opts:
|
|
85
85
|
kwargs["temperature"] = opts["temperature"]
|
|
86
86
|
|
|
87
|
+
# Native JSON mode support
|
|
88
|
+
if options.get("json_mode"):
|
|
89
|
+
kwargs["response_format"] = {"type": "json_object"}
|
|
90
|
+
|
|
87
91
|
try:
|
|
88
92
|
resp = self.client.chat.completions.create(**kwargs)
|
|
89
|
-
except Exception
|
|
93
|
+
except Exception:
|
|
90
94
|
# Re-raise any Groq API errors
|
|
91
95
|
raise
|
|
92
96
|
|
|
93
97
|
# Extract usage statistics
|
|
94
98
|
usage = getattr(resp, "usage", None)
|
|
95
99
|
prompt_tokens = getattr(usage, "prompt_tokens", 0)
|
|
96
|
-
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
100
|
+
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
97
101
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
98
102
|
|
|
99
|
-
# Calculate
|
|
100
|
-
|
|
101
|
-
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
102
|
-
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
103
|
-
total_cost = prompt_cost + completion_cost
|
|
103
|
+
# Calculate cost via shared mixin
|
|
104
|
+
total_cost = self._calculate_cost("groq", model, prompt_tokens, completion_tokens)
|
|
104
105
|
|
|
105
106
|
# Standard metadata object
|
|
106
107
|
meta = {
|
|
@@ -114,4 +115,4 @@ class GroqDriver(Driver):
|
|
|
114
115
|
|
|
115
116
|
# Extract generated text
|
|
116
117
|
text = resp.choices[0].message.content
|
|
117
|
-
return {"text": text, "meta": meta}
|
|
118
|
+
return {"text": text, "meta": meta}
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
2
4
|
import requests
|
|
5
|
+
|
|
3
6
|
from ..driver import Driver
|
|
4
|
-
from typing import Any, Dict
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class HuggingFaceDriver(Driver):
|
|
8
10
|
# Hugging Face is usage-based (credits/subscription), but we set costs to 0 for now.
|
|
9
|
-
MODEL_PRICING = {
|
|
10
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
11
|
-
}
|
|
11
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
12
12
|
|
|
13
13
|
def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
|
|
14
14
|
self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
|
|
@@ -22,7 +22,7 @@ class HuggingFaceDriver(Driver):
|
|
|
22
22
|
|
|
23
23
|
self.headers = {"Authorization": f"Bearer {self.token}"}
|
|
24
24
|
|
|
25
|
-
def generate(self, prompt: str, options:
|
|
25
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
26
26
|
payload = {
|
|
27
27
|
"inputs": prompt,
|
|
28
28
|
"parameters": options, # HF allows temperature, max_new_tokens, etc. here
|
|
@@ -32,7 +32,7 @@ class HuggingFaceDriver(Driver):
|
|
|
32
32
|
r.raise_for_status()
|
|
33
33
|
response_data = r.json()
|
|
34
34
|
except Exception as e:
|
|
35
|
-
raise RuntimeError(f"HuggingFaceDriver request failed: {e}")
|
|
35
|
+
raise RuntimeError(f"HuggingFaceDriver request failed: {e}") from e
|
|
36
36
|
|
|
37
37
|
# Different HF models return slightly different response formats
|
|
38
38
|
# Text-generation models usually return [{"generated_text": "..."}]
|
|
@@ -1,26 +1,26 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import json
|
|
3
|
-
import requests
|
|
4
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
5
8
|
from ..driver import Driver
|
|
6
|
-
from typing import Any, Dict
|
|
7
9
|
|
|
8
10
|
logger = logging.getLogger(__name__)
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class LMStudioDriver(Driver):
|
|
14
|
+
supports_json_mode = True
|
|
15
|
+
|
|
12
16
|
# LM Studio is local – costs are always zero.
|
|
13
|
-
MODEL_PRICING = {
|
|
14
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
15
|
-
}
|
|
17
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
16
18
|
|
|
17
19
|
def __init__(self, endpoint: str | None = None, model: str = "deepseek/deepseek-r1-0528-qwen3-8b"):
|
|
18
20
|
# Allow override via env var
|
|
19
|
-
self.endpoint = endpoint or os.getenv(
|
|
20
|
-
"LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
|
|
21
|
-
)
|
|
21
|
+
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
22
22
|
self.model = model
|
|
23
|
-
self.options:
|
|
23
|
+
self.options: dict[str, Any] = {}
|
|
24
24
|
|
|
25
25
|
# Validate connection to LM Studio server
|
|
26
26
|
self._validate_connection()
|
|
@@ -38,17 +38,30 @@ class LMStudioDriver(Driver):
|
|
|
38
38
|
except requests.exceptions.RequestException as e:
|
|
39
39
|
logger.warning(f"Could not validate connection to LM Studio server: {e}")
|
|
40
40
|
|
|
41
|
-
|
|
41
|
+
supports_messages = True
|
|
42
|
+
|
|
43
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
44
|
+
messages = [{"role": "user", "content": prompt}]
|
|
45
|
+
return self._do_generate(messages, options)
|
|
46
|
+
|
|
47
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
48
|
+
return self._do_generate(messages, options)
|
|
49
|
+
|
|
50
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
42
51
|
merged_options = self.options.copy()
|
|
43
52
|
if options:
|
|
44
53
|
merged_options.update(options)
|
|
45
54
|
|
|
46
55
|
payload = {
|
|
47
56
|
"model": merged_options.get("model", self.model),
|
|
48
|
-
"messages":
|
|
57
|
+
"messages": messages,
|
|
49
58
|
"temperature": merged_options.get("temperature", 0.7),
|
|
50
59
|
}
|
|
51
60
|
|
|
61
|
+
# Native JSON mode support
|
|
62
|
+
if merged_options.get("json_mode"):
|
|
63
|
+
payload["response_format"] = {"type": "json_object"}
|
|
64
|
+
|
|
52
65
|
try:
|
|
53
66
|
logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
|
|
54
67
|
logger.debug(f"Request payload: {payload}")
|
|
@@ -70,7 +83,7 @@ class LMStudioDriver(Driver):
|
|
|
70
83
|
raise
|
|
71
84
|
except Exception as e:
|
|
72
85
|
logger.error(f"Unexpected error in LM Studio request: {e}")
|
|
73
|
-
raise RuntimeError(f"LM Studio request failed: {e}")
|
|
86
|
+
raise RuntimeError(f"LM Studio request failed: {e}") from e
|
|
74
87
|
|
|
75
88
|
# Extract text
|
|
76
89
|
text = response_data["choices"][0]["message"]["content"]
|
|
@@ -1,27 +1,27 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
2
4
|
import requests
|
|
5
|
+
|
|
3
6
|
from ..driver import Driver
|
|
4
|
-
from typing import Any, Dict
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class LocalHTTPDriver(Driver):
|
|
8
10
|
# Default: no cost; extend if your local service has pricing logic
|
|
9
|
-
MODEL_PRICING = {
|
|
10
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
11
|
-
}
|
|
11
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
12
12
|
|
|
13
13
|
def __init__(self, endpoint: str | None = None, model: str = "local-model"):
|
|
14
14
|
self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
|
|
15
15
|
self.model = model
|
|
16
16
|
|
|
17
|
-
def generate(self, prompt: str, options:
|
|
17
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
18
18
|
payload = {"prompt": prompt, "options": options}
|
|
19
19
|
try:
|
|
20
20
|
r = requests.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
|
|
21
21
|
r.raise_for_status()
|
|
22
22
|
response_data = r.json()
|
|
23
23
|
except Exception as e:
|
|
24
|
-
raise RuntimeError(f"LocalHTTPDriver request failed: {e}")
|
|
24
|
+
raise RuntimeError(f"LocalHTTPDriver request failed: {e}") from e
|
|
25
25
|
|
|
26
26
|
# If the local API already provides {"text": "...", "meta": {...}}, just return it
|
|
27
27
|
if "text" in response_data and "meta" in response_data:
|
|
@@ -1,38 +1,40 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import json
|
|
3
|
-
import requests
|
|
4
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
|
|
5
9
|
from ..driver import Driver
|
|
6
|
-
from typing import Any, Dict
|
|
7
10
|
|
|
8
11
|
logger = logging.getLogger(__name__)
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
class OllamaDriver(Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
supports_streaming = True
|
|
17
|
+
|
|
12
18
|
# Ollama is free – costs are always zero.
|
|
13
|
-
MODEL_PRICING = {
|
|
14
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
15
|
-
}
|
|
19
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
16
20
|
|
|
17
21
|
def __init__(self, endpoint: str | None = None, model: str = "llama3"):
|
|
18
22
|
# Allow override via env var
|
|
19
|
-
self.endpoint = endpoint or os.getenv(
|
|
20
|
-
"OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
|
|
21
|
-
)
|
|
23
|
+
self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
|
|
22
24
|
self.model = model
|
|
23
25
|
self.options = {} # Initialize empty options dict
|
|
24
|
-
|
|
26
|
+
|
|
25
27
|
# Validate connection to Ollama server
|
|
26
28
|
self._validate_connection()
|
|
27
|
-
|
|
29
|
+
|
|
28
30
|
def _validate_connection(self):
|
|
29
31
|
"""Validate connection to the Ollama server."""
|
|
30
32
|
try:
|
|
31
33
|
# Send a simple HEAD request to check if server is accessible
|
|
32
34
|
# Use the base API endpoint without the specific path
|
|
33
|
-
base_url = self.endpoint.split(
|
|
35
|
+
base_url = self.endpoint.split("/api/")[0]
|
|
34
36
|
health_url = f"{base_url}/api/version"
|
|
35
|
-
|
|
37
|
+
|
|
36
38
|
logger.debug(f"Validating connection to Ollama server at: {health_url}")
|
|
37
39
|
response = requests.head(health_url, timeout=5)
|
|
38
40
|
response.raise_for_status()
|
|
@@ -42,7 +44,9 @@ class OllamaDriver(Driver):
|
|
|
42
44
|
# We don't raise an error here to allow for delayed server startup
|
|
43
45
|
# The actual error will be raised when generate() is called
|
|
44
46
|
|
|
45
|
-
|
|
47
|
+
supports_messages = True
|
|
48
|
+
|
|
49
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
46
50
|
# Merge instance options with call-specific options
|
|
47
51
|
merged_options = self.options.copy()
|
|
48
52
|
if options:
|
|
@@ -54,6 +58,10 @@ class OllamaDriver(Driver):
|
|
|
54
58
|
"stream": False,
|
|
55
59
|
}
|
|
56
60
|
|
|
61
|
+
# Native JSON mode support
|
|
62
|
+
if merged_options.get("json_mode"):
|
|
63
|
+
payload["format"] = "json"
|
|
64
|
+
|
|
57
65
|
# Add any Ollama-specific options from merged_options
|
|
58
66
|
if "temperature" in merged_options:
|
|
59
67
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -65,21 +73,21 @@ class OllamaDriver(Driver):
|
|
|
65
73
|
try:
|
|
66
74
|
logger.debug(f"Sending request to Ollama endpoint: {self.endpoint}")
|
|
67
75
|
logger.debug(f"Request payload: {payload}")
|
|
68
|
-
|
|
76
|
+
|
|
69
77
|
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
70
78
|
logger.debug(f"Response status code: {r.status_code}")
|
|
71
|
-
|
|
79
|
+
|
|
72
80
|
r.raise_for_status()
|
|
73
|
-
|
|
81
|
+
|
|
74
82
|
response_text = r.text
|
|
75
83
|
logger.debug(f"Raw response text: {response_text}")
|
|
76
|
-
|
|
84
|
+
|
|
77
85
|
response_data = r.json()
|
|
78
86
|
logger.debug(f"Parsed response data: {response_data}")
|
|
79
|
-
|
|
87
|
+
|
|
80
88
|
if not isinstance(response_data, dict):
|
|
81
89
|
raise ValueError(f"Expected dict response, got {type(response_data)}")
|
|
82
|
-
|
|
90
|
+
|
|
83
91
|
except requests.exceptions.ConnectionError as e:
|
|
84
92
|
logger.error(f"Connection error to Ollama endpoint: {e}")
|
|
85
93
|
# Preserve original exception
|
|
@@ -91,11 +99,11 @@ class OllamaDriver(Driver):
|
|
|
91
99
|
except json.JSONDecodeError as e:
|
|
92
100
|
logger.error(f"Failed to decode JSON response: {e}")
|
|
93
101
|
# Re-raise JSONDecodeError with more context
|
|
94
|
-
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos)
|
|
102
|
+
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
|
|
95
103
|
except Exception as e:
|
|
96
104
|
logger.error(f"Unexpected error in Ollama request: {e}")
|
|
97
105
|
# Only wrap unknown exceptions in RuntimeError
|
|
98
|
-
raise RuntimeError(f"Ollama request failed: {e}")
|
|
106
|
+
raise RuntimeError(f"Ollama request failed: {e}") from e
|
|
99
107
|
|
|
100
108
|
# Extract token counts
|
|
101
109
|
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
@@ -113,4 +121,130 @@ class OllamaDriver(Driver):
|
|
|
113
121
|
}
|
|
114
122
|
|
|
115
123
|
# Ollama returns text in "response"
|
|
116
|
-
return {"text": response_data.get("response", ""), "meta": meta}
|
|
124
|
+
return {"text": response_data.get("response", ""), "meta": meta}
|
|
125
|
+
|
|
126
|
+
# ------------------------------------------------------------------
|
|
127
|
+
# Streaming
|
|
128
|
+
# ------------------------------------------------------------------
|
|
129
|
+
|
|
130
|
+
def generate_messages_stream(
|
|
131
|
+
self,
|
|
132
|
+
messages: list[dict[str, Any]],
|
|
133
|
+
options: dict[str, Any],
|
|
134
|
+
) -> Iterator[dict[str, Any]]:
|
|
135
|
+
"""Yield response chunks via Ollama streaming API."""
|
|
136
|
+
merged_options = self.options.copy()
|
|
137
|
+
if options:
|
|
138
|
+
merged_options.update(options)
|
|
139
|
+
|
|
140
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
141
|
+
|
|
142
|
+
payload: dict[str, Any] = {
|
|
143
|
+
"model": merged_options.get("model", self.model),
|
|
144
|
+
"messages": messages,
|
|
145
|
+
"stream": True,
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
if merged_options.get("json_mode"):
|
|
149
|
+
payload["format"] = "json"
|
|
150
|
+
if "temperature" in merged_options:
|
|
151
|
+
payload["temperature"] = merged_options["temperature"]
|
|
152
|
+
if "top_p" in merged_options:
|
|
153
|
+
payload["top_p"] = merged_options["top_p"]
|
|
154
|
+
if "top_k" in merged_options:
|
|
155
|
+
payload["top_k"] = merged_options["top_k"]
|
|
156
|
+
|
|
157
|
+
full_text = ""
|
|
158
|
+
prompt_tokens = 0
|
|
159
|
+
completion_tokens = 0
|
|
160
|
+
|
|
161
|
+
r = requests.post(chat_endpoint, json=payload, timeout=120, stream=True)
|
|
162
|
+
r.raise_for_status()
|
|
163
|
+
|
|
164
|
+
for line in r.iter_lines():
|
|
165
|
+
if not line:
|
|
166
|
+
continue
|
|
167
|
+
chunk = json.loads(line)
|
|
168
|
+
if chunk.get("done"):
|
|
169
|
+
prompt_tokens = chunk.get("prompt_eval_count", 0)
|
|
170
|
+
completion_tokens = chunk.get("eval_count", 0)
|
|
171
|
+
else:
|
|
172
|
+
content = chunk.get("message", {}).get("content", "")
|
|
173
|
+
if content:
|
|
174
|
+
full_text += content
|
|
175
|
+
yield {"type": "delta", "text": content}
|
|
176
|
+
|
|
177
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
178
|
+
yield {
|
|
179
|
+
"type": "done",
|
|
180
|
+
"text": full_text,
|
|
181
|
+
"meta": {
|
|
182
|
+
"prompt_tokens": prompt_tokens,
|
|
183
|
+
"completion_tokens": completion_tokens,
|
|
184
|
+
"total_tokens": total_tokens,
|
|
185
|
+
"cost": 0.0,
|
|
186
|
+
"raw_response": {},
|
|
187
|
+
"model_name": merged_options.get("model", self.model),
|
|
188
|
+
},
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
192
|
+
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
193
|
+
merged_options = self.options.copy()
|
|
194
|
+
if options:
|
|
195
|
+
merged_options.update(options)
|
|
196
|
+
|
|
197
|
+
# Derive the chat endpoint from the generate endpoint
|
|
198
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
199
|
+
|
|
200
|
+
payload: dict[str, Any] = {
|
|
201
|
+
"model": merged_options.get("model", self.model),
|
|
202
|
+
"messages": messages,
|
|
203
|
+
"stream": False,
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
# Native JSON mode support
|
|
207
|
+
if merged_options.get("json_mode"):
|
|
208
|
+
payload["format"] = "json"
|
|
209
|
+
|
|
210
|
+
if "temperature" in merged_options:
|
|
211
|
+
payload["temperature"] = merged_options["temperature"]
|
|
212
|
+
if "top_p" in merged_options:
|
|
213
|
+
payload["top_p"] = merged_options["top_p"]
|
|
214
|
+
if "top_k" in merged_options:
|
|
215
|
+
payload["top_k"] = merged_options["top_k"]
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
logger.debug(f"Sending chat request to Ollama endpoint: {chat_endpoint}")
|
|
219
|
+
r = requests.post(chat_endpoint, json=payload, timeout=120)
|
|
220
|
+
r.raise_for_status()
|
|
221
|
+
response_data = r.json()
|
|
222
|
+
|
|
223
|
+
if not isinstance(response_data, dict):
|
|
224
|
+
raise ValueError(f"Expected dict response, got {type(response_data)}")
|
|
225
|
+
except requests.exceptions.ConnectionError:
|
|
226
|
+
raise
|
|
227
|
+
except requests.exceptions.HTTPError:
|
|
228
|
+
raise
|
|
229
|
+
except json.JSONDecodeError as e:
|
|
230
|
+
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
|
|
231
|
+
except Exception as e:
|
|
232
|
+
raise RuntimeError(f"Ollama chat request failed: {e}") from e
|
|
233
|
+
|
|
234
|
+
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
235
|
+
completion_tokens = response_data.get("eval_count", 0)
|
|
236
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
237
|
+
|
|
238
|
+
meta = {
|
|
239
|
+
"prompt_tokens": prompt_tokens,
|
|
240
|
+
"completion_tokens": completion_tokens,
|
|
241
|
+
"total_tokens": total_tokens,
|
|
242
|
+
"cost": 0.0,
|
|
243
|
+
"raw_response": response_data,
|
|
244
|
+
"model_name": merged_options.get("model", self.model),
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
# Chat endpoint returns response in message.content
|
|
248
|
+
message = response_data.get("message", {})
|
|
249
|
+
text = message.get("content", "")
|
|
250
|
+
return {"text": text, "meta": meta}
|