prompture 0.0.35__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +132 -3
- prompture/_version.py +2 -2
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +208 -17
- prompture/async_core.py +16 -0
- prompture/async_driver.py +63 -0
- prompture/async_groups.py +551 -0
- prompture/conversation.py +222 -18
- prompture/core.py +46 -12
- prompture/cost_mixin.py +37 -0
- prompture/discovery.py +132 -44
- prompture/driver.py +77 -0
- prompture/drivers/__init__.py +5 -1
- prompture/drivers/async_azure_driver.py +11 -5
- prompture/drivers/async_claude_driver.py +184 -9
- prompture/drivers/async_google_driver.py +222 -28
- prompture/drivers/async_grok_driver.py +11 -5
- prompture/drivers/async_groq_driver.py +11 -5
- prompture/drivers/async_lmstudio_driver.py +74 -5
- prompture/drivers/async_ollama_driver.py +13 -3
- prompture/drivers/async_openai_driver.py +162 -5
- prompture/drivers/async_openrouter_driver.py +11 -5
- prompture/drivers/async_registry.py +5 -1
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +17 -1
- prompture/drivers/google_driver.py +227 -33
- prompture/drivers/grok_driver.py +11 -5
- prompture/drivers/groq_driver.py +11 -5
- prompture/drivers/lmstudio_driver.py +73 -8
- prompture/drivers/ollama_driver.py +16 -5
- prompture/drivers/openai_driver.py +26 -11
- prompture/drivers/openrouter_driver.py +11 -5
- prompture/drivers/vision_helpers.py +153 -0
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +112 -2
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/serialization.py +218 -0
- prompture/settings.py +1 -0
- prompture-0.0.40.dev1.dist-info/METADATA +369 -0
- prompture-0.0.40.dev1.dist-info/RECORD +78 -0
- prompture-0.0.35.dist-info/METADATA +0 -464
- prompture-0.0.35.dist-info/RECORD +0 -66
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.35.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/image.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Image handling utilities for vision-capable LLM drivers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import mimetypes
|
|
7
|
+
import re
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class ImageContent:
|
|
15
|
+
"""Normalized image representation for vision-capable drivers.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
data: Base64-encoded image data.
|
|
19
|
+
media_type: MIME type (e.g. ``"image/png"``, ``"image/jpeg"``).
|
|
20
|
+
source_type: How the image is delivered — ``"base64"`` or ``"url"``.
|
|
21
|
+
url: Original URL when ``source_type`` is ``"url"``.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
data: str
|
|
25
|
+
media_type: str
|
|
26
|
+
source_type: str = "base64"
|
|
27
|
+
url: str | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Public type alias accepted by all image-aware APIs.
|
|
31
|
+
ImageInput = Union[bytes, str, Path, ImageContent]
|
|
32
|
+
|
|
33
|
+
# Known data-URI prefix pattern
|
|
34
|
+
_DATA_URI_RE = re.compile(r"^data:(image/[a-zA-Z0-9.+-]+);base64,(.+)$", re.DOTALL)
|
|
35
|
+
|
|
36
|
+
# Base64 detection heuristic — must look like pure base64 of reasonable length
|
|
37
|
+
_BASE64_RE = re.compile(r"^[A-Za-z0-9+/\n\r]+=*$")
|
|
38
|
+
|
|
39
|
+
_MIME_FROM_EXT: dict[str, str] = {
|
|
40
|
+
".jpg": "image/jpeg",
|
|
41
|
+
".jpeg": "image/jpeg",
|
|
42
|
+
".png": "image/png",
|
|
43
|
+
".gif": "image/gif",
|
|
44
|
+
".webp": "image/webp",
|
|
45
|
+
".bmp": "image/bmp",
|
|
46
|
+
".svg": "image/svg+xml",
|
|
47
|
+
".tiff": "image/tiff",
|
|
48
|
+
".tif": "image/tiff",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
_MAGIC_BYTES: list[tuple[bytes, str]] = [
|
|
52
|
+
(b"\x89PNG", "image/png"),
|
|
53
|
+
(b"\xff\xd8\xff", "image/jpeg"),
|
|
54
|
+
(b"GIF87a", "image/gif"),
|
|
55
|
+
(b"GIF89a", "image/gif"),
|
|
56
|
+
(b"RIFF", "image/webp"), # WebP starts with RIFF...WEBP
|
|
57
|
+
(b"BM", "image/bmp"),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _guess_media_type_from_bytes(data: bytes) -> str:
|
|
62
|
+
"""Guess MIME type from the first few bytes of image data."""
|
|
63
|
+
for magic, mime in _MAGIC_BYTES:
|
|
64
|
+
if data[: len(magic)] == magic:
|
|
65
|
+
return mime
|
|
66
|
+
return "image/png" # safe fallback
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _guess_media_type(path: str) -> str:
|
|
70
|
+
"""Guess MIME type from a file path or URL."""
|
|
71
|
+
# Strip query strings for URLs
|
|
72
|
+
clean = path.split("?")[0].split("#")[0]
|
|
73
|
+
ext = Path(clean).suffix.lower()
|
|
74
|
+
if ext in _MIME_FROM_EXT:
|
|
75
|
+
return _MIME_FROM_EXT[ext]
|
|
76
|
+
guessed = mimetypes.guess_type(clean)[0]
|
|
77
|
+
return guessed or "image/png"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
# Constructor functions
|
|
82
|
+
# ------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def image_from_bytes(data: bytes, media_type: str | None = None) -> ImageContent:
|
|
86
|
+
"""Create an :class:`ImageContent` from raw bytes.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
data: Raw image bytes.
|
|
90
|
+
media_type: MIME type. Auto-detected from magic bytes when *None*.
|
|
91
|
+
"""
|
|
92
|
+
if not data:
|
|
93
|
+
raise ValueError("Image data cannot be empty")
|
|
94
|
+
b64 = base64.b64encode(data).decode("ascii")
|
|
95
|
+
mt = media_type or _guess_media_type_from_bytes(data)
|
|
96
|
+
return ImageContent(data=b64, media_type=mt)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def image_from_base64(b64: str, media_type: str = "image/png") -> ImageContent:
|
|
100
|
+
"""Create an :class:`ImageContent` from a base64-encoded string.
|
|
101
|
+
|
|
102
|
+
Accepts both raw base64 and ``data:`` URIs.
|
|
103
|
+
"""
|
|
104
|
+
m = _DATA_URI_RE.match(b64)
|
|
105
|
+
if m:
|
|
106
|
+
return ImageContent(data=m.group(2), media_type=m.group(1))
|
|
107
|
+
return ImageContent(data=b64, media_type=media_type)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def image_from_file(path: str | Path, media_type: str | None = None) -> ImageContent:
|
|
111
|
+
"""Create an :class:`ImageContent` by reading a local file.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
path: Path to an image file.
|
|
115
|
+
media_type: MIME type. Guessed from extension when *None*.
|
|
116
|
+
"""
|
|
117
|
+
p = Path(path)
|
|
118
|
+
if not p.exists():
|
|
119
|
+
raise FileNotFoundError(f"Image file not found: {p}")
|
|
120
|
+
raw = p.read_bytes()
|
|
121
|
+
mt = media_type or _guess_media_type(str(p))
|
|
122
|
+
return image_from_bytes(raw, mt)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def image_from_url(url: str, media_type: str | None = None) -> ImageContent:
|
|
126
|
+
"""Create an :class:`ImageContent` referencing a remote URL.
|
|
127
|
+
|
|
128
|
+
The image is **not** downloaded — the URL is stored directly so
|
|
129
|
+
drivers that accept URL-based images can pass it through. For
|
|
130
|
+
drivers that require base64, the URL is embedded as a data URI.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
url: Publicly-accessible image URL.
|
|
134
|
+
media_type: MIME type. Guessed from the URL when *None*.
|
|
135
|
+
"""
|
|
136
|
+
mt = media_type or _guess_media_type(url)
|
|
137
|
+
return ImageContent(data="", media_type=mt, source_type="url", url=url)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# ------------------------------------------------------------------
|
|
141
|
+
# Smart constructor
|
|
142
|
+
# ------------------------------------------------------------------
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def make_image(source: ImageInput) -> ImageContent:
|
|
146
|
+
"""Auto-detect the source type and return an :class:`ImageContent`.
|
|
147
|
+
|
|
148
|
+
Accepts:
|
|
149
|
+
- ``ImageContent`` — returned as-is.
|
|
150
|
+
- ``bytes`` — base64-encoded with auto-detected MIME.
|
|
151
|
+
- ``str`` — tries (in order): data URI, URL, file path, raw base64.
|
|
152
|
+
- ``pathlib.Path`` — read from disk.
|
|
153
|
+
"""
|
|
154
|
+
if isinstance(source, ImageContent):
|
|
155
|
+
return source
|
|
156
|
+
|
|
157
|
+
if isinstance(source, bytes):
|
|
158
|
+
return image_from_bytes(source)
|
|
159
|
+
|
|
160
|
+
if isinstance(source, Path):
|
|
161
|
+
return image_from_file(source)
|
|
162
|
+
|
|
163
|
+
if isinstance(source, str):
|
|
164
|
+
# 1. data URI
|
|
165
|
+
if source.startswith("data:"):
|
|
166
|
+
return image_from_base64(source)
|
|
167
|
+
|
|
168
|
+
# 2. URL
|
|
169
|
+
if source.startswith(("http://", "https://")):
|
|
170
|
+
return image_from_url(source)
|
|
171
|
+
|
|
172
|
+
# 3. File path (if exists on disk)
|
|
173
|
+
p = Path(source)
|
|
174
|
+
if p.exists():
|
|
175
|
+
return image_from_file(p)
|
|
176
|
+
|
|
177
|
+
# 4. Assume raw base64
|
|
178
|
+
return image_from_base64(source)
|
|
179
|
+
|
|
180
|
+
raise TypeError(f"Unsupported image source type: {type(source).__name__}")
|
prompture/ledger.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Persistent model usage ledger — tracks which LLM models have been used.
|
|
2
|
+
|
|
3
|
+
Stores per-model usage stats (call count, tokens, cost, timestamps) in a
|
|
4
|
+
SQLite database at ``~/.prompture/usage/model_ledger.db``. The public
|
|
5
|
+
convenience functions are fire-and-forget: they never raise exceptions so
|
|
6
|
+
they cannot break existing extraction/conversation flows.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import logging
|
|
13
|
+
import sqlite3
|
|
14
|
+
import threading
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("prompture.ledger")
|
|
20
|
+
|
|
21
|
+
_DEFAULT_DB_DIR = Path.home() / ".prompture" / "usage"
|
|
22
|
+
_DEFAULT_DB_PATH = _DEFAULT_DB_DIR / "model_ledger.db"
|
|
23
|
+
|
|
24
|
+
_SCHEMA_SQL = """
|
|
25
|
+
CREATE TABLE IF NOT EXISTS model_usage (
|
|
26
|
+
model_name TEXT NOT NULL,
|
|
27
|
+
api_key_hash TEXT NOT NULL,
|
|
28
|
+
use_count INTEGER NOT NULL DEFAULT 1,
|
|
29
|
+
total_tokens INTEGER NOT NULL DEFAULT 0,
|
|
30
|
+
total_cost REAL NOT NULL DEFAULT 0.0,
|
|
31
|
+
first_used TEXT NOT NULL,
|
|
32
|
+
last_used TEXT NOT NULL,
|
|
33
|
+
last_status TEXT NOT NULL DEFAULT 'success',
|
|
34
|
+
PRIMARY KEY (model_name, api_key_hash)
|
|
35
|
+
);
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModelUsageLedger:
|
|
40
|
+
"""SQLite-backed model usage tracker.
|
|
41
|
+
|
|
42
|
+
Thread-safe via an internal :class:`threading.Lock`.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
db_path: Path to the SQLite database file. Defaults to
|
|
46
|
+
``~/.prompture/usage/model_ledger.db``.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, db_path: str | Path | None = None) -> None:
|
|
50
|
+
self._db_path = Path(db_path) if db_path else _DEFAULT_DB_PATH
|
|
51
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
self._lock = threading.Lock()
|
|
53
|
+
self._init_db()
|
|
54
|
+
|
|
55
|
+
def _init_db(self) -> None:
|
|
56
|
+
with self._lock:
|
|
57
|
+
conn = sqlite3.connect(str(self._db_path))
|
|
58
|
+
try:
|
|
59
|
+
conn.executescript(_SCHEMA_SQL)
|
|
60
|
+
conn.commit()
|
|
61
|
+
finally:
|
|
62
|
+
conn.close()
|
|
63
|
+
|
|
64
|
+
def _connect(self) -> sqlite3.Connection:
|
|
65
|
+
conn = sqlite3.connect(str(self._db_path))
|
|
66
|
+
conn.row_factory = sqlite3.Row
|
|
67
|
+
return conn
|
|
68
|
+
|
|
69
|
+
# ------------------------------------------------------------------ #
|
|
70
|
+
# Recording
|
|
71
|
+
# ------------------------------------------------------------------ #
|
|
72
|
+
|
|
73
|
+
def record_usage(
|
|
74
|
+
self,
|
|
75
|
+
model_name: str,
|
|
76
|
+
*,
|
|
77
|
+
api_key_hash: str = "",
|
|
78
|
+
tokens: int = 0,
|
|
79
|
+
cost: float = 0.0,
|
|
80
|
+
status: str = "success",
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Record a model usage event (upsert).
|
|
83
|
+
|
|
84
|
+
On conflict the row's counters are incremented and ``last_used``
|
|
85
|
+
is updated.
|
|
86
|
+
"""
|
|
87
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
88
|
+
with self._lock:
|
|
89
|
+
conn = self._connect()
|
|
90
|
+
try:
|
|
91
|
+
conn.execute(
|
|
92
|
+
"""
|
|
93
|
+
INSERT INTO model_usage
|
|
94
|
+
(model_name, api_key_hash, use_count, total_tokens, total_cost,
|
|
95
|
+
first_used, last_used, last_status)
|
|
96
|
+
VALUES (?, ?, 1, ?, ?, ?, ?, ?)
|
|
97
|
+
ON CONFLICT(model_name, api_key_hash) DO UPDATE SET
|
|
98
|
+
use_count = use_count + 1,
|
|
99
|
+
total_tokens = total_tokens + excluded.total_tokens,
|
|
100
|
+
total_cost = total_cost + excluded.total_cost,
|
|
101
|
+
last_used = excluded.last_used,
|
|
102
|
+
last_status = excluded.last_status
|
|
103
|
+
""",
|
|
104
|
+
(model_name, api_key_hash, tokens, cost, now, now, status),
|
|
105
|
+
)
|
|
106
|
+
conn.commit()
|
|
107
|
+
finally:
|
|
108
|
+
conn.close()
|
|
109
|
+
|
|
110
|
+
# ------------------------------------------------------------------ #
|
|
111
|
+
# Queries
|
|
112
|
+
# ------------------------------------------------------------------ #
|
|
113
|
+
|
|
114
|
+
def get_model_stats(self, model_name: str, api_key_hash: str = "") -> dict[str, Any] | None:
|
|
115
|
+
"""Return stats for a specific model + key combination, or ``None``."""
|
|
116
|
+
with self._lock:
|
|
117
|
+
conn = self._connect()
|
|
118
|
+
try:
|
|
119
|
+
row = conn.execute(
|
|
120
|
+
"SELECT * FROM model_usage WHERE model_name = ? AND api_key_hash = ?",
|
|
121
|
+
(model_name, api_key_hash),
|
|
122
|
+
).fetchone()
|
|
123
|
+
if row is None:
|
|
124
|
+
return None
|
|
125
|
+
return dict(row)
|
|
126
|
+
finally:
|
|
127
|
+
conn.close()
|
|
128
|
+
|
|
129
|
+
def get_verified_models(self) -> set[str]:
|
|
130
|
+
"""Return model names that have at least one successful usage."""
|
|
131
|
+
with self._lock:
|
|
132
|
+
conn = self._connect()
|
|
133
|
+
try:
|
|
134
|
+
rows = conn.execute(
|
|
135
|
+
"SELECT DISTINCT model_name FROM model_usage WHERE last_status = 'success'"
|
|
136
|
+
).fetchall()
|
|
137
|
+
return {r["model_name"] for r in rows}
|
|
138
|
+
finally:
|
|
139
|
+
conn.close()
|
|
140
|
+
|
|
141
|
+
def get_recently_used(self, limit: int = 10) -> list[dict[str, Any]]:
|
|
142
|
+
"""Return recent model usage rows ordered by ``last_used`` descending."""
|
|
143
|
+
with self._lock:
|
|
144
|
+
conn = self._connect()
|
|
145
|
+
try:
|
|
146
|
+
rows = conn.execute(
|
|
147
|
+
"SELECT * FROM model_usage ORDER BY last_used DESC LIMIT ?",
|
|
148
|
+
(limit,),
|
|
149
|
+
).fetchall()
|
|
150
|
+
return [dict(r) for r in rows]
|
|
151
|
+
finally:
|
|
152
|
+
conn.close()
|
|
153
|
+
|
|
154
|
+
def get_all_stats(self) -> list[dict[str, Any]]:
|
|
155
|
+
"""Return all usage rows."""
|
|
156
|
+
with self._lock:
|
|
157
|
+
conn = self._connect()
|
|
158
|
+
try:
|
|
159
|
+
rows = conn.execute("SELECT * FROM model_usage ORDER BY last_used DESC").fetchall()
|
|
160
|
+
return [dict(r) for r in rows]
|
|
161
|
+
finally:
|
|
162
|
+
conn.close()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# ------------------------------------------------------------------
|
|
166
|
+
# Module-level singleton
|
|
167
|
+
# ------------------------------------------------------------------
|
|
168
|
+
|
|
169
|
+
_ledger: ModelUsageLedger | None = None
|
|
170
|
+
_ledger_lock = threading.Lock()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _get_ledger() -> ModelUsageLedger:
|
|
174
|
+
"""Return (and lazily create) the module-level singleton ledger."""
|
|
175
|
+
global _ledger
|
|
176
|
+
if _ledger is None:
|
|
177
|
+
with _ledger_lock:
|
|
178
|
+
if _ledger is None:
|
|
179
|
+
_ledger = ModelUsageLedger()
|
|
180
|
+
return _ledger
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ------------------------------------------------------------------
|
|
184
|
+
# Public convenience functions (fire-and-forget)
|
|
185
|
+
# ------------------------------------------------------------------
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def record_model_usage(
|
|
189
|
+
model_name: str,
|
|
190
|
+
*,
|
|
191
|
+
api_key_hash: str = "",
|
|
192
|
+
tokens: int = 0,
|
|
193
|
+
cost: float = 0.0,
|
|
194
|
+
status: str = "success",
|
|
195
|
+
) -> None:
|
|
196
|
+
"""Record a model usage event. Never raises — all exceptions are swallowed."""
|
|
197
|
+
try:
|
|
198
|
+
_get_ledger().record_usage(
|
|
199
|
+
model_name,
|
|
200
|
+
api_key_hash=api_key_hash,
|
|
201
|
+
tokens=tokens,
|
|
202
|
+
cost=cost,
|
|
203
|
+
status=status,
|
|
204
|
+
)
|
|
205
|
+
except Exception:
|
|
206
|
+
logger.debug("Failed to record model usage for %s", model_name, exc_info=True)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_recently_used_models(limit: int = 10) -> list[dict[str, Any]]:
|
|
210
|
+
"""Return recently used models. Returns empty list on error."""
|
|
211
|
+
try:
|
|
212
|
+
return _get_ledger().get_recently_used(limit)
|
|
213
|
+
except Exception:
|
|
214
|
+
logger.debug("Failed to get recently used models", exc_info=True)
|
|
215
|
+
return []
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# ------------------------------------------------------------------
|
|
219
|
+
# API key hash helper
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
|
|
222
|
+
_LOCAL_PROVIDERS = frozenset({"ollama", "lmstudio", "local_http", "airllm"})
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _resolve_api_key_hash(model_name: str) -> str:
|
|
226
|
+
"""Derive an 8-char hex hash of the API key for the given model's provider.
|
|
227
|
+
|
|
228
|
+
Local providers (ollama, lmstudio, etc.) return ``""``.
|
|
229
|
+
"""
|
|
230
|
+
try:
|
|
231
|
+
provider = model_name.split("/", 1)[0].lower() if "/" in model_name else model_name.lower()
|
|
232
|
+
if provider in _LOCAL_PROVIDERS:
|
|
233
|
+
return ""
|
|
234
|
+
|
|
235
|
+
from .settings import settings
|
|
236
|
+
|
|
237
|
+
key_map: dict[str, str | None] = {
|
|
238
|
+
"openai": settings.openai_api_key,
|
|
239
|
+
"claude": settings.claude_api_key,
|
|
240
|
+
"google": settings.google_api_key,
|
|
241
|
+
"groq": settings.groq_api_key,
|
|
242
|
+
"grok": settings.grok_api_key,
|
|
243
|
+
"openrouter": settings.openrouter_api_key,
|
|
244
|
+
"azure": settings.azure_api_key,
|
|
245
|
+
"huggingface": settings.hf_token,
|
|
246
|
+
}
|
|
247
|
+
api_key = key_map.get(provider)
|
|
248
|
+
if not api_key:
|
|
249
|
+
return ""
|
|
250
|
+
return hashlib.sha256(api_key.encode()).hexdigest()[:8]
|
|
251
|
+
except Exception:
|
|
252
|
+
return ""
|
prompture/model_rates.py
CHANGED
|
@@ -9,6 +9,7 @@ import contextlib
|
|
|
9
9
|
import json
|
|
10
10
|
import logging
|
|
11
11
|
import threading
|
|
12
|
+
from dataclasses import dataclass
|
|
12
13
|
from datetime import datetime, timezone
|
|
13
14
|
from pathlib import Path
|
|
14
15
|
from typing import Any, Optional
|
|
@@ -139,7 +140,12 @@ def _lookup_model(provider: str, model_id: str) -> Optional[dict[str, Any]]:
|
|
|
139
140
|
if not isinstance(provider_data, dict):
|
|
140
141
|
return None
|
|
141
142
|
|
|
142
|
-
|
|
143
|
+
# models.dev nests actual models under a "models" key
|
|
144
|
+
models = provider_data.get("models", provider_data)
|
|
145
|
+
if not isinstance(models, dict):
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
return models.get(model_id)
|
|
143
149
|
|
|
144
150
|
|
|
145
151
|
# ── Public API ──────────────────────────────────────────────────────────────
|
|
@@ -189,7 +195,12 @@ def get_all_provider_models(provider: str) -> list[str]:
|
|
|
189
195
|
if not isinstance(provider_data, dict):
|
|
190
196
|
return []
|
|
191
197
|
|
|
192
|
-
|
|
198
|
+
# models.dev nests actual models under a "models" key
|
|
199
|
+
models = provider_data.get("models", provider_data)
|
|
200
|
+
if not isinstance(models, dict):
|
|
201
|
+
return []
|
|
202
|
+
|
|
203
|
+
return list(models.keys())
|
|
193
204
|
|
|
194
205
|
|
|
195
206
|
def refresh_rates_cache(force: bool = False) -> bool:
|
|
@@ -215,3 +226,102 @@ def refresh_rates_cache(force: bool = False) -> bool:
|
|
|
215
226
|
return True
|
|
216
227
|
|
|
217
228
|
return False
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# ── Model Capabilities ─────────────────────────────────────────────────────
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclass(frozen=True)
|
|
235
|
+
class ModelCapabilities:
|
|
236
|
+
"""Normalized capability metadata for an LLM model from models.dev.
|
|
237
|
+
|
|
238
|
+
All fields default to ``None`` (unknown) so callers can distinguish
|
|
239
|
+
"the model doesn't support X" from "we have no data about X".
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
supports_temperature: Optional[bool] = None
|
|
243
|
+
supports_tool_use: Optional[bool] = None
|
|
244
|
+
supports_structured_output: Optional[bool] = None
|
|
245
|
+
supports_vision: Optional[bool] = None
|
|
246
|
+
is_reasoning: Optional[bool] = None
|
|
247
|
+
context_window: Optional[int] = None
|
|
248
|
+
max_output_tokens: Optional[int] = None
|
|
249
|
+
modalities_input: tuple[str, ...] = ()
|
|
250
|
+
modalities_output: tuple[str, ...] = ()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def get_model_capabilities(provider: str, model_id: str) -> Optional[ModelCapabilities]:
|
|
254
|
+
"""Return capability metadata for a model, or ``None`` if unavailable.
|
|
255
|
+
|
|
256
|
+
Maps models.dev fields to a :class:`ModelCapabilities` instance:
|
|
257
|
+
|
|
258
|
+
- ``temperature`` → ``supports_temperature``
|
|
259
|
+
- ``tool_call`` → ``supports_tool_use``
|
|
260
|
+
- ``structured_output`` → ``supports_structured_output``
|
|
261
|
+
- ``"image" in modalities.input`` → ``supports_vision``
|
|
262
|
+
- ``reasoning`` → ``is_reasoning``
|
|
263
|
+
- ``limit.context`` → ``context_window``
|
|
264
|
+
- ``limit.output`` → ``max_output_tokens``
|
|
265
|
+
"""
|
|
266
|
+
entry = _lookup_model(provider, model_id)
|
|
267
|
+
if entry is None:
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
# Boolean capabilities (True/False/None)
|
|
271
|
+
supports_temperature: Optional[bool] = None
|
|
272
|
+
if "temperature" in entry:
|
|
273
|
+
supports_temperature = bool(entry["temperature"])
|
|
274
|
+
|
|
275
|
+
supports_tool_use: Optional[bool] = None
|
|
276
|
+
if "tool_call" in entry:
|
|
277
|
+
supports_tool_use = bool(entry["tool_call"])
|
|
278
|
+
|
|
279
|
+
supports_structured_output: Optional[bool] = None
|
|
280
|
+
if "structured_output" in entry:
|
|
281
|
+
supports_structured_output = bool(entry["structured_output"])
|
|
282
|
+
|
|
283
|
+
is_reasoning: Optional[bool] = None
|
|
284
|
+
if "reasoning" in entry:
|
|
285
|
+
is_reasoning = bool(entry["reasoning"])
|
|
286
|
+
|
|
287
|
+
# Modalities
|
|
288
|
+
modalities = entry.get("modalities", {})
|
|
289
|
+
modalities_input: tuple[str, ...] = ()
|
|
290
|
+
modalities_output: tuple[str, ...] = ()
|
|
291
|
+
if isinstance(modalities, dict):
|
|
292
|
+
raw_in = modalities.get("input")
|
|
293
|
+
if isinstance(raw_in, (list, tuple)):
|
|
294
|
+
modalities_input = tuple(str(m) for m in raw_in)
|
|
295
|
+
raw_out = modalities.get("output")
|
|
296
|
+
if isinstance(raw_out, (list, tuple)):
|
|
297
|
+
modalities_output = tuple(str(m) for m in raw_out)
|
|
298
|
+
|
|
299
|
+
supports_vision: Optional[bool] = None
|
|
300
|
+
if modalities_input:
|
|
301
|
+
supports_vision = "image" in modalities_input
|
|
302
|
+
|
|
303
|
+
# Limits
|
|
304
|
+
context_window: Optional[int] = None
|
|
305
|
+
max_output_tokens: Optional[int] = None
|
|
306
|
+
limits = entry.get("limit", {})
|
|
307
|
+
if isinstance(limits, dict):
|
|
308
|
+
ctx = limits.get("context")
|
|
309
|
+
if ctx is not None:
|
|
310
|
+
with contextlib.suppress(TypeError, ValueError):
|
|
311
|
+
context_window = int(ctx)
|
|
312
|
+
out = limits.get("output")
|
|
313
|
+
if out is not None:
|
|
314
|
+
with contextlib.suppress(TypeError, ValueError):
|
|
315
|
+
max_output_tokens = int(out)
|
|
316
|
+
|
|
317
|
+
return ModelCapabilities(
|
|
318
|
+
supports_temperature=supports_temperature,
|
|
319
|
+
supports_tool_use=supports_tool_use,
|
|
320
|
+
supports_structured_output=supports_structured_output,
|
|
321
|
+
supports_vision=supports_vision,
|
|
322
|
+
is_reasoning=is_reasoning,
|
|
323
|
+
context_window=context_window,
|
|
324
|
+
max_output_tokens=max_output_tokens,
|
|
325
|
+
modalities_input=modalities_input,
|
|
326
|
+
modalities_output=modalities_output,
|
|
327
|
+
)
|