themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. themis/cli/__init__.py +5 -0
  2. themis/cli/__main__.py +6 -0
  3. themis/cli/commands/__init__.py +19 -0
  4. themis/cli/commands/benchmarks.py +221 -0
  5. themis/cli/commands/comparison.py +394 -0
  6. themis/cli/commands/config_commands.py +244 -0
  7. themis/cli/commands/cost.py +214 -0
  8. themis/cli/commands/demo.py +68 -0
  9. themis/cli/commands/info.py +90 -0
  10. themis/cli/commands/leaderboard.py +362 -0
  11. themis/cli/commands/math_benchmarks.py +318 -0
  12. themis/cli/commands/mcq_benchmarks.py +207 -0
  13. themis/cli/commands/sample_run.py +244 -0
  14. themis/cli/commands/visualize.py +299 -0
  15. themis/cli/main.py +93 -0
  16. themis/cli/new_project.py +33 -0
  17. themis/cli/utils.py +51 -0
  18. themis/config/__init__.py +19 -0
  19. themis/config/loader.py +27 -0
  20. themis/config/registry.py +34 -0
  21. themis/config/runtime.py +214 -0
  22. themis/config/schema.py +112 -0
  23. themis/core/__init__.py +5 -0
  24. themis/core/conversation.py +354 -0
  25. themis/core/entities.py +164 -0
  26. themis/core/serialization.py +231 -0
  27. themis/core/tools.py +393 -0
  28. themis/core/types.py +141 -0
  29. themis/datasets/__init__.py +273 -0
  30. themis/datasets/base.py +264 -0
  31. themis/datasets/commonsense_qa.py +174 -0
  32. themis/datasets/competition_math.py +265 -0
  33. themis/datasets/coqa.py +133 -0
  34. themis/datasets/gpqa.py +190 -0
  35. themis/datasets/gsm8k.py +123 -0
  36. themis/datasets/gsm_symbolic.py +124 -0
  37. themis/datasets/math500.py +122 -0
  38. themis/datasets/med_qa.py +179 -0
  39. themis/datasets/medmcqa.py +169 -0
  40. themis/datasets/mmlu_pro.py +262 -0
  41. themis/datasets/piqa.py +146 -0
  42. themis/datasets/registry.py +201 -0
  43. themis/datasets/schema.py +245 -0
  44. themis/datasets/sciq.py +150 -0
  45. themis/datasets/social_i_qa.py +151 -0
  46. themis/datasets/super_gpqa.py +263 -0
  47. themis/evaluation/__init__.py +1 -0
  48. themis/evaluation/conditional.py +410 -0
  49. themis/evaluation/extractors/__init__.py +19 -0
  50. themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
  51. themis/evaluation/extractors/exceptions.py +7 -0
  52. themis/evaluation/extractors/identity_extractor.py +29 -0
  53. themis/evaluation/extractors/json_field_extractor.py +45 -0
  54. themis/evaluation/extractors/math_verify_extractor.py +37 -0
  55. themis/evaluation/extractors/regex_extractor.py +43 -0
  56. themis/evaluation/math_verify_utils.py +87 -0
  57. themis/evaluation/metrics/__init__.py +21 -0
  58. themis/evaluation/metrics/composite_metric.py +47 -0
  59. themis/evaluation/metrics/consistency_metric.py +80 -0
  60. themis/evaluation/metrics/exact_match.py +51 -0
  61. themis/evaluation/metrics/length_difference_tolerance.py +33 -0
  62. themis/evaluation/metrics/math_verify_accuracy.py +40 -0
  63. themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
  64. themis/evaluation/metrics/response_length.py +33 -0
  65. themis/evaluation/metrics/rubric_judge_metric.py +134 -0
  66. themis/evaluation/pipeline.py +49 -0
  67. themis/evaluation/pipelines/__init__.py +15 -0
  68. themis/evaluation/pipelines/composable_pipeline.py +357 -0
  69. themis/evaluation/pipelines/standard_pipeline.py +288 -0
  70. themis/evaluation/reports.py +293 -0
  71. themis/evaluation/statistics/__init__.py +53 -0
  72. themis/evaluation/statistics/bootstrap.py +79 -0
  73. themis/evaluation/statistics/confidence_intervals.py +121 -0
  74. themis/evaluation/statistics/distributions.py +207 -0
  75. themis/evaluation/statistics/effect_sizes.py +124 -0
  76. themis/evaluation/statistics/hypothesis_tests.py +305 -0
  77. themis/evaluation/statistics/types.py +139 -0
  78. themis/evaluation/strategies/__init__.py +13 -0
  79. themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
  80. themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
  81. themis/evaluation/strategies/evaluation_strategy.py +24 -0
  82. themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
  83. themis/experiment/__init__.py +5 -0
  84. themis/experiment/builder.py +151 -0
  85. themis/experiment/cache_manager.py +129 -0
  86. themis/experiment/comparison.py +631 -0
  87. themis/experiment/cost.py +310 -0
  88. themis/experiment/definitions.py +62 -0
  89. themis/experiment/export.py +690 -0
  90. themis/experiment/export_csv.py +159 -0
  91. themis/experiment/integration_manager.py +104 -0
  92. themis/experiment/math.py +192 -0
  93. themis/experiment/mcq.py +169 -0
  94. themis/experiment/orchestrator.py +373 -0
  95. themis/experiment/pricing.py +317 -0
  96. themis/experiment/storage.py +255 -0
  97. themis/experiment/visualization.py +588 -0
  98. themis/generation/__init__.py +1 -0
  99. themis/generation/agentic_runner.py +420 -0
  100. themis/generation/batching.py +254 -0
  101. themis/generation/clients.py +143 -0
  102. themis/generation/conversation_runner.py +236 -0
  103. themis/generation/plan.py +456 -0
  104. themis/generation/providers/litellm_provider.py +221 -0
  105. themis/generation/providers/vllm_provider.py +135 -0
  106. themis/generation/router.py +34 -0
  107. themis/generation/runner.py +207 -0
  108. themis/generation/strategies.py +98 -0
  109. themis/generation/templates.py +71 -0
  110. themis/generation/turn_strategies.py +393 -0
  111. themis/generation/types.py +9 -0
  112. themis/integrations/__init__.py +0 -0
  113. themis/integrations/huggingface.py +61 -0
  114. themis/integrations/wandb.py +65 -0
  115. themis/interfaces/__init__.py +83 -0
  116. themis/project/__init__.py +20 -0
  117. themis/project/definitions.py +98 -0
  118. themis/project/patterns.py +230 -0
  119. themis/providers/__init__.py +5 -0
  120. themis/providers/registry.py +39 -0
  121. themis/utils/api_generator.py +379 -0
  122. themis/utils/cost_tracking.py +376 -0
  123. themis/utils/dashboard.py +452 -0
  124. themis/utils/logging_utils.py +41 -0
  125. themis/utils/progress.py +58 -0
  126. themis/utils/tracing.py +320 -0
  127. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
  128. themis_eval-0.1.1.dist-info/RECORD +134 -0
  129. themis_eval-0.1.0.dist-info/RECORD +0 -8
  130. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
  131. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
  132. {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ """Experiment orchestrator primitives."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from datetime import datetime, timezone
6
+ from typing import Callable, Sequence
7
+
8
+ from themis.config.schema import IntegrationsConfig
9
+ from themis.core.entities import (
10
+ EvaluationRecord,
11
+ ExperimentFailure,
12
+ ExperimentReport,
13
+ GenerationRecord,
14
+ GenerationTask,
15
+ MetricScore,
16
+ )
17
+ from themis.evaluation import pipeline as evaluation_pipeline
18
+ from themis.evaluation.reports import EvaluationFailure
19
+ from themis.experiment import storage as experiment_storage
20
+ from themis.experiment.cache_manager import CacheManager
21
+ from themis.experiment.cost import CostTracker
22
+ from themis.experiment.integration_manager import IntegrationManager
23
+ from themis.experiment.pricing import calculate_cost, get_provider_pricing
24
+ from themis.generation import plan as generation_plan
25
+ from themis.generation import runner as generation_runner
26
+
27
+
28
+ class ExperimentOrchestrator:
29
+ """Orchestrates experiment execution: generation → evaluation → reporting.
30
+
31
+ This class coordinates the experiment workflow using focused managers:
32
+ - CacheManager: Handles storage and resumability
33
+ - IntegrationManager: Handles WandB and HuggingFace Hub
34
+
35
+ Single Responsibility: Orchestration of experiment flow
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ *,
41
+ generation_plan: generation_plan.GenerationPlan,
42
+ generation_runner: generation_runner.GenerationRunner,
43
+ evaluation_pipeline: evaluation_pipeline.EvaluationPipeline,
44
+ storage: experiment_storage.ExperimentStorage | None = None,
45
+ integrations_config: IntegrationsConfig | None = None,
46
+ cache_manager: CacheManager | None = None,
47
+ integration_manager: IntegrationManager | None = None,
48
+ ) -> None:
49
+ """Initialize experiment orchestrator.
50
+
51
+ Args:
52
+ generation_plan: Plan for expanding dataset into tasks
53
+ generation_runner: Runner for executing generation tasks
54
+ evaluation_pipeline: Pipeline for evaluating outputs
55
+ storage: Optional storage backend (deprecated, use cache_manager)
56
+ integrations_config: Integration config (deprecated, use integration_manager)
57
+ cache_manager: Manager for caching and resumability
58
+ integration_manager: Manager for external integrations
59
+ """
60
+ self._plan = generation_plan
61
+ self._runner = generation_runner
62
+ self._evaluation = evaluation_pipeline
63
+
64
+ # Support both new managers and legacy direct parameters for backward compatibility
65
+ self._cache = cache_manager or CacheManager(
66
+ storage=storage,
67
+ enable_resume=True,
68
+ enable_cache=True,
69
+ )
70
+ self._integrations = integration_manager or IntegrationManager(
71
+ config=integrations_config or IntegrationsConfig()
72
+ )
73
+
74
+ # Initialize cost tracker
75
+ self._cost_tracker = CostTracker()
76
+
77
+ # Keep legacy references for backward compatibility
78
+ self._storage = storage
79
+
80
+ def run(
81
+ self,
82
+ dataset: Sequence[dict[str, object]] | None = None,
83
+ *,
84
+ dataset_loader: Callable[[], Sequence[dict[str, object]]] | None = None,
85
+ max_samples: int | None = None,
86
+ run_id: str | None = None,
87
+ resume: bool = True,
88
+ cache_results: bool = True,
89
+ on_result: Callable[[GenerationRecord], None] | None = None,
90
+ ) -> ExperimentReport:
91
+ """Run experiment: generate responses, evaluate, and report results.
92
+
93
+ Args:
94
+ dataset: Optional dataset samples to use
95
+ dataset_loader: Optional callable to load dataset
96
+ max_samples: Optional limit on number of samples
97
+ run_id: Optional run identifier for caching
98
+ resume: Whether to resume from cached results
99
+ cache_results: Whether to cache new results
100
+ on_result: Optional callback for each generation result
101
+
102
+ Returns:
103
+ ExperimentReport with generation results, evaluation, and metadata
104
+ """
105
+ # Initialize integrations
106
+ self._integrations.initialize_run(
107
+ {
108
+ "max_samples": max_samples,
109
+ "run_id": run_id,
110
+ "resume": resume,
111
+ }
112
+ )
113
+
114
+ # Prepare dataset
115
+ dataset_list = self._resolve_dataset(
116
+ dataset=dataset, dataset_loader=dataset_loader, run_id=run_id
117
+ )
118
+ selected_dataset = (
119
+ dataset_list[:max_samples] if max_samples is not None else dataset_list
120
+ )
121
+ run_identifier = run_id or self._default_run_id()
122
+
123
+ # Cache dataset for resumability
124
+ if dataset_list:
125
+ self._cache.cache_dataset(run_identifier, dataset_list)
126
+
127
+ # Expand dataset into generation tasks
128
+ tasks = list(self._plan.expand(selected_dataset))
129
+
130
+ # Load cached results if resuming
131
+ cached_records = (
132
+ self._cache.load_cached_records(run_identifier) if resume else {}
133
+ )
134
+ cached_evaluations = (
135
+ self._cache.load_cached_evaluations(run_identifier) if resume else {}
136
+ )
137
+
138
+ # Process tasks: use cached or run new generations
139
+ generation_results: list[GenerationRecord] = []
140
+ failures: list[ExperimentFailure] = []
141
+ pending_tasks: list[GenerationTask] = []
142
+ pending_records: list[GenerationRecord] = []
143
+ pending_keys: list[str] = []
144
+ cached_eval_records: list[EvaluationRecord] = []
145
+
146
+ for task in tasks:
147
+ cache_key = experiment_storage.task_cache_key(task)
148
+ cached = cached_records.get(cache_key)
149
+ if cached is not None:
150
+ generation_results.append(cached)
151
+ if cached.error:
152
+ failures.append(
153
+ ExperimentFailure(
154
+ sample_id=cached.task.metadata.get("dataset_id"),
155
+ message=cached.error.message,
156
+ )
157
+ )
158
+ evaluation = cached_evaluations.get(cache_key)
159
+ if evaluation is not None:
160
+ cached_eval_records.append(evaluation)
161
+ else:
162
+ pending_records.append(cached)
163
+ pending_keys.append(cache_key)
164
+ if on_result:
165
+ on_result(cached)
166
+ else:
167
+ pending_tasks.append(task)
168
+
169
+ # Run pending generation tasks
170
+ if pending_tasks:
171
+ for record in self._runner.run(pending_tasks):
172
+ generation_results.append(record)
173
+
174
+ # Track cost for successful generations
175
+ if record.output and record.output.usage:
176
+ usage = record.output.usage
177
+ prompt_tokens = usage.get("prompt_tokens", 0)
178
+ completion_tokens = usage.get("completion_tokens", 0)
179
+ model = record.task.model.identifier
180
+
181
+ # Calculate cost using pricing database
182
+ cost = calculate_cost(model, prompt_tokens, completion_tokens)
183
+ self._cost_tracker.record_generation(
184
+ model=model,
185
+ prompt_tokens=prompt_tokens,
186
+ completion_tokens=completion_tokens,
187
+ cost=cost,
188
+ )
189
+
190
+ if record.error:
191
+ failures.append(
192
+ ExperimentFailure(
193
+ sample_id=record.task.metadata.get("dataset_id"),
194
+ message=record.error.message,
195
+ )
196
+ )
197
+ cache_key = experiment_storage.task_cache_key(record.task)
198
+ if cache_results:
199
+ self._cache.save_generation_record(
200
+ run_identifier, record, cache_key
201
+ )
202
+ pending_records.append(record)
203
+ pending_keys.append(cache_key)
204
+ if on_result:
205
+ on_result(record)
206
+
207
+ # Evaluate pending records
208
+ if pending_records:
209
+ new_evaluation_report = self._evaluation.evaluate(pending_records)
210
+ else:
211
+ new_evaluation_report = evaluation_pipeline.EvaluationReport(
212
+ metrics={}, failures=[], records=[]
213
+ )
214
+
215
+ # Cache evaluation results
216
+ for record, evaluation in zip(pending_records, new_evaluation_report.records):
217
+ self._cache.save_evaluation_record(run_identifier, record, evaluation)
218
+
219
+ # Combine cached and new evaluations
220
+ evaluation_report = self._combine_evaluations(
221
+ cached_eval_records, new_evaluation_report
222
+ )
223
+
224
+ # Get cost breakdown
225
+ cost_breakdown = self._cost_tracker.get_breakdown()
226
+
227
+ # Build metadata
228
+ metadata = {
229
+ "total_samples": len(selected_dataset),
230
+ "successful_generations": sum(
231
+ 1 for result in generation_results if not result.error
232
+ ),
233
+ "failed_generations": sum(
234
+ 1 for result in generation_results if result.error
235
+ ),
236
+ "run_id": run_identifier,
237
+ "evaluation_failures": sum(
238
+ 1 for record in evaluation_report.records if record.failures
239
+ )
240
+ + len(evaluation_report.failures),
241
+ # Cost tracking
242
+ "cost": {
243
+ "total_cost": cost_breakdown.total_cost,
244
+ "generation_cost": cost_breakdown.generation_cost,
245
+ "evaluation_cost": cost_breakdown.evaluation_cost,
246
+ "currency": cost_breakdown.currency,
247
+ "token_counts": cost_breakdown.token_counts,
248
+ "api_calls": cost_breakdown.api_calls,
249
+ "per_model_costs": cost_breakdown.per_model_costs,
250
+ },
251
+ }
252
+
253
+ # Create final report
254
+ report = ExperimentReport(
255
+ generation_results=generation_results,
256
+ evaluation_report=evaluation_report,
257
+ failures=failures,
258
+ metadata=metadata,
259
+ )
260
+
261
+ # Log to integrations
262
+ self._integrations.log_results(report)
263
+
264
+ # Upload to HuggingFace Hub if enabled
265
+ run_path = self._cache.get_run_path(run_identifier)
266
+ self._integrations.upload_results(report, run_path)
267
+
268
+ # Save report.json for multi-experiment comparison
269
+ if cache_results:
270
+ self._save_report_json(report, run_identifier)
271
+
272
+ return report
273
+
274
+ def _default_run_id(self) -> str:
275
+ return datetime.now(timezone.utc).strftime("run-%Y%m%d-%H%M%S")
276
+
277
+ def _resolve_dataset(
278
+ self,
279
+ *,
280
+ dataset: Sequence[dict[str, object]] | None,
281
+ dataset_loader: Callable[[], Sequence[dict[str, object]]] | None,
282
+ run_id: str | None,
283
+ ) -> list[dict[str, object]]:
284
+ """Resolve dataset from various sources.
285
+
286
+ Args:
287
+ dataset: Direct dataset samples
288
+ dataset_loader: Callable to load dataset
289
+ run_id: Run ID to load cached dataset
290
+
291
+ Returns:
292
+ List of dataset samples
293
+
294
+ Raises:
295
+ ValueError: If no dataset source is available
296
+ """
297
+ if dataset is not None:
298
+ return list(dataset)
299
+ if dataset_loader is not None:
300
+ return list(dataset_loader())
301
+ # Try to load from cache (for backward compatibility, still use _storage directly)
302
+ if self._storage is not None and run_id is not None:
303
+ return self._storage.load_dataset(run_id)
304
+ raise ValueError(
305
+ "No dataset provided. Supply `dataset=` rows, a `dataset_loader`, "
306
+ "or set `run_id` with storage configured so cached data can be reloaded."
307
+ )
308
+
309
+ def _combine_evaluations(
310
+ self,
311
+ cached_records: list[EvaluationRecord],
312
+ new_report: evaluation_pipeline.EvaluationReport,
313
+ ) -> evaluation_pipeline.EvaluationReport:
314
+ all_records = list(cached_records) + list(new_report.records)
315
+ per_metric: dict[str, list[MetricScore]] = {}
316
+ for record in all_records:
317
+ for score in record.scores:
318
+ per_metric.setdefault(score.metric_name, []).append(score)
319
+
320
+ aggregates: dict[str, evaluation_pipeline.MetricAggregate] = {}
321
+ metric_names = set(per_metric.keys()) | set(new_report.metrics.keys())
322
+ for name in metric_names:
323
+ scores = per_metric.get(name, [])
324
+ mean = sum(score.value for score in scores) / len(scores) if scores else 0.0
325
+ aggregates[name] = evaluation_pipeline.MetricAggregate(
326
+ name=name,
327
+ count=len(scores),
328
+ mean=mean,
329
+ per_sample=scores,
330
+ )
331
+
332
+ failures = list(new_report.failures)
333
+ for record in cached_records:
334
+ for message in record.failures:
335
+ failures.append(
336
+ EvaluationFailure(sample_id=record.sample_id, message=message)
337
+ )
338
+
339
+ return evaluation_pipeline.EvaluationReport(
340
+ metrics=aggregates,
341
+ failures=failures,
342
+ records=all_records,
343
+ )
344
+
345
+ def _save_report_json(self, report: ExperimentReport, run_id: str) -> None:
346
+ """Save experiment report as JSON for multi-experiment comparison.
347
+
348
+ Args:
349
+ report: Experiment report to save
350
+ run_id: Run identifier
351
+ """
352
+ from pathlib import Path
353
+
354
+ from themis.experiment.export import build_json_report
355
+
356
+ # Get run path from cache manager
357
+ run_path_str = self._cache.get_run_path(run_id)
358
+ if run_path_str is None:
359
+ # No storage configured, skip saving report.json
360
+ return
361
+
362
+ run_path = Path(run_path_str)
363
+ report_path = run_path / "report.json"
364
+
365
+ # Build JSON report
366
+ json_data = build_json_report(report, title=f"Experiment {run_id}")
367
+
368
+ # Save to file
369
+ import json
370
+
371
+ report_path.parent.mkdir(parents=True, exist_ok=True)
372
+ with report_path.open("w", encoding="utf-8") as f:
373
+ json.dump(json_data, f, indent=2)
@@ -0,0 +1,317 @@
1
+ """Provider pricing database and cost calculation utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ # Pricing table for common LLM providers (prices per token in USD)
8
+ # Updated as of November 2024
9
+ PRICING_TABLE: dict[str, dict[str, float]] = {
10
+ # OpenAI models
11
+ "gpt-4": {
12
+ "prompt_tokens": 0.00003, # $30 per 1M tokens
13
+ "completion_tokens": 0.00006, # $60 per 1M tokens
14
+ },
15
+ "gpt-4-32k": {
16
+ "prompt_tokens": 0.00006,
17
+ "completion_tokens": 0.00012,
18
+ },
19
+ "gpt-4-turbo": {
20
+ "prompt_tokens": 0.00001, # $10 per 1M tokens
21
+ "completion_tokens": 0.00003, # $30 per 1M tokens
22
+ },
23
+ "gpt-4-turbo-preview": {
24
+ "prompt_tokens": 0.00001,
25
+ "completion_tokens": 0.00003,
26
+ },
27
+ "gpt-3.5-turbo": {
28
+ "prompt_tokens": 0.0000005, # $0.50 per 1M tokens
29
+ "completion_tokens": 0.0000015, # $1.50 per 1M tokens
30
+ },
31
+ "gpt-3.5-turbo-16k": {
32
+ "prompt_tokens": 0.000003,
33
+ "completion_tokens": 0.000004,
34
+ },
35
+ # Anthropic Claude models
36
+ "claude-3-5-sonnet-20241022": {
37
+ "prompt_tokens": 0.000003, # $3 per 1M tokens
38
+ "completion_tokens": 0.000015, # $15 per 1M tokens
39
+ },
40
+ "claude-3-opus-20240229": {
41
+ "prompt_tokens": 0.000015, # $15 per 1M tokens
42
+ "completion_tokens": 0.000075, # $75 per 1M tokens
43
+ },
44
+ "claude-3-sonnet-20240229": {
45
+ "prompt_tokens": 0.000003,
46
+ "completion_tokens": 0.000015,
47
+ },
48
+ "claude-3-haiku-20240307": {
49
+ "prompt_tokens": 0.00000025, # $0.25 per 1M tokens
50
+ "completion_tokens": 0.00000125, # $1.25 per 1M tokens
51
+ },
52
+ # Google models
53
+ "gemini-pro": {
54
+ "prompt_tokens": 0.00000025,
55
+ "completion_tokens": 0.0000005,
56
+ },
57
+ "gemini-1.5-pro": {
58
+ "prompt_tokens": 0.00000125, # $1.25 per 1M tokens
59
+ "completion_tokens": 0.000005, # $5 per 1M tokens
60
+ },
61
+ "gemini-1.5-flash": {
62
+ "prompt_tokens": 0.000000075, # $0.075 per 1M tokens
63
+ "completion_tokens": 0.0000003, # $0.30 per 1M tokens
64
+ },
65
+ # Mistral models
66
+ "mistral-large-latest": {
67
+ "prompt_tokens": 0.000002, # $2 per 1M tokens
68
+ "completion_tokens": 0.000006, # $6 per 1M tokens
69
+ },
70
+ "mistral-medium-latest": {
71
+ "prompt_tokens": 0.0000027,
72
+ "completion_tokens": 0.0000081,
73
+ },
74
+ "mistral-small-latest": {
75
+ "prompt_tokens": 0.000001,
76
+ "completion_tokens": 0.000003,
77
+ },
78
+ # Cohere models
79
+ "command-r-plus": {
80
+ "prompt_tokens": 0.000003,
81
+ "completion_tokens": 0.000015,
82
+ },
83
+ "command-r": {
84
+ "prompt_tokens": 0.0000005,
85
+ "completion_tokens": 0.0000015,
86
+ },
87
+ # Meta Llama (via various providers - using typical cloud pricing)
88
+ "llama-3.1-70b": {
89
+ "prompt_tokens": 0.00000088,
90
+ "completion_tokens": 0.00000088,
91
+ },
92
+ "llama-3.1-8b": {
93
+ "prompt_tokens": 0.0000002,
94
+ "completion_tokens": 0.0000002,
95
+ },
96
+ # Default fallback for unknown models
97
+ "default": {
98
+ "prompt_tokens": 0.000001,
99
+ "completion_tokens": 0.000002,
100
+ },
101
+ }
102
+
103
+ # Model aliases and variations
104
+ MODEL_ALIASES: dict[str, str] = {
105
+ # OpenAI aliases
106
+ "gpt-4-0613": "gpt-4",
107
+ "gpt-4-0314": "gpt-4",
108
+ "gpt-4-1106-preview": "gpt-4-turbo-preview",
109
+ "gpt-4-0125-preview": "gpt-4-turbo-preview",
110
+ "gpt-3.5-turbo-0613": "gpt-3.5-turbo",
111
+ "gpt-3.5-turbo-0301": "gpt-3.5-turbo",
112
+ "gpt-3.5-turbo-1106": "gpt-3.5-turbo",
113
+ # Anthropic aliases
114
+ "claude-3-opus": "claude-3-opus-20240229",
115
+ "claude-3-sonnet": "claude-3-sonnet-20240229",
116
+ "claude-3-haiku": "claude-3-haiku-20240307",
117
+ "claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
118
+ # Google aliases
119
+ "gemini-pro-1.0": "gemini-pro",
120
+ "gemini-1.5-pro-latest": "gemini-1.5-pro",
121
+ "gemini-1.5-flash-latest": "gemini-1.5-flash",
122
+ }
123
+
124
+
125
+ def normalize_model_name(model: str) -> str:
126
+ """Normalize model name to canonical form.
127
+
128
+ Args:
129
+ model: Model identifier (may include provider prefix)
130
+
131
+ Returns:
132
+ Normalized model name
133
+
134
+ Example:
135
+ >>> normalize_model_name("openai/gpt-4-0613")
136
+ 'gpt-4'
137
+ >>> normalize_model_name("claude-3-opus")
138
+ 'claude-3-opus-20240229'
139
+ """
140
+ # Remove provider prefix if present (e.g., "openai/gpt-4" -> "gpt-4")
141
+ if "/" in model:
142
+ model = model.split("/", 1)[1]
143
+
144
+ # Look up alias
145
+ model = MODEL_ALIASES.get(model, model)
146
+
147
+ return model
148
+
149
+
150
+ def get_provider_pricing(model: str) -> dict[str, float]:
151
+ """Get pricing for a model.
152
+
153
+ Args:
154
+ model: Model identifier
155
+
156
+ Returns:
157
+ Dict with 'prompt_tokens' and 'completion_tokens' prices per token
158
+
159
+ Example:
160
+ >>> pricing = get_provider_pricing("gpt-4")
161
+ >>> print(f"Prompt: ${pricing['prompt_tokens'] * 1_000_000:.2f}/1M tokens")
162
+ Prompt: $30.00/1M tokens
163
+ """
164
+ normalized = normalize_model_name(model)
165
+
166
+ # Check if we have pricing for this model
167
+ if normalized in PRICING_TABLE:
168
+ return PRICING_TABLE[normalized].copy()
169
+
170
+ # Try to find a partial match (e.g., "gpt-4-turbo-2024-04-09" matches "gpt-4-turbo")
171
+ for known_model in PRICING_TABLE:
172
+ if known_model in normalized or normalized.startswith(known_model):
173
+ return PRICING_TABLE[known_model].copy()
174
+
175
+ # Fallback to default pricing
176
+ return PRICING_TABLE["default"].copy()
177
+
178
+
179
+ def calculate_cost(
180
+ model: str,
181
+ prompt_tokens: int,
182
+ completion_tokens: int,
183
+ pricing: dict[str, float] | None = None,
184
+ ) -> float:
185
+ """Calculate cost for a model completion.
186
+
187
+ Args:
188
+ model: Model identifier
189
+ prompt_tokens: Number of prompt tokens
190
+ completion_tokens: Number of completion tokens
191
+ pricing: Optional custom pricing (if None, uses default pricing table)
192
+
193
+ Returns:
194
+ Total cost in USD
195
+
196
+ Example:
197
+ >>> cost = calculate_cost("gpt-4", 1000, 500)
198
+ >>> print(f"Cost: ${cost:.4f}")
199
+ Cost: $0.0600
200
+ """
201
+ if pricing is None:
202
+ pricing = get_provider_pricing(model)
203
+
204
+ prompt_cost = prompt_tokens * pricing["prompt_tokens"]
205
+ completion_cost = completion_tokens * pricing["completion_tokens"]
206
+
207
+ return prompt_cost + completion_cost
208
+
209
+
210
+ def compare_provider_costs(
211
+ prompt_tokens: int,
212
+ completion_tokens: int,
213
+ models: list[str],
214
+ ) -> dict[str, float]:
215
+ """Compare costs across multiple providers for same workload.
216
+
217
+ Args:
218
+ prompt_tokens: Number of prompt tokens
219
+ completion_tokens: Number of completion tokens
220
+ models: List of model identifiers to compare
221
+
222
+ Returns:
223
+ Dict mapping model names to costs
224
+
225
+ Example:
226
+ >>> costs = compare_provider_costs(
227
+ ... 1000, 500, ["gpt-4", "gpt-3.5-turbo", "claude-3-haiku"]
228
+ ... )
229
+ >>> for model, cost in sorted(costs.items(), key=lambda x: x[1]):
230
+ ... print(f"{model}: ${cost:.4f}")
231
+ claude-3-haiku: $0.0009
232
+ gpt-3.5-turbo: $0.0013
233
+ gpt-4: $0.0600
234
+ """
235
+ costs = {}
236
+ for model in models:
237
+ costs[model] = calculate_cost(model, prompt_tokens, completion_tokens)
238
+ return costs
239
+
240
+
241
+ def estimate_tokens(text: str, chars_per_token: float = 4.0) -> int:
242
+ """Estimate number of tokens from text.
243
+
244
+ This is a rough approximation. For accurate token counts,
245
+ use the model's tokenizer.
246
+
247
+ Args:
248
+ text: Input text
249
+ chars_per_token: Average characters per token (default: 4.0)
250
+
251
+ Returns:
252
+ Estimated token count
253
+
254
+ Example:
255
+ >>> text = "This is a sample text for token estimation."
256
+ >>> tokens = estimate_tokens(text)
257
+ >>> print(f"Estimated tokens: {tokens}")
258
+ Estimated tokens: 11
259
+ """
260
+ if not text:
261
+ return 0
262
+ return max(1, int(len(text) / chars_per_token))
263
+
264
+
265
+ def get_all_models() -> list[str]:
266
+ """Get list of all models with known pricing.
267
+
268
+ Returns:
269
+ List of model identifiers
270
+ """
271
+ return [k for k in PRICING_TABLE.keys() if k != "default"]
272
+
273
+
274
+ def get_pricing_summary() -> dict[str, Any]:
275
+ """Get summary of pricing for all models.
276
+
277
+ Returns:
278
+ Dict with model pricing information
279
+
280
+ Example:
281
+ >>> summary = get_pricing_summary()
282
+ >>> print(f"Total models: {summary['total_models']}")
283
+ >>> print(f"Cheapest: {summary['cheapest_model']}")
284
+ """
285
+ models = get_all_models()
286
+
287
+ # Find cheapest and most expensive (based on prompt + completion average)
288
+ model_avg_costs = {}
289
+ for model in models:
290
+ pricing = PRICING_TABLE[model]
291
+ avg_cost = (pricing["prompt_tokens"] + pricing["completion_tokens"]) / 2
292
+ model_avg_costs[model] = avg_cost
293
+
294
+ cheapest = min(model_avg_costs.items(), key=lambda x: x[1])
295
+ most_expensive = max(model_avg_costs.items(), key=lambda x: x[1])
296
+
297
+ return {
298
+ "total_models": len(models),
299
+ "cheapest_model": cheapest[0],
300
+ "cheapest_avg_cost_per_token": cheapest[1],
301
+ "most_expensive_model": most_expensive[0],
302
+ "most_expensive_avg_cost_per_token": most_expensive[1],
303
+ "models": models,
304
+ }
305
+
306
+
307
+ __all__ = [
308
+ "PRICING_TABLE",
309
+ "MODEL_ALIASES",
310
+ "normalize_model_name",
311
+ "get_provider_pricing",
312
+ "calculate_cost",
313
+ "compare_provider_costs",
314
+ "estimate_tokens",
315
+ "get_all_models",
316
+ "get_pricing_summary",
317
+ ]