prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/logging.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Logging configuration for the Prompture library.
|
|
2
|
+
|
|
3
|
+
Provides a structured JSON formatter and a convenience function for users
|
|
4
|
+
to enable Prompture's internal logging with a single call.
|
|
5
|
+
|
|
6
|
+
Usage::
|
|
7
|
+
|
|
8
|
+
from prompture import configure_logging
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
# Simple: enable DEBUG-level output to stderr
|
|
12
|
+
configure_logging(logging.DEBUG)
|
|
13
|
+
|
|
14
|
+
# Structured JSON lines (useful for log aggregation)
|
|
15
|
+
configure_logging(logging.DEBUG, json_format=True)
|
|
16
|
+
|
|
17
|
+
# Provide your own handler
|
|
18
|
+
fh = logging.FileHandler("prompture.log")
|
|
19
|
+
configure_logging(logging.INFO, handler=fh)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
from datetime import datetime, timezone
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class JSONFormatter(logging.Formatter):
|
|
31
|
+
"""Emit each log record as a single JSON line.
|
|
32
|
+
|
|
33
|
+
Fields always present: ``timestamp``, ``level``, ``logger``, ``message``.
|
|
34
|
+
If the caller passes ``extra={"prompture_data": ...}`` the value is
|
|
35
|
+
included under the ``data`` key.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
39
|
+
payload: dict[str, Any] = {
|
|
40
|
+
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
|
41
|
+
"level": record.levelname,
|
|
42
|
+
"logger": record.name,
|
|
43
|
+
"message": record.getMessage(),
|
|
44
|
+
}
|
|
45
|
+
data = getattr(record, "prompture_data", None)
|
|
46
|
+
if data is not None:
|
|
47
|
+
payload["data"] = data
|
|
48
|
+
return json.dumps(payload, default=str, ensure_ascii=False)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def configure_logging(
|
|
52
|
+
level: int = logging.DEBUG,
|
|
53
|
+
handler: logging.Handler | None = None,
|
|
54
|
+
json_format: bool = False,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Set up Prompture's library logger for application-level visibility.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
level: Minimum severity to emit (e.g. ``logging.DEBUG``).
|
|
60
|
+
handler: Custom :class:`logging.Handler`. When *None*, a
|
|
61
|
+
:class:`logging.StreamHandler` writing to *stderr* is created.
|
|
62
|
+
json_format: When *True*, messages are formatted as JSON lines
|
|
63
|
+
via :class:`JSONFormatter`.
|
|
64
|
+
"""
|
|
65
|
+
logger = logging.getLogger("prompture")
|
|
66
|
+
logger.setLevel(level)
|
|
67
|
+
|
|
68
|
+
if handler is None:
|
|
69
|
+
handler = logging.StreamHandler()
|
|
70
|
+
|
|
71
|
+
if json_format:
|
|
72
|
+
handler.setFormatter(JSONFormatter())
|
|
73
|
+
else:
|
|
74
|
+
handler.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s"))
|
|
75
|
+
|
|
76
|
+
handler.setLevel(level)
|
|
77
|
+
|
|
78
|
+
# Avoid adding duplicate handlers when called multiple times.
|
|
79
|
+
logger.handlers = [h for h in logger.handlers if h is not handler]
|
|
80
|
+
logger.addHandler(handler)
|
prompture/model_rates.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Live model rates from models.dev API with local caching.
|
|
2
|
+
|
|
3
|
+
Fetches pricing and metadata for LLM models from https://models.dev/api.json,
|
|
4
|
+
caches locally with TTL-based auto-refresh, and provides lookup functions
|
|
5
|
+
used by drivers for cost calculations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import threading
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Maps prompture provider names to models.dev provider names
|
|
19
|
+
PROVIDER_MAP: dict[str, str] = {
|
|
20
|
+
"openai": "openai",
|
|
21
|
+
"claude": "anthropic",
|
|
22
|
+
"google": "google",
|
|
23
|
+
"groq": "groq",
|
|
24
|
+
"grok": "xai",
|
|
25
|
+
"azure": "azure",
|
|
26
|
+
"openrouter": "openrouter",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
_API_URL = "https://models.dev/api.json"
|
|
30
|
+
_CACHE_DIR = Path.home() / ".prompture" / "cache"
|
|
31
|
+
_CACHE_FILE = _CACHE_DIR / "models_dev.json"
|
|
32
|
+
_META_FILE = _CACHE_DIR / "models_dev_meta.json"
|
|
33
|
+
|
|
34
|
+
_lock = threading.Lock()
|
|
35
|
+
_data: Optional[dict[str, Any]] = None
|
|
36
|
+
_loaded = False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_ttl_days() -> int:
|
|
40
|
+
"""Get TTL from settings if available, otherwise default to 7."""
|
|
41
|
+
try:
|
|
42
|
+
from .settings import settings
|
|
43
|
+
|
|
44
|
+
return getattr(settings, "model_rates_ttl_days", 7)
|
|
45
|
+
except Exception:
|
|
46
|
+
return 7
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _cache_is_valid() -> bool:
|
|
50
|
+
"""Check whether the local cache exists and is within TTL."""
|
|
51
|
+
if not _CACHE_FILE.exists() or not _META_FILE.exists():
|
|
52
|
+
return False
|
|
53
|
+
try:
|
|
54
|
+
meta = json.loads(_META_FILE.read_text(encoding="utf-8"))
|
|
55
|
+
fetched_at = datetime.fromisoformat(meta["fetched_at"])
|
|
56
|
+
ttl_days = meta.get("ttl_days", _get_ttl_days())
|
|
57
|
+
age = datetime.now(timezone.utc) - fetched_at
|
|
58
|
+
return age.total_seconds() < ttl_days * 86400
|
|
59
|
+
except Exception:
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _write_cache(data: dict[str, Any]) -> None:
|
|
64
|
+
"""Write API data and metadata to local cache."""
|
|
65
|
+
try:
|
|
66
|
+
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
_CACHE_FILE.write_text(json.dumps(data), encoding="utf-8")
|
|
68
|
+
meta = {
|
|
69
|
+
"fetched_at": datetime.now(timezone.utc).isoformat(),
|
|
70
|
+
"ttl_days": _get_ttl_days(),
|
|
71
|
+
}
|
|
72
|
+
_META_FILE.write_text(json.dumps(meta), encoding="utf-8")
|
|
73
|
+
except Exception as exc:
|
|
74
|
+
logger.debug("Failed to write model rates cache: %s", exc)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _read_cache() -> Optional[dict[str, Any]]:
|
|
78
|
+
"""Read cached API data from disk."""
|
|
79
|
+
try:
|
|
80
|
+
return json.loads(_CACHE_FILE.read_text(encoding="utf-8"))
|
|
81
|
+
except Exception:
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _fetch_from_api() -> Optional[dict[str, Any]]:
|
|
86
|
+
"""Fetch fresh data from models.dev API."""
|
|
87
|
+
try:
|
|
88
|
+
import requests
|
|
89
|
+
|
|
90
|
+
resp = requests.get(_API_URL, timeout=15)
|
|
91
|
+
resp.raise_for_status()
|
|
92
|
+
return resp.json()
|
|
93
|
+
except Exception as exc:
|
|
94
|
+
logger.debug("Failed to fetch model rates from %s: %s", _API_URL, exc)
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _ensure_loaded() -> Optional[dict[str, Any]]:
|
|
99
|
+
"""Lazy-load data: use cache if valid, otherwise fetch from API."""
|
|
100
|
+
global _data, _loaded
|
|
101
|
+
if _loaded:
|
|
102
|
+
return _data
|
|
103
|
+
|
|
104
|
+
with _lock:
|
|
105
|
+
# Double-check after acquiring lock
|
|
106
|
+
if _loaded:
|
|
107
|
+
return _data
|
|
108
|
+
|
|
109
|
+
if _cache_is_valid():
|
|
110
|
+
_data = _read_cache()
|
|
111
|
+
if _data is not None:
|
|
112
|
+
_loaded = True
|
|
113
|
+
return _data
|
|
114
|
+
|
|
115
|
+
# Cache missing or expired — fetch fresh
|
|
116
|
+
fresh = _fetch_from_api()
|
|
117
|
+
if fresh is not None:
|
|
118
|
+
_data = fresh
|
|
119
|
+
_write_cache(fresh)
|
|
120
|
+
else:
|
|
121
|
+
# Fetch failed — try stale cache as last resort
|
|
122
|
+
_data = _read_cache()
|
|
123
|
+
|
|
124
|
+
_loaded = True
|
|
125
|
+
return _data
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _lookup_model(provider: str, model_id: str) -> Optional[dict[str, Any]]:
|
|
129
|
+
"""Find a model entry in the cached data.
|
|
130
|
+
|
|
131
|
+
The API structure is ``{provider: {model_id: {...}, ...}, ...}``.
|
|
132
|
+
"""
|
|
133
|
+
data = _ensure_loaded()
|
|
134
|
+
if data is None:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
api_provider = PROVIDER_MAP.get(provider, provider)
|
|
138
|
+
provider_data = data.get(api_provider)
|
|
139
|
+
if not isinstance(provider_data, dict):
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
return provider_data.get(model_id)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# ── Public API ──────────────────────────────────────────────────────────────
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_model_rates(provider: str, model_id: str) -> Optional[dict[str, float]]:
|
|
149
|
+
"""Return pricing dict for a model, or ``None`` if unavailable.
|
|
150
|
+
|
|
151
|
+
Returned keys mirror models.dev cost fields (per 1M tokens):
|
|
152
|
+
``input``, ``output``, and optionally ``cache_read``, ``cache_write``,
|
|
153
|
+
``reasoning``.
|
|
154
|
+
"""
|
|
155
|
+
entry = _lookup_model(provider, model_id)
|
|
156
|
+
if entry is None:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
cost = entry.get("cost")
|
|
160
|
+
if not isinstance(cost, dict):
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
rates: dict[str, float] = {}
|
|
164
|
+
for key in ("input", "output", "cache_read", "cache_write", "reasoning"):
|
|
165
|
+
val = cost.get(key)
|
|
166
|
+
if val is not None:
|
|
167
|
+
with contextlib.suppress(TypeError, ValueError):
|
|
168
|
+
rates[key] = float(val)
|
|
169
|
+
|
|
170
|
+
# Must have at least input and output to be useful
|
|
171
|
+
if "input" in rates and "output" in rates:
|
|
172
|
+
return rates
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_model_info(provider: str, model_id: str) -> Optional[dict[str, Any]]:
|
|
177
|
+
"""Return full model metadata (cost, limits, capabilities), or ``None``."""
|
|
178
|
+
return _lookup_model(provider, model_id)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_all_provider_models(provider: str) -> list[str]:
|
|
182
|
+
"""Return list of model IDs available for a provider."""
|
|
183
|
+
data = _ensure_loaded()
|
|
184
|
+
if data is None:
|
|
185
|
+
return []
|
|
186
|
+
|
|
187
|
+
api_provider = PROVIDER_MAP.get(provider, provider)
|
|
188
|
+
provider_data = data.get(api_provider)
|
|
189
|
+
if not isinstance(provider_data, dict):
|
|
190
|
+
return []
|
|
191
|
+
|
|
192
|
+
return list(provider_data.keys())
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def refresh_rates_cache(force: bool = False) -> bool:
|
|
196
|
+
"""Fetch fresh data from models.dev.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
force: If ``True``, fetch even when the cache is still within TTL.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
``True`` if fresh data was fetched and cached successfully.
|
|
203
|
+
"""
|
|
204
|
+
global _data, _loaded
|
|
205
|
+
|
|
206
|
+
with _lock:
|
|
207
|
+
if not force and _cache_is_valid():
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
fresh = _fetch_from_api()
|
|
211
|
+
if fresh is not None:
|
|
212
|
+
_data = fresh
|
|
213
|
+
_write_cache(fresh)
|
|
214
|
+
_loaded = True
|
|
215
|
+
return True
|
|
216
|
+
|
|
217
|
+
return False
|
prompture/persistence.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Conversation persistence — file and SQLite storage backends.
|
|
2
|
+
|
|
3
|
+
Provides:
|
|
4
|
+
|
|
5
|
+
- :func:`save_to_file` / :func:`load_from_file` for simple JSON file storage.
|
|
6
|
+
- :class:`ConversationStore` for SQLite-backed storage with tag search.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import sqlite3
|
|
13
|
+
import threading
|
|
14
|
+
from datetime import datetime, timezone
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
# ------------------------------------------------------------------
|
|
19
|
+
# File-based persistence
|
|
20
|
+
# ------------------------------------------------------------------
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def save_to_file(data: dict[str, Any], path: str | Path) -> None:
|
|
24
|
+
"""Write a conversation export dict as JSON to *path*.
|
|
25
|
+
|
|
26
|
+
Creates parent directories if they don't exist.
|
|
27
|
+
"""
|
|
28
|
+
p = Path(path)
|
|
29
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
p.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def load_from_file(path: str | Path) -> dict[str, Any]:
|
|
34
|
+
"""Read a conversation export dict from a JSON file.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
FileNotFoundError: If *path* does not exist.
|
|
38
|
+
"""
|
|
39
|
+
p = Path(path)
|
|
40
|
+
if not p.exists():
|
|
41
|
+
raise FileNotFoundError(f"Conversation file not found: {p}")
|
|
42
|
+
return json.loads(p.read_text(encoding="utf-8"))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ------------------------------------------------------------------
|
|
46
|
+
# SQLite-backed ConversationStore
|
|
47
|
+
# ------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
_DEFAULT_DB_DIR = Path.home() / ".prompture" / "conversations"
|
|
50
|
+
_DEFAULT_DB_PATH = _DEFAULT_DB_DIR / "conversations.db"
|
|
51
|
+
|
|
52
|
+
_SCHEMA_SQL = """
|
|
53
|
+
CREATE TABLE IF NOT EXISTS conversations (
|
|
54
|
+
id TEXT PRIMARY KEY,
|
|
55
|
+
model_name TEXT NOT NULL,
|
|
56
|
+
data TEXT NOT NULL,
|
|
57
|
+
created_at TEXT NOT NULL,
|
|
58
|
+
last_active TEXT NOT NULL,
|
|
59
|
+
turn_count INTEGER NOT NULL DEFAULT 0
|
|
60
|
+
);
|
|
61
|
+
|
|
62
|
+
CREATE TABLE IF NOT EXISTS conversation_tags (
|
|
63
|
+
conversation_id TEXT NOT NULL,
|
|
64
|
+
tag TEXT NOT NULL,
|
|
65
|
+
PRIMARY KEY (conversation_id, tag),
|
|
66
|
+
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
|
67
|
+
);
|
|
68
|
+
|
|
69
|
+
CREATE INDEX IF NOT EXISTS idx_tags_tag ON conversation_tags(tag);
|
|
70
|
+
CREATE INDEX IF NOT EXISTS idx_conversations_last_active ON conversations(last_active);
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ConversationStore:
|
|
75
|
+
"""SQLite-backed conversation storage with tag search.
|
|
76
|
+
|
|
77
|
+
Thread-safe — uses an internal :class:`threading.Lock` for all
|
|
78
|
+
database operations (mirrors the pattern used by ``cache.py``).
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
db_path: Path to the SQLite database file. Defaults to
|
|
82
|
+
``~/.prompture/conversations/conversations.db``.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, db_path: str | Path | None = None) -> None:
|
|
86
|
+
self._db_path = Path(db_path) if db_path else _DEFAULT_DB_PATH
|
|
87
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
self._lock = threading.Lock()
|
|
89
|
+
self._init_db()
|
|
90
|
+
|
|
91
|
+
def _init_db(self) -> None:
|
|
92
|
+
with self._lock:
|
|
93
|
+
conn = sqlite3.connect(str(self._db_path))
|
|
94
|
+
try:
|
|
95
|
+
conn.executescript(_SCHEMA_SQL)
|
|
96
|
+
conn.commit()
|
|
97
|
+
finally:
|
|
98
|
+
conn.close()
|
|
99
|
+
|
|
100
|
+
def _connect(self) -> sqlite3.Connection:
|
|
101
|
+
conn = sqlite3.connect(str(self._db_path))
|
|
102
|
+
conn.execute("PRAGMA foreign_keys = ON")
|
|
103
|
+
conn.row_factory = sqlite3.Row
|
|
104
|
+
return conn
|
|
105
|
+
|
|
106
|
+
# ------------------------------------------------------------------ #
|
|
107
|
+
# CRUD
|
|
108
|
+
# ------------------------------------------------------------------ #
|
|
109
|
+
|
|
110
|
+
def save(self, conversation_id: str, data: dict[str, Any]) -> None:
|
|
111
|
+
"""Upsert a conversation and replace its tags."""
|
|
112
|
+
meta = data.get("metadata", {})
|
|
113
|
+
model_name = data.get("model_name", "")
|
|
114
|
+
created_at = meta.get("created_at", datetime.now(timezone.utc).isoformat())
|
|
115
|
+
last_active = meta.get("last_active", datetime.now(timezone.utc).isoformat())
|
|
116
|
+
turn_count = meta.get("turn_count", 0)
|
|
117
|
+
tags = meta.get("tags", [])
|
|
118
|
+
|
|
119
|
+
data_json = json.dumps(data, ensure_ascii=False)
|
|
120
|
+
|
|
121
|
+
with self._lock:
|
|
122
|
+
conn = self._connect()
|
|
123
|
+
try:
|
|
124
|
+
conn.execute(
|
|
125
|
+
"""
|
|
126
|
+
INSERT INTO conversations (id, model_name, data, created_at, last_active, turn_count)
|
|
127
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
128
|
+
ON CONFLICT(id) DO UPDATE SET
|
|
129
|
+
model_name = excluded.model_name,
|
|
130
|
+
data = excluded.data,
|
|
131
|
+
last_active = excluded.last_active,
|
|
132
|
+
turn_count = excluded.turn_count
|
|
133
|
+
""",
|
|
134
|
+
(conversation_id, model_name, data_json, created_at, last_active, turn_count),
|
|
135
|
+
)
|
|
136
|
+
# Replace tags
|
|
137
|
+
conn.execute(
|
|
138
|
+
"DELETE FROM conversation_tags WHERE conversation_id = ?",
|
|
139
|
+
(conversation_id,),
|
|
140
|
+
)
|
|
141
|
+
if tags:
|
|
142
|
+
conn.executemany(
|
|
143
|
+
"INSERT INTO conversation_tags (conversation_id, tag) VALUES (?, ?)",
|
|
144
|
+
[(conversation_id, t) for t in tags],
|
|
145
|
+
)
|
|
146
|
+
conn.commit()
|
|
147
|
+
finally:
|
|
148
|
+
conn.close()
|
|
149
|
+
|
|
150
|
+
def load(self, conversation_id: str) -> dict[str, Any] | None:
|
|
151
|
+
"""Load a conversation by ID. Returns ``None`` if not found."""
|
|
152
|
+
with self._lock:
|
|
153
|
+
conn = self._connect()
|
|
154
|
+
try:
|
|
155
|
+
row = conn.execute(
|
|
156
|
+
"SELECT data FROM conversations WHERE id = ?",
|
|
157
|
+
(conversation_id,),
|
|
158
|
+
).fetchone()
|
|
159
|
+
if row is None:
|
|
160
|
+
return None
|
|
161
|
+
return json.loads(row["data"])
|
|
162
|
+
finally:
|
|
163
|
+
conn.close()
|
|
164
|
+
|
|
165
|
+
def delete(self, conversation_id: str) -> bool:
|
|
166
|
+
"""Delete a conversation. Returns *True* if it existed."""
|
|
167
|
+
with self._lock:
|
|
168
|
+
conn = self._connect()
|
|
169
|
+
try:
|
|
170
|
+
cursor = conn.execute(
|
|
171
|
+
"DELETE FROM conversations WHERE id = ?",
|
|
172
|
+
(conversation_id,),
|
|
173
|
+
)
|
|
174
|
+
conn.commit()
|
|
175
|
+
return cursor.rowcount > 0
|
|
176
|
+
finally:
|
|
177
|
+
conn.close()
|
|
178
|
+
|
|
179
|
+
# ------------------------------------------------------------------ #
|
|
180
|
+
# Search / listing
|
|
181
|
+
# ------------------------------------------------------------------ #
|
|
182
|
+
|
|
183
|
+
def find_by_tag(self, tag: str) -> list[dict[str, Any]]:
|
|
184
|
+
"""Return summary dicts for all conversations with the given tag."""
|
|
185
|
+
with self._lock:
|
|
186
|
+
conn = self._connect()
|
|
187
|
+
try:
|
|
188
|
+
rows = conn.execute(
|
|
189
|
+
"""
|
|
190
|
+
SELECT c.id, c.model_name, c.created_at, c.last_active, c.turn_count
|
|
191
|
+
FROM conversations c
|
|
192
|
+
INNER JOIN conversation_tags ct ON c.id = ct.conversation_id
|
|
193
|
+
WHERE ct.tag = ?
|
|
194
|
+
ORDER BY c.last_active DESC
|
|
195
|
+
""",
|
|
196
|
+
(tag,),
|
|
197
|
+
).fetchall()
|
|
198
|
+
return [self._row_to_summary(conn, r) for r in rows]
|
|
199
|
+
finally:
|
|
200
|
+
conn.close()
|
|
201
|
+
|
|
202
|
+
def find_by_id(self, conversation_id: str) -> dict[str, Any] | None:
|
|
203
|
+
"""Return a summary dict (with tags) for a conversation, or ``None``."""
|
|
204
|
+
with self._lock:
|
|
205
|
+
conn = self._connect()
|
|
206
|
+
try:
|
|
207
|
+
row = conn.execute(
|
|
208
|
+
"SELECT id, model_name, created_at, last_active, turn_count FROM conversations WHERE id = ?",
|
|
209
|
+
(conversation_id,),
|
|
210
|
+
).fetchone()
|
|
211
|
+
if row is None:
|
|
212
|
+
return None
|
|
213
|
+
return self._row_to_summary(conn, row)
|
|
214
|
+
finally:
|
|
215
|
+
conn.close()
|
|
216
|
+
|
|
217
|
+
def list_all(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
|
|
218
|
+
"""Return summary dicts ordered by ``last_active`` descending."""
|
|
219
|
+
with self._lock:
|
|
220
|
+
conn = self._connect()
|
|
221
|
+
try:
|
|
222
|
+
rows = conn.execute(
|
|
223
|
+
"""
|
|
224
|
+
SELECT id, model_name, created_at, last_active, turn_count
|
|
225
|
+
FROM conversations
|
|
226
|
+
ORDER BY last_active DESC
|
|
227
|
+
LIMIT ? OFFSET ?
|
|
228
|
+
""",
|
|
229
|
+
(limit, offset),
|
|
230
|
+
).fetchall()
|
|
231
|
+
return [self._row_to_summary(conn, r) for r in rows]
|
|
232
|
+
finally:
|
|
233
|
+
conn.close()
|
|
234
|
+
|
|
235
|
+
# ------------------------------------------------------------------ #
|
|
236
|
+
# Internal
|
|
237
|
+
# ------------------------------------------------------------------ #
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _row_to_summary(conn: sqlite3.Connection, row: sqlite3.Row) -> dict[str, Any]:
|
|
241
|
+
"""Build a summary dict from a DB row, including tags."""
|
|
242
|
+
cid = row["id"]
|
|
243
|
+
tag_rows = conn.execute(
|
|
244
|
+
"SELECT tag FROM conversation_tags WHERE conversation_id = ?",
|
|
245
|
+
(cid,),
|
|
246
|
+
).fetchall()
|
|
247
|
+
return {
|
|
248
|
+
"id": cid,
|
|
249
|
+
"model_name": row["model_name"],
|
|
250
|
+
"created_at": row["created_at"],
|
|
251
|
+
"last_active": row["last_active"],
|
|
252
|
+
"turn_count": row["turn_count"],
|
|
253
|
+
"tags": [tr["tag"] for tr in tag_rows],
|
|
254
|
+
}
|