prompture 0.0.32.dev1__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +7 -1
- prompture/discovery.py +11 -1
- prompture/drivers/__init__.py +6 -0
- prompture/drivers/airllm_driver.py +116 -0
- prompture/drivers/azure_driver.py +10 -4
- prompture/drivers/claude_driver.py +10 -4
- prompture/drivers/google_driver.py +14 -6
- prompture/drivers/grok_driver.py +10 -4
- prompture/drivers/groq_driver.py +10 -4
- prompture/drivers/openai_driver.py +10 -4
- prompture/drivers/openrouter_driver.py +10 -4
- prompture/model_rates.py +216 -0
- prompture/settings.py +7 -0
- {prompture-0.0.32.dev1.dist-info → prompture-0.0.33.dist-info}/METADATA +3 -1
- prompture-0.0.33.dist-info/RECORD +30 -0
- {prompture-0.0.32.dev1.dist-info → prompture-0.0.33.dist-info}/WHEEL +1 -1
- prompture-0.0.32.dev1.dist-info/RECORD +0 -28
- {prompture-0.0.32.dev1.dist-info → prompture-0.0.33.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.32.dev1.dist-info → prompture-0.0.33.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.32.dev1.dist-info → prompture-0.0.33.dist-info}/top_level.txt +0 -0
prompture/__init__.py
CHANGED
|
@@ -13,7 +13,7 @@ from .core import (
|
|
|
13
13
|
extract_from_pandas,
|
|
14
14
|
render_output,
|
|
15
15
|
)
|
|
16
|
-
from .drivers import get_driver, get_driver_for_model, OpenAIDriver, LocalHTTPDriver, OllamaDriver, ClaudeDriver, LMStudioDriver, AzureDriver, GoogleDriver, GroqDriver, OpenRouterDriver, GrokDriver
|
|
16
|
+
from .drivers import get_driver, get_driver_for_model, OpenAIDriver, LocalHTTPDriver, OllamaDriver, ClaudeDriver, LMStudioDriver, AzureDriver, GoogleDriver, GroqDriver, OpenRouterDriver, GrokDriver, AirLLMDriver
|
|
17
17
|
from .tools import clean_json_text, clean_toon_text
|
|
18
18
|
from .field_definitions import (
|
|
19
19
|
FIELD_DEFINITIONS, get_field_definition, get_required_fields, get_field_names,
|
|
@@ -24,6 +24,7 @@ from .field_definitions import (
|
|
|
24
24
|
from .runner import run_suite_from_spec
|
|
25
25
|
from .validator import validate_against_schema
|
|
26
26
|
from .discovery import get_available_models
|
|
27
|
+
from .model_rates import get_model_rates, get_model_info, refresh_rates_cache
|
|
27
28
|
|
|
28
29
|
# Load environment variables from .env file
|
|
29
30
|
load_dotenv()
|
|
@@ -87,6 +88,11 @@ __all__ = [
|
|
|
87
88
|
"GroqDriver",
|
|
88
89
|
"OpenRouterDriver",
|
|
89
90
|
"GrokDriver",
|
|
91
|
+
"AirLLMDriver",
|
|
90
92
|
# Discovery
|
|
91
93
|
"get_available_models",
|
|
94
|
+
# Model Rates
|
|
95
|
+
"get_model_rates",
|
|
96
|
+
"get_model_info",
|
|
97
|
+
"refresh_rates_cache",
|
|
92
98
|
]
|
prompture/discovery.py
CHANGED
|
@@ -33,7 +33,8 @@ def get_available_models() -> List[str]:
|
|
|
33
33
|
A list of unique model strings in the format "provider/model_id".
|
|
34
34
|
"""
|
|
35
35
|
available_models: Set[str] = set()
|
|
36
|
-
|
|
36
|
+
configured_providers: Set[str] = set()
|
|
37
|
+
|
|
37
38
|
# Map of provider name to driver class
|
|
38
39
|
# We need to map the registry keys to the actual classes to check MODEL_PRICING
|
|
39
40
|
# and instantiate for dynamic checks if needed.
|
|
@@ -96,6 +97,8 @@ def get_available_models() -> List[str]:
|
|
|
96
97
|
if not is_configured:
|
|
97
98
|
continue
|
|
98
99
|
|
|
100
|
+
configured_providers.add(provider)
|
|
101
|
+
|
|
99
102
|
# 2. Static Detection: Get models from MODEL_PRICING
|
|
100
103
|
if hasattr(driver_cls, "MODEL_PRICING"):
|
|
101
104
|
pricing = driver_cls.MODEL_PRICING
|
|
@@ -146,4 +149,11 @@ def get_available_models() -> List[str]:
|
|
|
146
149
|
logger.warning(f"Error detecting models for provider {provider}: {e}")
|
|
147
150
|
continue
|
|
148
151
|
|
|
152
|
+
# Enrich with live model list from models.dev cache
|
|
153
|
+
from .model_rates import get_all_provider_models, PROVIDER_MAP
|
|
154
|
+
for prompture_name, api_name in PROVIDER_MAP.items():
|
|
155
|
+
if prompture_name in configured_providers:
|
|
156
|
+
for model_id in get_all_provider_models(api_name):
|
|
157
|
+
available_models.add(f"{prompture_name}/{model_id}")
|
|
158
|
+
|
|
149
159
|
return sorted(list(available_models))
|
prompture/drivers/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from .google_driver import GoogleDriver
|
|
|
8
8
|
from .groq_driver import GroqDriver
|
|
9
9
|
from .openrouter_driver import OpenRouterDriver
|
|
10
10
|
from .grok_driver import GrokDriver
|
|
11
|
+
from .airllm_driver import AirLLMDriver
|
|
11
12
|
from ..settings import settings
|
|
12
13
|
|
|
13
14
|
|
|
@@ -54,6 +55,10 @@ DRIVER_REGISTRY = {
|
|
|
54
55
|
api_key=settings.grok_api_key,
|
|
55
56
|
model=model or settings.grok_model
|
|
56
57
|
),
|
|
58
|
+
"airllm": lambda model=None: AirLLMDriver(
|
|
59
|
+
model=model or settings.airllm_model,
|
|
60
|
+
compression=settings.airllm_compression,
|
|
61
|
+
),
|
|
57
62
|
}
|
|
58
63
|
|
|
59
64
|
|
|
@@ -115,6 +120,7 @@ __all__ = [
|
|
|
115
120
|
"GroqDriver",
|
|
116
121
|
"OpenRouterDriver",
|
|
117
122
|
"GrokDriver",
|
|
123
|
+
"AirLLMDriver",
|
|
118
124
|
"get_driver",
|
|
119
125
|
"get_driver_for_model",
|
|
120
126
|
]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from ..driver import Driver
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AirLLMDriver(Driver):
|
|
9
|
+
"""Driver for AirLLM — run large models (70B+) on consumer GPUs via
|
|
10
|
+
layer-by-layer memory management.
|
|
11
|
+
|
|
12
|
+
The ``airllm`` package is a lazy dependency: it is imported on first
|
|
13
|
+
``generate()`` call so the rest of Prompture works without it installed.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
MODEL_PRICING = {
|
|
17
|
+
"default": {"prompt": 0.0, "completion": 0.0}
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
def __init__(self, model: str = "meta-llama/Llama-2-7b-hf",
|
|
21
|
+
compression: Optional[str] = None):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
model: HuggingFace repo ID (e.g. ``"meta-llama/Llama-2-70b-hf"``).
|
|
25
|
+
compression: Optional quantization mode — ``"4bit"`` or ``"8bit"``.
|
|
26
|
+
"""
|
|
27
|
+
self.model = model
|
|
28
|
+
self.compression = compression
|
|
29
|
+
self.options: Dict[str, Any] = {}
|
|
30
|
+
self._llm = None
|
|
31
|
+
self._tokenizer = None
|
|
32
|
+
|
|
33
|
+
# ------------------------------------------------------------------
|
|
34
|
+
# Lazy model loading
|
|
35
|
+
# ------------------------------------------------------------------
|
|
36
|
+
def _ensure_loaded(self):
|
|
37
|
+
"""Load the AirLLM model and tokenizer on first use."""
|
|
38
|
+
if self._llm is not None:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from airllm import AutoModel
|
|
43
|
+
except ImportError:
|
|
44
|
+
raise ImportError(
|
|
45
|
+
"The 'airllm' package is required for the AirLLM driver. "
|
|
46
|
+
"Install it with: pip install prompture[airllm]"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
from transformers import AutoTokenizer
|
|
51
|
+
except ImportError:
|
|
52
|
+
raise ImportError(
|
|
53
|
+
"The 'transformers' package is required for the AirLLM driver. "
|
|
54
|
+
"Install it with: pip install transformers"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
logger.info(f"Loading AirLLM model: {self.model} "
|
|
58
|
+
f"(compression={self.compression})")
|
|
59
|
+
|
|
60
|
+
load_kwargs: Dict[str, Any] = {}
|
|
61
|
+
if self.compression:
|
|
62
|
+
load_kwargs["compression"] = self.compression
|
|
63
|
+
|
|
64
|
+
self._llm = AutoModel.from_pretrained(self.model, **load_kwargs)
|
|
65
|
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
|
|
66
|
+
logger.info("AirLLM model loaded successfully")
|
|
67
|
+
|
|
68
|
+
# ------------------------------------------------------------------
|
|
69
|
+
# Driver interface
|
|
70
|
+
# ------------------------------------------------------------------
|
|
71
|
+
def generate(self, prompt: str, options: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
72
|
+
self._ensure_loaded()
|
|
73
|
+
|
|
74
|
+
merged_options = self.options.copy()
|
|
75
|
+
if options:
|
|
76
|
+
merged_options.update(options)
|
|
77
|
+
|
|
78
|
+
max_new_tokens = merged_options.get("max_new_tokens", 256)
|
|
79
|
+
|
|
80
|
+
# Tokenize
|
|
81
|
+
input_ids = self._tokenizer(
|
|
82
|
+
prompt, return_tensors="pt"
|
|
83
|
+
).input_ids
|
|
84
|
+
|
|
85
|
+
prompt_tokens = input_ids.shape[1]
|
|
86
|
+
|
|
87
|
+
logger.debug(f"AirLLM generating with max_new_tokens={max_new_tokens}, "
|
|
88
|
+
f"prompt_tokens={prompt_tokens}")
|
|
89
|
+
|
|
90
|
+
# Generate
|
|
91
|
+
output_ids = self._llm.generate(
|
|
92
|
+
input_ids,
|
|
93
|
+
max_new_tokens=max_new_tokens,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Decode only the newly generated tokens (strip the prompt prefix)
|
|
97
|
+
new_tokens = output_ids[0, prompt_tokens:]
|
|
98
|
+
completion_tokens = len(new_tokens)
|
|
99
|
+
text = self._tokenizer.decode(new_tokens, skip_special_tokens=True)
|
|
100
|
+
|
|
101
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
102
|
+
|
|
103
|
+
meta = {
|
|
104
|
+
"prompt_tokens": prompt_tokens,
|
|
105
|
+
"completion_tokens": completion_tokens,
|
|
106
|
+
"total_tokens": total_tokens,
|
|
107
|
+
"cost": 0.0,
|
|
108
|
+
"raw_response": {
|
|
109
|
+
"model": self.model,
|
|
110
|
+
"compression": self.compression,
|
|
111
|
+
"max_new_tokens": max_new_tokens,
|
|
112
|
+
},
|
|
113
|
+
"model_name": self.model,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return {"text": text, "meta": meta}
|
|
@@ -111,10 +111,16 @@ class AzureDriver(Driver):
|
|
|
111
111
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
112
112
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
113
113
|
|
|
114
|
-
# Calculate cost
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
114
|
+
# Calculate cost — try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
|
|
115
|
+
from ..model_rates import get_model_rates
|
|
116
|
+
live_rates = get_model_rates("azure", model)
|
|
117
|
+
if live_rates:
|
|
118
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
119
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
120
|
+
else:
|
|
121
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
122
|
+
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
123
|
+
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
118
124
|
total_cost = prompt_cost + completion_cost
|
|
119
125
|
|
|
120
126
|
# Standardized meta object
|
|
@@ -64,10 +64,16 @@ class ClaudeDriver(Driver):
|
|
|
64
64
|
completion_tokens = resp.usage.output_tokens
|
|
65
65
|
total_tokens = prompt_tokens + completion_tokens
|
|
66
66
|
|
|
67
|
-
# Calculate cost
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
67
|
+
# Calculate cost — try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
|
|
68
|
+
from ..model_rates import get_model_rates
|
|
69
|
+
live_rates = get_model_rates("claude", model)
|
|
70
|
+
if live_rates:
|
|
71
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
72
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
73
|
+
else:
|
|
74
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
75
|
+
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
76
|
+
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
71
77
|
total_cost = prompt_cost + completion_cost
|
|
72
78
|
|
|
73
79
|
# Create standardized meta object
|
|
@@ -134,14 +134,22 @@ class GoogleDriver(Driver):
|
|
|
134
134
|
raise ValueError("Empty response from model")
|
|
135
135
|
|
|
136
136
|
# Calculate token usage and cost
|
|
137
|
-
# Note: Using character count as proxy since Google charges per character
|
|
138
137
|
prompt_chars = len(prompt)
|
|
139
138
|
completion_chars = len(response.text)
|
|
140
|
-
|
|
141
|
-
#
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
139
|
+
|
|
140
|
+
# Try live rates first (per 1M tokens), fall back to hardcoded character-based pricing
|
|
141
|
+
from ..model_rates import get_model_rates
|
|
142
|
+
live_rates = get_model_rates("google", self.model)
|
|
143
|
+
if live_rates:
|
|
144
|
+
# models.dev reports token-based pricing; estimate tokens from chars (~4 chars/token)
|
|
145
|
+
est_prompt_tokens = prompt_chars / 4
|
|
146
|
+
est_completion_tokens = completion_chars / 4
|
|
147
|
+
prompt_cost = (est_prompt_tokens / 1_000_000) * live_rates["input"]
|
|
148
|
+
completion_cost = (est_completion_tokens / 1_000_000) * live_rates["output"]
|
|
149
|
+
else:
|
|
150
|
+
model_pricing = self.MODEL_PRICING.get(self.model, {"prompt": 0, "completion": 0})
|
|
151
|
+
prompt_cost = (prompt_chars / 1_000_000) * model_pricing["prompt"]
|
|
152
|
+
completion_cost = (completion_chars / 1_000_000) * model_pricing["completion"]
|
|
145
153
|
total_cost = prompt_cost + completion_cost
|
|
146
154
|
|
|
147
155
|
meta = {
|
prompture/drivers/grok_driver.py
CHANGED
|
@@ -133,10 +133,16 @@ class GrokDriver(Driver):
|
|
|
133
133
|
completion_tokens = usage.get("completion_tokens", 0)
|
|
134
134
|
total_tokens = usage.get("total_tokens", 0)
|
|
135
135
|
|
|
136
|
-
# Calculate cost
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
136
|
+
# Calculate cost — try live rates first (per 1M tokens), fall back to hardcoded (per 1M tokens)
|
|
137
|
+
from ..model_rates import get_model_rates
|
|
138
|
+
live_rates = get_model_rates("grok", model)
|
|
139
|
+
if live_rates:
|
|
140
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
141
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
142
|
+
else:
|
|
143
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
144
|
+
prompt_cost = (prompt_tokens / 1_000_000) * model_pricing["prompt"]
|
|
145
|
+
completion_cost = (completion_tokens / 1_000_000) * model_pricing["completion"]
|
|
140
146
|
total_cost = prompt_cost + completion_cost
|
|
141
147
|
|
|
142
148
|
# Standardized meta object
|
prompture/drivers/groq_driver.py
CHANGED
|
@@ -96,10 +96,16 @@ class GroqDriver(Driver):
|
|
|
96
96
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
97
97
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
98
98
|
|
|
99
|
-
# Calculate costs
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
99
|
+
# Calculate costs — try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
|
|
100
|
+
from ..model_rates import get_model_rates
|
|
101
|
+
live_rates = get_model_rates("groq", model)
|
|
102
|
+
if live_rates:
|
|
103
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
104
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
105
|
+
else:
|
|
106
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
107
|
+
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
108
|
+
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
103
109
|
total_cost = prompt_cost + completion_cost
|
|
104
110
|
|
|
105
111
|
# Standard metadata object
|
|
@@ -97,10 +97,16 @@ class OpenAIDriver(Driver):
|
|
|
97
97
|
completion_tokens = getattr(usage, "completion_tokens", 0)
|
|
98
98
|
total_tokens = getattr(usage, "total_tokens", 0)
|
|
99
99
|
|
|
100
|
-
# Calculate cost
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
100
|
+
# Calculate cost — try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
|
|
101
|
+
from ..model_rates import get_model_rates
|
|
102
|
+
live_rates = get_model_rates("openai", model)
|
|
103
|
+
if live_rates:
|
|
104
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
105
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
106
|
+
else:
|
|
107
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
108
|
+
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
109
|
+
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
104
110
|
total_cost = prompt_cost + completion_cost
|
|
105
111
|
|
|
106
112
|
# Standardized meta object
|
|
@@ -110,10 +110,16 @@ class OpenRouterDriver(Driver):
|
|
|
110
110
|
completion_tokens = usage.get("completion_tokens", 0)
|
|
111
111
|
total_tokens = usage.get("total_tokens", 0)
|
|
112
112
|
|
|
113
|
-
# Calculate cost
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
113
|
+
# Calculate cost — try live rates first (per 1M tokens), fall back to hardcoded (per 1K tokens)
|
|
114
|
+
from ..model_rates import get_model_rates
|
|
115
|
+
live_rates = get_model_rates("openrouter", model)
|
|
116
|
+
if live_rates:
|
|
117
|
+
prompt_cost = (prompt_tokens / 1_000_000) * live_rates["input"]
|
|
118
|
+
completion_cost = (completion_tokens / 1_000_000) * live_rates["output"]
|
|
119
|
+
else:
|
|
120
|
+
model_pricing = self.MODEL_PRICING.get(model, {"prompt": 0, "completion": 0})
|
|
121
|
+
prompt_cost = (prompt_tokens / 1000) * model_pricing["prompt"]
|
|
122
|
+
completion_cost = (completion_tokens / 1000) * model_pricing["completion"]
|
|
117
123
|
total_cost = prompt_cost + completion_cost
|
|
118
124
|
|
|
119
125
|
# Standardized meta object
|
prompture/model_rates.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Live model rates from models.dev API with local caching.
|
|
2
|
+
|
|
3
|
+
Fetches pricing and metadata for LLM models from https://models.dev/api.json,
|
|
4
|
+
caches locally with TTL-based auto-refresh, and provides lookup functions
|
|
5
|
+
used by drivers for cost calculations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import threading
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, List, Optional
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
# Maps prompture provider names to models.dev provider names
|
|
18
|
+
PROVIDER_MAP: Dict[str, str] = {
|
|
19
|
+
"openai": "openai",
|
|
20
|
+
"claude": "anthropic",
|
|
21
|
+
"google": "google",
|
|
22
|
+
"groq": "groq",
|
|
23
|
+
"grok": "xai",
|
|
24
|
+
"azure": "azure",
|
|
25
|
+
"openrouter": "openrouter",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
_API_URL = "https://models.dev/api.json"
|
|
29
|
+
_CACHE_DIR = Path.home() / ".prompture" / "cache"
|
|
30
|
+
_CACHE_FILE = _CACHE_DIR / "models_dev.json"
|
|
31
|
+
_META_FILE = _CACHE_DIR / "models_dev_meta.json"
|
|
32
|
+
|
|
33
|
+
_lock = threading.Lock()
|
|
34
|
+
_data: Optional[Dict[str, Any]] = None
|
|
35
|
+
_loaded = False
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_ttl_days() -> int:
|
|
39
|
+
"""Get TTL from settings if available, otherwise default to 7."""
|
|
40
|
+
try:
|
|
41
|
+
from .settings import settings
|
|
42
|
+
return getattr(settings, "model_rates_ttl_days", 7)
|
|
43
|
+
except Exception:
|
|
44
|
+
return 7
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _cache_is_valid() -> bool:
|
|
48
|
+
"""Check whether the local cache exists and is within TTL."""
|
|
49
|
+
if not _CACHE_FILE.exists() or not _META_FILE.exists():
|
|
50
|
+
return False
|
|
51
|
+
try:
|
|
52
|
+
meta = json.loads(_META_FILE.read_text(encoding="utf-8"))
|
|
53
|
+
fetched_at = datetime.fromisoformat(meta["fetched_at"])
|
|
54
|
+
ttl_days = meta.get("ttl_days", _get_ttl_days())
|
|
55
|
+
age = datetime.now(timezone.utc) - fetched_at
|
|
56
|
+
return age.total_seconds() < ttl_days * 86400
|
|
57
|
+
except Exception:
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _write_cache(data: Dict[str, Any]) -> None:
|
|
62
|
+
"""Write API data and metadata to local cache."""
|
|
63
|
+
try:
|
|
64
|
+
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
65
|
+
_CACHE_FILE.write_text(json.dumps(data), encoding="utf-8")
|
|
66
|
+
meta = {
|
|
67
|
+
"fetched_at": datetime.now(timezone.utc).isoformat(),
|
|
68
|
+
"ttl_days": _get_ttl_days(),
|
|
69
|
+
}
|
|
70
|
+
_META_FILE.write_text(json.dumps(meta), encoding="utf-8")
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
logger.debug("Failed to write model rates cache: %s", exc)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _read_cache() -> Optional[Dict[str, Any]]:
|
|
76
|
+
"""Read cached API data from disk."""
|
|
77
|
+
try:
|
|
78
|
+
return json.loads(_CACHE_FILE.read_text(encoding="utf-8"))
|
|
79
|
+
except Exception:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _fetch_from_api() -> Optional[Dict[str, Any]]:
|
|
84
|
+
"""Fetch fresh data from models.dev API."""
|
|
85
|
+
try:
|
|
86
|
+
import requests
|
|
87
|
+
resp = requests.get(_API_URL, timeout=15)
|
|
88
|
+
resp.raise_for_status()
|
|
89
|
+
return resp.json()
|
|
90
|
+
except Exception as exc:
|
|
91
|
+
logger.debug("Failed to fetch model rates from %s: %s", _API_URL, exc)
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _ensure_loaded() -> Optional[Dict[str, Any]]:
|
|
96
|
+
"""Lazy-load data: use cache if valid, otherwise fetch from API."""
|
|
97
|
+
global _data, _loaded
|
|
98
|
+
if _loaded:
|
|
99
|
+
return _data
|
|
100
|
+
|
|
101
|
+
with _lock:
|
|
102
|
+
# Double-check after acquiring lock
|
|
103
|
+
if _loaded:
|
|
104
|
+
return _data
|
|
105
|
+
|
|
106
|
+
if _cache_is_valid():
|
|
107
|
+
_data = _read_cache()
|
|
108
|
+
if _data is not None:
|
|
109
|
+
_loaded = True
|
|
110
|
+
return _data
|
|
111
|
+
|
|
112
|
+
# Cache missing or expired — fetch fresh
|
|
113
|
+
fresh = _fetch_from_api()
|
|
114
|
+
if fresh is not None:
|
|
115
|
+
_data = fresh
|
|
116
|
+
_write_cache(fresh)
|
|
117
|
+
else:
|
|
118
|
+
# Fetch failed — try stale cache as last resort
|
|
119
|
+
_data = _read_cache()
|
|
120
|
+
|
|
121
|
+
_loaded = True
|
|
122
|
+
return _data
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _lookup_model(provider: str, model_id: str) -> Optional[Dict[str, Any]]:
|
|
126
|
+
"""Find a model entry in the cached data.
|
|
127
|
+
|
|
128
|
+
The API structure is ``{provider: {model_id: {...}, ...}, ...}``.
|
|
129
|
+
"""
|
|
130
|
+
data = _ensure_loaded()
|
|
131
|
+
if data is None:
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
api_provider = PROVIDER_MAP.get(provider, provider)
|
|
135
|
+
provider_data = data.get(api_provider)
|
|
136
|
+
if not isinstance(provider_data, dict):
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
return provider_data.get(model_id)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# ── Public API ──────────────────────────────────────────────────────────────
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_model_rates(provider: str, model_id: str) -> Optional[Dict[str, float]]:
|
|
146
|
+
"""Return pricing dict for a model, or ``None`` if unavailable.
|
|
147
|
+
|
|
148
|
+
Returned keys mirror models.dev cost fields (per 1M tokens):
|
|
149
|
+
``input``, ``output``, and optionally ``cache_read``, ``cache_write``,
|
|
150
|
+
``reasoning``.
|
|
151
|
+
"""
|
|
152
|
+
entry = _lookup_model(provider, model_id)
|
|
153
|
+
if entry is None:
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
cost = entry.get("cost")
|
|
157
|
+
if not isinstance(cost, dict):
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
rates: Dict[str, float] = {}
|
|
161
|
+
for key in ("input", "output", "cache_read", "cache_write", "reasoning"):
|
|
162
|
+
val = cost.get(key)
|
|
163
|
+
if val is not None:
|
|
164
|
+
try:
|
|
165
|
+
rates[key] = float(val)
|
|
166
|
+
except (TypeError, ValueError):
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
# Must have at least input and output to be useful
|
|
170
|
+
if "input" in rates and "output" in rates:
|
|
171
|
+
return rates
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_model_info(provider: str, model_id: str) -> Optional[Dict[str, Any]]:
|
|
176
|
+
"""Return full model metadata (cost, limits, capabilities), or ``None``."""
|
|
177
|
+
return _lookup_model(provider, model_id)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_all_provider_models(provider: str) -> List[str]:
|
|
181
|
+
"""Return list of model IDs available for a provider."""
|
|
182
|
+
data = _ensure_loaded()
|
|
183
|
+
if data is None:
|
|
184
|
+
return []
|
|
185
|
+
|
|
186
|
+
api_provider = PROVIDER_MAP.get(provider, provider)
|
|
187
|
+
provider_data = data.get(api_provider)
|
|
188
|
+
if not isinstance(provider_data, dict):
|
|
189
|
+
return []
|
|
190
|
+
|
|
191
|
+
return list(provider_data.keys())
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def refresh_rates_cache(force: bool = False) -> bool:
|
|
195
|
+
"""Fetch fresh data from models.dev.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
force: If ``True``, fetch even when the cache is still within TTL.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
``True`` if fresh data was fetched and cached successfully.
|
|
202
|
+
"""
|
|
203
|
+
global _data, _loaded
|
|
204
|
+
|
|
205
|
+
with _lock:
|
|
206
|
+
if not force and _cache_is_valid():
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
fresh = _fetch_from_api()
|
|
210
|
+
if fresh is not None:
|
|
211
|
+
_data = fresh
|
|
212
|
+
_write_cache(fresh)
|
|
213
|
+
_loaded = True
|
|
214
|
+
return True
|
|
215
|
+
|
|
216
|
+
return False
|
prompture/settings.py
CHANGED
|
@@ -48,6 +48,13 @@ class Settings(BaseSettings):
|
|
|
48
48
|
grok_api_key: Optional[str] = None
|
|
49
49
|
grok_model: str = "grok-4-fast-reasoning"
|
|
50
50
|
|
|
51
|
+
# AirLLM
|
|
52
|
+
airllm_model: str = "meta-llama/Llama-2-7b-hf"
|
|
53
|
+
airllm_compression: Optional[str] = None # "4bit" or "8bit"
|
|
54
|
+
|
|
55
|
+
# Model rates cache
|
|
56
|
+
model_rates_ttl_days: int = 7 # How often to refresh models.dev cache
|
|
57
|
+
|
|
51
58
|
model_config = SettingsConfigDict(
|
|
52
59
|
env_file=".env",
|
|
53
60
|
extra="ignore",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: prompture
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Ask LLMs to return structured JSON and run cross-model tests. API-first.
|
|
5
5
|
Home-page: https://github.com/jhd3197/prompture
|
|
6
6
|
Author: Juan Denis
|
|
@@ -29,6 +29,8 @@ Requires-Dist: tukuy>=0.0.6
|
|
|
29
29
|
Requires-Dist: pyyaml>=6.0
|
|
30
30
|
Provides-Extra: test
|
|
31
31
|
Requires-Dist: pytest>=7.0; extra == "test"
|
|
32
|
+
Provides-Extra: airllm
|
|
33
|
+
Requires-Dist: airllm>=2.8.0; extra == "airllm"
|
|
32
34
|
Dynamic: author
|
|
33
35
|
Dynamic: author-email
|
|
34
36
|
Dynamic: classifier
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
prompture/__init__.py,sha256=gnZYWKiWV_FTUeG9A88nkgPrAwtx6U_23UrOYhNTiOw,2891
|
|
2
|
+
prompture/cli.py,sha256=vA86GNjtKSHz8eRMl5YDaT9HHIWuhkeJtfx8jqTaqtM,809
|
|
3
|
+
prompture/core.py,sha256=x_FhOY37ygQVHo4zHUyiWsV4BuOClkELsVhEV-K4jJ0,53689
|
|
4
|
+
prompture/discovery.py,sha256=JbaOhZuf41yYOFdv6wZmRjfRPum_df5V5fVVbKyOPoY,7240
|
|
5
|
+
prompture/driver.py,sha256=w8pdXHujImIGF3ee8rkG8f6-UD0h2jLHhucSPInRrYI,989
|
|
6
|
+
prompture/field_definitions.py,sha256=6kDMYNedccTK5l2L_I8_NI3_av-iYHqGPwkKDy8214c,21731
|
|
7
|
+
prompture/model_rates.py,sha256=B3VdFFIPaJ31xSIVq96bAD3P4dnrIguauyNrD7WHCgQ,6428
|
|
8
|
+
prompture/runner.py,sha256=5xwal3iBQQj4_q7l3Rjr0e3RrUMJPaPDLiEchO0mmHo,4192
|
|
9
|
+
prompture/settings.py,sha256=F4RQt4HB3rOUMoKs1r-Y7W55Dvk_LdXyTD88S8mMojM,1730
|
|
10
|
+
prompture/tools.py,sha256=qyT8oJl_v9GolABkflW0SvEx22yNkEJZKTu-40nJbs0,40329
|
|
11
|
+
prompture/validator.py,sha256=oLzVsNveHuF-N_uOd11_uDa9Q5rFyo0wrk_l1N4zqDk,996
|
|
12
|
+
prompture/drivers/__init__.py,sha256=hi2u4Z2KQFfgqce1QvjRlDKRzB2xfJZpidGNMsQ82oI,4105
|
|
13
|
+
prompture/drivers/airllm_driver.py,sha256=g1WmQDwSfK0BIyG96JrZY7W_VHXOS7wDSeegE7B1q4Y,3956
|
|
14
|
+
prompture/drivers/azure_driver.py,sha256=t8RsGSexwPaM8VzakMRMpssh7Nf-StY-C5BfWmoXdzE,5016
|
|
15
|
+
prompture/drivers/claude_driver.py,sha256=KcJRIcS9OPK6IBs8pUxxcKFlBH_eivgKLJcDuUk1_YU,3665
|
|
16
|
+
prompture/drivers/google_driver.py,sha256=-fbnJ003VC01YApujNUC1lg7E4J9x-Jm8sEJfLX00cI,6876
|
|
17
|
+
prompture/drivers/grok_driver.py,sha256=24FxmqiZNF8znIATn7CnFExqP_XvivXyvoxVFnC4iW8,5400
|
|
18
|
+
prompture/drivers/groq_driver.py,sha256=FZPz1sPfYj86HjwtHX7U7YE60_oDAfr4TfS1I7NdKzI,4313
|
|
19
|
+
prompture/drivers/hugging_driver.py,sha256=rngz7hIR7l-9M_xe4EjWPaBqdyPFHdQsqnDDy9gm5So,2357
|
|
20
|
+
prompture/drivers/lmstudio_driver.py,sha256=Umy1kT211TAxxSPyQrtZnIGIZgqFeSV87FLTiPFF0CY,3455
|
|
21
|
+
prompture/drivers/local_http_driver.py,sha256=S2diikvtQOQHF7fB07zU2X0QWkej4Of__rJgaU2C6FI,1669
|
|
22
|
+
prompture/drivers/ollama_driver.py,sha256=fq_eFgwmCT3SK1D-ICHjxLjcm_An0suwkFIWC38xsS0,4681
|
|
23
|
+
prompture/drivers/openai_driver.py,sha256=pO12D_4jmbCKkSDRLtk5olb7UqBqZyY0sh6IUJK1fjE,4371
|
|
24
|
+
prompture/drivers/openrouter_driver.py,sha256=f4JWl3YApAgrvuskUz0athbdS82GZasclVKx1AA9-mA,5454
|
|
25
|
+
prompture-0.0.33.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
26
|
+
prompture-0.0.33.dist-info/METADATA,sha256=2hcu-U0S8qw5AIa5lPTwB6LV9RjhFhJfS6A2GC3X-sM,18109
|
|
27
|
+
prompture-0.0.33.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
28
|
+
prompture-0.0.33.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
29
|
+
prompture-0.0.33.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
30
|
+
prompture-0.0.33.dist-info/RECORD,,
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
prompture/__init__.py,sha256=kCcOseMTHaJkl-vtzXVbbBdWRQlIWWBr-C-l9E2mScU,2689
|
|
2
|
-
prompture/cli.py,sha256=vA86GNjtKSHz8eRMl5YDaT9HHIWuhkeJtfx8jqTaqtM,809
|
|
3
|
-
prompture/core.py,sha256=x_FhOY37ygQVHo4zHUyiWsV4BuOClkELsVhEV-K4jJ0,53689
|
|
4
|
-
prompture/discovery.py,sha256=qQ7Quz0Tqo0f2h9DqMlV7RqMP4XOeue_ZwzXq4bf6B8,6788
|
|
5
|
-
prompture/driver.py,sha256=w8pdXHujImIGF3ee8rkG8f6-UD0h2jLHhucSPInRrYI,989
|
|
6
|
-
prompture/field_definitions.py,sha256=6kDMYNedccTK5l2L_I8_NI3_av-iYHqGPwkKDy8214c,21731
|
|
7
|
-
prompture/runner.py,sha256=5xwal3iBQQj4_q7l3Rjr0e3RrUMJPaPDLiEchO0mmHo,4192
|
|
8
|
-
prompture/settings.py,sha256=vHRkBAZNP6yRsI2Sm4FMa_FCw0Zxy2VX97ooiVYWvks,1500
|
|
9
|
-
prompture/tools.py,sha256=qyT8oJl_v9GolABkflW0SvEx22yNkEJZKTu-40nJbs0,40329
|
|
10
|
-
prompture/validator.py,sha256=oLzVsNveHuF-N_uOd11_uDa9Q5rFyo0wrk_l1N4zqDk,996
|
|
11
|
-
prompture/drivers/__init__.py,sha256=IQ7DsWC_FP45h2CprWRhQ7lKi3-9ZO6CgweNX6IxTUA,3896
|
|
12
|
-
prompture/drivers/azure_driver.py,sha256=GROhK3hqMfMurnEgpAawa1DPS-FhOU0YQcgy9SNGTzM,4622
|
|
13
|
-
prompture/drivers/claude_driver.py,sha256=ZEHQNqNThLZ0p-WmGVuKiNyiudGYGP07xIzbgZhLY1g,3293
|
|
14
|
-
prompture/drivers/google_driver.py,sha256=bCsCSuCRise0L_HOmw-jBh1hrpd8glNBkVFlOZeP0DM,6338
|
|
15
|
-
prompture/drivers/grok_driver.py,sha256=Xp6L75oL3dN8St8_m46C_5bM8FcaIdNKUASAt9kZ39w,5003
|
|
16
|
-
prompture/drivers/groq_driver.py,sha256=91WGXP8G5dO0beuFO8FehZszlDC_X9hv_yPzQRGmcqw,3920
|
|
17
|
-
prompture/drivers/hugging_driver.py,sha256=rngz7hIR7l-9M_xe4EjWPaBqdyPFHdQsqnDDy9gm5So,2357
|
|
18
|
-
prompture/drivers/lmstudio_driver.py,sha256=Umy1kT211TAxxSPyQrtZnIGIZgqFeSV87FLTiPFF0CY,3455
|
|
19
|
-
prompture/drivers/local_http_driver.py,sha256=S2diikvtQOQHF7fB07zU2X0QWkej4Of__rJgaU2C6FI,1669
|
|
20
|
-
prompture/drivers/ollama_driver.py,sha256=fq_eFgwmCT3SK1D-ICHjxLjcm_An0suwkFIWC38xsS0,4681
|
|
21
|
-
prompture/drivers/openai_driver.py,sha256=9q9OjQslquRFvIl1Hd9JVmFFFVh6OBIWrFulw1mkYWg,3976
|
|
22
|
-
prompture/drivers/openrouter_driver.py,sha256=GKvLOFDhsyopH-k3iaD3VWllm7xbGuopRSA02MfCKoM,5031
|
|
23
|
-
prompture-0.0.32.dev1.dist-info/licenses/LICENSE,sha256=0HgDepH7aaHNFhHF-iXuW6_GqDfYPnVkjtiCAZ4yS8I,1060
|
|
24
|
-
prompture-0.0.32.dev1.dist-info/METADATA,sha256=3oNb4hhkYR7ZuLsrG5wrRxJjbuLnazBaQHKaW2yAM0Y,18043
|
|
25
|
-
prompture-0.0.32.dev1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
26
|
-
prompture-0.0.32.dev1.dist-info/entry_points.txt,sha256=AFPG3lJR86g4IJMoWQUW5Ph7G6MLNWG3A2u2Tp9zkp8,48
|
|
27
|
-
prompture-0.0.32.dev1.dist-info/top_level.txt,sha256=to86zq_kjfdoLeAxQNr420UWqT0WzkKoZ509J7Qr2t4,10
|
|
28
|
-
prompture-0.0.32.dev1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|