tokenable 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,485 @@
1
+ """Pricing database — model lookup, cost calculation, capability inference."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import re
7
+ from importlib.resources import files
8
+ from typing import TYPE_CHECKING
9
+
10
+ from tokenable.models import (
11
+ Capability,
12
+ ModelPricing,
13
+ ModelStatus,
14
+ ModelTier,
15
+ Provider,
16
+ )
17
+
18
+ # ── Tier thresholds ──────────────────────────────────────────────────
19
+
20
+ _TIER_BUDGET_MAX = 5.0
21
+ _TIER_PREMIUM_MIN = 20.0
22
+
23
+
24
+ def compute_tier(output_cost_per_million: float) -> ModelTier:
25
+ """Derive cost tier from output pricing."""
26
+ if output_cost_per_million <= _TIER_BUDGET_MAX:
27
+ return ModelTier.BUDGET
28
+ if output_cost_per_million < _TIER_PREMIUM_MIN:
29
+ return ModelTier.MID
30
+ return ModelTier.PREMIUM
31
+
32
+
33
+ # ── Data loading ─────────────────────────────────────────────────────
34
+
35
+ _PROVIDER_DATA: dict[Provider, list[ModelPricing]] = {}
36
+ _BENCHMARKS: dict[str, dict] = {}
37
+
38
+
39
+ def _load_provider_json(provider: Provider) -> list[ModelPricing]:
40
+ """Load and parse a provider's pricing JSON."""
41
+ data_dir = files("tokenable.providers") / "data"
42
+ raw = (data_dir / f"{provider.value}.json").read_text("utf-8")
43
+ data = json.loads(raw)
44
+ models: list[ModelPricing] = []
45
+ for m in data.get("models", []):
46
+ caps = [Capability(c) for c in m.get("capabilities", [])]
47
+ status = ModelStatus(m.get("status", "current"))
48
+ tier = compute_tier(m.get("output_cost_per_million", 0))
49
+ models.append(
50
+ ModelPricing(
51
+ id=m["id"],
52
+ name=m["name"],
53
+ provider=provider,
54
+ aliases=m.get("aliases", []),
55
+ status=status,
56
+ input_cost_per_million=m["input_cost_per_million"],
57
+ output_cost_per_million=m["output_cost_per_million"],
58
+ cache_read_input_cost_per_million=m.get("cache_read_input_cost_per_million"),
59
+ cache_write_input_cost_per_million=m.get("cache_write_input_cost_per_million"),
60
+ batch_input_cost_per_million=m.get("batch_input_cost_per_million"),
61
+ batch_output_cost_per_million=m.get("batch_output_cost_per_million"),
62
+ fast_input_cost_per_million=m.get("fast_input_cost_per_million"),
63
+ fast_output_cost_per_million=m.get("fast_output_cost_per_million"),
64
+ input_cost_above_200k_per_million=m.get("input_cost_above_200k_per_million"),
65
+ output_cost_above_200k_per_million=m.get("output_cost_above_200k_per_million"),
66
+ context_window=m["context_window"],
67
+ max_output_tokens=m["max_output_tokens"],
68
+ supports_vision=m.get("supports_vision", False),
69
+ supports_tools=m.get("supports_tools", False),
70
+ supports_prompt_caching=m.get("supports_prompt_caching", False),
71
+ supports_reasoning=m.get("supports_reasoning", False),
72
+ supports_computer_use=m.get("supports_computer_use", False),
73
+ tier=tier,
74
+ capabilities=caps,
75
+ knowledge_cutoff=m.get("knowledge_cutoff"),
76
+ )
77
+ )
78
+ return models
79
+
80
+
81
+ def _ensure_loaded() -> None:
82
+ """Lazy-load all provider data."""
83
+ if _PROVIDER_DATA:
84
+ return
85
+ for p in Provider:
86
+ _PROVIDER_DATA[p] = _load_provider_json(p)
87
+ # Load benchmarks
88
+ data_dir = files("tokenable.providers") / "data"
89
+ raw = (data_dir / "benchmarks.json").read_text("utf-8")
90
+ data = json.loads(raw)
91
+ _BENCHMARKS.update(data.get("models", {}))
92
+
93
+
94
+ # ── Public API ───────────────────────────────────────────────────────
95
+
96
+
97
+ def get_all_providers() -> list[Provider]:
98
+ """Get all supported providers."""
99
+ return list(Provider)
100
+
101
+
102
+ def get_provider_models(provider: Provider) -> list[ModelPricing]:
103
+ """Get all models for a provider."""
104
+ _ensure_loaded()
105
+ return _PROVIDER_DATA.get(provider, [])
106
+
107
+
108
+ def normalize_model_id(model_id: str) -> str:
109
+ """Strip platform/routing prefixes and version suffixes from model IDs."""
110
+ mid = model_id
111
+ mid = re.sub(r"^openrouter/", "", mid)
112
+ mid = re.sub(r"^(aws|gcp|azure)/(anthropic|openai|google|meta|mistralai|cohere|ai21)/", "", mid)
113
+ mid = re.sub(r"^(bedrock/|azure/|vertex_ai/|azure_ai/)", "", mid)
114
+ mid = re.sub(
115
+ r"^(models/|gemini/|xai/|openai/|perplexity/|anthropic/|google/|meta-llama/|meta/|mistralai/)",
116
+ "",
117
+ mid,
118
+ )
119
+ mid = re.sub(r"^bedrock-", "", mid)
120
+ mid = re.sub(r"^(anthropic|amazon|meta|cohere|ai21|mistral|stability)\.", "", mid)
121
+ mid = re.sub(r"-v\d+:\d+$", "", mid)
122
+ return mid
123
+
124
+
125
+ def get_model(provider: Provider, model_id: str) -> ModelPricing | None:
126
+ """Look up a model by canonical ID or alias."""
127
+ _ensure_loaded()
128
+ models = _PROVIDER_DATA.get(provider, [])
129
+
130
+ # Exact match on ID
131
+ for m in models:
132
+ if m.id == model_id:
133
+ return m
134
+
135
+ # Exact match on alias
136
+ for m in models:
137
+ if model_id in m.aliases:
138
+ return m
139
+
140
+ # Normalized match
141
+ normalized = normalize_model_id(model_id)
142
+ if normalized != model_id:
143
+ for m in models:
144
+ if m.id == normalized:
145
+ return m
146
+ for m in models:
147
+ if normalized in m.aliases:
148
+ return m
149
+
150
+ return None
151
+
152
+
153
+ def get_all_models() -> list[ModelPricing]:
154
+ """Get all models across all providers."""
155
+ _ensure_loaded()
156
+ result: list[ModelPricing] = []
157
+ for models in _PROVIDER_DATA.values():
158
+ result.extend(models)
159
+ return result
160
+
161
+
162
+ def calculate_cost(
163
+ model: ModelPricing,
164
+ input_tokens: int,
165
+ output_tokens: int,
166
+ cached_input_tokens: int = 0,
167
+ use_batch: bool = False,
168
+ use_fast: bool = False,
169
+ is_long_context: bool | None = None,
170
+ ) -> float:
171
+ """Calculate total USD cost for a request."""
172
+ if is_long_context is None:
173
+ is_long_context = input_tokens > 200_000
174
+
175
+ effective_cached = min(cached_input_tokens, input_tokens)
176
+ uncached_input = max(0, input_tokens - effective_cached)
177
+
178
+ # Input rate
179
+ if use_fast and model.fast_input_cost_per_million is not None:
180
+ input_rate = model.fast_input_cost_per_million
181
+ elif use_batch and model.batch_input_cost_per_million is not None:
182
+ input_rate = model.batch_input_cost_per_million
183
+ elif is_long_context and model.input_cost_above_200k_per_million is not None:
184
+ input_rate = model.input_cost_above_200k_per_million
185
+ else:
186
+ input_rate = model.input_cost_per_million
187
+
188
+ # Output rate
189
+ if use_fast and model.fast_output_cost_per_million is not None:
190
+ output_rate = model.fast_output_cost_per_million
191
+ elif use_batch and model.batch_output_cost_per_million is not None:
192
+ output_rate = model.batch_output_cost_per_million
193
+ elif is_long_context and model.output_cost_above_200k_per_million is not None:
194
+ output_rate = model.output_cost_above_200k_per_million
195
+ else:
196
+ output_rate = model.output_cost_per_million
197
+
198
+ input_cost = (uncached_input / 1_000_000) * input_rate
199
+ cached_cost = (
200
+ (effective_cached / 1_000_000) * model.cache_read_input_cost_per_million
201
+ if effective_cached > 0 and model.cache_read_input_cost_per_million is not None
202
+ else 0.0
203
+ )
204
+ output_cost = (output_tokens / 1_000_000) * output_rate
205
+
206
+ return input_cost + cached_cost + output_cost
207
+
208
+
209
+ # ── Capability inference ─────────────────────────────────────────────
210
+
211
+ _CAPABILITY_KEYWORDS: list[tuple[Capability, re.Pattern[str]]] = [
212
+ (
213
+ Capability.CODE,
214
+ re.compile(
215
+ r"\b(code|coding|debug|refactor|function|typescript|python|javascript|programming|syntax|compile|lint|regex|sql|api endpoint)\b",
216
+ re.IGNORECASE,
217
+ ),
218
+ ),
219
+ (
220
+ Capability.REASONING,
221
+ re.compile(
222
+ r"\b(step[- ]by[- ]step|analyze|reason|think carefully|evaluate|logic|math|proof|calculate|deduce|chain of thought)\b",
223
+ re.IGNORECASE,
224
+ ),
225
+ ),
226
+ (
227
+ Capability.CREATIVE,
228
+ re.compile(
229
+ r"\b(story|creative|poem|narrative|fiction|blog post|essay|copywriting)\b",
230
+ re.IGNORECASE,
231
+ ),
232
+ ),
233
+ (
234
+ Capability.VISION,
235
+ re.compile(
236
+ r"\b(image|screenshot|photo|picture|diagram|chart|visual|ocr|describe the image)\b",
237
+ re.IGNORECASE,
238
+ ),
239
+ ),
240
+ (
241
+ Capability.SEARCH,
242
+ re.compile(
243
+ r"\b(search|find information|look up|latest news|real[- ]time)\b", re.IGNORECASE
244
+ ),
245
+ ),
246
+ (
247
+ Capability.AUDIO,
248
+ re.compile(r"\b(audio|transcribe|speech|voice|listen)\b", re.IGNORECASE),
249
+ ),
250
+ ]
251
+
252
+
253
+ def infer_required_capabilities(text: str) -> list[Capability]:
254
+ """Infer required capabilities from free-form text."""
255
+ if not text.strip():
256
+ return [Capability.GENERAL]
257
+ matched: set[Capability] = set()
258
+ for cap, pattern in _CAPABILITY_KEYWORDS:
259
+ if pattern.search(text):
260
+ matched.add(cap)
261
+ return list(matched) if matched else [Capability.GENERAL]
262
+
263
+
264
+ def get_models_by_capabilities(
265
+ required: list[Capability],
266
+ provider: Provider | None = None,
267
+ status: ModelStatus = ModelStatus.CURRENT,
268
+ max_output_cost_per_million: float | None = None,
269
+ ) -> list[ModelPricing]:
270
+ """Get all models that support ALL specified capabilities."""
271
+ _ensure_loaded()
272
+ source = get_provider_models(provider) if provider else get_all_models()
273
+ result = []
274
+ for m in source:
275
+ if m.status != status:
276
+ continue
277
+ if (
278
+ max_output_cost_per_million is not None
279
+ and m.output_cost_per_million > max_output_cost_per_million
280
+ ):
281
+ continue
282
+ if all(cap in m.capabilities for cap in required):
283
+ result.append(m)
284
+ result.sort(key=lambda m: m.output_cost_per_million)
285
+ return result
286
+
287
+
288
+ # ── Benchmarks ───────────────────────────────────────────────────────
289
+
290
+ _CAP_TO_BENCHMARK: dict[Capability, str] = {
291
+ Capability.CODE: "coding",
292
+ Capability.REASONING: "reasoning",
293
+ Capability.GENERAL: "overall",
294
+ Capability.CREATIVE: "creative_writing",
295
+ Capability.VISION: "overall",
296
+ Capability.SEARCH: "overall",
297
+ Capability.AUDIO: "overall",
298
+ }
299
+
300
+
301
+ def get_model_benchmarks(provider: Provider, model_id: str) -> dict | None:
302
+ """Get benchmark scores for a model."""
303
+ _ensure_loaded()
304
+ key = f"{provider.value}/{model_id}"
305
+ if key in _BENCHMARKS:
306
+ return _BENCHMARKS[key]
307
+ normalized = normalize_model_id(model_id)
308
+ if normalized != model_id:
309
+ norm_key = f"{provider.value}/{normalized}"
310
+ if norm_key in _BENCHMARKS:
311
+ return _BENCHMARKS[norm_key]
312
+ model = get_model(provider, model_id)
313
+ if model and model.id != model_id:
314
+ canon_key = f"{provider.value}/{model.id}"
315
+ if canon_key in _BENCHMARKS:
316
+ return _BENCHMARKS[canon_key]
317
+ return None
318
+
319
+
320
+ def get_quality_score(provider: Provider, model_id: str, capability: Capability) -> float | None:
321
+ """Get quality score for a model on a specific capability."""
322
+ bench = get_model_benchmarks(provider, model_id)
323
+ if not bench:
324
+ return None
325
+ field = _CAP_TO_BENCHMARK.get(capability, "overall")
326
+ score = bench.get(field)
327
+ return float(score) if isinstance(score, (int, float)) else bench.get("overall")
328
+
329
+
330
+ def get_min_quality_score(
331
+ provider: Provider, model_id: str, capabilities: list[Capability]
332
+ ) -> float | None:
333
+ """Get minimum quality score across multiple capabilities."""
334
+ scores: list[float] = []
335
+ for cap in capabilities:
336
+ score = get_quality_score(provider, model_id, cap)
337
+ if score is None:
338
+ return None
339
+ scores.append(score)
340
+ return min(scores) if scores else None
341
+
342
+
343
+ # ── Model suggestion ─────────────────────────────────────────────────
344
+
345
+ _TIER_ORDER: dict[ModelTier, int] = {ModelTier.BUDGET: 0, ModelTier.MID: 1, ModelTier.PREMIUM: 2}
346
+ _DEFAULT_MIN_QUALITY_RATIO = 0.7
347
+
348
+
349
+ class AlternativeSuggestion:
350
+ """A suggested alternative model."""
351
+
352
+ def __init__(
353
+ self,
354
+ model: ModelPricing,
355
+ reasoning: str,
356
+ savings_percent: int,
357
+ quality_score: float | None = None,
358
+ current_quality_score: float | None = None,
359
+ ):
360
+ self.model = model
361
+ self.reasoning = reasoning
362
+ self.savings_percent = savings_percent
363
+ self.quality_score = quality_score
364
+ self.current_quality_score = current_quality_score
365
+
366
+
367
+ def suggest_alternatives(
368
+ current_model_id: str,
369
+ current_provider: Provider,
370
+ required_capabilities: list[Capability],
371
+ confidence: str = "medium",
372
+ min_quality_ratio: float = _DEFAULT_MIN_QUALITY_RATIO,
373
+ ) -> list[AlternativeSuggestion]:
374
+ """Suggest cheaper alternatives that match required capabilities."""
375
+ current = get_model(current_provider, current_model_id)
376
+ if not current:
377
+ return []
378
+
379
+ current_quality = get_min_quality_score(
380
+ current_provider, current_model_id, required_capabilities
381
+ )
382
+
383
+ candidates = get_models_by_capabilities(required_capabilities)
384
+ candidates = [
385
+ m
386
+ for m in candidates
387
+ if not (m.id == current.id and m.provider == current.provider)
388
+ and m.output_cost_per_million < current.output_cost_per_million
389
+ ]
390
+
391
+ # Filter by confidence
392
+ filtered = []
393
+ for m in candidates:
394
+ current_tier = _TIER_ORDER[current.tier]
395
+ candidate_tier = _TIER_ORDER[m.tier]
396
+ tier_drop = current_tier - candidate_tier
397
+
398
+ if confidence == "low" and (m.provider != current_provider or tier_drop > 0):
399
+ continue
400
+ if confidence == "medium" and tier_drop > 1:
401
+ continue
402
+ # Quality gate
403
+ if current_quality is not None:
404
+ cq = get_min_quality_score(m.provider, m.id, required_capabilities)
405
+ if cq is not None and cq < current_quality * min_quality_ratio:
406
+ continue
407
+ filtered.append(m)
408
+
409
+ # Sort by quality-adjusted cost
410
+ ref_input, ref_output = 4096, 1024
411
+ current_cost = calculate_cost(current, ref_input, ref_output)
412
+
413
+ def sort_key(m: ModelPricing) -> float:
414
+ cost = calculate_cost(m, ref_input, ref_output)
415
+ q = get_min_quality_score(m.provider, m.id, required_capabilities)
416
+ return cost / (q / 100) if q else cost
417
+
418
+ filtered.sort(key=sort_key)
419
+
420
+ results: list[AlternativeSuggestion] = []
421
+ for m in filtered[:3]:
422
+ candidate_cost = calculate_cost(m, ref_input, ref_output)
423
+ savings = (
424
+ round(((current_cost - candidate_cost) / current_cost) * 100) if current_cost > 0 else 0
425
+ )
426
+ caps = f"[{', '.join(c.value for c in required_capabilities)}]"
427
+ cq = get_min_quality_score(m.provider, m.id, required_capabilities)
428
+ quality_suffix = (
429
+ f" (quality: {cq} vs {current_quality})"
430
+ if cq is not None and current_quality is not None
431
+ else ""
432
+ )
433
+ if m.provider == current_provider:
434
+ reasoning = f"Task requires {caps} — {m.name} handles that at {savings}% lower cost{quality_suffix}"
435
+ else:
436
+ reasoning = f"Task requires {caps} — {m.provider.value}/{m.name} handles that at {savings}% lower cost{quality_suffix}"
437
+ results.append(
438
+ AlternativeSuggestion(
439
+ model=m,
440
+ reasoning=reasoning,
441
+ savings_percent=savings,
442
+ quality_score=cq,
443
+ current_quality_score=current_quality,
444
+ )
445
+ )
446
+ return results
447
+
448
+
449
+ class TaskSuggestion:
450
+ """Result of a task-based model suggestion."""
451
+
452
+ def __init__(
453
+ self, model: ModelPricing, inferred_capabilities: list[Capability], reasoning: str
454
+ ):
455
+ self.model = model
456
+ self.inferred_capabilities = inferred_capabilities
457
+ self.reasoning = reasoning
458
+
459
+
460
+ def suggest_model_for_task(
461
+ text: str,
462
+ provider: Provider | None = None,
463
+ max_cost_per_million: float | None = None,
464
+ ) -> TaskSuggestion | None:
465
+ """Suggest the best value model for a free-form task description."""
466
+ capabilities = infer_required_capabilities(text)
467
+ candidates = get_models_by_capabilities(
468
+ capabilities,
469
+ provider=provider,
470
+ max_output_cost_per_million=max_cost_per_million,
471
+ )
472
+ if not candidates:
473
+ return None
474
+
475
+ def sort_key(m: ModelPricing) -> float:
476
+ q = get_min_quality_score(m.provider, m.id, capabilities)
477
+ return m.output_cost_per_million / (q / 100) if q else m.output_cost_per_million
478
+
479
+ sorted_candidates = sorted(candidates, key=sort_key)
480
+ best = sorted_candidates[0]
481
+ caps = f"[{', '.join(c.value for c in capabilities)}]"
482
+ quality = get_min_quality_score(best.provider, best.id, capabilities)
483
+ quality_suffix = f" — quality: {quality}/100" if quality is not None else ""
484
+ reasoning = f"Inferred capabilities: {caps} — {best.provider.value}/{best.name} (${best.output_cost_per_million}/M output) is the best value model{quality_suffix}"
485
+ return TaskSuggestion(model=best, inferred_capabilities=capabilities, reasoning=reasoning)