prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
2
4
|
import requests
|
|
5
|
+
|
|
3
6
|
from ..driver import Driver
|
|
4
|
-
from typing import Any, Dict
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class HuggingFaceDriver(Driver):
|
|
8
10
|
# Hugging Face is usage-based (credits/subscription), but we set costs to 0 for now.
|
|
9
|
-
MODEL_PRICING = {
|
|
10
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
11
|
-
}
|
|
11
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
12
12
|
|
|
13
13
|
def __init__(self, endpoint: str | None = None, token: str | None = None, model: str = "bert-base-uncased"):
|
|
14
14
|
self.endpoint = endpoint or os.getenv("HF_ENDPOINT")
|
|
@@ -22,7 +22,7 @@ class HuggingFaceDriver(Driver):
|
|
|
22
22
|
|
|
23
23
|
self.headers = {"Authorization": f"Bearer {self.token}"}
|
|
24
24
|
|
|
25
|
-
def generate(self, prompt: str, options:
|
|
25
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
26
26
|
payload = {
|
|
27
27
|
"inputs": prompt,
|
|
28
28
|
"parameters": options, # HF allows temperature, max_new_tokens, etc. here
|
|
@@ -32,7 +32,7 @@ class HuggingFaceDriver(Driver):
|
|
|
32
32
|
r.raise_for_status()
|
|
33
33
|
response_data = r.json()
|
|
34
34
|
except Exception as e:
|
|
35
|
-
raise RuntimeError(f"HuggingFaceDriver request failed: {e}")
|
|
35
|
+
raise RuntimeError(f"HuggingFaceDriver request failed: {e}") from e
|
|
36
36
|
|
|
37
37
|
# Different HF models return slightly different response formats
|
|
38
38
|
# Text-generation models usually return [{"generated_text": "..."}]
|
|
@@ -1,59 +1,109 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import json
|
|
3
|
-
import requests
|
|
4
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
|
|
5
8
|
from ..driver import Driver
|
|
6
|
-
from typing import Any, Dict
|
|
7
9
|
|
|
8
10
|
logger = logging.getLogger(__name__)
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class LMStudioDriver(Driver):
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
}
|
|
14
|
+
supports_json_mode = True
|
|
15
|
+
supports_json_schema = True
|
|
16
|
+
supports_vision = True
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
# LM Studio is local – costs are always zero.
|
|
19
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
endpoint: str | None = None,
|
|
24
|
+
model: str = "deepseek/deepseek-r1-0528-qwen3-8b",
|
|
25
|
+
api_key: str | None = None,
|
|
26
|
+
):
|
|
18
27
|
# Allow override via env var
|
|
19
|
-
self.endpoint = endpoint or os.getenv(
|
|
20
|
-
"LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions"
|
|
21
|
-
)
|
|
28
|
+
self.endpoint = endpoint or os.getenv("LMSTUDIO_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions")
|
|
22
29
|
self.model = model
|
|
23
|
-
self.options:
|
|
30
|
+
self.options: dict[str, Any] = {}
|
|
31
|
+
|
|
32
|
+
# Derive base_url once for reuse across management endpoints
|
|
33
|
+
self.base_url = self.endpoint.split("/v1/")[0]
|
|
34
|
+
|
|
35
|
+
# API key for LM Studio 0.4.0+ authentication
|
|
36
|
+
self.api_key = api_key or os.getenv("LMSTUDIO_API_KEY")
|
|
37
|
+
self._headers = self._build_headers()
|
|
24
38
|
|
|
25
39
|
# Validate connection to LM Studio server
|
|
26
40
|
self._validate_connection()
|
|
27
41
|
|
|
42
|
+
def _build_headers(self) -> dict[str, str]:
|
|
43
|
+
"""Build request headers, including auth if an API key is configured."""
|
|
44
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
45
|
+
if self.api_key:
|
|
46
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
47
|
+
return headers
|
|
48
|
+
|
|
28
49
|
def _validate_connection(self):
|
|
29
50
|
"""Validate connection to the LM Studio server."""
|
|
30
51
|
try:
|
|
31
|
-
|
|
32
|
-
health_url = f"{base_url}/v1/models"
|
|
52
|
+
health_url = f"{self.base_url}/v1/models"
|
|
33
53
|
|
|
34
54
|
logger.debug(f"Validating connection to LM Studio server at: {health_url}")
|
|
35
|
-
response = requests.get(health_url, timeout=5)
|
|
55
|
+
response = requests.get(health_url, headers=self._headers, timeout=5)
|
|
36
56
|
response.raise_for_status()
|
|
37
57
|
logger.debug("Connection to LM Studio server validated successfully")
|
|
38
58
|
except requests.exceptions.RequestException as e:
|
|
39
59
|
logger.warning(f"Could not validate connection to LM Studio server: {e}")
|
|
40
60
|
|
|
41
|
-
|
|
61
|
+
supports_messages = True
|
|
62
|
+
|
|
63
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
64
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
65
|
+
|
|
66
|
+
return _prepare_openai_vision_messages(messages)
|
|
67
|
+
|
|
68
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
69
|
+
messages = [{"role": "user", "content": prompt}]
|
|
70
|
+
return self._do_generate(messages, options)
|
|
71
|
+
|
|
72
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
73
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
74
|
+
|
|
75
|
+
def _do_generate(self, messages: list[dict[str, str]], options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
42
76
|
merged_options = self.options.copy()
|
|
43
77
|
if options:
|
|
44
78
|
merged_options.update(options)
|
|
45
79
|
|
|
46
80
|
payload = {
|
|
47
81
|
"model": merged_options.get("model", self.model),
|
|
48
|
-
"messages":
|
|
82
|
+
"messages": messages,
|
|
49
83
|
"temperature": merged_options.get("temperature", 0.7),
|
|
50
84
|
}
|
|
51
85
|
|
|
86
|
+
# Native JSON mode support (LM Studio requires json_schema, not json_object)
|
|
87
|
+
if merged_options.get("json_mode"):
|
|
88
|
+
json_schema = merged_options.get("json_schema")
|
|
89
|
+
if json_schema:
|
|
90
|
+
payload["response_format"] = {
|
|
91
|
+
"type": "json_schema",
|
|
92
|
+
"json_schema": {
|
|
93
|
+
"name": "extraction",
|
|
94
|
+
"schema": json_schema,
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
else:
|
|
98
|
+
# No schema provided — omit response_format entirely;
|
|
99
|
+
# LM Studio rejects "json_object" type.
|
|
100
|
+
pass
|
|
101
|
+
|
|
52
102
|
try:
|
|
53
103
|
logger.debug(f"Sending request to LM Studio endpoint: {self.endpoint}")
|
|
54
104
|
logger.debug(f"Request payload: {payload}")
|
|
55
105
|
|
|
56
|
-
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
106
|
+
r = requests.post(self.endpoint, json=payload, headers=self._headers, timeout=120)
|
|
57
107
|
r.raise_for_status()
|
|
58
108
|
|
|
59
109
|
response_data = r.json()
|
|
@@ -70,7 +120,7 @@ class LMStudioDriver(Driver):
|
|
|
70
120
|
raise
|
|
71
121
|
except Exception as e:
|
|
72
122
|
logger.error(f"Unexpected error in LM Studio request: {e}")
|
|
73
|
-
raise RuntimeError(f"LM Studio request failed: {e}")
|
|
123
|
+
raise RuntimeError(f"LM Studio request failed: {e}") from e
|
|
74
124
|
|
|
75
125
|
# Extract text
|
|
76
126
|
text = response_data["choices"][0]["message"]["content"]
|
|
@@ -91,3 +141,31 @@ class LMStudioDriver(Driver):
|
|
|
91
141
|
}
|
|
92
142
|
|
|
93
143
|
return {"text": text, "meta": meta}
|
|
144
|
+
|
|
145
|
+
# -- Model management (LM Studio 0.4.0+) ----------------------------------
|
|
146
|
+
|
|
147
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
148
|
+
"""List currently loaded models via GET /v1/models (OpenAI-compatible)."""
|
|
149
|
+
url = f"{self.base_url}/v1/models"
|
|
150
|
+
r = requests.get(url, headers=self._headers, timeout=10)
|
|
151
|
+
r.raise_for_status()
|
|
152
|
+
data = r.json()
|
|
153
|
+
return data.get("data", [])
|
|
154
|
+
|
|
155
|
+
def load_model(self, model: str, context_length: int | None = None) -> dict[str, Any]:
|
|
156
|
+
"""Load a model into LM Studio via POST /api/v1/models/load."""
|
|
157
|
+
url = f"{self.base_url}/api/v1/models/load"
|
|
158
|
+
payload: dict[str, Any] = {"model": model}
|
|
159
|
+
if context_length is not None:
|
|
160
|
+
payload["context_length"] = context_length
|
|
161
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=120)
|
|
162
|
+
r.raise_for_status()
|
|
163
|
+
return r.json()
|
|
164
|
+
|
|
165
|
+
def unload_model(self, model: str) -> dict[str, Any]:
|
|
166
|
+
"""Unload a model from LM Studio via POST /api/v1/models/unload."""
|
|
167
|
+
url = f"{self.base_url}/api/v1/models/unload"
|
|
168
|
+
payload = {"instance_id": model}
|
|
169
|
+
r = requests.post(url, json=payload, headers=self._headers, timeout=30)
|
|
170
|
+
r.raise_for_status()
|
|
171
|
+
return r.json()
|
|
@@ -1,27 +1,27 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
2
4
|
import requests
|
|
5
|
+
|
|
3
6
|
from ..driver import Driver
|
|
4
|
-
from typing import Any, Dict
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class LocalHTTPDriver(Driver):
|
|
8
10
|
# Default: no cost; extend if your local service has pricing logic
|
|
9
|
-
MODEL_PRICING = {
|
|
10
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
11
|
-
}
|
|
11
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
12
12
|
|
|
13
13
|
def __init__(self, endpoint: str | None = None, model: str = "local-model"):
|
|
14
14
|
self.endpoint = endpoint or os.getenv("LOCAL_HTTP_ENDPOINT", "http://localhost:8000/generate")
|
|
15
15
|
self.model = model
|
|
16
16
|
|
|
17
|
-
def generate(self, prompt: str, options:
|
|
17
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
18
18
|
payload = {"prompt": prompt, "options": options}
|
|
19
19
|
try:
|
|
20
20
|
r = requests.post(self.endpoint, json=payload, timeout=options.get("timeout", 30))
|
|
21
21
|
r.raise_for_status()
|
|
22
22
|
response_data = r.json()
|
|
23
23
|
except Exception as e:
|
|
24
|
-
raise RuntimeError(f"LocalHTTPDriver request failed: {e}")
|
|
24
|
+
raise RuntimeError(f"LocalHTTPDriver request failed: {e}") from e
|
|
25
25
|
|
|
26
26
|
# If the local API already provides {"text": "...", "meta": {...}}, just return it
|
|
27
27
|
if "text" in response_data and "meta" in response_data:
|
|
@@ -1,38 +1,42 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import json
|
|
3
|
-
import requests
|
|
4
2
|
import logging
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
|
|
5
9
|
from ..driver import Driver
|
|
6
|
-
from typing import Any, Dict
|
|
7
10
|
|
|
8
11
|
logger = logging.getLogger(__name__)
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
class OllamaDriver(Driver):
|
|
15
|
+
supports_json_mode = True
|
|
16
|
+
supports_json_schema = True
|
|
17
|
+
supports_streaming = True
|
|
18
|
+
supports_vision = True
|
|
19
|
+
|
|
12
20
|
# Ollama is free – costs are always zero.
|
|
13
|
-
MODEL_PRICING = {
|
|
14
|
-
"default": {"prompt": 0.0, "completion": 0.0}
|
|
15
|
-
}
|
|
21
|
+
MODEL_PRICING = {"default": {"prompt": 0.0, "completion": 0.0}}
|
|
16
22
|
|
|
17
23
|
def __init__(self, endpoint: str | None = None, model: str = "llama3"):
|
|
18
24
|
# Allow override via env var
|
|
19
|
-
self.endpoint = endpoint or os.getenv(
|
|
20
|
-
"OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
|
|
21
|
-
)
|
|
25
|
+
self.endpoint = endpoint or os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434/api/generate")
|
|
22
26
|
self.model = model
|
|
23
27
|
self.options = {} # Initialize empty options dict
|
|
24
|
-
|
|
28
|
+
|
|
25
29
|
# Validate connection to Ollama server
|
|
26
30
|
self._validate_connection()
|
|
27
|
-
|
|
31
|
+
|
|
28
32
|
def _validate_connection(self):
|
|
29
33
|
"""Validate connection to the Ollama server."""
|
|
30
34
|
try:
|
|
31
35
|
# Send a simple HEAD request to check if server is accessible
|
|
32
36
|
# Use the base API endpoint without the specific path
|
|
33
|
-
base_url = self.endpoint.split(
|
|
37
|
+
base_url = self.endpoint.split("/api/")[0]
|
|
34
38
|
health_url = f"{base_url}/api/version"
|
|
35
|
-
|
|
39
|
+
|
|
36
40
|
logger.debug(f"Validating connection to Ollama server at: {health_url}")
|
|
37
41
|
response = requests.head(health_url, timeout=5)
|
|
38
42
|
response.raise_for_status()
|
|
@@ -42,7 +46,14 @@ class OllamaDriver(Driver):
|
|
|
42
46
|
# We don't raise an error here to allow for delayed server startup
|
|
43
47
|
# The actual error will be raised when generate() is called
|
|
44
48
|
|
|
45
|
-
|
|
49
|
+
supports_messages = True
|
|
50
|
+
|
|
51
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
52
|
+
from .vision_helpers import _prepare_ollama_vision_messages
|
|
53
|
+
|
|
54
|
+
return _prepare_ollama_vision_messages(messages)
|
|
55
|
+
|
|
56
|
+
def generate(self, prompt: str, options: Optional[dict[str, Any]] = None) -> dict[str, Any]:
|
|
46
57
|
# Merge instance options with call-specific options
|
|
47
58
|
merged_options = self.options.copy()
|
|
48
59
|
if options:
|
|
@@ -54,6 +65,11 @@ class OllamaDriver(Driver):
|
|
|
54
65
|
"stream": False,
|
|
55
66
|
}
|
|
56
67
|
|
|
68
|
+
# Native JSON mode / structured output support
|
|
69
|
+
if merged_options.get("json_mode"):
|
|
70
|
+
json_schema = merged_options.get("json_schema")
|
|
71
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
72
|
+
|
|
57
73
|
# Add any Ollama-specific options from merged_options
|
|
58
74
|
if "temperature" in merged_options:
|
|
59
75
|
payload["temperature"] = merged_options["temperature"]
|
|
@@ -65,21 +81,21 @@ class OllamaDriver(Driver):
|
|
|
65
81
|
try:
|
|
66
82
|
logger.debug(f"Sending request to Ollama endpoint: {self.endpoint}")
|
|
67
83
|
logger.debug(f"Request payload: {payload}")
|
|
68
|
-
|
|
84
|
+
|
|
69
85
|
r = requests.post(self.endpoint, json=payload, timeout=120)
|
|
70
86
|
logger.debug(f"Response status code: {r.status_code}")
|
|
71
|
-
|
|
87
|
+
|
|
72
88
|
r.raise_for_status()
|
|
73
|
-
|
|
89
|
+
|
|
74
90
|
response_text = r.text
|
|
75
91
|
logger.debug(f"Raw response text: {response_text}")
|
|
76
|
-
|
|
92
|
+
|
|
77
93
|
response_data = r.json()
|
|
78
94
|
logger.debug(f"Parsed response data: {response_data}")
|
|
79
|
-
|
|
95
|
+
|
|
80
96
|
if not isinstance(response_data, dict):
|
|
81
97
|
raise ValueError(f"Expected dict response, got {type(response_data)}")
|
|
82
|
-
|
|
98
|
+
|
|
83
99
|
except requests.exceptions.ConnectionError as e:
|
|
84
100
|
logger.error(f"Connection error to Ollama endpoint: {e}")
|
|
85
101
|
# Preserve original exception
|
|
@@ -91,11 +107,11 @@ class OllamaDriver(Driver):
|
|
|
91
107
|
except json.JSONDecodeError as e:
|
|
92
108
|
logger.error(f"Failed to decode JSON response: {e}")
|
|
93
109
|
# Re-raise JSONDecodeError with more context
|
|
94
|
-
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos)
|
|
110
|
+
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
|
|
95
111
|
except Exception as e:
|
|
96
112
|
logger.error(f"Unexpected error in Ollama request: {e}")
|
|
97
113
|
# Only wrap unknown exceptions in RuntimeError
|
|
98
|
-
raise RuntimeError(f"Ollama request failed: {e}")
|
|
114
|
+
raise RuntimeError(f"Ollama request failed: {e}") from e
|
|
99
115
|
|
|
100
116
|
# Extract token counts
|
|
101
117
|
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
@@ -113,4 +129,133 @@ class OllamaDriver(Driver):
|
|
|
113
129
|
}
|
|
114
130
|
|
|
115
131
|
# Ollama returns text in "response"
|
|
116
|
-
return {"text": response_data.get("response", ""), "meta": meta}
|
|
132
|
+
return {"text": response_data.get("response", ""), "meta": meta}
|
|
133
|
+
|
|
134
|
+
# ------------------------------------------------------------------
|
|
135
|
+
# Streaming
|
|
136
|
+
# ------------------------------------------------------------------
|
|
137
|
+
|
|
138
|
+
def generate_messages_stream(
|
|
139
|
+
self,
|
|
140
|
+
messages: list[dict[str, Any]],
|
|
141
|
+
options: dict[str, Any],
|
|
142
|
+
) -> Iterator[dict[str, Any]]:
|
|
143
|
+
"""Yield response chunks via Ollama streaming API."""
|
|
144
|
+
merged_options = self.options.copy()
|
|
145
|
+
if options:
|
|
146
|
+
merged_options.update(options)
|
|
147
|
+
|
|
148
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
149
|
+
|
|
150
|
+
payload: dict[str, Any] = {
|
|
151
|
+
"model": merged_options.get("model", self.model),
|
|
152
|
+
"messages": messages,
|
|
153
|
+
"stream": True,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
if merged_options.get("json_mode"):
|
|
157
|
+
json_schema = merged_options.get("json_schema")
|
|
158
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
159
|
+
if "temperature" in merged_options:
|
|
160
|
+
payload["temperature"] = merged_options["temperature"]
|
|
161
|
+
if "top_p" in merged_options:
|
|
162
|
+
payload["top_p"] = merged_options["top_p"]
|
|
163
|
+
if "top_k" in merged_options:
|
|
164
|
+
payload["top_k"] = merged_options["top_k"]
|
|
165
|
+
|
|
166
|
+
full_text = ""
|
|
167
|
+
prompt_tokens = 0
|
|
168
|
+
completion_tokens = 0
|
|
169
|
+
|
|
170
|
+
r = requests.post(chat_endpoint, json=payload, timeout=120, stream=True)
|
|
171
|
+
r.raise_for_status()
|
|
172
|
+
|
|
173
|
+
for line in r.iter_lines():
|
|
174
|
+
if not line:
|
|
175
|
+
continue
|
|
176
|
+
chunk = json.loads(line)
|
|
177
|
+
if chunk.get("done"):
|
|
178
|
+
prompt_tokens = chunk.get("prompt_eval_count", 0)
|
|
179
|
+
completion_tokens = chunk.get("eval_count", 0)
|
|
180
|
+
else:
|
|
181
|
+
content = chunk.get("message", {}).get("content", "")
|
|
182
|
+
if content:
|
|
183
|
+
full_text += content
|
|
184
|
+
yield {"type": "delta", "text": content}
|
|
185
|
+
|
|
186
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
187
|
+
yield {
|
|
188
|
+
"type": "done",
|
|
189
|
+
"text": full_text,
|
|
190
|
+
"meta": {
|
|
191
|
+
"prompt_tokens": prompt_tokens,
|
|
192
|
+
"completion_tokens": completion_tokens,
|
|
193
|
+
"total_tokens": total_tokens,
|
|
194
|
+
"cost": 0.0,
|
|
195
|
+
"raw_response": {},
|
|
196
|
+
"model_name": merged_options.get("model", self.model),
|
|
197
|
+
},
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
|
|
201
|
+
"""Use Ollama's /api/chat endpoint for multi-turn conversations."""
|
|
202
|
+
messages = self._prepare_messages(messages)
|
|
203
|
+
merged_options = self.options.copy()
|
|
204
|
+
if options:
|
|
205
|
+
merged_options.update(options)
|
|
206
|
+
|
|
207
|
+
# Derive the chat endpoint from the generate endpoint
|
|
208
|
+
chat_endpoint = self.endpoint.replace("/api/generate", "/api/chat")
|
|
209
|
+
|
|
210
|
+
payload: dict[str, Any] = {
|
|
211
|
+
"model": merged_options.get("model", self.model),
|
|
212
|
+
"messages": messages,
|
|
213
|
+
"stream": False,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
# Native JSON mode / structured output support
|
|
217
|
+
if merged_options.get("json_mode"):
|
|
218
|
+
json_schema = merged_options.get("json_schema")
|
|
219
|
+
payload["format"] = json_schema if json_schema else "json"
|
|
220
|
+
|
|
221
|
+
if "temperature" in merged_options:
|
|
222
|
+
payload["temperature"] = merged_options["temperature"]
|
|
223
|
+
if "top_p" in merged_options:
|
|
224
|
+
payload["top_p"] = merged_options["top_p"]
|
|
225
|
+
if "top_k" in merged_options:
|
|
226
|
+
payload["top_k"] = merged_options["top_k"]
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
logger.debug(f"Sending chat request to Ollama endpoint: {chat_endpoint}")
|
|
230
|
+
r = requests.post(chat_endpoint, json=payload, timeout=120)
|
|
231
|
+
r.raise_for_status()
|
|
232
|
+
response_data = r.json()
|
|
233
|
+
|
|
234
|
+
if not isinstance(response_data, dict):
|
|
235
|
+
raise ValueError(f"Expected dict response, got {type(response_data)}")
|
|
236
|
+
except requests.exceptions.ConnectionError:
|
|
237
|
+
raise
|
|
238
|
+
except requests.exceptions.HTTPError:
|
|
239
|
+
raise
|
|
240
|
+
except json.JSONDecodeError as e:
|
|
241
|
+
raise json.JSONDecodeError(f"Invalid JSON response from Ollama: {e.msg}", e.doc, e.pos) from e
|
|
242
|
+
except Exception as e:
|
|
243
|
+
raise RuntimeError(f"Ollama chat request failed: {e}") from e
|
|
244
|
+
|
|
245
|
+
prompt_tokens = response_data.get("prompt_eval_count", 0)
|
|
246
|
+
completion_tokens = response_data.get("eval_count", 0)
|
|
247
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
248
|
+
|
|
249
|
+
meta = {
|
|
250
|
+
"prompt_tokens": prompt_tokens,
|
|
251
|
+
"completion_tokens": completion_tokens,
|
|
252
|
+
"total_tokens": total_tokens,
|
|
253
|
+
"cost": 0.0,
|
|
254
|
+
"raw_response": response_data,
|
|
255
|
+
"model_name": merged_options.get("model", self.model),
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
# Chat endpoint returns response in message.content
|
|
259
|
+
message = response_data.get("message", {})
|
|
260
|
+
text = message.get("content", "")
|
|
261
|
+
return {"text": text, "meta": meta}
|