prompture 0.0.38.dev2__py3-none-any.whl → 0.0.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +12 -1
- prompture/_version.py +2 -2
- prompture/agent.py +11 -11
- prompture/async_agent.py +11 -11
- prompture/async_conversation.py +9 -0
- prompture/async_core.py +16 -0
- prompture/async_driver.py +39 -0
- prompture/async_groups.py +63 -0
- prompture/conversation.py +9 -0
- prompture/core.py +16 -0
- prompture/cost_mixin.py +62 -0
- prompture/discovery.py +108 -43
- prompture/driver.py +39 -0
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_azure_driver.py +7 -6
- prompture/drivers/async_claude_driver.py +177 -8
- prompture/drivers/async_google_driver.py +10 -0
- prompture/drivers/async_grok_driver.py +4 -4
- prompture/drivers/async_groq_driver.py +4 -4
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +312 -0
- prompture/drivers/async_openai_driver.py +158 -6
- prompture/drivers/async_openrouter_driver.py +196 -7
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +303 -0
- prompture/drivers/azure_driver.py +6 -5
- prompture/drivers/claude_driver.py +10 -0
- prompture/drivers/google_driver.py +10 -0
- prompture/drivers/grok_driver.py +4 -4
- prompture/drivers/groq_driver.py +4 -4
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +342 -0
- prompture/drivers/openai_driver.py +22 -12
- prompture/drivers/openrouter_driver.py +248 -44
- prompture/drivers/zai_driver.py +318 -0
- prompture/groups.py +42 -0
- prompture/ledger.py +252 -0
- prompture/model_rates.py +114 -2
- prompture/settings.py +16 -1
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/METADATA +1 -1
- prompture-0.0.42.dist-info/RECORD +84 -0
- prompture-0.0.38.dev2.dist-info/RECORD +0 -77
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/WHEEL +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.38.dev2.dist-info → prompture-0.0.42.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""Z.ai (Zhipu AI) driver implementation.
|
|
2
|
+
Requires the `requests` package. Uses ZHIPU_API_KEY env var.
|
|
3
|
+
|
|
4
|
+
The Z.ai API is fully OpenAI-compatible (/chat/completions).
|
|
5
|
+
All pricing comes from models.dev (provider: "zai") — no hardcoded pricing.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import requests
|
|
14
|
+
|
|
15
|
+
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
16
|
+
from ..driver import Driver
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ZaiDriver(CostMixin, Driver):
|
|
20
|
+
supports_json_mode = True
|
|
21
|
+
supports_json_schema = True
|
|
22
|
+
supports_tool_use = True
|
|
23
|
+
supports_streaming = True
|
|
24
|
+
supports_vision = True
|
|
25
|
+
|
|
26
|
+
# All pricing resolved live from models.dev (provider: "zai")
|
|
27
|
+
MODEL_PRICING: dict[str, dict[str, Any]] = {}
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
api_key: str | None = None,
|
|
32
|
+
model: str = "glm-4.7",
|
|
33
|
+
endpoint: str = "https://api.z.ai/api/paas/v4",
|
|
34
|
+
):
|
|
35
|
+
"""Initialize Z.ai driver.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
api_key: Zhipu API key. If not provided, will look for ZHIPU_API_KEY env var.
|
|
39
|
+
model: Model to use. Defaults to glm-4.7.
|
|
40
|
+
endpoint: API base URL. Defaults to https://api.z.ai/api/paas/v4.
|
|
41
|
+
"""
|
|
42
|
+
self.api_key = api_key or os.getenv("ZHIPU_API_KEY")
|
|
43
|
+
if not self.api_key:
|
|
44
|
+
raise ValueError("Zhipu API key not found. Set ZHIPU_API_KEY env var.")
|
|
45
|
+
|
|
46
|
+
self.model = model
|
|
47
|
+
self.base_url = endpoint.rstrip("/")
|
|
48
|
+
|
|
49
|
+
self.headers = {
|
|
50
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
51
|
+
"Content-Type": "application/json",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
supports_messages = True
|
|
55
|
+
|
|
56
|
+
def _prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
57
|
+
from .vision_helpers import _prepare_openai_vision_messages
|
|
58
|
+
|
|
59
|
+
return _prepare_openai_vision_messages(messages)
|
|
60
|
+
|
|
61
|
+
def generate(self, prompt: str, options: dict[str, Any]) -> dict[str, Any]:
|
|
62
|
+
messages = [{"role": "user", "content": prompt}]
|
|
63
|
+
return self._do_generate(messages, options)
|
|
64
|
+
|
|
65
|
+
def generate_messages(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
return self._do_generate(self._prepare_messages(messages), options)
|
|
67
|
+
|
|
68
|
+
def _do_generate(self, messages: list[dict[str, str]], options: dict[str, Any]) -> dict[str, Any]:
|
|
69
|
+
if not self.api_key:
|
|
70
|
+
raise RuntimeError("Zhipu API key not found")
|
|
71
|
+
|
|
72
|
+
model = options.get("model", self.model)
|
|
73
|
+
|
|
74
|
+
model_config = self._get_model_config("zai", model)
|
|
75
|
+
tokens_param = model_config["tokens_param"]
|
|
76
|
+
supports_temperature = model_config["supports_temperature"]
|
|
77
|
+
|
|
78
|
+
self._validate_model_capabilities(
|
|
79
|
+
"zai",
|
|
80
|
+
model,
|
|
81
|
+
using_json_schema=bool(options.get("json_schema")),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
85
|
+
|
|
86
|
+
data: dict[str, Any] = {
|
|
87
|
+
"model": model,
|
|
88
|
+
"messages": messages,
|
|
89
|
+
}
|
|
90
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
91
|
+
|
|
92
|
+
if supports_temperature and "temperature" in opts:
|
|
93
|
+
data["temperature"] = opts["temperature"]
|
|
94
|
+
|
|
95
|
+
# Native JSON mode support
|
|
96
|
+
if options.get("json_mode"):
|
|
97
|
+
json_schema = options.get("json_schema")
|
|
98
|
+
if json_schema:
|
|
99
|
+
schema_copy = prepare_strict_schema(json_schema)
|
|
100
|
+
data["response_format"] = {
|
|
101
|
+
"type": "json_schema",
|
|
102
|
+
"json_schema": {
|
|
103
|
+
"name": "extraction",
|
|
104
|
+
"strict": True,
|
|
105
|
+
"schema": schema_copy,
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
else:
|
|
109
|
+
data["response_format"] = {"type": "json_object"}
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
response = requests.post(
|
|
113
|
+
f"{self.base_url}/chat/completions",
|
|
114
|
+
headers=self.headers,
|
|
115
|
+
json=data,
|
|
116
|
+
timeout=120,
|
|
117
|
+
)
|
|
118
|
+
response.raise_for_status()
|
|
119
|
+
resp = response.json()
|
|
120
|
+
except requests.exceptions.HTTPError as e:
|
|
121
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
122
|
+
raise RuntimeError(error_msg) from e
|
|
123
|
+
except requests.exceptions.RequestException as e:
|
|
124
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
125
|
+
|
|
126
|
+
usage = resp.get("usage", {})
|
|
127
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
128
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
129
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
130
|
+
|
|
131
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
132
|
+
|
|
133
|
+
meta = {
|
|
134
|
+
"prompt_tokens": prompt_tokens,
|
|
135
|
+
"completion_tokens": completion_tokens,
|
|
136
|
+
"total_tokens": total_tokens,
|
|
137
|
+
"cost": round(total_cost, 6),
|
|
138
|
+
"raw_response": resp,
|
|
139
|
+
"model_name": model,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
text = resp["choices"][0]["message"]["content"]
|
|
143
|
+
return {"text": text, "meta": meta}
|
|
144
|
+
|
|
145
|
+
# ------------------------------------------------------------------
|
|
146
|
+
# Tool use
|
|
147
|
+
# ------------------------------------------------------------------
|
|
148
|
+
|
|
149
|
+
def generate_messages_with_tools(
|
|
150
|
+
self,
|
|
151
|
+
messages: list[dict[str, Any]],
|
|
152
|
+
tools: list[dict[str, Any]],
|
|
153
|
+
options: dict[str, Any],
|
|
154
|
+
) -> dict[str, Any]:
|
|
155
|
+
"""Generate a response that may include tool calls."""
|
|
156
|
+
if not self.api_key:
|
|
157
|
+
raise RuntimeError("Zhipu API key not found")
|
|
158
|
+
|
|
159
|
+
model = options.get("model", self.model)
|
|
160
|
+
model_config = self._get_model_config("zai", model)
|
|
161
|
+
tokens_param = model_config["tokens_param"]
|
|
162
|
+
supports_temperature = model_config["supports_temperature"]
|
|
163
|
+
|
|
164
|
+
self._validate_model_capabilities("zai", model, using_tool_use=True)
|
|
165
|
+
|
|
166
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
167
|
+
|
|
168
|
+
data: dict[str, Any] = {
|
|
169
|
+
"model": model,
|
|
170
|
+
"messages": messages,
|
|
171
|
+
"tools": tools,
|
|
172
|
+
}
|
|
173
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
174
|
+
|
|
175
|
+
if supports_temperature and "temperature" in opts:
|
|
176
|
+
data["temperature"] = opts["temperature"]
|
|
177
|
+
|
|
178
|
+
if "tool_choice" in options:
|
|
179
|
+
data["tool_choice"] = options["tool_choice"]
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
response = requests.post(
|
|
183
|
+
f"{self.base_url}/chat/completions",
|
|
184
|
+
headers=self.headers,
|
|
185
|
+
json=data,
|
|
186
|
+
timeout=120,
|
|
187
|
+
)
|
|
188
|
+
response.raise_for_status()
|
|
189
|
+
resp = response.json()
|
|
190
|
+
except requests.exceptions.HTTPError as e:
|
|
191
|
+
error_msg = f"Z.ai API request failed: {e!s}"
|
|
192
|
+
raise RuntimeError(error_msg) from e
|
|
193
|
+
except requests.exceptions.RequestException as e:
|
|
194
|
+
raise RuntimeError(f"Z.ai API request failed: {e!s}") from e
|
|
195
|
+
|
|
196
|
+
usage = resp.get("usage", {})
|
|
197
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
198
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
199
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
200
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
201
|
+
|
|
202
|
+
meta = {
|
|
203
|
+
"prompt_tokens": prompt_tokens,
|
|
204
|
+
"completion_tokens": completion_tokens,
|
|
205
|
+
"total_tokens": total_tokens,
|
|
206
|
+
"cost": round(total_cost, 6),
|
|
207
|
+
"raw_response": resp,
|
|
208
|
+
"model_name": model,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
choice = resp["choices"][0]
|
|
212
|
+
text = choice["message"].get("content") or ""
|
|
213
|
+
stop_reason = choice.get("finish_reason")
|
|
214
|
+
|
|
215
|
+
tool_calls_out: list[dict[str, Any]] = []
|
|
216
|
+
for tc in choice["message"].get("tool_calls", []):
|
|
217
|
+
try:
|
|
218
|
+
args = json.loads(tc["function"]["arguments"])
|
|
219
|
+
except (json.JSONDecodeError, TypeError):
|
|
220
|
+
args = {}
|
|
221
|
+
tool_calls_out.append(
|
|
222
|
+
{
|
|
223
|
+
"id": tc["id"],
|
|
224
|
+
"name": tc["function"]["name"],
|
|
225
|
+
"arguments": args,
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return {
|
|
230
|
+
"text": text,
|
|
231
|
+
"meta": meta,
|
|
232
|
+
"tool_calls": tool_calls_out,
|
|
233
|
+
"stop_reason": stop_reason,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# ------------------------------------------------------------------
|
|
237
|
+
# Streaming
|
|
238
|
+
# ------------------------------------------------------------------
|
|
239
|
+
|
|
240
|
+
def generate_messages_stream(
|
|
241
|
+
self,
|
|
242
|
+
messages: list[dict[str, Any]],
|
|
243
|
+
options: dict[str, Any],
|
|
244
|
+
) -> Iterator[dict[str, Any]]:
|
|
245
|
+
"""Yield response chunks via Z.ai streaming API."""
|
|
246
|
+
if not self.api_key:
|
|
247
|
+
raise RuntimeError("Zhipu API key not found")
|
|
248
|
+
|
|
249
|
+
model = options.get("model", self.model)
|
|
250
|
+
model_config = self._get_model_config("zai", model)
|
|
251
|
+
tokens_param = model_config["tokens_param"]
|
|
252
|
+
supports_temperature = model_config["supports_temperature"]
|
|
253
|
+
|
|
254
|
+
opts = {"temperature": 1.0, "max_tokens": 512, **options}
|
|
255
|
+
|
|
256
|
+
data: dict[str, Any] = {
|
|
257
|
+
"model": model,
|
|
258
|
+
"messages": messages,
|
|
259
|
+
"stream": True,
|
|
260
|
+
"stream_options": {"include_usage": True},
|
|
261
|
+
}
|
|
262
|
+
data[tokens_param] = opts.get("max_tokens", 512)
|
|
263
|
+
|
|
264
|
+
if supports_temperature and "temperature" in opts:
|
|
265
|
+
data["temperature"] = opts["temperature"]
|
|
266
|
+
|
|
267
|
+
response = requests.post(
|
|
268
|
+
f"{self.base_url}/chat/completions",
|
|
269
|
+
headers=self.headers,
|
|
270
|
+
json=data,
|
|
271
|
+
stream=True,
|
|
272
|
+
timeout=120,
|
|
273
|
+
)
|
|
274
|
+
response.raise_for_status()
|
|
275
|
+
|
|
276
|
+
full_text = ""
|
|
277
|
+
prompt_tokens = 0
|
|
278
|
+
completion_tokens = 0
|
|
279
|
+
|
|
280
|
+
for line in response.iter_lines(decode_unicode=True):
|
|
281
|
+
if not line or not line.startswith("data: "):
|
|
282
|
+
continue
|
|
283
|
+
payload = line[len("data: ") :]
|
|
284
|
+
if payload.strip() == "[DONE]":
|
|
285
|
+
break
|
|
286
|
+
try:
|
|
287
|
+
chunk = json.loads(payload)
|
|
288
|
+
except json.JSONDecodeError:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
usage = chunk.get("usage")
|
|
292
|
+
if usage:
|
|
293
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
294
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
295
|
+
|
|
296
|
+
choices = chunk.get("choices", [])
|
|
297
|
+
if choices:
|
|
298
|
+
delta = choices[0].get("delta", {})
|
|
299
|
+
content = delta.get("content", "")
|
|
300
|
+
if content:
|
|
301
|
+
full_text += content
|
|
302
|
+
yield {"type": "delta", "text": content}
|
|
303
|
+
|
|
304
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
305
|
+
total_cost = self._calculate_cost("zai", model, prompt_tokens, completion_tokens)
|
|
306
|
+
|
|
307
|
+
yield {
|
|
308
|
+
"type": "done",
|
|
309
|
+
"text": full_text,
|
|
310
|
+
"meta": {
|
|
311
|
+
"prompt_tokens": prompt_tokens,
|
|
312
|
+
"completion_tokens": completion_tokens,
|
|
313
|
+
"total_tokens": total_tokens,
|
|
314
|
+
"cost": round(total_cost, 6),
|
|
315
|
+
"raw_response": {},
|
|
316
|
+
"model_name": model,
|
|
317
|
+
},
|
|
318
|
+
}
|
prompture/groups.py
CHANGED
|
@@ -114,6 +114,27 @@ class SequentialGroup:
|
|
|
114
114
|
"""Request graceful shutdown after the current agent finishes."""
|
|
115
115
|
self._stop_requested = True
|
|
116
116
|
|
|
117
|
+
@property
|
|
118
|
+
def shared_state(self) -> dict[str, Any]:
|
|
119
|
+
"""Return a copy of the current shared execution state."""
|
|
120
|
+
return dict(self._state)
|
|
121
|
+
|
|
122
|
+
def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
|
|
123
|
+
"""Merge external key-value pairs into this group's shared state.
|
|
124
|
+
|
|
125
|
+
Existing keys are NOT overwritten (uses setdefault semantics).
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
state: Key-value pairs to inject.
|
|
129
|
+
recursive: If True, also inject into nested sub-groups.
|
|
130
|
+
"""
|
|
131
|
+
for k, v in state.items():
|
|
132
|
+
self._state.setdefault(k, v)
|
|
133
|
+
if recursive:
|
|
134
|
+
for agent, _ in self._agents:
|
|
135
|
+
if hasattr(agent, "inject_state"):
|
|
136
|
+
agent.inject_state(state, recursive=True)
|
|
137
|
+
|
|
117
138
|
def save(self, path: str) -> None:
|
|
118
139
|
"""Run and save result to file. Convenience wrapper."""
|
|
119
140
|
result = self.run()
|
|
@@ -267,6 +288,27 @@ class LoopGroup:
|
|
|
267
288
|
"""Request graceful shutdown."""
|
|
268
289
|
self._stop_requested = True
|
|
269
290
|
|
|
291
|
+
@property
|
|
292
|
+
def shared_state(self) -> dict[str, Any]:
|
|
293
|
+
"""Return a copy of the current shared execution state."""
|
|
294
|
+
return dict(self._state)
|
|
295
|
+
|
|
296
|
+
def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
|
|
297
|
+
"""Merge external key-value pairs into this group's shared state.
|
|
298
|
+
|
|
299
|
+
Existing keys are NOT overwritten (uses setdefault semantics).
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
state: Key-value pairs to inject.
|
|
303
|
+
recursive: If True, also inject into nested sub-groups.
|
|
304
|
+
"""
|
|
305
|
+
for k, v in state.items():
|
|
306
|
+
self._state.setdefault(k, v)
|
|
307
|
+
if recursive:
|
|
308
|
+
for agent, _ in self._agents:
|
|
309
|
+
if hasattr(agent, "inject_state"):
|
|
310
|
+
agent.inject_state(state, recursive=True)
|
|
311
|
+
|
|
270
312
|
def run(self, prompt: str = "") -> GroupResult:
|
|
271
313
|
"""Execute the loop."""
|
|
272
314
|
self._stop_requested = False
|
prompture/ledger.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Persistent model usage ledger — tracks which LLM models have been used.
|
|
2
|
+
|
|
3
|
+
Stores per-model usage stats (call count, tokens, cost, timestamps) in a
|
|
4
|
+
SQLite database at ``~/.prompture/usage/model_ledger.db``. The public
|
|
5
|
+
convenience functions are fire-and-forget: they never raise exceptions so
|
|
6
|
+
they cannot break existing extraction/conversation flows.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import logging
|
|
13
|
+
import sqlite3
|
|
14
|
+
import threading
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("prompture.ledger")
|
|
20
|
+
|
|
21
|
+
_DEFAULT_DB_DIR = Path.home() / ".prompture" / "usage"
|
|
22
|
+
_DEFAULT_DB_PATH = _DEFAULT_DB_DIR / "model_ledger.db"
|
|
23
|
+
|
|
24
|
+
_SCHEMA_SQL = """
|
|
25
|
+
CREATE TABLE IF NOT EXISTS model_usage (
|
|
26
|
+
model_name TEXT NOT NULL,
|
|
27
|
+
api_key_hash TEXT NOT NULL,
|
|
28
|
+
use_count INTEGER NOT NULL DEFAULT 1,
|
|
29
|
+
total_tokens INTEGER NOT NULL DEFAULT 0,
|
|
30
|
+
total_cost REAL NOT NULL DEFAULT 0.0,
|
|
31
|
+
first_used TEXT NOT NULL,
|
|
32
|
+
last_used TEXT NOT NULL,
|
|
33
|
+
last_status TEXT NOT NULL DEFAULT 'success',
|
|
34
|
+
PRIMARY KEY (model_name, api_key_hash)
|
|
35
|
+
);
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModelUsageLedger:
|
|
40
|
+
"""SQLite-backed model usage tracker.
|
|
41
|
+
|
|
42
|
+
Thread-safe via an internal :class:`threading.Lock`.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
db_path: Path to the SQLite database file. Defaults to
|
|
46
|
+
``~/.prompture/usage/model_ledger.db``.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, db_path: str | Path | None = None) -> None:
|
|
50
|
+
self._db_path = Path(db_path) if db_path else _DEFAULT_DB_PATH
|
|
51
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
self._lock = threading.Lock()
|
|
53
|
+
self._init_db()
|
|
54
|
+
|
|
55
|
+
def _init_db(self) -> None:
|
|
56
|
+
with self._lock:
|
|
57
|
+
conn = sqlite3.connect(str(self._db_path))
|
|
58
|
+
try:
|
|
59
|
+
conn.executescript(_SCHEMA_SQL)
|
|
60
|
+
conn.commit()
|
|
61
|
+
finally:
|
|
62
|
+
conn.close()
|
|
63
|
+
|
|
64
|
+
def _connect(self) -> sqlite3.Connection:
|
|
65
|
+
conn = sqlite3.connect(str(self._db_path))
|
|
66
|
+
conn.row_factory = sqlite3.Row
|
|
67
|
+
return conn
|
|
68
|
+
|
|
69
|
+
# ------------------------------------------------------------------ #
|
|
70
|
+
# Recording
|
|
71
|
+
# ------------------------------------------------------------------ #
|
|
72
|
+
|
|
73
|
+
def record_usage(
|
|
74
|
+
self,
|
|
75
|
+
model_name: str,
|
|
76
|
+
*,
|
|
77
|
+
api_key_hash: str = "",
|
|
78
|
+
tokens: int = 0,
|
|
79
|
+
cost: float = 0.0,
|
|
80
|
+
status: str = "success",
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Record a model usage event (upsert).
|
|
83
|
+
|
|
84
|
+
On conflict the row's counters are incremented and ``last_used``
|
|
85
|
+
is updated.
|
|
86
|
+
"""
|
|
87
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
88
|
+
with self._lock:
|
|
89
|
+
conn = self._connect()
|
|
90
|
+
try:
|
|
91
|
+
conn.execute(
|
|
92
|
+
"""
|
|
93
|
+
INSERT INTO model_usage
|
|
94
|
+
(model_name, api_key_hash, use_count, total_tokens, total_cost,
|
|
95
|
+
first_used, last_used, last_status)
|
|
96
|
+
VALUES (?, ?, 1, ?, ?, ?, ?, ?)
|
|
97
|
+
ON CONFLICT(model_name, api_key_hash) DO UPDATE SET
|
|
98
|
+
use_count = use_count + 1,
|
|
99
|
+
total_tokens = total_tokens + excluded.total_tokens,
|
|
100
|
+
total_cost = total_cost + excluded.total_cost,
|
|
101
|
+
last_used = excluded.last_used,
|
|
102
|
+
last_status = excluded.last_status
|
|
103
|
+
""",
|
|
104
|
+
(model_name, api_key_hash, tokens, cost, now, now, status),
|
|
105
|
+
)
|
|
106
|
+
conn.commit()
|
|
107
|
+
finally:
|
|
108
|
+
conn.close()
|
|
109
|
+
|
|
110
|
+
# ------------------------------------------------------------------ #
|
|
111
|
+
# Queries
|
|
112
|
+
# ------------------------------------------------------------------ #
|
|
113
|
+
|
|
114
|
+
def get_model_stats(self, model_name: str, api_key_hash: str = "") -> dict[str, Any] | None:
|
|
115
|
+
"""Return stats for a specific model + key combination, or ``None``."""
|
|
116
|
+
with self._lock:
|
|
117
|
+
conn = self._connect()
|
|
118
|
+
try:
|
|
119
|
+
row = conn.execute(
|
|
120
|
+
"SELECT * FROM model_usage WHERE model_name = ? AND api_key_hash = ?",
|
|
121
|
+
(model_name, api_key_hash),
|
|
122
|
+
).fetchone()
|
|
123
|
+
if row is None:
|
|
124
|
+
return None
|
|
125
|
+
return dict(row)
|
|
126
|
+
finally:
|
|
127
|
+
conn.close()
|
|
128
|
+
|
|
129
|
+
def get_verified_models(self) -> set[str]:
|
|
130
|
+
"""Return model names that have at least one successful usage."""
|
|
131
|
+
with self._lock:
|
|
132
|
+
conn = self._connect()
|
|
133
|
+
try:
|
|
134
|
+
rows = conn.execute(
|
|
135
|
+
"SELECT DISTINCT model_name FROM model_usage WHERE last_status = 'success'"
|
|
136
|
+
).fetchall()
|
|
137
|
+
return {r["model_name"] for r in rows}
|
|
138
|
+
finally:
|
|
139
|
+
conn.close()
|
|
140
|
+
|
|
141
|
+
def get_recently_used(self, limit: int = 10) -> list[dict[str, Any]]:
|
|
142
|
+
"""Return recent model usage rows ordered by ``last_used`` descending."""
|
|
143
|
+
with self._lock:
|
|
144
|
+
conn = self._connect()
|
|
145
|
+
try:
|
|
146
|
+
rows = conn.execute(
|
|
147
|
+
"SELECT * FROM model_usage ORDER BY last_used DESC LIMIT ?",
|
|
148
|
+
(limit,),
|
|
149
|
+
).fetchall()
|
|
150
|
+
return [dict(r) for r in rows]
|
|
151
|
+
finally:
|
|
152
|
+
conn.close()
|
|
153
|
+
|
|
154
|
+
def get_all_stats(self) -> list[dict[str, Any]]:
|
|
155
|
+
"""Return all usage rows."""
|
|
156
|
+
with self._lock:
|
|
157
|
+
conn = self._connect()
|
|
158
|
+
try:
|
|
159
|
+
rows = conn.execute("SELECT * FROM model_usage ORDER BY last_used DESC").fetchall()
|
|
160
|
+
return [dict(r) for r in rows]
|
|
161
|
+
finally:
|
|
162
|
+
conn.close()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# ------------------------------------------------------------------
|
|
166
|
+
# Module-level singleton
|
|
167
|
+
# ------------------------------------------------------------------
|
|
168
|
+
|
|
169
|
+
_ledger: ModelUsageLedger | None = None
|
|
170
|
+
_ledger_lock = threading.Lock()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _get_ledger() -> ModelUsageLedger:
|
|
174
|
+
"""Return (and lazily create) the module-level singleton ledger."""
|
|
175
|
+
global _ledger
|
|
176
|
+
if _ledger is None:
|
|
177
|
+
with _ledger_lock:
|
|
178
|
+
if _ledger is None:
|
|
179
|
+
_ledger = ModelUsageLedger()
|
|
180
|
+
return _ledger
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ------------------------------------------------------------------
|
|
184
|
+
# Public convenience functions (fire-and-forget)
|
|
185
|
+
# ------------------------------------------------------------------
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def record_model_usage(
|
|
189
|
+
model_name: str,
|
|
190
|
+
*,
|
|
191
|
+
api_key_hash: str = "",
|
|
192
|
+
tokens: int = 0,
|
|
193
|
+
cost: float = 0.0,
|
|
194
|
+
status: str = "success",
|
|
195
|
+
) -> None:
|
|
196
|
+
"""Record a model usage event. Never raises — all exceptions are swallowed."""
|
|
197
|
+
try:
|
|
198
|
+
_get_ledger().record_usage(
|
|
199
|
+
model_name,
|
|
200
|
+
api_key_hash=api_key_hash,
|
|
201
|
+
tokens=tokens,
|
|
202
|
+
cost=cost,
|
|
203
|
+
status=status,
|
|
204
|
+
)
|
|
205
|
+
except Exception:
|
|
206
|
+
logger.debug("Failed to record model usage for %s", model_name, exc_info=True)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_recently_used_models(limit: int = 10) -> list[dict[str, Any]]:
|
|
210
|
+
"""Return recently used models. Returns empty list on error."""
|
|
211
|
+
try:
|
|
212
|
+
return _get_ledger().get_recently_used(limit)
|
|
213
|
+
except Exception:
|
|
214
|
+
logger.debug("Failed to get recently used models", exc_info=True)
|
|
215
|
+
return []
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# ------------------------------------------------------------------
|
|
219
|
+
# API key hash helper
|
|
220
|
+
# ------------------------------------------------------------------
|
|
221
|
+
|
|
222
|
+
_LOCAL_PROVIDERS = frozenset({"ollama", "lmstudio", "local_http", "airllm"})
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def _resolve_api_key_hash(model_name: str) -> str:
|
|
226
|
+
"""Derive an 8-char hex hash of the API key for the given model's provider.
|
|
227
|
+
|
|
228
|
+
Local providers (ollama, lmstudio, etc.) return ``""``.
|
|
229
|
+
"""
|
|
230
|
+
try:
|
|
231
|
+
provider = model_name.split("/", 1)[0].lower() if "/" in model_name else model_name.lower()
|
|
232
|
+
if provider in _LOCAL_PROVIDERS:
|
|
233
|
+
return ""
|
|
234
|
+
|
|
235
|
+
from .settings import settings
|
|
236
|
+
|
|
237
|
+
key_map: dict[str, str | None] = {
|
|
238
|
+
"openai": settings.openai_api_key,
|
|
239
|
+
"claude": settings.claude_api_key,
|
|
240
|
+
"google": settings.google_api_key,
|
|
241
|
+
"groq": settings.groq_api_key,
|
|
242
|
+
"grok": settings.grok_api_key,
|
|
243
|
+
"openrouter": settings.openrouter_api_key,
|
|
244
|
+
"azure": settings.azure_api_key,
|
|
245
|
+
"huggingface": settings.hf_token,
|
|
246
|
+
}
|
|
247
|
+
api_key = key_map.get(provider)
|
|
248
|
+
if not api_key:
|
|
249
|
+
return ""
|
|
250
|
+
return hashlib.sha256(api_key.encode()).hexdigest()[:8]
|
|
251
|
+
except Exception:
|
|
252
|
+
return ""
|