prompture 0.0.33.dev1__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. prompture/__init__.py +133 -49
  2. prompture/_version.py +34 -0
  3. prompture/aio/__init__.py +74 -0
  4. prompture/async_conversation.py +484 -0
  5. prompture/async_core.py +803 -0
  6. prompture/async_driver.py +131 -0
  7. prompture/cache.py +469 -0
  8. prompture/callbacks.py +50 -0
  9. prompture/cli.py +7 -3
  10. prompture/conversation.py +504 -0
  11. prompture/core.py +475 -352
  12. prompture/cost_mixin.py +51 -0
  13. prompture/discovery.py +50 -35
  14. prompture/driver.py +125 -5
  15. prompture/drivers/__init__.py +171 -73
  16. prompture/drivers/airllm_driver.py +13 -20
  17. prompture/drivers/async_airllm_driver.py +26 -0
  18. prompture/drivers/async_azure_driver.py +117 -0
  19. prompture/drivers/async_claude_driver.py +107 -0
  20. prompture/drivers/async_google_driver.py +132 -0
  21. prompture/drivers/async_grok_driver.py +91 -0
  22. prompture/drivers/async_groq_driver.py +84 -0
  23. prompture/drivers/async_hugging_driver.py +61 -0
  24. prompture/drivers/async_lmstudio_driver.py +79 -0
  25. prompture/drivers/async_local_http_driver.py +44 -0
  26. prompture/drivers/async_ollama_driver.py +125 -0
  27. prompture/drivers/async_openai_driver.py +96 -0
  28. prompture/drivers/async_openrouter_driver.py +96 -0
  29. prompture/drivers/async_registry.py +129 -0
  30. prompture/drivers/azure_driver.py +36 -9
  31. prompture/drivers/claude_driver.py +86 -34
  32. prompture/drivers/google_driver.py +87 -51
  33. prompture/drivers/grok_driver.py +29 -32
  34. prompture/drivers/groq_driver.py +27 -26
  35. prompture/drivers/hugging_driver.py +6 -6
  36. prompture/drivers/lmstudio_driver.py +26 -13
  37. prompture/drivers/local_http_driver.py +6 -6
  38. prompture/drivers/ollama_driver.py +90 -23
  39. prompture/drivers/openai_driver.py +36 -9
  40. prompture/drivers/openrouter_driver.py +31 -25
  41. prompture/drivers/registry.py +306 -0
  42. prompture/field_definitions.py +106 -96
  43. prompture/logging.py +80 -0
  44. prompture/model_rates.py +217 -0
  45. prompture/runner.py +49 -47
  46. prompture/session.py +117 -0
  47. prompture/settings.py +14 -1
  48. prompture/tools.py +172 -265
  49. prompture/validator.py +3 -3
  50. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/METADATA +18 -20
  51. prompture-0.0.34.dist-info/RECORD +55 -0
  52. prompture-0.0.33.dev1.dist-info/RECORD +0 -29
  53. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/WHEEL +0 -0
  54. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/entry_points.txt +0 -0
  55. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/licenses/LICENSE +0 -0
  56. {prompture-0.0.33.dev1.dist-info → prompture-0.0.34.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ """Live model rates from models.dev API with local caching.
2
+
3
+ Fetches pricing and metadata for LLM models from https://models.dev/api.json,
4
+ caches locally with TTL-based auto-refresh, and provides lookup functions
5
+ used by drivers for cost calculations.
6
+ """
7
+
8
+ import contextlib
9
+ import json
10
+ import logging
11
+ import threading
12
+ from datetime import datetime, timezone
13
+ from pathlib import Path
14
+ from typing import Any, Optional
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Maps prompture provider names to models.dev provider names
19
+ PROVIDER_MAP: dict[str, str] = {
20
+ "openai": "openai",
21
+ "claude": "anthropic",
22
+ "google": "google",
23
+ "groq": "groq",
24
+ "grok": "xai",
25
+ "azure": "azure",
26
+ "openrouter": "openrouter",
27
+ }
28
+
29
+ _API_URL = "https://models.dev/api.json"
30
+ _CACHE_DIR = Path.home() / ".prompture" / "cache"
31
+ _CACHE_FILE = _CACHE_DIR / "models_dev.json"
32
+ _META_FILE = _CACHE_DIR / "models_dev_meta.json"
33
+
34
+ _lock = threading.Lock()
35
+ _data: Optional[dict[str, Any]] = None
36
+ _loaded = False
37
+
38
+
39
+ def _get_ttl_days() -> int:
40
+ """Get TTL from settings if available, otherwise default to 7."""
41
+ try:
42
+ from .settings import settings
43
+
44
+ return getattr(settings, "model_rates_ttl_days", 7)
45
+ except Exception:
46
+ return 7
47
+
48
+
49
+ def _cache_is_valid() -> bool:
50
+ """Check whether the local cache exists and is within TTL."""
51
+ if not _CACHE_FILE.exists() or not _META_FILE.exists():
52
+ return False
53
+ try:
54
+ meta = json.loads(_META_FILE.read_text(encoding="utf-8"))
55
+ fetched_at = datetime.fromisoformat(meta["fetched_at"])
56
+ ttl_days = meta.get("ttl_days", _get_ttl_days())
57
+ age = datetime.now(timezone.utc) - fetched_at
58
+ return age.total_seconds() < ttl_days * 86400
59
+ except Exception:
60
+ return False
61
+
62
+
63
+ def _write_cache(data: dict[str, Any]) -> None:
64
+ """Write API data and metadata to local cache."""
65
+ try:
66
+ _CACHE_DIR.mkdir(parents=True, exist_ok=True)
67
+ _CACHE_FILE.write_text(json.dumps(data), encoding="utf-8")
68
+ meta = {
69
+ "fetched_at": datetime.now(timezone.utc).isoformat(),
70
+ "ttl_days": _get_ttl_days(),
71
+ }
72
+ _META_FILE.write_text(json.dumps(meta), encoding="utf-8")
73
+ except Exception as exc:
74
+ logger.debug("Failed to write model rates cache: %s", exc)
75
+
76
+
77
+ def _read_cache() -> Optional[dict[str, Any]]:
78
+ """Read cached API data from disk."""
79
+ try:
80
+ return json.loads(_CACHE_FILE.read_text(encoding="utf-8"))
81
+ except Exception:
82
+ return None
83
+
84
+
85
+ def _fetch_from_api() -> Optional[dict[str, Any]]:
86
+ """Fetch fresh data from models.dev API."""
87
+ try:
88
+ import requests
89
+
90
+ resp = requests.get(_API_URL, timeout=15)
91
+ resp.raise_for_status()
92
+ return resp.json()
93
+ except Exception as exc:
94
+ logger.debug("Failed to fetch model rates from %s: %s", _API_URL, exc)
95
+ return None
96
+
97
+
98
+ def _ensure_loaded() -> Optional[dict[str, Any]]:
99
+ """Lazy-load data: use cache if valid, otherwise fetch from API."""
100
+ global _data, _loaded
101
+ if _loaded:
102
+ return _data
103
+
104
+ with _lock:
105
+ # Double-check after acquiring lock
106
+ if _loaded:
107
+ return _data
108
+
109
+ if _cache_is_valid():
110
+ _data = _read_cache()
111
+ if _data is not None:
112
+ _loaded = True
113
+ return _data
114
+
115
+ # Cache missing or expired — fetch fresh
116
+ fresh = _fetch_from_api()
117
+ if fresh is not None:
118
+ _data = fresh
119
+ _write_cache(fresh)
120
+ else:
121
+ # Fetch failed — try stale cache as last resort
122
+ _data = _read_cache()
123
+
124
+ _loaded = True
125
+ return _data
126
+
127
+
128
+ def _lookup_model(provider: str, model_id: str) -> Optional[dict[str, Any]]:
129
+ """Find a model entry in the cached data.
130
+
131
+ The API structure is ``{provider: {model_id: {...}, ...}, ...}``.
132
+ """
133
+ data = _ensure_loaded()
134
+ if data is None:
135
+ return None
136
+
137
+ api_provider = PROVIDER_MAP.get(provider, provider)
138
+ provider_data = data.get(api_provider)
139
+ if not isinstance(provider_data, dict):
140
+ return None
141
+
142
+ return provider_data.get(model_id)
143
+
144
+
145
+ # ── Public API ──────────────────────────────────────────────────────────────
146
+
147
+
148
+ def get_model_rates(provider: str, model_id: str) -> Optional[dict[str, float]]:
149
+ """Return pricing dict for a model, or ``None`` if unavailable.
150
+
151
+ Returned keys mirror models.dev cost fields (per 1M tokens):
152
+ ``input``, ``output``, and optionally ``cache_read``, ``cache_write``,
153
+ ``reasoning``.
154
+ """
155
+ entry = _lookup_model(provider, model_id)
156
+ if entry is None:
157
+ return None
158
+
159
+ cost = entry.get("cost")
160
+ if not isinstance(cost, dict):
161
+ return None
162
+
163
+ rates: dict[str, float] = {}
164
+ for key in ("input", "output", "cache_read", "cache_write", "reasoning"):
165
+ val = cost.get(key)
166
+ if val is not None:
167
+ with contextlib.suppress(TypeError, ValueError):
168
+ rates[key] = float(val)
169
+
170
+ # Must have at least input and output to be useful
171
+ if "input" in rates and "output" in rates:
172
+ return rates
173
+ return None
174
+
175
+
176
+ def get_model_info(provider: str, model_id: str) -> Optional[dict[str, Any]]:
177
+ """Return full model metadata (cost, limits, capabilities), or ``None``."""
178
+ return _lookup_model(provider, model_id)
179
+
180
+
181
+ def get_all_provider_models(provider: str) -> list[str]:
182
+ """Return list of model IDs available for a provider."""
183
+ data = _ensure_loaded()
184
+ if data is None:
185
+ return []
186
+
187
+ api_provider = PROVIDER_MAP.get(provider, provider)
188
+ provider_data = data.get(api_provider)
189
+ if not isinstance(provider_data, dict):
190
+ return []
191
+
192
+ return list(provider_data.keys())
193
+
194
+
195
+ def refresh_rates_cache(force: bool = False) -> bool:
196
+ """Fetch fresh data from models.dev.
197
+
198
+ Args:
199
+ force: If ``True``, fetch even when the cache is still within TTL.
200
+
201
+ Returns:
202
+ ``True`` if fresh data was fetched and cached successfully.
203
+ """
204
+ global _data, _loaded
205
+
206
+ with _lock:
207
+ if not force and _cache_is_valid():
208
+ return False
209
+
210
+ fresh = _fetch_from_api()
211
+ if fresh is not None:
212
+ _data = fresh
213
+ _write_cache(fresh)
214
+ _loaded = True
215
+ return True
216
+
217
+ return False
prompture/runner.py CHANGED
@@ -1,13 +1,15 @@
1
1
  """Test suite runner for executing JSON validation tests across multiple models."""
2
- from typing import Dict, Any, List
3
2
 
4
- from .core import ask_for_json, Driver
3
+ from typing import Any
4
+
5
5
  from prompture.validator import validate_against_schema
6
6
 
7
+ from .core import Driver, ask_for_json
8
+
7
9
 
8
- def run_suite_from_spec(spec: Dict[str, Any], drivers: Dict[str, Driver]) -> Dict[str, Any]:
10
+ def run_suite_from_spec(spec: dict[str, Any], drivers: dict[str, Driver]) -> dict[str, Any]:
9
11
  """Run a test suite specified by a spec dictionary across multiple models.
10
-
12
+
11
13
  Args:
12
14
  spec: A dictionary containing the test suite specification with the structure:
13
15
  {
@@ -21,7 +23,7 @@ def run_suite_from_spec(spec: Dict[str, Any], drivers: Dict[str, Driver]) -> Dic
21
23
  }, ...]
22
24
  }
23
25
  drivers: A dictionary mapping driver names to driver instances
24
-
26
+
25
27
  Returns:
26
28
  A dictionary containing test results with the structure:
27
29
  {
@@ -42,67 +44,67 @@ def run_suite_from_spec(spec: Dict[str, Any], drivers: Dict[str, Driver]) -> Dic
42
44
  }
43
45
  """
44
46
  results = []
45
-
47
+
46
48
  for test in spec["tests"]:
47
49
  for model in spec["models"]:
48
50
  driver = drivers.get(model["driver"])
49
51
  if not driver:
50
52
  continue
51
-
53
+
52
54
  # Run test for each input
53
55
  for input_data in test["inputs"]:
54
56
  # Format prompt template with input data
55
57
  try:
56
58
  prompt = test["prompt_template"].format(**input_data)
57
59
  except KeyError as e:
58
- results.append({
59
- "test_id": test["id"],
60
- "model_id": model["id"],
61
- "input": input_data,
62
- "prompt": test["prompt_template"],
63
- "error": f"Template formatting error: missing key {e}",
64
- "validation": {"ok": False, "error": "Prompt formatting failed", "data": None},
65
- "usage": {"total_tokens": 0, "cost": 0}
66
- })
60
+ results.append(
61
+ {
62
+ "test_id": test["id"],
63
+ "model_id": model["id"],
64
+ "input": input_data,
65
+ "prompt": test["prompt_template"],
66
+ "error": f"Template formatting error: missing key {e}",
67
+ "validation": {"ok": False, "error": "Prompt formatting failed", "data": None},
68
+ "usage": {"total_tokens": 0, "cost": 0},
69
+ }
70
+ )
67
71
  continue
68
-
72
+
69
73
  # Get JSON response from model
70
74
  try:
71
75
  response = ask_for_json(
72
76
  driver=driver,
73
77
  content_prompt=prompt,
74
78
  json_schema=test["schema"],
75
- options=model.get("options", {})
79
+ options=model.get("options", {}),
76
80
  )
77
-
81
+
78
82
  # Validate response against schema
79
- validation = validate_against_schema(
80
- response["json_string"],
81
- test["schema"]
83
+ validation = validate_against_schema(response["json_string"], test["schema"])
84
+
85
+ results.append(
86
+ {
87
+ "test_id": test["id"],
88
+ "model_id": model["id"],
89
+ "input": input_data,
90
+ "prompt": prompt,
91
+ "response": response["json_object"],
92
+ "validation": validation,
93
+ "usage": response["usage"],
94
+ }
82
95
  )
83
-
84
- results.append({
85
- "test_id": test["id"],
86
- "model_id": model["id"],
87
- "input": input_data,
88
- "prompt": prompt,
89
- "response": response["json_object"],
90
- "validation": validation,
91
- "usage": response["usage"]
92
- })
93
-
96
+
94
97
  except Exception as e:
95
- results.append({
96
- "test_id": test["id"],
97
- "model_id": model["id"],
98
- "input": input_data,
99
- "prompt": prompt,
100
- "error": str(e),
101
- "validation": {"ok": False, "error": "Model response error", "data": None},
102
- "usage": {"total_tokens": 0, "cost": 0}
103
- })
104
-
105
- return {
106
- "meta": spec.get("meta", {}),
107
- "results": results
108
- }
98
+ results.append(
99
+ {
100
+ "test_id": test["id"],
101
+ "model_id": model["id"],
102
+ "input": input_data,
103
+ "prompt": prompt,
104
+ "error": str(e),
105
+ "validation": {"ok": False, "error": "Model response error", "data": None},
106
+ "usage": {"total_tokens": 0, "cost": 0},
107
+ }
108
+ )
109
+
110
+ return {"meta": spec.get("meta", {}), "results": results}
prompture/session.py ADDED
@@ -0,0 +1,117 @@
1
+ """Usage session tracking for Prompture.
2
+
3
+ Provides :class:`UsageSession` which accumulates token counts, costs, and
4
+ errors across multiple driver calls. A session instance is compatible as
5
+ both an ``on_response`` and ``on_error`` callback, so you can wire it
6
+ directly into :class:`~prompture.callbacks.DriverCallbacks`.
7
+
8
+ Usage::
9
+
10
+ from prompture import UsageSession, DriverCallbacks
11
+
12
+ session = UsageSession()
13
+ callbacks = DriverCallbacks(
14
+ on_response=session.record,
15
+ on_error=session.record_error,
16
+ )
17
+
18
+ # ... pass *callbacks* to your driver / conversation ...
19
+
20
+ print(session.summary()["formatted"])
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from dataclasses import dataclass, field
26
+ from typing import Any
27
+
28
+
29
+ @dataclass
30
+ class UsageSession:
31
+ """Accumulates usage statistics across multiple driver calls."""
32
+
33
+ prompt_tokens: int = 0
34
+ completion_tokens: int = 0
35
+ total_tokens: int = 0
36
+ total_cost: float = 0.0
37
+ call_count: int = 0
38
+ errors: int = 0
39
+ _per_model: dict[str, dict[str, Any]] = field(default_factory=dict, repr=False)
40
+
41
+ # ------------------------------------------------------------------ #
42
+ # Recording
43
+ # ------------------------------------------------------------------ #
44
+
45
+ def record(self, response_info: dict[str, Any]) -> None:
46
+ """Record a successful driver response.
47
+
48
+ Compatible as an ``on_response`` callback for
49
+ :class:`~prompture.callbacks.DriverCallbacks`.
50
+
51
+ Args:
52
+ response_info: Payload dict with at least ``meta`` and
53
+ optionally ``driver`` keys.
54
+ """
55
+ meta = response_info.get("meta", {})
56
+ pt = meta.get("prompt_tokens", 0)
57
+ ct = meta.get("completion_tokens", 0)
58
+ tt = meta.get("total_tokens", 0)
59
+ cost = meta.get("cost", 0.0)
60
+
61
+ self.prompt_tokens += pt
62
+ self.completion_tokens += ct
63
+ self.total_tokens += tt
64
+ self.total_cost += cost
65
+ self.call_count += 1
66
+
67
+ model = response_info.get("driver", "unknown")
68
+ bucket = self._per_model.setdefault(
69
+ model,
70
+ {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "cost": 0.0, "calls": 0},
71
+ )
72
+ bucket["prompt_tokens"] += pt
73
+ bucket["completion_tokens"] += ct
74
+ bucket["total_tokens"] += tt
75
+ bucket["cost"] += cost
76
+ bucket["calls"] += 1
77
+
78
+ def record_error(self, error_info: dict[str, Any]) -> None:
79
+ """Record a driver error.
80
+
81
+ Compatible as an ``on_error`` callback for
82
+ :class:`~prompture.callbacks.DriverCallbacks`.
83
+ """
84
+ self.errors += 1
85
+
86
+ # ------------------------------------------------------------------ #
87
+ # Reporting
88
+ # ------------------------------------------------------------------ #
89
+
90
+ def summary(self) -> dict[str, Any]:
91
+ """Return a machine-readable summary with a ``formatted`` string."""
92
+ formatted = (
93
+ f"Session: {self.total_tokens:,} tokens across {self.call_count} call(s) costing ${self.total_cost:.4f}"
94
+ )
95
+ if self.errors:
96
+ formatted += f" ({self.errors} error(s))"
97
+
98
+ return {
99
+ "prompt_tokens": self.prompt_tokens,
100
+ "completion_tokens": self.completion_tokens,
101
+ "total_tokens": self.total_tokens,
102
+ "total_cost": self.total_cost,
103
+ "call_count": self.call_count,
104
+ "errors": self.errors,
105
+ "per_model": dict(self._per_model),
106
+ "formatted": formatted,
107
+ }
108
+
109
+ def reset(self) -> None:
110
+ """Clear all accumulated counters."""
111
+ self.prompt_tokens = 0
112
+ self.completion_tokens = 0
113
+ self.total_tokens = 0
114
+ self.total_cost = 0.0
115
+ self.call_count = 0
116
+ self.errors = 0
117
+ self._per_model.clear()
prompture/settings.py CHANGED
@@ -1,6 +1,8 @@
1
- from pydantic_settings import BaseSettings, SettingsConfigDict
2
1
  from typing import Optional
3
2
 
3
+ from pydantic_settings import BaseSettings, SettingsConfigDict
4
+
5
+
4
6
  class Settings(BaseSettings):
5
7
  """Application settings loaded from environment variables or .env file."""
6
8
 
@@ -52,6 +54,17 @@ class Settings(BaseSettings):
52
54
  airllm_model: str = "meta-llama/Llama-2-7b-hf"
53
55
  airllm_compression: Optional[str] = None # "4bit" or "8bit"
54
56
 
57
+ # Model rates cache
58
+ model_rates_ttl_days: int = 7 # How often to refresh models.dev cache
59
+
60
+ # Response cache
61
+ cache_enabled: bool = False
62
+ cache_backend: str = "memory"
63
+ cache_ttl_seconds: int = 3600
64
+ cache_memory_maxsize: int = 256
65
+ cache_sqlite_path: Optional[str] = None
66
+ cache_redis_url: Optional[str] = None
67
+
55
68
  model_config = SettingsConfigDict(
56
69
  env_file=".env",
57
70
  extra="ignore",