prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. prompture/__init__.py +146 -23
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +607 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +169 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +55 -0
  9. prompture/cli.py +63 -4
  10. prompture/conversation.py +631 -0
  11. prompture/core.py +876 -263
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +164 -0
  14. prompture/driver.py +168 -5
  15. prompture/drivers/__init__.py +173 -69
  16. prompture/drivers/airllm_driver.py +109 -0
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +251 -34
  32. prompture/drivers/google_driver.py +107 -38
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +157 -23
  39. prompture/drivers/openai_driver.py +178 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/scaffold/__init__.py +1 -0
  47. prompture/scaffold/generator.py +84 -0
  48. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  49. prompture/scaffold/templates/README.md.j2 +41 -0
  50. prompture/scaffold/templates/config.py.j2 +21 -0
  51. prompture/scaffold/templates/env.example.j2 +8 -0
  52. prompture/scaffold/templates/main.py.j2 +86 -0
  53. prompture/scaffold/templates/models.py.j2 +40 -0
  54. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  55. prompture/server.py +183 -0
  56. prompture/session.py +117 -0
  57. prompture/settings.py +18 -1
  58. prompture/tools.py +219 -267
  59. prompture/tools_schema.py +254 -0
  60. prompture/validator.py +3 -3
  61. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
  62. prompture-0.0.35.dist-info/RECORD +66 -0
  63. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
  64. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  65. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
  66. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
  67. {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ """Async driver base class for LLM adapters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from collections.abc import AsyncIterator
8
+ from typing import Any
9
+
10
+ from .callbacks import DriverCallbacks
11
+ from .driver import Driver
12
+
13
+ logger = logging.getLogger("prompture.async_driver")
14
+
15
+
16
+ class AsyncDriver:
17
+ """Async adapter base. Implement ``async generate(prompt, options)``
18
+ returning ``{"text": ..., "meta": {...}}``.
19
+
20
+ The ``meta`` dict follows the same contract as :class:`Driver`:
21
+
22
+ .. code-block:: python
23
+
24
+ {
25
+ "prompt_tokens": int,
26
+ "completion_tokens": int,
27
+ "total_tokens": int,
28
+ "cost": float,
29
+ "raw_response": dict,
30
+ }
31
+ """
32
+
33
+ supports_json_mode: bool = False
34
+ supports_json_schema: bool = False
35
+ supports_messages: bool = False
36
+ supports_tool_use: bool = False
37
+ supports_streaming: bool = False
38
+
39
+ callbacks: DriverCallbacks | None = None
40
+
41
+ async def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
42
+ raise NotImplementedError
43
+
44
+ async def generate_messages(self, messages: list[dict[str, Any]], options: dict[str, Any]) -> dict[str, Any]:
45
+ """Generate a response from a list of conversation messages (async).
46
+
47
+ Default implementation flattens the messages into a single prompt
48
+ and delegates to :meth:`generate`. Drivers that natively support
49
+ message arrays should override this and set
50
+ ``supports_messages = True``.
51
+ """
52
+ prompt = Driver._flatten_messages(messages)
53
+ return await self.generate(prompt, options)
54
+
55
+ # ------------------------------------------------------------------
56
+ # Tool use
57
+ # ------------------------------------------------------------------
58
+
59
+ async def generate_messages_with_tools(
60
+ self,
61
+ messages: list[dict[str, Any]],
62
+ tools: list[dict[str, Any]],
63
+ options: dict[str, Any],
64
+ ) -> dict[str, Any]:
65
+ """Generate a response that may include tool calls (async).
66
+
67
+ Returns a dict with keys: ``text``, ``meta``, ``tool_calls``, ``stop_reason``.
68
+ """
69
+ raise NotImplementedError(f"{self.__class__.__name__} does not support tool use")
70
+
71
+ # ------------------------------------------------------------------
72
+ # Streaming
73
+ # ------------------------------------------------------------------
74
+
75
+ async def generate_messages_stream(
76
+ self,
77
+ messages: list[dict[str, Any]],
78
+ options: dict[str, Any],
79
+ ) -> AsyncIterator[dict[str, Any]]:
80
+ """Yield response chunks incrementally (async).
81
+
82
+ Each chunk is a dict:
83
+ - ``{"type": "delta", "text": str}``
84
+ - ``{"type": "done", "text": str, "meta": dict}``
85
+ """
86
+ raise NotImplementedError(f"{self.__class__.__name__} does not support streaming")
87
+ # yield is needed to make this an async generator
88
+ yield # pragma: no cover
89
+
90
+ # ------------------------------------------------------------------
91
+ # Hook-aware wrappers
92
+ # ------------------------------------------------------------------
93
+
94
+ async def generate_with_hooks(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
95
+ """Wrap :meth:`generate` with on_request / on_response / on_error callbacks."""
96
+ driver_name = getattr(self, "model", self.__class__.__name__)
97
+ self._fire_callback(
98
+ "on_request",
99
+ {"prompt": prompt, "messages": None, "options": options, "driver": driver_name},
100
+ )
101
+ t0 = time.perf_counter()
102
+ try:
103
+ resp = await self.generate(prompt, options)
104
+ except Exception as exc:
105
+ self._fire_callback(
106
+ "on_error",
107
+ {"error": exc, "prompt": prompt, "messages": None, "options": options, "driver": driver_name},
108
+ )
109
+ raise
110
+ elapsed_ms = (time.perf_counter() - t0) * 1000
111
+ self._fire_callback(
112
+ "on_response",
113
+ {
114
+ "text": resp.get("text", ""),
115
+ "meta": resp.get("meta", {}),
116
+ "driver": driver_name,
117
+ "elapsed_ms": elapsed_ms,
118
+ },
119
+ )
120
+ return resp
121
+
122
+ async def generate_messages_with_hooks(
123
+ self, messages: list[dict[str, Any]], options: dict[str, Any]
124
+ ) -> dict[str, Any]:
125
+ """Wrap :meth:`generate_messages` with callbacks."""
126
+ driver_name = getattr(self, "model", self.__class__.__name__)
127
+ self._fire_callback(
128
+ "on_request",
129
+ {"prompt": None, "messages": messages, "options": options, "driver": driver_name},
130
+ )
131
+ t0 = time.perf_counter()
132
+ try:
133
+ resp = await self.generate_messages(messages, options)
134
+ except Exception as exc:
135
+ self._fire_callback(
136
+ "on_error",
137
+ {"error": exc, "prompt": None, "messages": messages, "options": options, "driver": driver_name},
138
+ )
139
+ raise
140
+ elapsed_ms = (time.perf_counter() - t0) * 1000
141
+ self._fire_callback(
142
+ "on_response",
143
+ {
144
+ "text": resp.get("text", ""),
145
+ "meta": resp.get("meta", {}),
146
+ "driver": driver_name,
147
+ "elapsed_ms": elapsed_ms,
148
+ },
149
+ )
150
+ return resp
151
+
152
+ # ------------------------------------------------------------------
153
+ # Internal helpers
154
+ # ------------------------------------------------------------------
155
+
156
+ def _fire_callback(self, event: str, payload: dict[str, Any]) -> None:
157
+ """Invoke a single callback, swallowing and logging any exception."""
158
+ if self.callbacks is None:
159
+ return
160
+ cb = getattr(self.callbacks, event, None)
161
+ if cb is None:
162
+ return
163
+ try:
164
+ cb(payload)
165
+ except Exception:
166
+ logger.exception("Callback %s raised an exception", event)
167
+
168
+ # Re-export the static helper for convenience
169
+ _flatten_messages = staticmethod(Driver._flatten_messages)
prompture/cache.py ADDED
@@ -0,0 +1,469 @@
1
+ """Response caching layer for prompture.
2
+
3
+ Provides pluggable cache backends (memory, SQLite, Redis) so repeated
4
+ identical LLM calls can be served from cache. Disabled by default.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ import sqlite3
12
+ import threading
13
+ import time
14
+ from abc import ABC, abstractmethod
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Cache key generation
21
+ # ---------------------------------------------------------------------------
22
+
23
+ _CACHE_RELEVANT_OPTIONS = frozenset(
24
+ {
25
+ "temperature",
26
+ "max_tokens",
27
+ "top_p",
28
+ "top_k",
29
+ "frequency_penalty",
30
+ "presence_penalty",
31
+ "stop",
32
+ "seed",
33
+ "json_mode",
34
+ }
35
+ )
36
+
37
+
38
+ def make_cache_key(
39
+ prompt: str,
40
+ model_name: str,
41
+ schema: dict[str, Any] | None = None,
42
+ options: dict[str, Any] | None = None,
43
+ output_format: str = "json",
44
+ pydantic_qualname: str | None = None,
45
+ ) -> str:
46
+ """Return a deterministic SHA-256 hex key for the given call parameters.
47
+
48
+ Only cache-relevant options (temperature, max_tokens, etc.) are included
49
+ so that unrelated option changes don't bust the cache.
50
+ """
51
+ filtered_opts: dict[str, Any] = {}
52
+ if options:
53
+ filtered_opts = {k: v for k, v in sorted(options.items()) if k in _CACHE_RELEVANT_OPTIONS}
54
+
55
+ parts: dict[str, Any] = {
56
+ "prompt": prompt,
57
+ "model_name": model_name,
58
+ "schema": schema,
59
+ "options": filtered_opts,
60
+ "output_format": output_format,
61
+ }
62
+ if pydantic_qualname is not None:
63
+ parts["pydantic_qualname"] = pydantic_qualname
64
+
65
+ blob = json.dumps(parts, sort_keys=True, default=str)
66
+ return hashlib.sha256(blob.encode()).hexdigest()
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Backend ABC
71
+ # ---------------------------------------------------------------------------
72
+
73
+
74
+ class CacheBackend(ABC):
75
+ """Abstract base class for cache storage backends."""
76
+
77
+ @abstractmethod
78
+ def get(self, key: str) -> Any | None:
79
+ """Return the cached value or ``None`` on miss."""
80
+
81
+ @abstractmethod
82
+ def set(self, key: str, value: Any, ttl: int | None = None) -> None:
83
+ """Store *value* under *key* with optional TTL in seconds."""
84
+
85
+ @abstractmethod
86
+ def delete(self, key: str) -> None:
87
+ """Remove a single key."""
88
+
89
+ @abstractmethod
90
+ def clear(self) -> None:
91
+ """Remove all entries."""
92
+
93
+ @abstractmethod
94
+ def has(self, key: str) -> bool:
95
+ """Return whether *key* exists and is not expired."""
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Memory backend
100
+ # ---------------------------------------------------------------------------
101
+
102
+
103
+ class MemoryCacheBackend(CacheBackend):
104
+ """In-process LRU cache backed by an ``OrderedDict``.
105
+
106
+ Parameters
107
+ ----------
108
+ maxsize:
109
+ Maximum number of entries before the least-recently-used item is
110
+ evicted. Defaults to 256.
111
+ """
112
+
113
+ def __init__(self, maxsize: int = 256) -> None:
114
+ self._maxsize = maxsize
115
+ self._data: OrderedDict[str, tuple[Any, float | None]] = OrderedDict()
116
+ self._lock = threading.Lock()
117
+
118
+ # -- helpers --
119
+ def _is_expired(self, entry: tuple[Any, float | None]) -> bool:
120
+ _value, expires_at = entry
121
+ if expires_at is None:
122
+ return False
123
+ return time.time() > expires_at
124
+
125
+ # -- public API --
126
+ def get(self, key: str) -> Any | None:
127
+ with self._lock:
128
+ entry = self._data.get(key)
129
+ if entry is None:
130
+ return None
131
+ if self._is_expired(entry):
132
+ del self._data[key]
133
+ return None
134
+ # Move to end (most-recently used)
135
+ self._data.move_to_end(key)
136
+ return entry[0]
137
+
138
+ def set(self, key: str, value: Any, ttl: int | None = None) -> None:
139
+ expires_at = (time.time() + ttl) if ttl else None
140
+ with self._lock:
141
+ if key in self._data:
142
+ self._data.move_to_end(key)
143
+ self._data[key] = (value, expires_at)
144
+ # Evict LRU entries
145
+ while len(self._data) > self._maxsize:
146
+ self._data.popitem(last=False)
147
+
148
+ def delete(self, key: str) -> None:
149
+ with self._lock:
150
+ self._data.pop(key, None)
151
+
152
+ def clear(self) -> None:
153
+ with self._lock:
154
+ self._data.clear()
155
+
156
+ def has(self, key: str) -> bool:
157
+ with self._lock:
158
+ entry = self._data.get(key)
159
+ if entry is None:
160
+ return False
161
+ if self._is_expired(entry):
162
+ del self._data[key]
163
+ return False
164
+ return True
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # SQLite backend
169
+ # ---------------------------------------------------------------------------
170
+
171
+ _DEFAULT_SQLITE_PATH = Path.home() / ".prompture" / "cache" / "response_cache.db"
172
+
173
+
174
+ class SQLiteCacheBackend(CacheBackend):
175
+ """Persistent cache using a local SQLite database.
176
+
177
+ Parameters
178
+ ----------
179
+ db_path:
180
+ Path to the SQLite file. Defaults to
181
+ ``~/.prompture/cache/response_cache.db``.
182
+ """
183
+
184
+ def __init__(self, db_path: str | None = None) -> None:
185
+ self._db_path = Path(db_path) if db_path else _DEFAULT_SQLITE_PATH
186
+ self._db_path.parent.mkdir(parents=True, exist_ok=True)
187
+ self._lock = threading.Lock()
188
+ self._init_db()
189
+
190
+ def _connect(self) -> sqlite3.Connection:
191
+ return sqlite3.connect(str(self._db_path), timeout=5)
192
+
193
+ def _init_db(self) -> None:
194
+ with self._lock:
195
+ conn = self._connect()
196
+ try:
197
+ conn.execute(
198
+ """
199
+ CREATE TABLE IF NOT EXISTS cache (
200
+ key TEXT PRIMARY KEY,
201
+ value TEXT NOT NULL,
202
+ created_at REAL NOT NULL,
203
+ ttl REAL
204
+ )
205
+ """
206
+ )
207
+ conn.commit()
208
+ finally:
209
+ conn.close()
210
+
211
+ def get(self, key: str) -> Any | None:
212
+ with self._lock:
213
+ conn = self._connect()
214
+ try:
215
+ row = conn.execute("SELECT value, created_at, ttl FROM cache WHERE key = ?", (key,)).fetchone()
216
+ if row is None:
217
+ return None
218
+ value_json, created_at, ttl = row
219
+ if ttl is not None and time.time() > created_at + ttl:
220
+ conn.execute("DELETE FROM cache WHERE key = ?", (key,))
221
+ conn.commit()
222
+ return None
223
+ return json.loads(value_json)
224
+ finally:
225
+ conn.close()
226
+
227
+ def set(self, key: str, value: Any, ttl: int | None = None) -> None:
228
+ value_json = json.dumps(value, default=str)
229
+ now = time.time()
230
+ with self._lock:
231
+ conn = self._connect()
232
+ try:
233
+ conn.execute(
234
+ "INSERT OR REPLACE INTO cache (key, value, created_at, ttl) VALUES (?, ?, ?, ?)",
235
+ (key, value_json, now, ttl),
236
+ )
237
+ conn.commit()
238
+ finally:
239
+ conn.close()
240
+
241
+ def delete(self, key: str) -> None:
242
+ with self._lock:
243
+ conn = self._connect()
244
+ try:
245
+ conn.execute("DELETE FROM cache WHERE key = ?", (key,))
246
+ conn.commit()
247
+ finally:
248
+ conn.close()
249
+
250
+ def clear(self) -> None:
251
+ with self._lock:
252
+ conn = self._connect()
253
+ try:
254
+ conn.execute("DELETE FROM cache")
255
+ conn.commit()
256
+ finally:
257
+ conn.close()
258
+
259
+ def has(self, key: str) -> bool:
260
+ return self.get(key) is not None
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # Redis backend
265
+ # ---------------------------------------------------------------------------
266
+
267
+
268
+ class RedisCacheBackend(CacheBackend):
269
+ """Cache backend using Redis with native TTL support.
270
+
271
+ Requires the ``redis`` package (``pip install redis`` or
272
+ ``pip install prompture[redis]``).
273
+
274
+ Parameters
275
+ ----------
276
+ redis_url:
277
+ Redis connection URL (e.g. ``redis://localhost:6379/0``).
278
+ prefix:
279
+ Key prefix. Defaults to ``"prompture:cache:"``.
280
+ """
281
+
282
+ def __init__(self, redis_url: str = "redis://localhost:6379/0", prefix: str = "prompture:cache:") -> None:
283
+ try:
284
+ import redis as _redis
285
+ except ImportError:
286
+ raise RuntimeError(
287
+ "Redis cache backend requires the 'redis' package. "
288
+ "Install it with: pip install redis (or: pip install prompture[redis])"
289
+ ) from None
290
+
291
+ self._client = _redis.from_url(redis_url, decode_responses=True)
292
+ self._prefix = prefix
293
+
294
+ def _prefixed(self, key: str) -> str:
295
+ return f"{self._prefix}{key}"
296
+
297
+ def get(self, key: str) -> Any | None:
298
+ raw = self._client.get(self._prefixed(key))
299
+ if raw is None:
300
+ return None
301
+ return json.loads(raw)
302
+
303
+ def set(self, key: str, value: Any, ttl: int | None = None) -> None:
304
+ value_json = json.dumps(value, default=str)
305
+ if ttl:
306
+ self._client.setex(self._prefixed(key), ttl, value_json)
307
+ else:
308
+ self._client.set(self._prefixed(key), value_json)
309
+
310
+ def delete(self, key: str) -> None:
311
+ self._client.delete(self._prefixed(key))
312
+
313
+ def clear(self) -> None:
314
+ # Scan for keys with our prefix and delete them
315
+ cursor = 0
316
+ while True:
317
+ cursor, keys = self._client.scan(cursor, match=f"{self._prefix}*", count=100)
318
+ if keys:
319
+ self._client.delete(*keys)
320
+ if cursor == 0:
321
+ break
322
+
323
+ def has(self, key: str) -> bool:
324
+ return bool(self._client.exists(self._prefixed(key)))
325
+
326
+
327
+ # ---------------------------------------------------------------------------
328
+ # ResponseCache orchestrator
329
+ # ---------------------------------------------------------------------------
330
+
331
+
332
+ class ResponseCache:
333
+ """Orchestrator that wraps a :class:`CacheBackend` with hit/miss stats
334
+ and an ``enabled`` toggle.
335
+
336
+ Parameters
337
+ ----------
338
+ backend:
339
+ The storage backend to use.
340
+ enabled:
341
+ Whether caching is active. When ``False``, all lookups return
342
+ ``None`` and stores are no-ops.
343
+ default_ttl:
344
+ Default time-to-live in seconds for cached entries.
345
+ """
346
+
347
+ def __init__(
348
+ self,
349
+ backend: CacheBackend,
350
+ enabled: bool = True,
351
+ default_ttl: int = 3600,
352
+ ) -> None:
353
+ self.backend = backend
354
+ self.enabled = enabled
355
+ self.default_ttl = default_ttl
356
+ self._hits = 0
357
+ self._misses = 0
358
+ self._sets = 0
359
+ self._lock = threading.Lock()
360
+
361
+ def get(self, key: str, *, force: bool = False) -> Any | None:
362
+ if not self.enabled and not force:
363
+ with self._lock:
364
+ self._misses += 1
365
+ return None
366
+ value = self.backend.get(key)
367
+ with self._lock:
368
+ if value is not None:
369
+ self._hits += 1
370
+ else:
371
+ self._misses += 1
372
+ return value
373
+
374
+ def set(self, key: str, value: Any, ttl: int | None = None, *, force: bool = False) -> None:
375
+ if not self.enabled and not force:
376
+ return
377
+ self.backend.set(key, value, ttl or self.default_ttl)
378
+ with self._lock:
379
+ self._sets += 1
380
+
381
+ def invalidate(self, key: str) -> None:
382
+ self.backend.delete(key)
383
+
384
+ def clear(self) -> None:
385
+ self.backend.clear()
386
+ with self._lock:
387
+ self._hits = 0
388
+ self._misses = 0
389
+ self._sets = 0
390
+
391
+ def stats(self) -> dict[str, int]:
392
+ with self._lock:
393
+ return {"hits": self._hits, "misses": self._misses, "sets": self._sets}
394
+
395
+
396
+ # ---------------------------------------------------------------------------
397
+ # Module-level singleton
398
+ # ---------------------------------------------------------------------------
399
+
400
+ _cache_instance: ResponseCache | None = None
401
+ _cache_lock = threading.Lock()
402
+
403
+
404
+ def get_cache() -> ResponseCache:
405
+ """Return the module-level :class:`ResponseCache` singleton.
406
+
407
+ If :func:`configure_cache` has not been called, returns a disabled
408
+ cache backed by :class:`MemoryCacheBackend`.
409
+ """
410
+ global _cache_instance
411
+ with _cache_lock:
412
+ if _cache_instance is None:
413
+ _cache_instance = ResponseCache(
414
+ backend=MemoryCacheBackend(),
415
+ enabled=False,
416
+ )
417
+ return _cache_instance
418
+
419
+
420
+ def configure_cache(
421
+ backend: str = "memory",
422
+ enabled: bool = True,
423
+ ttl: int = 3600,
424
+ maxsize: int = 256,
425
+ db_path: str | None = None,
426
+ redis_url: str | None = None,
427
+ ) -> ResponseCache:
428
+ """Create (or replace) the module-level cache singleton.
429
+
430
+ Parameters
431
+ ----------
432
+ backend:
433
+ ``"memory"``, ``"sqlite"``, or ``"redis"``.
434
+ enabled:
435
+ Whether the cache is active.
436
+ ttl:
437
+ Default TTL in seconds.
438
+ maxsize:
439
+ Maximum entries for the memory backend.
440
+ db_path:
441
+ SQLite database path (only for ``"sqlite"`` backend).
442
+ redis_url:
443
+ Redis connection URL (only for ``"redis"`` backend).
444
+
445
+ Returns
446
+ -------
447
+ The newly configured :class:`ResponseCache`.
448
+ """
449
+ global _cache_instance
450
+
451
+ if backend == "memory":
452
+ be = MemoryCacheBackend(maxsize=maxsize)
453
+ elif backend == "sqlite":
454
+ be = SQLiteCacheBackend(db_path=db_path)
455
+ elif backend == "redis":
456
+ be = RedisCacheBackend(redis_url=redis_url or "redis://localhost:6379/0")
457
+ else:
458
+ raise ValueError(f"Unknown cache backend '{backend}'. Choose 'memory', 'sqlite', or 'redis'.")
459
+
460
+ with _cache_lock:
461
+ _cache_instance = ResponseCache(backend=be, enabled=enabled, default_ttl=ttl)
462
+ return _cache_instance
463
+
464
+
465
+ def _reset_cache() -> None:
466
+ """Reset the singleton to ``None``. **For testing only.**"""
467
+ global _cache_instance
468
+ with _cache_lock:
469
+ _cache_instance = None
prompture/callbacks.py ADDED
@@ -0,0 +1,55 @@
1
+ """Callback hooks for driver-level observability.
2
+
3
+ Provides :class:`DriverCallbacks`, a lightweight container for functions
4
+ that are invoked before/after every driver call, giving full visibility
5
+ into request/response payloads and errors without modifying driver code.
6
+
7
+ Usage::
8
+
9
+ from prompture import DriverCallbacks
10
+
11
+ def log_request(info: dict) -> None:
12
+ print(f"-> {info['driver']} prompt length={len(info.get('prompt', ''))}")
13
+
14
+ def log_response(info: dict) -> None:
15
+ print(f"<- {info['driver']} {info['elapsed_ms']:.0f}ms")
16
+
17
+ callbacks = DriverCallbacks(on_request=log_request, on_response=log_response)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Callable
24
+
25
+ # Type aliases for callback signatures.
26
+ # Each callback receives a single ``dict[str, Any]`` payload and returns nothing.
27
+ OnRequestCallback = Callable[[dict[str, Any]], None]
28
+ OnResponseCallback = Callable[[dict[str, Any]], None]
29
+ OnErrorCallback = Callable[[dict[str, Any]], None]
30
+ OnStreamDeltaCallback = Callable[[dict[str, Any]], None]
31
+
32
+
33
+ @dataclass
34
+ class DriverCallbacks:
35
+ """Optional callbacks fired around every driver call.
36
+
37
+ Payload shapes:
38
+
39
+ ``on_request``
40
+ ``{prompt, messages, options, driver}``
41
+
42
+ ``on_response``
43
+ ``{text, meta, driver, elapsed_ms}``
44
+
45
+ ``on_error``
46
+ ``{error, prompt, messages, options, driver}``
47
+
48
+ ``on_stream_delta``
49
+ ``{text, driver}``
50
+ """
51
+
52
+ on_request: OnRequestCallback | None = field(default=None)
53
+ on_response: OnResponseCallback | None = field(default=None)
54
+ on_error: OnErrorCallback | None = field(default=None)
55
+ on_stream_delta: OnStreamDeltaCallback | None = field(default=None)