prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +133 -49
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +484 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +131 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +50 -0
- prompture/cli.py +7 -3
- prompture/conversation.py +504 -0
- prompture/core.py +475 -352
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +50 -35
- prompture/driver.py +125 -5
- prompture/drivers/__init__.py +171 -73
- prompture/drivers/airllm_driver.py +13 -20
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +86 -34
- prompture/drivers/google_driver.py +87 -51
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +90 -23
- prompture/drivers/openai_driver.py +36 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/session.py +117 -0
- prompture/settings.py +14 -1
- prompture/tools.py +172 -265
- prompture/validator.py +3 -3
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
- prompture-0.0.34.dist-info/RECORD +55 -0
- prompture-0.0.33.dev1.dist-info/RECORD +0 -29
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Async driver base class for LLM adapters."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .callbacks import DriverCallbacks
|
|
10
|
+
from .driver import Driver
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("prompture.async_driver")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AsyncDriver:
|
|
16
|
+
"""Async adapter base. Implement ``async generate(prompt, options)``
|
|
17
|
+
returning ``{"text": ..., "meta": {...}}``.
|
|
18
|
+
|
|
19
|
+
The ``meta`` dict follows the same contract as :class:`Driver`:
|
|
20
|
+
|
|
21
|
+
.. code-block:: python
|
|
22
|
+
|
|
23
|
+
{
|
|
24
|
+
"prompt_tokens": int,
|
|
25
|
+
"completion_tokens": int,
|
|
26
|
+
"total_tokens": int,
|
|
27
|
+
"cost": float,
|
|
28
|
+
"raw_response": dict,
|
|
29
|
+
}
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
supports_json_mode: bool = False
|
|
33
|
+
supports_json_schema: bool = False
|
|
34
|
+
supports_messages: bool = False
|
|
35
|
+
|
|
36
|
+
callbacks: DriverCallbacks | None = None
|
|
37
|
+
|
|
38
|
+
async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
async def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
42
|
+
"""Generate a response from a list of conversation messages (async).
|
|
43
|
+
|
|
44
|
+
Default implementation flattens the messages into a single prompt
|
|
45
|
+
and delegates to :meth:`generate`. Drivers that natively support
|
|
46
|
+
message arrays should override this and set
|
|
47
|
+
``supports_messages = True``.
|
|
48
|
+
"""
|
|
49
|
+
prompt = Driver._flatten_messages(messages)
|
|
50
|
+
return await self.generate(prompt, options)
|
|
51
|
+
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
# Hook-aware wrappers
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
async def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
57
|
+
"""Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
|
|
58
|
+
driver_name = getattr(self, "model", self.__class__.__name__)
|
|
59
|
+
self._fire_callback(
|
|
60
|
+
"on_request",
|
|
61
|
+
{"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
|
|
62
|
+
)
|
|
63
|
+
t0 = time.perf_counter()
|
|
64
|
+
try:
|
|
65
|
+
resp = await self.generate(prompt, options)
|
|
66
|
+
except Exception as exc:
|
|
67
|
+
self._fire_callback(
|
|
68
|
+
"on_error",
|
|
69
|
+
{"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
|
|
70
|
+
)
|
|
71
|
+
raise
|
|
72
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
73
|
+
self._fire_callback(
|
|
74
|
+
"on_response",
|
|
75
|
+
{
|
|
76
|
+
"text": resp.get("text", ""),
|
|
77
|
+
"meta": resp.get("meta", {}),
|
|
78
|
+
"driver": driver_name,
|
|
79
|
+
"elapsed_ms": elapsed_ms,
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
return resp
|
|
83
|
+
|
|
84
|
+
async def generate_messages_with_hooks(
|
|
85
|
+
self, messages: list[dict[str, str]], options: dict[str, Any]
|
|
86
|
+
) -> dict[str, Any]:
|
|
87
|
+
"""Wrap :meth:`generate_messages` with callbacks."""
|
|
88
|
+
driver_name = getattr(self, "model", self.__class__.__name__)
|
|
89
|
+
self._fire_callback(
|
|
90
|
+
"on_request",
|
|
91
|
+
{"prompt": None, "messages": messages, "options": options, "driver": driver_name},
|
|
92
|
+
)
|
|
93
|
+
t0 = time.perf_counter()
|
|
94
|
+
try:
|
|
95
|
+
resp = await self.generate_messages(messages, options)
|
|
96
|
+
except Exception as exc:
|
|
97
|
+
self._fire_callback(
|
|
98
|
+
"on_error",
|
|
99
|
+
{"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
|
|
100
|
+
)
|
|
101
|
+
raise
|
|
102
|
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
103
|
+
self._fire_callback(
|
|
104
|
+
"on_response",
|
|
105
|
+
{
|
|
106
|
+
"text": resp.get("text", ""),
|
|
107
|
+
"meta": resp.get("meta", {}),
|
|
108
|
+
"driver": driver_name,
|
|
109
|
+
"elapsed_ms": elapsed_ms,
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
return resp
|
|
113
|
+
|
|
114
|
+
# ------------------------------------------------------------------
|
|
115
|
+
# Internal helpers
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
|
|
118
|
+
def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
|
|
119
|
+
"""Invoke a single callback, swallowing and logging any exception."""
|
|
120
|
+
if self.callbacks is None:
|
|
121
|
+
return
|
|
122
|
+
cb = getattr(self.callbacks, event, None)
|
|
123
|
+
if cb is None:
|
|
124
|
+
return
|
|
125
|
+
try:
|
|
126
|
+
cb(payload)
|
|
127
|
+
except Exception:
|
|
128
|
+
logger.exception("Callback %s raised an exception", event)
|
|
129
|
+
|
|
130
|
+
# Re-export the static helper for convenience
|
|
131
|
+
_flatten_messages = staticmethod(Driver._flatten_messages)
|
prompture/cache.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
"""Response caching layer for prompture.
|
|
2
|
+
|
|
3
|
+
Provides pluggable cache backends (memory, SQLite, Redis) so repeated
|
|
4
|
+
identical LLM calls can be served from cache. Disabled by default.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import sqlite3
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from collections import OrderedDict
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
# Cache key generation
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
_CACHE_RELEVANT_OPTIONS = frozenset(
|
|
24
|
+
{
|
|
25
|
+
"temperature",
|
|
26
|
+
"max_tokens",
|
|
27
|
+
"top_p",
|
|
28
|
+
"top_k",
|
|
29
|
+
"frequency_penalty",
|
|
30
|
+
"presence_penalty",
|
|
31
|
+
"stop",
|
|
32
|
+
"seed",
|
|
33
|
+
"json_mode",
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def make_cache_key(
|
|
39
|
+
prompt: str,
|
|
40
|
+
model_name: str,
|
|
41
|
+
schema: dict[str, Any] | None = None,
|
|
42
|
+
options: dict[str, Any] | None = None,
|
|
43
|
+
output_format: str = "json",
|
|
44
|
+
pydantic_qualname: str | None = None,
|
|
45
|
+
) -> str:
|
|
46
|
+
"""Return a deterministic SHA-256 hex key for the given call parameters.
|
|
47
|
+
|
|
48
|
+
Only cache-relevant options (temperature, max_tokens, etc.) are included
|
|
49
|
+
so that unrelated option changes don't bust the cache.
|
|
50
|
+
"""
|
|
51
|
+
filtered_opts: dict[str, Any] = {}
|
|
52
|
+
if options:
|
|
53
|
+
filtered_opts = {k: v for k, v in sorted(options.items()) if k in _CACHE_RELEVANT_OPTIONS}
|
|
54
|
+
|
|
55
|
+
parts: dict[str, Any] = {
|
|
56
|
+
"prompt": prompt,
|
|
57
|
+
"model_name": model_name,
|
|
58
|
+
"schema": schema,
|
|
59
|
+
"options": filtered_opts,
|
|
60
|
+
"output_format": output_format,
|
|
61
|
+
}
|
|
62
|
+
if pydantic_qualname is not None:
|
|
63
|
+
parts["pydantic_qualname"] = pydantic_qualname
|
|
64
|
+
|
|
65
|
+
blob = json.dumps(parts, sort_keys=True, default=str)
|
|
66
|
+
return hashlib.sha256(blob.encode()).hexdigest()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# ---------------------------------------------------------------------------
|
|
70
|
+
# Backend ABC
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class CacheBackend(ABC):
|
|
75
|
+
"""Abstract base class for cache storage backends."""
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def get(self, key: str) -> Any | None:
|
|
79
|
+
"""Return the cached value or ``None`` on miss."""
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def set(self, key: str, value: Any, ttl: int | None = None) -> None:
|
|
83
|
+
"""Store *value* under *key* with optional TTL in seconds."""
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
def delete(self, key: str) -> None:
|
|
87
|
+
"""Remove a single key."""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def clear(self) -> None:
|
|
91
|
+
"""Remove all entries."""
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def has(self, key: str) -> bool:
|
|
95
|
+
"""Return whether *key* exists and is not expired."""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# ---------------------------------------------------------------------------
|
|
99
|
+
# Memory backend
|
|
100
|
+
# ---------------------------------------------------------------------------
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class MemoryCacheBackend(CacheBackend):
|
|
104
|
+
"""In-process LRU cache backed by an ``OrderedDict``.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
maxsize:
|
|
109
|
+
Maximum number of entries before the least-recently-used item is
|
|
110
|
+
evicted. Defaults to 256.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(self, maxsize: int = 256) -> None:
|
|
114
|
+
self._maxsize = maxsize
|
|
115
|
+
self._data: OrderedDict[str, tuple[Any, float | None]] = OrderedDict()
|
|
116
|
+
self._lock = threading.Lock()
|
|
117
|
+
|
|
118
|
+
# -- helpers --
|
|
119
|
+
def _is_expired(self, entry: tuple[Any, float | None]) -> bool:
|
|
120
|
+
_value, expires_at = entry
|
|
121
|
+
if expires_at is None:
|
|
122
|
+
return False
|
|
123
|
+
return time.time() > expires_at
|
|
124
|
+
|
|
125
|
+
# -- public API --
|
|
126
|
+
def get(self, key: str) -> Any | None:
|
|
127
|
+
with self._lock:
|
|
128
|
+
entry = self._data.get(key)
|
|
129
|
+
if entry is None:
|
|
130
|
+
return None
|
|
131
|
+
if self._is_expired(entry):
|
|
132
|
+
del self._data[key]
|
|
133
|
+
return None
|
|
134
|
+
# Move to end (most-recently used)
|
|
135
|
+
self._data.move_to_end(key)
|
|
136
|
+
return entry[0]
|
|
137
|
+
|
|
138
|
+
def set(self, key: str, value: Any, ttl: int | None = None) -> None:
|
|
139
|
+
expires_at = (time.time() + ttl) if ttl else None
|
|
140
|
+
with self._lock:
|
|
141
|
+
if key in self._data:
|
|
142
|
+
self._data.move_to_end(key)
|
|
143
|
+
self._data[key] = (value, expires_at)
|
|
144
|
+
# Evict LRU entries
|
|
145
|
+
while len(self._data) > self._maxsize:
|
|
146
|
+
self._data.popitem(last=False)
|
|
147
|
+
|
|
148
|
+
def delete(self, key: str) -> None:
|
|
149
|
+
with self._lock:
|
|
150
|
+
self._data.pop(key, None)
|
|
151
|
+
|
|
152
|
+
def clear(self) -> None:
|
|
153
|
+
with self._lock:
|
|
154
|
+
self._data.clear()
|
|
155
|
+
|
|
156
|
+
def has(self, key: str) -> bool:
|
|
157
|
+
with self._lock:
|
|
158
|
+
entry = self._data.get(key)
|
|
159
|
+
if entry is None:
|
|
160
|
+
return False
|
|
161
|
+
if self._is_expired(entry):
|
|
162
|
+
del self._data[key]
|
|
163
|
+
return False
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ---------------------------------------------------------------------------
|
|
168
|
+
# SQLite backend
|
|
169
|
+
# ---------------------------------------------------------------------------
|
|
170
|
+
|
|
171
|
+
_DEFAULT_SQLITE_PATH = Path.home() / ".prompture" / "cache" / "response_cache.db"
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class SQLiteCacheBackend(CacheBackend):
|
|
175
|
+
"""Persistent cache using a local SQLite database.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
db_path:
|
|
180
|
+
Path to the SQLite file. Defaults to
|
|
181
|
+
``~/.prompture/cache/response_cache.db``.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(self, db_path: str | None = None) -> None:
|
|
185
|
+
self._db_path = Path(db_path) if db_path else _DEFAULT_SQLITE_PATH
|
|
186
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
187
|
+
self._lock = threading.Lock()
|
|
188
|
+
self._init_db()
|
|
189
|
+
|
|
190
|
+
def _connect(self) -> sqlite3.Connection:
|
|
191
|
+
return sqlite3.connect(str(self._db_path), timeout=5)
|
|
192
|
+
|
|
193
|
+
def _init_db(self) -> None:
|
|
194
|
+
with self._lock:
|
|
195
|
+
conn = self._connect()
|
|
196
|
+
try:
|
|
197
|
+
conn.execute(
|
|
198
|
+
"""
|
|
199
|
+
CREATE TABLE IF NOT EXISTS cache (
|
|
200
|
+
key TEXT PRIMARY KEY,
|
|
201
|
+
value TEXT NOT NULL,
|
|
202
|
+
created_at REAL NOT NULL,
|
|
203
|
+
ttl REAL
|
|
204
|
+
)
|
|
205
|
+
"""
|
|
206
|
+
)
|
|
207
|
+
conn.commit()
|
|
208
|
+
finally:
|
|
209
|
+
conn.close()
|
|
210
|
+
|
|
211
|
+
def get(self, key: str) -> Any | None:
|
|
212
|
+
with self._lock:
|
|
213
|
+
conn = self._connect()
|
|
214
|
+
try:
|
|
215
|
+
row = conn.execute("SELECT value, created_at, ttl FROM cache WHERE key = ?", (key,)).fetchone()
|
|
216
|
+
if row is None:
|
|
217
|
+
return None
|
|
218
|
+
value_json, created_at, ttl = row
|
|
219
|
+
if ttl is not None and time.time() > created_at + ttl:
|
|
220
|
+
conn.execute("DELETE FROM cache WHERE key = ?", (key,))
|
|
221
|
+
conn.commit()
|
|
222
|
+
return None
|
|
223
|
+
return json.loads(value_json)
|
|
224
|
+
finally:
|
|
225
|
+
conn.close()
|
|
226
|
+
|
|
227
|
+
def set(self, key: str, value: Any, ttl: int | None = None) -> None:
|
|
228
|
+
value_json = json.dumps(value, default=str)
|
|
229
|
+
now = time.time()
|
|
230
|
+
with self._lock:
|
|
231
|
+
conn = self._connect()
|
|
232
|
+
try:
|
|
233
|
+
conn.execute(
|
|
234
|
+
"INSERT OR REPLACE INTO cache (key, value, created_at, ttl) VALUES (?, ?, ?, ?)",
|
|
235
|
+
(key, value_json, now, ttl),
|
|
236
|
+
)
|
|
237
|
+
conn.commit()
|
|
238
|
+
finally:
|
|
239
|
+
conn.close()
|
|
240
|
+
|
|
241
|
+
def delete(self, key: str) -> None:
|
|
242
|
+
with self._lock:
|
|
243
|
+
conn = self._connect()
|
|
244
|
+
try:
|
|
245
|
+
conn.execute("DELETE FROM cache WHERE key = ?", (key,))
|
|
246
|
+
conn.commit()
|
|
247
|
+
finally:
|
|
248
|
+
conn.close()
|
|
249
|
+
|
|
250
|
+
def clear(self) -> None:
|
|
251
|
+
with self._lock:
|
|
252
|
+
conn = self._connect()
|
|
253
|
+
try:
|
|
254
|
+
conn.execute("DELETE FROM cache")
|
|
255
|
+
conn.commit()
|
|
256
|
+
finally:
|
|
257
|
+
conn.close()
|
|
258
|
+
|
|
259
|
+
def has(self, key: str) -> bool:
|
|
260
|
+
return self.get(key) is not None
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# ---------------------------------------------------------------------------
|
|
264
|
+
# Redis backend
|
|
265
|
+
# ---------------------------------------------------------------------------
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class RedisCacheBackend(CacheBackend):
|
|
269
|
+
"""Cache backend using Redis with native TTL support.
|
|
270
|
+
|
|
271
|
+
Requires the ``redis`` package (``pip install redis`` or
|
|
272
|
+
``pip install prompture[redis]``).
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
redis_url:
|
|
277
|
+
Redis connection URL (e.g. ``redis://localhost:6379/0``).
|
|
278
|
+
prefix:
|
|
279
|
+
Key prefix. Defaults to ``"prompture:cache:"``.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(self, redis_url: str = "redis://localhost:6379/0", prefix: str = "prompture:cache:") -> None:
|
|
283
|
+
try:
|
|
284
|
+
import redis as _redis
|
|
285
|
+
except ImportError:
|
|
286
|
+
raise RuntimeError(
|
|
287
|
+
"Redis cache backend requires the 'redis' package. "
|
|
288
|
+
"Install it with: pip install redis (or: pip install prompture[redis])"
|
|
289
|
+
) from None
|
|
290
|
+
|
|
291
|
+
self._client = _redis.from_url(redis_url, decode_responses=True)
|
|
292
|
+
self._prefix = prefix
|
|
293
|
+
|
|
294
|
+
def _prefixed(self, key: str) -> str:
|
|
295
|
+
return f"{self._prefix}{key}"
|
|
296
|
+
|
|
297
|
+
def get(self, key: str) -> Any | None:
|
|
298
|
+
raw = self._client.get(self._prefixed(key))
|
|
299
|
+
if raw is None:
|
|
300
|
+
return None
|
|
301
|
+
return json.loads(raw)
|
|
302
|
+
|
|
303
|
+
def set(self, key: str, value: Any, ttl: int | None = None) -> None:
|
|
304
|
+
value_json = json.dumps(value, default=str)
|
|
305
|
+
if ttl:
|
|
306
|
+
self._client.setex(self._prefixed(key), ttl, value_json)
|
|
307
|
+
else:
|
|
308
|
+
self._client.set(self._prefixed(key), value_json)
|
|
309
|
+
|
|
310
|
+
def delete(self, key: str) -> None:
|
|
311
|
+
self._client.delete(self._prefixed(key))
|
|
312
|
+
|
|
313
|
+
def clear(self) -> None:
|
|
314
|
+
# Scan for keys with our prefix and delete them
|
|
315
|
+
cursor = 0
|
|
316
|
+
while True:
|
|
317
|
+
cursor, keys = self._client.scan(cursor, match=f"{self._prefix}*", count=100)
|
|
318
|
+
if keys:
|
|
319
|
+
self._client.delete(*keys)
|
|
320
|
+
if cursor == 0:
|
|
321
|
+
break
|
|
322
|
+
|
|
323
|
+
def has(self, key: str) -> bool:
|
|
324
|
+
return bool(self._client.exists(self._prefixed(key)))
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# ---------------------------------------------------------------------------
|
|
328
|
+
# ResponseCache orchestrator
|
|
329
|
+
# ---------------------------------------------------------------------------
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class ResponseCache:
|
|
333
|
+
"""Orchestrator that wraps a :class:`CacheBackend` with hit/miss stats
|
|
334
|
+
and an ``enabled`` toggle.
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
backend:
|
|
339
|
+
The storage backend to use.
|
|
340
|
+
enabled:
|
|
341
|
+
Whether caching is active. When ``False``, all lookups return
|
|
342
|
+
``None`` and stores are no-ops.
|
|
343
|
+
default_ttl:
|
|
344
|
+
Default time-to-live in seconds for cached entries.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(
|
|
348
|
+
self,
|
|
349
|
+
backend: CacheBackend,
|
|
350
|
+
enabled: bool = True,
|
|
351
|
+
default_ttl: int = 3600,
|
|
352
|
+
) -> None:
|
|
353
|
+
self.backend = backend
|
|
354
|
+
self.enabled = enabled
|
|
355
|
+
self.default_ttl = default_ttl
|
|
356
|
+
self._hits = 0
|
|
357
|
+
self._misses = 0
|
|
358
|
+
self._sets = 0
|
|
359
|
+
self._lock = threading.Lock()
|
|
360
|
+
|
|
361
|
+
def get(self, key: str, *, force: bool = False) -> Any | None:
|
|
362
|
+
if not self.enabled and not force:
|
|
363
|
+
with self._lock:
|
|
364
|
+
self._misses += 1
|
|
365
|
+
return None
|
|
366
|
+
value = self.backend.get(key)
|
|
367
|
+
with self._lock:
|
|
368
|
+
if value is not None:
|
|
369
|
+
self._hits += 1
|
|
370
|
+
else:
|
|
371
|
+
self._misses += 1
|
|
372
|
+
return value
|
|
373
|
+
|
|
374
|
+
def set(self, key: str, value: Any, ttl: int | None = None, *, force: bool = False) -> None:
|
|
375
|
+
if not self.enabled and not force:
|
|
376
|
+
return
|
|
377
|
+
self.backend.set(key, value, ttl or self.default_ttl)
|
|
378
|
+
with self._lock:
|
|
379
|
+
self._sets += 1
|
|
380
|
+
|
|
381
|
+
def invalidate(self, key: str) -> None:
|
|
382
|
+
self.backend.delete(key)
|
|
383
|
+
|
|
384
|
+
def clear(self) -> None:
|
|
385
|
+
self.backend.clear()
|
|
386
|
+
with self._lock:
|
|
387
|
+
self._hits = 0
|
|
388
|
+
self._misses = 0
|
|
389
|
+
self._sets = 0
|
|
390
|
+
|
|
391
|
+
def stats(self) -> dict[str, int]:
|
|
392
|
+
with self._lock:
|
|
393
|
+
return {"hits": self._hits, "misses": self._misses, "sets": self._sets}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
# ---------------------------------------------------------------------------
|
|
397
|
+
# Module-level singleton
|
|
398
|
+
# ---------------------------------------------------------------------------
|
|
399
|
+
|
|
400
|
+
_cache_instance: ResponseCache | None = None
|
|
401
|
+
_cache_lock = threading.Lock()
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def get_cache() -> ResponseCache:
|
|
405
|
+
"""Return the module-level :class:`ResponseCache` singleton.
|
|
406
|
+
|
|
407
|
+
If :func:`configure_cache` has not been called, returns a disabled
|
|
408
|
+
cache backed by :class:`MemoryCacheBackend`.
|
|
409
|
+
"""
|
|
410
|
+
global _cache_instance
|
|
411
|
+
with _cache_lock:
|
|
412
|
+
if _cache_instance is None:
|
|
413
|
+
_cache_instance = ResponseCache(
|
|
414
|
+
backend=MemoryCacheBackend(),
|
|
415
|
+
enabled=False,
|
|
416
|
+
)
|
|
417
|
+
return _cache_instance
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def configure_cache(
|
|
421
|
+
backend: str = "memory",
|
|
422
|
+
enabled: bool = True,
|
|
423
|
+
ttl: int = 3600,
|
|
424
|
+
maxsize: int = 256,
|
|
425
|
+
db_path: str | None = None,
|
|
426
|
+
redis_url: str | None = None,
|
|
427
|
+
) -> ResponseCache:
|
|
428
|
+
"""Create (or replace) the module-level cache singleton.
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
----------
|
|
432
|
+
backend:
|
|
433
|
+
``"memory"``, ``"sqlite"``, or ``"redis"``.
|
|
434
|
+
enabled:
|
|
435
|
+
Whether the cache is active.
|
|
436
|
+
ttl:
|
|
437
|
+
Default TTL in seconds.
|
|
438
|
+
maxsize:
|
|
439
|
+
Maximum entries for the memory backend.
|
|
440
|
+
db_path:
|
|
441
|
+
SQLite database path (only for ``"sqlite"`` backend).
|
|
442
|
+
redis_url:
|
|
443
|
+
Redis connection URL (only for ``"redis"`` backend).
|
|
444
|
+
|
|
445
|
+
Returns
|
|
446
|
+
-------
|
|
447
|
+
The newly configured :class:`ResponseCache`.
|
|
448
|
+
"""
|
|
449
|
+
global _cache_instance
|
|
450
|
+
|
|
451
|
+
if backend == "memory":
|
|
452
|
+
be = MemoryCacheBackend(maxsize=maxsize)
|
|
453
|
+
elif backend == "sqlite":
|
|
454
|
+
be = SQLiteCacheBackend(db_path=db_path)
|
|
455
|
+
elif backend == "redis":
|
|
456
|
+
be = RedisCacheBackend(redis_url=redis_url or "redis://localhost:6379/0")
|
|
457
|
+
else:
|
|
458
|
+
raise ValueError(f"Unknown cache backend '{backend}'. Choose 'memory', 'sqlite', or 'redis'.")
|
|
459
|
+
|
|
460
|
+
with _cache_lock:
|
|
461
|
+
_cache_instance = ResponseCache(backend=be, enabled=enabled, default_ttl=ttl)
|
|
462
|
+
return _cache_instance
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def _reset_cache() -> None:
|
|
466
|
+
"""Reset the singleton to ``None``. **For testing only.**"""
|
|
467
|
+
global _cache_instance
|
|
468
|
+
with _cache_lock:
|
|
469
|
+
_cache_instance = None
|
prompture/callbacks.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Callback hooks for driver-level observability.
|
|
2
|
+
|
|
3
|
+
Provides :class:`DriverCallbacks`, a lightweight container for functions
|
|
4
|
+
that are invoked before/after every driver call, giving full visibility
|
|
5
|
+
into request/response payloads and errors without modifying driver code.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
from prompture import DriverCallbacks
|
|
10
|
+
|
|
11
|
+
def log_request(info: dict) -> None:
|
|
12
|
+
print(f"-> {info['driver']} prompt length={len(info.get('prompt', ''))}")
|
|
13
|
+
|
|
14
|
+
def log_response(info: dict) -> None:
|
|
15
|
+
print(f"<- {info['driver']} {info['elapsed_ms']:.0f}ms")
|
|
16
|
+
|
|
17
|
+
callbacks = DriverCallbacks(on_request=log_request, on_response=log_response)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from typing import Any, Callable
|
|
24
|
+
|
|
25
|
+
# Type aliases for callback signatures.
|
|
26
|
+
# Each callback receives a single ``dict[str, Any]`` payload and returns nothing.
|
|
27
|
+
OnRequestCallback = Callable[[dict[str, Any]], None]
|
|
28
|
+
OnResponseCallback = Callable[[dict[str, Any]], None]
|
|
29
|
+
OnErrorCallback = Callable[[dict[str, Any]], None]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class DriverCallbacks:
|
|
34
|
+
"""Optional callbacks fired around every driver call.
|
|
35
|
+
|
|
36
|
+
Payload shapes:
|
|
37
|
+
|
|
38
|
+
``on_request``
|
|
39
|
+
``{prompt, messages, options, driver}``
|
|
40
|
+
|
|
41
|
+
``on_response``
|
|
42
|
+
``{text, meta, driver, elapsed_ms}``
|
|
43
|
+
|
|
44
|
+
``on_error``
|
|
45
|
+
``{error, prompt, messages, options, driver}``
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
on_request: OnRequestCallback | None = field(default=None)
|
|
49
|
+
on_response: OnResponseCallback | None = field(default=None)
|
|
50
|
+
on_error: OnErrorCallback | None = field(default=None)
|
prompture/cli.py
CHANGED
|
@@ -1,23 +1,27 @@
|
|
|
1
1
|
import json
|
|
2
|
+
|
|
2
3
|
import click
|
|
3
|
-
|
|
4
|
+
|
|
4
5
|
from .drivers import OllamaDriver
|
|
6
|
+
from .runner import run_suite_from_spec
|
|
7
|
+
|
|
5
8
|
|
|
6
9
|
@click.group()
|
|
7
10
|
def cli():
|
|
8
11
|
"""Simple CLI to run JSON specs"""
|
|
9
12
|
pass
|
|
10
13
|
|
|
14
|
+
|
|
11
15
|
@cli.command()
|
|
12
16
|
@click.argument("specfile", type=click.Path(exists=True))
|
|
13
17
|
@click.argument("outfile", type=click.Path())
|
|
14
18
|
def run(specfile, outfile):
|
|
15
19
|
"""Run a spec JSON and save report."""
|
|
16
|
-
with open(specfile,
|
|
20
|
+
with open(specfile, encoding="utf-8") as fh:
|
|
17
21
|
spec = json.load(fh)
|
|
18
22
|
# Use Ollama as default driver since it can run locally
|
|
19
23
|
drivers = {"ollama": OllamaDriver(endpoint="http://localhost:11434", model="gemma:latest")}
|
|
20
24
|
report = run_suite_from_spec(spec, drivers)
|
|
21
25
|
with open(outfile, "w", encoding="utf-8") as fh:
|
|
22
26
|
json.dump(report, fh, indent=2, ensure_ascii=False)
|
|
23
|
-
click.echo(f"Report saved to {outfile}")
|
|
27
|
+
click.echo(f"Report saved to {outfile}")
|