tokenable 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tokenable/__init__.py +3 -0
- tokenable/__main__.py +5 -0
- tokenable/calibration/__init__.py +129 -0
- tokenable/calibration/providers.py +202 -0
- tokenable/cli/__init__.py +695 -0
- tokenable/config/__init__.py +155 -0
- tokenable/enforcer/__init__.py +124 -0
- tokenable/estimator/__init__.py +192 -0
- tokenable/fixer/__init__.py +101 -0
- tokenable/mcp/__init__.py +249 -0
- tokenable/models/__init__.py +308 -0
- tokenable/pricing_sync.py +145 -0
- tokenable/providers/__init__.py +485 -0
- tokenable/providers/data/anthropic.json +452 -0
- tokenable/providers/data/benchmarks.json +324 -0
- tokenable/providers/data/google.json +318 -0
- tokenable/providers/data/openai.json +507 -0
- tokenable/providers/data/perplexity.json +88 -0
- tokenable/providers/data/xai.json +263 -0
- tokenable/py.typed +1 -0
- tokenable/recommender/__init__.py +209 -0
- tokenable/scanner/__init__.py +303 -0
- tokenable/telemetry/__init__.py +92 -0
- tokenable/utils/__init__.py +19 -0
- tokenable-1.0.0.dist-info/METADATA +196 -0
- tokenable-1.0.0.dist-info/RECORD +29 -0
- tokenable-1.0.0.dist-info/WHEEL +4 -0
- tokenable-1.0.0.dist-info/entry_points.txt +2 -0
- tokenable-1.0.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
"""Pricing database — model lookup, cost calculation, capability inference."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import re
|
|
7
|
+
from importlib.resources import files
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from tokenable.models import (
|
|
11
|
+
Capability,
|
|
12
|
+
ModelPricing,
|
|
13
|
+
ModelStatus,
|
|
14
|
+
ModelTier,
|
|
15
|
+
Provider,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# ── Tier thresholds ──────────────────────────────────────────────────
|
|
19
|
+
|
|
20
|
+
_TIER_BUDGET_MAX = 5.0
|
|
21
|
+
_TIER_PREMIUM_MIN = 20.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compute_tier(output_cost_per_million: float) -> ModelTier:
|
|
25
|
+
"""Derive cost tier from output pricing."""
|
|
26
|
+
if output_cost_per_million <= _TIER_BUDGET_MAX:
|
|
27
|
+
return ModelTier.BUDGET
|
|
28
|
+
if output_cost_per_million < _TIER_PREMIUM_MIN:
|
|
29
|
+
return ModelTier.MID
|
|
30
|
+
return ModelTier.PREMIUM
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ── Data loading ─────────────────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
_PROVIDER_DATA: dict[Provider, list[ModelPricing]] = {}
|
|
36
|
+
_BENCHMARKS: dict[str, dict] = {}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _load_provider_json(provider: Provider) -> list[ModelPricing]:
|
|
40
|
+
"""Load and parse a provider's pricing JSON."""
|
|
41
|
+
data_dir = files("tokenable.providers") / "data"
|
|
42
|
+
raw = (data_dir / f"{provider.value}.json").read_text("utf-8")
|
|
43
|
+
data = json.loads(raw)
|
|
44
|
+
models: list[ModelPricing] = []
|
|
45
|
+
for m in data.get("models", []):
|
|
46
|
+
caps = [Capability(c) for c in m.get("capabilities", [])]
|
|
47
|
+
status = ModelStatus(m.get("status", "current"))
|
|
48
|
+
tier = compute_tier(m.get("output_cost_per_million", 0))
|
|
49
|
+
models.append(
|
|
50
|
+
ModelPricing(
|
|
51
|
+
id=m["id"],
|
|
52
|
+
name=m["name"],
|
|
53
|
+
provider=provider,
|
|
54
|
+
aliases=m.get("aliases", []),
|
|
55
|
+
status=status,
|
|
56
|
+
input_cost_per_million=m["input_cost_per_million"],
|
|
57
|
+
output_cost_per_million=m["output_cost_per_million"],
|
|
58
|
+
cache_read_input_cost_per_million=m.get("cache_read_input_cost_per_million"),
|
|
59
|
+
cache_write_input_cost_per_million=m.get("cache_write_input_cost_per_million"),
|
|
60
|
+
batch_input_cost_per_million=m.get("batch_input_cost_per_million"),
|
|
61
|
+
batch_output_cost_per_million=m.get("batch_output_cost_per_million"),
|
|
62
|
+
fast_input_cost_per_million=m.get("fast_input_cost_per_million"),
|
|
63
|
+
fast_output_cost_per_million=m.get("fast_output_cost_per_million"),
|
|
64
|
+
input_cost_above_200k_per_million=m.get("input_cost_above_200k_per_million"),
|
|
65
|
+
output_cost_above_200k_per_million=m.get("output_cost_above_200k_per_million"),
|
|
66
|
+
context_window=m["context_window"],
|
|
67
|
+
max_output_tokens=m["max_output_tokens"],
|
|
68
|
+
supports_vision=m.get("supports_vision", False),
|
|
69
|
+
supports_tools=m.get("supports_tools", False),
|
|
70
|
+
supports_prompt_caching=m.get("supports_prompt_caching", False),
|
|
71
|
+
supports_reasoning=m.get("supports_reasoning", False),
|
|
72
|
+
supports_computer_use=m.get("supports_computer_use", False),
|
|
73
|
+
tier=tier,
|
|
74
|
+
capabilities=caps,
|
|
75
|
+
knowledge_cutoff=m.get("knowledge_cutoff"),
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
return models
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _ensure_loaded() -> None:
|
|
82
|
+
"""Lazy-load all provider data."""
|
|
83
|
+
if _PROVIDER_DATA:
|
|
84
|
+
return
|
|
85
|
+
for p in Provider:
|
|
86
|
+
_PROVIDER_DATA[p] = _load_provider_json(p)
|
|
87
|
+
# Load benchmarks
|
|
88
|
+
data_dir = files("tokenable.providers") / "data"
|
|
89
|
+
raw = (data_dir / "benchmarks.json").read_text("utf-8")
|
|
90
|
+
data = json.loads(raw)
|
|
91
|
+
_BENCHMARKS.update(data.get("models", {}))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# ── Public API ───────────────────────────────────────────────────────
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_all_providers() -> list[Provider]:
|
|
98
|
+
"""Get all supported providers."""
|
|
99
|
+
return list(Provider)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_provider_models(provider: Provider) -> list[ModelPricing]:
|
|
103
|
+
"""Get all models for a provider."""
|
|
104
|
+
_ensure_loaded()
|
|
105
|
+
return _PROVIDER_DATA.get(provider, [])
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def normalize_model_id(model_id: str) -> str:
|
|
109
|
+
"""Strip platform/routing prefixes and version suffixes from model IDs."""
|
|
110
|
+
mid = model_id
|
|
111
|
+
mid = re.sub(r"^openrouter/", "", mid)
|
|
112
|
+
mid = re.sub(r"^(aws|gcp|azure)/(anthropic|openai|google|meta|mistralai|cohere|ai21)/", "", mid)
|
|
113
|
+
mid = re.sub(r"^(bedrock/|azure/|vertex_ai/|azure_ai/)", "", mid)
|
|
114
|
+
mid = re.sub(
|
|
115
|
+
r"^(models/|gemini/|xai/|openai/|perplexity/|anthropic/|google/|meta-llama/|meta/|mistralai/)",
|
|
116
|
+
"",
|
|
117
|
+
mid,
|
|
118
|
+
)
|
|
119
|
+
mid = re.sub(r"^bedrock-", "", mid)
|
|
120
|
+
mid = re.sub(r"^(anthropic|amazon|meta|cohere|ai21|mistral|stability)\.", "", mid)
|
|
121
|
+
mid = re.sub(r"-v\d+:\d+$", "", mid)
|
|
122
|
+
return mid
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_model(provider: Provider, model_id: str) -> ModelPricing | None:
|
|
126
|
+
"""Look up a model by canonical ID or alias."""
|
|
127
|
+
_ensure_loaded()
|
|
128
|
+
models = _PROVIDER_DATA.get(provider, [])
|
|
129
|
+
|
|
130
|
+
# Exact match on ID
|
|
131
|
+
for m in models:
|
|
132
|
+
if m.id == model_id:
|
|
133
|
+
return m
|
|
134
|
+
|
|
135
|
+
# Exact match on alias
|
|
136
|
+
for m in models:
|
|
137
|
+
if model_id in m.aliases:
|
|
138
|
+
return m
|
|
139
|
+
|
|
140
|
+
# Normalized match
|
|
141
|
+
normalized = normalize_model_id(model_id)
|
|
142
|
+
if normalized != model_id:
|
|
143
|
+
for m in models:
|
|
144
|
+
if m.id == normalized:
|
|
145
|
+
return m
|
|
146
|
+
for m in models:
|
|
147
|
+
if normalized in m.aliases:
|
|
148
|
+
return m
|
|
149
|
+
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def get_all_models() -> list[ModelPricing]:
|
|
154
|
+
"""Get all models across all providers."""
|
|
155
|
+
_ensure_loaded()
|
|
156
|
+
result: list[ModelPricing] = []
|
|
157
|
+
for models in _PROVIDER_DATA.values():
|
|
158
|
+
result.extend(models)
|
|
159
|
+
return result
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def calculate_cost(
|
|
163
|
+
model: ModelPricing,
|
|
164
|
+
input_tokens: int,
|
|
165
|
+
output_tokens: int,
|
|
166
|
+
cached_input_tokens: int = 0,
|
|
167
|
+
use_batch: bool = False,
|
|
168
|
+
use_fast: bool = False,
|
|
169
|
+
is_long_context: bool | None = None,
|
|
170
|
+
) -> float:
|
|
171
|
+
"""Calculate total USD cost for a request."""
|
|
172
|
+
if is_long_context is None:
|
|
173
|
+
is_long_context = input_tokens > 200_000
|
|
174
|
+
|
|
175
|
+
effective_cached = min(cached_input_tokens, input_tokens)
|
|
176
|
+
uncached_input = max(0, input_tokens - effective_cached)
|
|
177
|
+
|
|
178
|
+
# Input rate
|
|
179
|
+
if use_fast and model.fast_input_cost_per_million is not None:
|
|
180
|
+
input_rate = model.fast_input_cost_per_million
|
|
181
|
+
elif use_batch and model.batch_input_cost_per_million is not None:
|
|
182
|
+
input_rate = model.batch_input_cost_per_million
|
|
183
|
+
elif is_long_context and model.input_cost_above_200k_per_million is not None:
|
|
184
|
+
input_rate = model.input_cost_above_200k_per_million
|
|
185
|
+
else:
|
|
186
|
+
input_rate = model.input_cost_per_million
|
|
187
|
+
|
|
188
|
+
# Output rate
|
|
189
|
+
if use_fast and model.fast_output_cost_per_million is not None:
|
|
190
|
+
output_rate = model.fast_output_cost_per_million
|
|
191
|
+
elif use_batch and model.batch_output_cost_per_million is not None:
|
|
192
|
+
output_rate = model.batch_output_cost_per_million
|
|
193
|
+
elif is_long_context and model.output_cost_above_200k_per_million is not None:
|
|
194
|
+
output_rate = model.output_cost_above_200k_per_million
|
|
195
|
+
else:
|
|
196
|
+
output_rate = model.output_cost_per_million
|
|
197
|
+
|
|
198
|
+
input_cost = (uncached_input / 1_000_000) * input_rate
|
|
199
|
+
cached_cost = (
|
|
200
|
+
(effective_cached / 1_000_000) * model.cache_read_input_cost_per_million
|
|
201
|
+
if effective_cached > 0 and model.cache_read_input_cost_per_million is not None
|
|
202
|
+
else 0.0
|
|
203
|
+
)
|
|
204
|
+
output_cost = (output_tokens / 1_000_000) * output_rate
|
|
205
|
+
|
|
206
|
+
return input_cost + cached_cost + output_cost
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# ── Capability inference ─────────────────────────────────────────────
|
|
210
|
+
|
|
211
|
+
_CAPABILITY_KEYWORDS: list[tuple[Capability, re.Pattern[str]]] = [
|
|
212
|
+
(
|
|
213
|
+
Capability.CODE,
|
|
214
|
+
re.compile(
|
|
215
|
+
r"\b(code|coding|debug|refactor|function|typescript|python|javascript|programming|syntax|compile|lint|regex|sql|api endpoint)\b",
|
|
216
|
+
re.IGNORECASE,
|
|
217
|
+
),
|
|
218
|
+
),
|
|
219
|
+
(
|
|
220
|
+
Capability.REASONING,
|
|
221
|
+
re.compile(
|
|
222
|
+
r"\b(step[- ]by[- ]step|analyze|reason|think carefully|evaluate|logic|math|proof|calculate|deduce|chain of thought)\b",
|
|
223
|
+
re.IGNORECASE,
|
|
224
|
+
),
|
|
225
|
+
),
|
|
226
|
+
(
|
|
227
|
+
Capability.CREATIVE,
|
|
228
|
+
re.compile(
|
|
229
|
+
r"\b(story|creative|poem|narrative|fiction|blog post|essay|copywriting)\b",
|
|
230
|
+
re.IGNORECASE,
|
|
231
|
+
),
|
|
232
|
+
),
|
|
233
|
+
(
|
|
234
|
+
Capability.VISION,
|
|
235
|
+
re.compile(
|
|
236
|
+
r"\b(image|screenshot|photo|picture|diagram|chart|visual|ocr|describe the image)\b",
|
|
237
|
+
re.IGNORECASE,
|
|
238
|
+
),
|
|
239
|
+
),
|
|
240
|
+
(
|
|
241
|
+
Capability.SEARCH,
|
|
242
|
+
re.compile(
|
|
243
|
+
r"\b(search|find information|look up|latest news|real[- ]time)\b", re.IGNORECASE
|
|
244
|
+
),
|
|
245
|
+
),
|
|
246
|
+
(
|
|
247
|
+
Capability.AUDIO,
|
|
248
|
+
re.compile(r"\b(audio|transcribe|speech|voice|listen)\b", re.IGNORECASE),
|
|
249
|
+
),
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def infer_required_capabilities(text: str) -> list[Capability]:
|
|
254
|
+
"""Infer required capabilities from free-form text."""
|
|
255
|
+
if not text.strip():
|
|
256
|
+
return [Capability.GENERAL]
|
|
257
|
+
matched: set[Capability] = set()
|
|
258
|
+
for cap, pattern in _CAPABILITY_KEYWORDS:
|
|
259
|
+
if pattern.search(text):
|
|
260
|
+
matched.add(cap)
|
|
261
|
+
return list(matched) if matched else [Capability.GENERAL]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def get_models_by_capabilities(
|
|
265
|
+
required: list[Capability],
|
|
266
|
+
provider: Provider | None = None,
|
|
267
|
+
status: ModelStatus = ModelStatus.CURRENT,
|
|
268
|
+
max_output_cost_per_million: float | None = None,
|
|
269
|
+
) -> list[ModelPricing]:
|
|
270
|
+
"""Get all models that support ALL specified capabilities."""
|
|
271
|
+
_ensure_loaded()
|
|
272
|
+
source = get_provider_models(provider) if provider else get_all_models()
|
|
273
|
+
result = []
|
|
274
|
+
for m in source:
|
|
275
|
+
if m.status != status:
|
|
276
|
+
continue
|
|
277
|
+
if (
|
|
278
|
+
max_output_cost_per_million is not None
|
|
279
|
+
and m.output_cost_per_million > max_output_cost_per_million
|
|
280
|
+
):
|
|
281
|
+
continue
|
|
282
|
+
if all(cap in m.capabilities for cap in required):
|
|
283
|
+
result.append(m)
|
|
284
|
+
result.sort(key=lambda m: m.output_cost_per_million)
|
|
285
|
+
return result
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# ── Benchmarks ───────────────────────────────────────────────────────
|
|
289
|
+
|
|
290
|
+
_CAP_TO_BENCHMARK: dict[Capability, str] = {
|
|
291
|
+
Capability.CODE: "coding",
|
|
292
|
+
Capability.REASONING: "reasoning",
|
|
293
|
+
Capability.GENERAL: "overall",
|
|
294
|
+
Capability.CREATIVE: "creative_writing",
|
|
295
|
+
Capability.VISION: "overall",
|
|
296
|
+
Capability.SEARCH: "overall",
|
|
297
|
+
Capability.AUDIO: "overall",
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def get_model_benchmarks(provider: Provider, model_id: str) -> dict | None:
|
|
302
|
+
"""Get benchmark scores for a model."""
|
|
303
|
+
_ensure_loaded()
|
|
304
|
+
key = f"{provider.value}/{model_id}"
|
|
305
|
+
if key in _BENCHMARKS:
|
|
306
|
+
return _BENCHMARKS[key]
|
|
307
|
+
normalized = normalize_model_id(model_id)
|
|
308
|
+
if normalized != model_id:
|
|
309
|
+
norm_key = f"{provider.value}/{normalized}"
|
|
310
|
+
if norm_key in _BENCHMARKS:
|
|
311
|
+
return _BENCHMARKS[norm_key]
|
|
312
|
+
model = get_model(provider, model_id)
|
|
313
|
+
if model and model.id != model_id:
|
|
314
|
+
canon_key = f"{provider.value}/{model.id}"
|
|
315
|
+
if canon_key in _BENCHMARKS:
|
|
316
|
+
return _BENCHMARKS[canon_key]
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def get_quality_score(provider: Provider, model_id: str, capability: Capability) -> float | None:
|
|
321
|
+
"""Get quality score for a model on a specific capability."""
|
|
322
|
+
bench = get_model_benchmarks(provider, model_id)
|
|
323
|
+
if not bench:
|
|
324
|
+
return None
|
|
325
|
+
field = _CAP_TO_BENCHMARK.get(capability, "overall")
|
|
326
|
+
score = bench.get(field)
|
|
327
|
+
return float(score) if isinstance(score, (int, float)) else bench.get("overall")
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def get_min_quality_score(
|
|
331
|
+
provider: Provider, model_id: str, capabilities: list[Capability]
|
|
332
|
+
) -> float | None:
|
|
333
|
+
"""Get minimum quality score across multiple capabilities."""
|
|
334
|
+
scores: list[float] = []
|
|
335
|
+
for cap in capabilities:
|
|
336
|
+
score = get_quality_score(provider, model_id, cap)
|
|
337
|
+
if score is None:
|
|
338
|
+
return None
|
|
339
|
+
scores.append(score)
|
|
340
|
+
return min(scores) if scores else None
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# ── Model suggestion ─────────────────────────────────────────────────
|
|
344
|
+
|
|
345
|
+
_TIER_ORDER: dict[ModelTier, int] = {ModelTier.BUDGET: 0, ModelTier.MID: 1, ModelTier.PREMIUM: 2}
|
|
346
|
+
_DEFAULT_MIN_QUALITY_RATIO = 0.7
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
class AlternativeSuggestion:
|
|
350
|
+
"""A suggested alternative model."""
|
|
351
|
+
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
model: ModelPricing,
|
|
355
|
+
reasoning: str,
|
|
356
|
+
savings_percent: int,
|
|
357
|
+
quality_score: float | None = None,
|
|
358
|
+
current_quality_score: float | None = None,
|
|
359
|
+
):
|
|
360
|
+
self.model = model
|
|
361
|
+
self.reasoning = reasoning
|
|
362
|
+
self.savings_percent = savings_percent
|
|
363
|
+
self.quality_score = quality_score
|
|
364
|
+
self.current_quality_score = current_quality_score
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def suggest_alternatives(
|
|
368
|
+
current_model_id: str,
|
|
369
|
+
current_provider: Provider,
|
|
370
|
+
required_capabilities: list[Capability],
|
|
371
|
+
confidence: str = "medium",
|
|
372
|
+
min_quality_ratio: float = _DEFAULT_MIN_QUALITY_RATIO,
|
|
373
|
+
) -> list[AlternativeSuggestion]:
|
|
374
|
+
"""Suggest cheaper alternatives that match required capabilities."""
|
|
375
|
+
current = get_model(current_provider, current_model_id)
|
|
376
|
+
if not current:
|
|
377
|
+
return []
|
|
378
|
+
|
|
379
|
+
current_quality = get_min_quality_score(
|
|
380
|
+
current_provider, current_model_id, required_capabilities
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
candidates = get_models_by_capabilities(required_capabilities)
|
|
384
|
+
candidates = [
|
|
385
|
+
m
|
|
386
|
+
for m in candidates
|
|
387
|
+
if not (m.id == current.id and m.provider == current.provider)
|
|
388
|
+
and m.output_cost_per_million < current.output_cost_per_million
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
# Filter by confidence
|
|
392
|
+
filtered = []
|
|
393
|
+
for m in candidates:
|
|
394
|
+
current_tier = _TIER_ORDER[current.tier]
|
|
395
|
+
candidate_tier = _TIER_ORDER[m.tier]
|
|
396
|
+
tier_drop = current_tier - candidate_tier
|
|
397
|
+
|
|
398
|
+
if confidence == "low" and (m.provider != current_provider or tier_drop > 0):
|
|
399
|
+
continue
|
|
400
|
+
if confidence == "medium" and tier_drop > 1:
|
|
401
|
+
continue
|
|
402
|
+
# Quality gate
|
|
403
|
+
if current_quality is not None:
|
|
404
|
+
cq = get_min_quality_score(m.provider, m.id, required_capabilities)
|
|
405
|
+
if cq is not None and cq < current_quality * min_quality_ratio:
|
|
406
|
+
continue
|
|
407
|
+
filtered.append(m)
|
|
408
|
+
|
|
409
|
+
# Sort by quality-adjusted cost
|
|
410
|
+
ref_input, ref_output = 4096, 1024
|
|
411
|
+
current_cost = calculate_cost(current, ref_input, ref_output)
|
|
412
|
+
|
|
413
|
+
def sort_key(m: ModelPricing) -> float:
|
|
414
|
+
cost = calculate_cost(m, ref_input, ref_output)
|
|
415
|
+
q = get_min_quality_score(m.provider, m.id, required_capabilities)
|
|
416
|
+
return cost / (q / 100) if q else cost
|
|
417
|
+
|
|
418
|
+
filtered.sort(key=sort_key)
|
|
419
|
+
|
|
420
|
+
results: list[AlternativeSuggestion] = []
|
|
421
|
+
for m in filtered[:3]:
|
|
422
|
+
candidate_cost = calculate_cost(m, ref_input, ref_output)
|
|
423
|
+
savings = (
|
|
424
|
+
round(((current_cost - candidate_cost) / current_cost) * 100) if current_cost > 0 else 0
|
|
425
|
+
)
|
|
426
|
+
caps = f"[{', '.join(c.value for c in required_capabilities)}]"
|
|
427
|
+
cq = get_min_quality_score(m.provider, m.id, required_capabilities)
|
|
428
|
+
quality_suffix = (
|
|
429
|
+
f" (quality: {cq} vs {current_quality})"
|
|
430
|
+
if cq is not None and current_quality is not None
|
|
431
|
+
else ""
|
|
432
|
+
)
|
|
433
|
+
if m.provider == current_provider:
|
|
434
|
+
reasoning = f"Task requires {caps} — {m.name} handles that at {savings}% lower cost{quality_suffix}"
|
|
435
|
+
else:
|
|
436
|
+
reasoning = f"Task requires {caps} — {m.provider.value}/{m.name} handles that at {savings}% lower cost{quality_suffix}"
|
|
437
|
+
results.append(
|
|
438
|
+
AlternativeSuggestion(
|
|
439
|
+
model=m,
|
|
440
|
+
reasoning=reasoning,
|
|
441
|
+
savings_percent=savings,
|
|
442
|
+
quality_score=cq,
|
|
443
|
+
current_quality_score=current_quality,
|
|
444
|
+
)
|
|
445
|
+
)
|
|
446
|
+
return results
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class TaskSuggestion:
|
|
450
|
+
"""Result of a task-based model suggestion."""
|
|
451
|
+
|
|
452
|
+
def __init__(
|
|
453
|
+
self, model: ModelPricing, inferred_capabilities: list[Capability], reasoning: str
|
|
454
|
+
):
|
|
455
|
+
self.model = model
|
|
456
|
+
self.inferred_capabilities = inferred_capabilities
|
|
457
|
+
self.reasoning = reasoning
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def suggest_model_for_task(
|
|
461
|
+
text: str,
|
|
462
|
+
provider: Provider | None = None,
|
|
463
|
+
max_cost_per_million: float | None = None,
|
|
464
|
+
) -> TaskSuggestion | None:
|
|
465
|
+
"""Suggest the best value model for a free-form task description."""
|
|
466
|
+
capabilities = infer_required_capabilities(text)
|
|
467
|
+
candidates = get_models_by_capabilities(
|
|
468
|
+
capabilities,
|
|
469
|
+
provider=provider,
|
|
470
|
+
max_output_cost_per_million=max_cost_per_million,
|
|
471
|
+
)
|
|
472
|
+
if not candidates:
|
|
473
|
+
return None
|
|
474
|
+
|
|
475
|
+
def sort_key(m: ModelPricing) -> float:
|
|
476
|
+
q = get_min_quality_score(m.provider, m.id, capabilities)
|
|
477
|
+
return m.output_cost_per_million / (q / 100) if q else m.output_cost_per_million
|
|
478
|
+
|
|
479
|
+
sorted_candidates = sorted(candidates, key=sort_key)
|
|
480
|
+
best = sorted_candidates[0]
|
|
481
|
+
caps = f"[{', '.join(c.value for c in capabilities)}]"
|
|
482
|
+
quality = get_min_quality_score(best.provider, best.id, capabilities)
|
|
483
|
+
quality_suffix = f" — quality: {quality}/100" if quality is not None else ""
|
|
484
|
+
reasoning = f"Inferred capabilities: {caps} — {best.provider.value}/{best.name} (${best.output_cost_per_million}/M output) is the best value model{quality_suffix}"
|
|
485
|
+
return TaskSuggestion(model=best, inferred_capabilities=capabilities, reasoning=reasoning)
|