prompture 0.0.38.dev2__py3-none-any.whl → 0.0.40.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. prompture/__init__.py +12 -1
  2. prompture/_version.py +2 -2
  3. prompture/async_conversation.py +9 -0
  4. prompture/async_core.py +16 -0
  5. prompture/async_driver.py +39 -0
  6. prompture/conversation.py +9 -0
  7. prompture/core.py +16 -0
  8. prompture/cost_mixin.py +37 -0
  9. prompture/discovery.py +108 -43
  10. prompture/driver.py +39 -0
  11. prompture/drivers/async_azure_driver.py +4 -4
  12. prompture/drivers/async_claude_driver.py +177 -8
  13. prompture/drivers/async_google_driver.py +10 -0
  14. prompture/drivers/async_grok_driver.py +4 -4
  15. prompture/drivers/async_groq_driver.py +4 -4
  16. prompture/drivers/async_openai_driver.py +155 -4
  17. prompture/drivers/async_openrouter_driver.py +4 -4
  18. prompture/drivers/azure_driver.py +3 -3
  19. prompture/drivers/claude_driver.py +10 -0
  20. prompture/drivers/google_driver.py +10 -0
  21. prompture/drivers/grok_driver.py +4 -4
  22. prompture/drivers/groq_driver.py +4 -4
  23. prompture/drivers/openai_driver.py +19 -10
  24. prompture/drivers/openrouter_driver.py +4 -4
  25. prompture/ledger.py +252 -0
  26. prompture/model_rates.py +112 -2
  27. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/METADATA +1 -1
  28. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/RECORD +32 -31
  29. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/WHEEL +0 -0
  30. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/entry_points.txt +0 -0
  31. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/licenses/LICENSE +0 -0
  32. {prompture-0.0.38.dev2.dist-info → prompture-0.0.40.dev1.dist-info}/top_level.txt +0 -0
prompture/__init__.py CHANGED
@@ -110,8 +110,15 @@ from .image import (
110
110
  image_from_url,
111
111
  make_image,
112
112
  )
113
+ from .ledger import ModelUsageLedger, get_recently_used_models
113
114
  from .logging import JSONFormatter, configure_logging
114
- from .model_rates import get_model_info, get_model_rates, refresh_rates_cache
115
+ from .model_rates import (
116
+ ModelCapabilities,
117
+ get_model_capabilities,
118
+ get_model_info,
119
+ get_model_rates,
120
+ refresh_rates_cache,
121
+ )
115
122
  from .persistence import ConversationStore
116
123
  from .persona import (
117
124
  PERSONAS,
@@ -213,7 +220,9 @@ __all__ = [
213
220
  "LocalHTTPDriver",
214
221
  "LoopGroup",
215
222
  "MemoryCacheBackend",
223
+ "ModelCapabilities",
216
224
  "ModelRetry",
225
+ "ModelUsageLedger",
217
226
  "OllamaDriver",
218
227
  "OpenAIDriver",
219
228
  "OpenRouterDriver",
@@ -255,11 +264,13 @@ __all__ = [
255
264
  "get_driver_for_model",
256
265
  "get_field_definition",
257
266
  "get_field_names",
267
+ "get_model_capabilities",
258
268
  "get_model_info",
259
269
  "get_model_rates",
260
270
  "get_persona",
261
271
  "get_persona_names",
262
272
  "get_persona_registry_snapshot",
273
+ "get_recently_used_models",
263
274
  "get_registry_snapshot",
264
275
  "get_required_fields",
265
276
  "get_trait",
prompture/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.38.dev2'
32
- __version_tuple__ = version_tuple = (0, 0, 38, 'dev2')
31
+ __version__ = version = '0.0.40.dev1'
32
+ __version_tuple__ = version_tuple = (0, 0, 40, 'dev1')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -304,6 +304,15 @@ class AsyncConversation:
304
304
  self._usage["turns"] += 1
305
305
  self._maybe_auto_save()
306
306
 
307
+ from .ledger import _resolve_api_key_hash, record_model_usage
308
+
309
+ record_model_usage(
310
+ self._model_name,
311
+ api_key_hash=_resolve_api_key_hash(self._model_name),
312
+ tokens=meta.get("total_tokens", 0),
313
+ cost=meta.get("cost", 0.0),
314
+ )
315
+
307
316
  async def ask(
308
317
  self,
309
318
  content: str,
prompture/async_core.py CHANGED
@@ -35,6 +35,18 @@ from .tools import (
35
35
  logger = logging.getLogger("prompture.async_core")
36
36
 
37
37
 
38
+ def _record_usage_to_ledger(model_name: str, meta: dict[str, Any]) -> None:
39
+ """Fire-and-forget ledger recording for standalone async core functions."""
40
+ from .ledger import _resolve_api_key_hash, record_model_usage
41
+
42
+ record_model_usage(
43
+ model_name,
44
+ api_key_hash=_resolve_api_key_hash(model_name),
45
+ tokens=meta.get("total_tokens", 0),
46
+ cost=meta.get("cost", 0.0),
47
+ )
48
+
49
+
38
50
  async def clean_json_text_with_ai(
39
51
  driver: AsyncDriver, text: str, model_name: str = "", options: dict[str, Any] | None = None
40
52
  ) -> str:
@@ -117,6 +129,8 @@ async def render_output(
117
129
  "model_name": model_name or getattr(driver, "model", ""),
118
130
  }
119
131
 
132
+ _record_usage_to_ledger(model_name, resp.get("meta", {}))
133
+
120
134
  return {"text": raw, "usage": usage, "output_format": output_format}
121
135
 
122
136
 
@@ -211,6 +225,8 @@ async def ask_for_json(
211
225
  raw = resp.get("text", "")
212
226
  cleaned = clean_json_text(raw)
213
227
 
228
+ _record_usage_to_ledger(model_name, resp.get("meta", {}))
229
+
214
230
  try:
215
231
  json_obj = json.loads(cleaned)
216
232
  json_string = cleaned
prompture/async_driver.py CHANGED
@@ -166,6 +166,45 @@ class AsyncDriver:
166
166
  except Exception:
167
167
  logger.exception("Callback %s raised an exception", event)
168
168
 
169
+ def _validate_model_capabilities(
170
+ self,
171
+ provider: str,
172
+ model: str,
173
+ *,
174
+ using_tool_use: bool = False,
175
+ using_json_schema: bool = False,
176
+ using_vision: bool = False,
177
+ ) -> None:
178
+ """Log warnings when the model may not support a requested feature.
179
+
180
+ Uses models.dev metadata as a secondary signal. Warnings only — the
181
+ API is the final authority and models.dev data may be stale.
182
+ """
183
+ from .model_rates import get_model_capabilities
184
+
185
+ caps = get_model_capabilities(provider, model)
186
+ if caps is None:
187
+ return
188
+
189
+ if using_tool_use and caps.supports_tool_use is False:
190
+ logger.warning(
191
+ "Model %s/%s may not support tool use according to models.dev metadata",
192
+ provider,
193
+ model,
194
+ )
195
+ if using_json_schema and caps.supports_structured_output is False:
196
+ logger.warning(
197
+ "Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
198
+ provider,
199
+ model,
200
+ )
201
+ if using_vision and caps.supports_vision is False:
202
+ logger.warning(
203
+ "Model %s/%s may not support vision/image inputs according to models.dev metadata",
204
+ provider,
205
+ model,
206
+ )
207
+
169
208
  def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
170
209
  """Raise if messages contain image blocks and the driver lacks vision support."""
171
210
  if self.supports_vision:
prompture/conversation.py CHANGED
@@ -311,6 +311,15 @@ class Conversation:
311
311
  self._usage["turns"] += 1
312
312
  self._maybe_auto_save()
313
313
 
314
+ from .ledger import _resolve_api_key_hash, record_model_usage
315
+
316
+ record_model_usage(
317
+ self._model_name,
318
+ api_key_hash=_resolve_api_key_hash(self._model_name),
319
+ tokens=meta.get("total_tokens", 0),
320
+ cost=meta.get("cost", 0.0),
321
+ )
322
+
314
323
  def ask(
315
324
  self,
316
325
  content: str,
prompture/core.py CHANGED
@@ -31,6 +31,18 @@ from .tools import (
31
31
  logger = logging.getLogger("prompture.core")
32
32
 
33
33
 
34
+ def _record_usage_to_ledger(model_name: str, meta: dict[str, Any]) -> None:
35
+ """Fire-and-forget ledger recording for standalone core functions."""
36
+ from .ledger import _resolve_api_key_hash, record_model_usage
37
+
38
+ record_model_usage(
39
+ model_name,
40
+ api_key_hash=_resolve_api_key_hash(model_name),
41
+ tokens=meta.get("total_tokens", 0),
42
+ cost=meta.get("cost", 0.0),
43
+ )
44
+
45
+
34
46
  def _build_content_with_images(text: str, images: list[ImageInput] | None = None) -> str | list[dict[str, Any]]:
35
47
  """Return plain string when no images, or a list of content blocks."""
36
48
  if not images:
@@ -231,6 +243,8 @@ def render_output(
231
243
  "model_name": model_name or getattr(driver, "model", ""),
232
244
  }
233
245
 
246
+ _record_usage_to_ledger(model_name, resp.get("meta", {}))
247
+
234
248
  return {"text": raw, "usage": usage, "output_format": output_format}
235
249
 
236
250
 
@@ -353,6 +367,8 @@ def ask_for_json(
353
367
  raw = resp.get("text", "")
354
368
  cleaned = clean_json_text(raw)
355
369
 
370
+ _record_usage_to_ledger(model_name, resp.get("meta", {}))
371
+
356
372
  try:
357
373
  json_obj = json.loads(cleaned)
358
374
  json_string = cleaned
prompture/cost_mixin.py CHANGED
@@ -49,3 +49,40 @@ class CostMixin:
49
49
  completion_cost = (completion_tokens / unit) * model_pricing["completion"]
50
50
 
51
51
  return round(prompt_cost + completion_cost, 6)
52
+
53
+ def _get_model_config(self, provider: str, model: str) -> dict[str, Any]:
54
+ """Merge live models.dev capabilities with hardcoded ``MODEL_PRICING``.
55
+
56
+ Returns a dict with:
57
+ - ``tokens_param`` — always from hardcoded ``MODEL_PRICING`` (API-specific)
58
+ - ``supports_temperature`` — prefers live data, falls back to hardcoded, default ``True``
59
+ - ``context_window`` — from live data only (``None`` if unavailable)
60
+ - ``max_output_tokens`` — from live data only (``None`` if unavailable)
61
+ """
62
+ from .model_rates import get_model_capabilities
63
+
64
+ hardcoded = self.MODEL_PRICING.get(model, {})
65
+
66
+ # tokens_param is always from hardcoded config (API-specific, not in models.dev)
67
+ tokens_param = hardcoded.get("tokens_param", "max_tokens")
68
+
69
+ # Start with hardcoded supports_temperature, default True
70
+ supports_temperature = hardcoded.get("supports_temperature", True)
71
+
72
+ context_window: int | None = None
73
+ max_output_tokens: int | None = None
74
+
75
+ # Override with live data when available
76
+ caps = get_model_capabilities(provider, model)
77
+ if caps is not None:
78
+ if caps.supports_temperature is not None:
79
+ supports_temperature = caps.supports_temperature
80
+ context_window = caps.context_window
81
+ max_output_tokens = caps.max_output_tokens
82
+
83
+ return {
84
+ "tokens_param": tokens_param,
85
+ "supports_temperature": supports_temperature,
86
+ "context_window": context_window,
87
+ "max_output_tokens": max_output_tokens,
88
+ }
prompture/discovery.py CHANGED
@@ -1,7 +1,11 @@
1
1
  """Discovery module for auto-detecting available models."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
3
6
  import logging
4
7
  import os
8
+ from typing import Any, overload
5
9
 
6
10
  import requests
7
11
 
@@ -22,23 +26,40 @@ from .settings import settings
22
26
  logger = logging.getLogger(__name__)
23
27
 
24
28
 
25
- def get_available_models() -> list[str]:
26
- """
27
- Auto-detects all available models based on configured drivers and environment variables.
29
+ @overload
30
+ def get_available_models(*, include_capabilities: bool = False, verified_only: bool = False) -> list[str]: ...
31
+
32
+
33
+ @overload
34
+ def get_available_models(*, include_capabilities: bool = True, verified_only: bool = False) -> list[dict[str, Any]]: ...
35
+
36
+
37
+ def get_available_models(
38
+ *,
39
+ include_capabilities: bool = False,
40
+ verified_only: bool = False,
41
+ ) -> list[str] | list[dict[str, Any]]:
42
+ """Auto-detect available models based on configured drivers and environment variables.
43
+
44
+ Iterates through supported providers and checks if they are configured
45
+ (e.g. API key present). For static drivers, returns models from their
46
+ ``MODEL_PRICING`` keys. For dynamic drivers (like Ollama), attempts to
47
+ fetch available models from the endpoint.
28
48
 
29
- Iterates through supported providers and checks if they are configured (e.g. API key present).
30
- For static drivers, returns models from their MODEL_PRICING keys.
31
- For dynamic drivers (like Ollama), attempts to fetch available models from the endpoint.
49
+ Args:
50
+ include_capabilities: When ``True``, return enriched dicts with
51
+ ``model``, ``provider``, ``model_id``, and ``capabilities``
52
+ fields instead of plain ``"provider/model_id"`` strings.
53
+ verified_only: When ``True``, only return models that have been
54
+ successfully used (as recorded by the usage ledger).
32
55
 
33
56
  Returns:
34
- A list of unique model strings in the format "provider/model_id".
57
+ A sorted list of unique model strings (default) or enriched dicts.
35
58
  """
36
59
  available_models: set[str] = set()
37
60
  configured_providers: set[str] = set()
38
61
 
39
62
  # Map of provider name to driver class
40
- # We need to map the registry keys to the actual classes to check MODEL_PRICING
41
- # and instantiate for dynamic checks if needed.
42
63
  provider_classes = {
43
64
  "openai": OpenAIDriver,
44
65
  "azure": AzureDriver,
@@ -54,11 +75,6 @@ def get_available_models() -> list[str]:
54
75
 
55
76
  for provider, driver_cls in provider_classes.items():
56
77
  try:
57
- # 1. Check if the provider is configured (has API key or endpoint)
58
- # We can check this by looking at the settings or env vars that the driver uses.
59
- # A simple way is to try to instantiate it with defaults, but that might fail if keys are missing.
60
- # Instead, let's check the specific requirements for each known provider.
61
-
62
78
  is_configured = False
63
79
 
64
80
  if provider == "openai":
@@ -86,14 +102,11 @@ def get_available_models() -> list[str]:
86
102
  elif provider == "grok":
87
103
  if settings.grok_api_key or os.getenv("GROK_API_KEY"):
88
104
  is_configured = True
89
- elif provider == "ollama":
90
- # Ollama is always considered "configured" as it defaults to localhost
91
- # We will check connectivity later
92
- is_configured = True
93
- elif provider == "lmstudio":
94
- # LM Studio is similar to Ollama, defaults to localhost
95
- is_configured = True
96
- elif provider == "local_http" and (settings.local_http_endpoint or os.getenv("LOCAL_HTTP_ENDPOINT")):
105
+ elif (
106
+ provider == "ollama"
107
+ or provider == "lmstudio"
108
+ or (provider == "local_http" and os.getenv("LOCAL_HTTP_ENDPOINT"))
109
+ ):
97
110
  is_configured = True
98
111
 
99
112
  if not is_configured:
@@ -101,36 +114,20 @@ def get_available_models() -> list[str]:
101
114
 
102
115
  configured_providers.add(provider)
103
116
 
104
- # 2. Static Detection: Get models from MODEL_PRICING
117
+ # Static Detection: Get models from MODEL_PRICING
105
118
  if hasattr(driver_cls, "MODEL_PRICING"):
106
119
  pricing = driver_cls.MODEL_PRICING
107
120
  for model_id in pricing:
108
- # Skip "default" or generic keys if they exist
109
121
  if model_id == "default":
110
122
  continue
111
-
112
- # For Azure, the model_id in pricing is usually the base model name,
113
- # but the user needs to use the deployment ID.
114
- # However, our Azure driver implementation uses the deployment_id from init
115
- # as the "model" for the request, but expects the user to pass a model name
116
- # that maps to pricing?
117
- # Looking at AzureDriver:
118
- # kwargs = {"model": self.deployment_id, ...}
119
- # model = options.get("model", self.model) -> used for pricing lookup
120
- # So we should list the keys in MODEL_PRICING as available "models"
121
- # even though for Azure specifically it's a bit weird because of deployment IDs.
122
- # But for general discovery, listing supported models is correct.
123
-
124
123
  available_models.add(f"{provider}/{model_id}")
125
124
 
126
- # 3. Dynamic Detection: Specific logic for Ollama
125
+ # Dynamic Detection: Specific logic for Ollama
127
126
  if provider == "ollama":
128
127
  try:
129
128
  endpoint = settings.ollama_endpoint or os.getenv(
130
129
  "OLLAMA_ENDPOINT", "http://localhost:11434/api/generate"
131
130
  )
132
- # We need the base URL for tags, usually http://localhost:11434/api/tags
133
- # The configured endpoint might be .../api/generate or .../api/chat
134
131
  base_url = endpoint.split("/api/")[0]
135
132
  tags_url = f"{base_url}/api/tags"
136
133
 
@@ -141,8 +138,6 @@ def get_available_models() -> list[str]:
141
138
  for model in models:
142
139
  name = model.get("name")
143
140
  if name:
144
- # Ollama model names often include tags like "llama3:latest"
145
- # We can keep them as is.
146
141
  available_models.add(f"ollama/{name}")
147
142
  except Exception as e:
148
143
  logger.debug(f"Failed to fetch Ollama models: {e}")
@@ -184,4 +179,74 @@ def get_available_models() -> list[str]:
184
179
  for model_id in get_all_provider_models(api_name):
185
180
  available_models.add(f"{prompture_name}/{model_id}")
186
181
 
187
- return sorted(list(available_models))
182
+ sorted_models = sorted(available_models)
183
+
184
+ # --- verified_only filtering ---
185
+ verified_set: set[str] | None = None
186
+ if verified_only or include_capabilities:
187
+ try:
188
+ from .ledger import _get_ledger
189
+
190
+ ledger = _get_ledger()
191
+ verified_set = ledger.get_verified_models()
192
+ except Exception:
193
+ logger.debug("Could not load ledger for verified models", exc_info=True)
194
+ verified_set = set()
195
+
196
+ if verified_only and verified_set is not None:
197
+ sorted_models = [m for m in sorted_models if m in verified_set]
198
+
199
+ if not include_capabilities:
200
+ return sorted_models
201
+
202
+ # Build enriched dicts with capabilities from models.dev
203
+ from .model_rates import get_model_capabilities
204
+
205
+ # Fetch all ledger stats for annotation (keyed by model_name)
206
+ ledger_stats: dict[str, dict[str, Any]] = {}
207
+ try:
208
+ from .ledger import _get_ledger
209
+
210
+ for row in _get_ledger().get_all_stats():
211
+ name = row["model_name"]
212
+ if name not in ledger_stats:
213
+ ledger_stats[name] = row
214
+ else:
215
+ # Aggregate across API key hashes
216
+ existing = ledger_stats[name]
217
+ existing["use_count"] += row["use_count"]
218
+ existing["total_tokens"] += row["total_tokens"]
219
+ existing["total_cost"] += row["total_cost"]
220
+ if row["last_used"] > existing["last_used"]:
221
+ existing["last_used"] = row["last_used"]
222
+ except Exception:
223
+ logger.debug("Could not load ledger stats for enrichment", exc_info=True)
224
+
225
+ enriched: list[dict[str, Any]] = []
226
+ for model_str in sorted_models:
227
+ parts = model_str.split("/", 1)
228
+ provider = parts[0]
229
+ model_id = parts[1] if len(parts) > 1 else parts[0]
230
+
231
+ caps = get_model_capabilities(provider, model_id)
232
+ caps_dict = dataclasses.asdict(caps) if caps is not None else None
233
+
234
+ entry: dict[str, Any] = {
235
+ "model": model_str,
236
+ "provider": provider,
237
+ "model_id": model_id,
238
+ "capabilities": caps_dict,
239
+ "verified": verified_set is not None and model_str in verified_set,
240
+ }
241
+
242
+ stats = ledger_stats.get(model_str)
243
+ if stats:
244
+ entry["last_used"] = stats["last_used"]
245
+ entry["use_count"] = stats["use_count"]
246
+ else:
247
+ entry["last_used"] = None
248
+ entry["use_count"] = 0
249
+
250
+ enriched.append(entry)
251
+
252
+ return enriched
prompture/driver.py CHANGED
@@ -173,6 +173,45 @@ class Driver:
173
173
  except Exception:
174
174
  logger.exception("Callback %s raised an exception", event)
175
175
 
176
+ def _validate_model_capabilities(
177
+ self,
178
+ provider: str,
179
+ model: str,
180
+ *,
181
+ using_tool_use: bool = False,
182
+ using_json_schema: bool = False,
183
+ using_vision: bool = False,
184
+ ) -> None:
185
+ """Log warnings when the model may not support a requested feature.
186
+
187
+ Uses models.dev metadata as a secondary signal. Warnings only — the
188
+ API is the final authority and models.dev data may be stale.
189
+ """
190
+ from .model_rates import get_model_capabilities
191
+
192
+ caps = get_model_capabilities(provider, model)
193
+ if caps is None:
194
+ return
195
+
196
+ if using_tool_use and caps.supports_tool_use is False:
197
+ logger.warning(
198
+ "Model %s/%s may not support tool use according to models.dev metadata",
199
+ provider,
200
+ model,
201
+ )
202
+ if using_json_schema and caps.supports_structured_output is False:
203
+ logger.warning(
204
+ "Model %s/%s may not support structured output / JSON schema according to models.dev metadata",
205
+ provider,
206
+ model,
207
+ )
208
+ if using_vision and caps.supports_vision is False:
209
+ logger.warning(
210
+ "Model %s/%s may not support vision/image inputs according to models.dev metadata",
211
+ provider,
212
+ model,
213
+ )
214
+
176
215
  def _check_vision_support(self, messages: list[dict[str, Any]]) -> None:
177
216
  """Raise if messages contain image blocks and the driver lacks vision support."""
178
217
  if self.supports_vision:
@@ -70,9 +70,9 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
70
70
  raise RuntimeError("openai package (>=1.0.0) with AsyncAzureOpenAI not installed")
71
71
 
72
72
  model = options.get("model", self.model)
73
- model_info = self.MODEL_PRICING.get(model, {})
74
- tokens_param = model_info.get("tokens_param", "max_tokens")
75
- supports_temperature = model_info.get("supports_temperature", True)
73
+ model_config = self._get_model_config("azure", model)
74
+ tokens_param = model_config["tokens_param"]
75
+ supports_temperature = model_config["supports_temperature"]
76
76
 
77
77
  opts = {"temperature": 1.0, "max_tokens": 512, **options}
78
78
 
@@ -113,7 +113,7 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
113
113
  "prompt_tokens": prompt_tokens,
114
114
  "completion_tokens": completion_tokens,
115
115
  "total_tokens": total_tokens,
116
- "cost": total_cost,
116
+ "cost": round(total_cost, 6),
117
117
  "raw_response": resp.model_dump(),
118
118
  "model_name": model,
119
119
  "deployment_id": self.deployment_id,