arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,255 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import os
|
|
3
|
-
import re
|
|
4
|
-
from collections import defaultdict
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import Any, Iterator, Optional, Union
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass
|
|
11
|
-
class ModelTokenCost:
|
|
12
|
-
# Cost in USD
|
|
13
|
-
input: Optional[float] = None
|
|
14
|
-
output: Optional[float] = None
|
|
15
|
-
cache_write: Optional[float] = None
|
|
16
|
-
cache_read: Optional[float] = None
|
|
17
|
-
audio: Optional[float] = None
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class RegexDict:
|
|
21
|
-
__slots__ = ("_entries",)
|
|
22
|
-
|
|
23
|
-
def __init__(self) -> None:
|
|
24
|
-
self._entries: list[tuple[re.Pattern[str], Any]] = []
|
|
25
|
-
|
|
26
|
-
def __setitem__(self, pattern: Union[str, re.Pattern[str]], value: Any) -> None:
|
|
27
|
-
if isinstance(pattern, str):
|
|
28
|
-
compiled = re.compile(pattern)
|
|
29
|
-
elif isinstance(pattern, re.Pattern):
|
|
30
|
-
compiled = pattern
|
|
31
|
-
else:
|
|
32
|
-
raise TypeError("RegexDict key must be a str or re.Pattern")
|
|
33
|
-
|
|
34
|
-
for idx, (existing_pat, _) in enumerate(self._entries):
|
|
35
|
-
if existing_pat.pattern == compiled.pattern and existing_pat.flags == compiled.flags:
|
|
36
|
-
self._entries[idx] = (compiled, value)
|
|
37
|
-
return
|
|
38
|
-
self._entries.append((compiled, value))
|
|
39
|
-
|
|
40
|
-
def __delitem__(self, pattern: Union[str, re.Pattern[str]]) -> None:
|
|
41
|
-
if isinstance(pattern, str):
|
|
42
|
-
target = pattern
|
|
43
|
-
elif isinstance(pattern, re.Pattern):
|
|
44
|
-
target = pattern.pattern
|
|
45
|
-
else:
|
|
46
|
-
raise TypeError("RegexDict key must be a str or re.Pattern")
|
|
47
|
-
|
|
48
|
-
for idx, (existing_pat, _) in enumerate(self._entries):
|
|
49
|
-
if existing_pat.pattern == target:
|
|
50
|
-
del self._entries[idx]
|
|
51
|
-
return
|
|
52
|
-
raise KeyError(pattern)
|
|
53
|
-
|
|
54
|
-
def __getitem__(self, key: str) -> Any:
|
|
55
|
-
for pattern, value in self._entries:
|
|
56
|
-
if pattern.fullmatch(key):
|
|
57
|
-
return value
|
|
58
|
-
raise KeyError(key)
|
|
59
|
-
|
|
60
|
-
def __contains__(self, key: str) -> bool:
|
|
61
|
-
try:
|
|
62
|
-
_ = self[key]
|
|
63
|
-
return True
|
|
64
|
-
except KeyError:
|
|
65
|
-
return False
|
|
66
|
-
|
|
67
|
-
def __iter__(self) -> Iterator[tuple[str, Any]]:
|
|
68
|
-
for pattern, value in self._entries:
|
|
69
|
-
yield pattern.pattern, value
|
|
70
|
-
|
|
71
|
-
def __len__(self) -> int:
|
|
72
|
-
return len(self._entries)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
class ModelCostLookup:
|
|
76
|
-
__slots__ = ("_provider_model_map", "_model_map", "_overrides", "_cache", "_max_cache_size")
|
|
77
|
-
|
|
78
|
-
def __init__(self) -> None:
|
|
79
|
-
# Each provider maps to a *RegexDict* of (pattern -> cost).
|
|
80
|
-
self._provider_model_map: defaultdict[Optional[str], RegexDict] = defaultdict(RegexDict)
|
|
81
|
-
# Map from *pattern string* to a set of providers that have that pattern.
|
|
82
|
-
self._model_map: defaultdict[re.Pattern[str], set[Optional[str]]] = defaultdict(set)
|
|
83
|
-
# A prioritized list of cost overrides (later overrides have higher priority).
|
|
84
|
-
self._overrides: list[tuple[Optional[str], re.Pattern[str], ModelTokenCost]] = []
|
|
85
|
-
# Cache for computed costs keyed by (provider, model_name).
|
|
86
|
-
self._cache: dict[tuple[Optional[str], str], list[tuple[str, ModelTokenCost]]] = {}
|
|
87
|
-
self._max_cache_size = 100
|
|
88
|
-
|
|
89
|
-
def add_pattern(
|
|
90
|
-
self, provider: Optional[str], pattern: re.Pattern[str], cost: ModelTokenCost
|
|
91
|
-
) -> None:
|
|
92
|
-
"""Register a model pattern with its cost."""
|
|
93
|
-
|
|
94
|
-
assert isinstance(pattern, re.Pattern), "pattern must be a compiled regex"
|
|
95
|
-
self._provider_model_map[provider][pattern] = cost
|
|
96
|
-
self._model_map[pattern].add(provider)
|
|
97
|
-
self._cache.clear()
|
|
98
|
-
|
|
99
|
-
def remove_pattern(self, provider: Optional[str], pattern: re.Pattern[str]) -> None:
|
|
100
|
-
"""Remove a previously-registered model pattern."""
|
|
101
|
-
|
|
102
|
-
assert isinstance(pattern, re.Pattern), "pattern must be a compiled regex"
|
|
103
|
-
if provider not in self._provider_model_map:
|
|
104
|
-
return
|
|
105
|
-
del self._provider_model_map[provider][pattern]
|
|
106
|
-
self._model_map[pattern].discard(provider)
|
|
107
|
-
if not self._provider_model_map[provider]:
|
|
108
|
-
del self._provider_model_map[provider]
|
|
109
|
-
if not self._model_map[pattern]:
|
|
110
|
-
del self._model_map[pattern]
|
|
111
|
-
self._cache.clear()
|
|
112
|
-
|
|
113
|
-
def get_cost(
|
|
114
|
-
self, provider: Optional[str], model_name: str
|
|
115
|
-
) -> list[tuple[str, ModelTokenCost]]:
|
|
116
|
-
key = (provider, model_name)
|
|
117
|
-
if key in self._cache:
|
|
118
|
-
value = self._cache.pop(key)
|
|
119
|
-
self._cache[key] = value
|
|
120
|
-
return value
|
|
121
|
-
|
|
122
|
-
result = self._lookup_cost(provider, model_name)
|
|
123
|
-
|
|
124
|
-
if len(self._cache) >= self._max_cache_size:
|
|
125
|
-
self._cache.pop(next(iter(self._cache)))
|
|
126
|
-
|
|
127
|
-
self._cache[key] = result
|
|
128
|
-
return result
|
|
129
|
-
|
|
130
|
-
def has_model(self, provider: Optional[str], model_name: str) -> bool:
|
|
131
|
-
"""Return ``True`` if a cost (either base or overridden) exists for the model."""
|
|
132
|
-
|
|
133
|
-
return self._contains(provider, model_name)
|
|
134
|
-
|
|
135
|
-
def pattern_count(self) -> int:
|
|
136
|
-
"""Return the number of registered *base* patterns (overrides not counted)."""
|
|
137
|
-
|
|
138
|
-
return sum(len(regex_dict) for regex_dict in self._provider_model_map.values())
|
|
139
|
-
|
|
140
|
-
def _lookup_cost(
|
|
141
|
-
self, provider: Optional[str], model_name: str
|
|
142
|
-
) -> list[tuple[str, ModelTokenCost]]:
|
|
143
|
-
assert isinstance(model_name, str), "Lookup key must be a str"
|
|
144
|
-
# 1) Provider-specific lookup
|
|
145
|
-
if provider is not None:
|
|
146
|
-
override_cost = self._lookup_override(provider, model_name)
|
|
147
|
-
if override_cost is not None:
|
|
148
|
-
return [(provider, override_cost)]
|
|
149
|
-
|
|
150
|
-
regex_dict = self._provider_model_map.get(provider)
|
|
151
|
-
if regex_dict is None:
|
|
152
|
-
raise KeyError(provider)
|
|
153
|
-
return [(provider, regex_dict[model_name])]
|
|
154
|
-
|
|
155
|
-
# 2) provider-agnostic lookup
|
|
156
|
-
provider_cost_map: dict[str, ModelTokenCost] = {}
|
|
157
|
-
for p, regex_dict in self._provider_model_map.items():
|
|
158
|
-
try:
|
|
159
|
-
provider_cost_map[p] = regex_dict[model_name] # type: ignore
|
|
160
|
-
except KeyError:
|
|
161
|
-
continue
|
|
162
|
-
|
|
163
|
-
for override_provider, override_pattern, override_cost in self._overrides:
|
|
164
|
-
if override_pattern.fullmatch(model_name):
|
|
165
|
-
if override_provider is None:
|
|
166
|
-
for p in list(provider_cost_map):
|
|
167
|
-
provider_cost_map[p] = override_cost
|
|
168
|
-
else:
|
|
169
|
-
provider_cost_map[override_provider] = override_cost
|
|
170
|
-
|
|
171
|
-
if not provider_cost_map:
|
|
172
|
-
raise KeyError(model_name)
|
|
173
|
-
return list(provider_cost_map.items())
|
|
174
|
-
|
|
175
|
-
def _contains(self, provider: Optional[str], model_name: str) -> bool:
|
|
176
|
-
if provider is None:
|
|
177
|
-
if any(pat.fullmatch(model_name) for _, pat, _ in self._overrides):
|
|
178
|
-
return True
|
|
179
|
-
return any(model_name in regex_dict for regex_dict in self._provider_model_map.values())
|
|
180
|
-
|
|
181
|
-
if self._lookup_override(provider, model_name) is not None:
|
|
182
|
-
return True
|
|
183
|
-
|
|
184
|
-
regex_dict = self._provider_model_map.get(provider)
|
|
185
|
-
if not regex_dict:
|
|
186
|
-
return False
|
|
187
|
-
return model_name in regex_dict
|
|
188
|
-
|
|
189
|
-
def add_override(
|
|
190
|
-
self, provider: Optional[str], pattern: re.Pattern[str], cost: ModelTokenCost
|
|
191
|
-
) -> None:
|
|
192
|
-
"""Register a *prioritized* cost override.
|
|
193
|
-
|
|
194
|
-
Overrides are evaluated in the order in which they are added (LIFO).
|
|
195
|
-
"""
|
|
196
|
-
|
|
197
|
-
if not isinstance(pattern, re.Pattern):
|
|
198
|
-
raise TypeError("pattern must be a compiled regex")
|
|
199
|
-
self._overrides.append((provider, pattern, cost))
|
|
200
|
-
self._cache.clear()
|
|
201
|
-
|
|
202
|
-
def _lookup_override(
|
|
203
|
-
self, provider: Optional[str], model_name: str
|
|
204
|
-
) -> Optional[ModelTokenCost]:
|
|
205
|
-
"""Return the cost from the highest-priority override that matches, or *None*."""
|
|
206
|
-
|
|
207
|
-
for override_provider, override_pattern, override_cost in reversed(self._overrides):
|
|
208
|
-
provider_matches = override_provider is None or override_provider == provider
|
|
209
|
-
if provider_matches and override_pattern.fullmatch(model_name):
|
|
210
|
-
return override_cost
|
|
211
|
-
return None
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def create_cost_table(
|
|
215
|
-
manifest_path: Optional[Union[str, "os.PathLike[str]"]] = None,
|
|
216
|
-
) -> "ModelCostLookup":
|
|
217
|
-
if manifest_path is None:
|
|
218
|
-
manifest_path = Path(__file__).with_name("model_cost_manifest.json")
|
|
219
|
-
|
|
220
|
-
manifest_path = Path(manifest_path)
|
|
221
|
-
|
|
222
|
-
if not manifest_path.exists():
|
|
223
|
-
raise FileNotFoundError(f"Model cost manifest not found: {manifest_path}")
|
|
224
|
-
|
|
225
|
-
with manifest_path.open("r", encoding="utf-8") as fp:
|
|
226
|
-
try:
|
|
227
|
-
manifest_entries: list[dict[str, Any]] = json.load(fp)
|
|
228
|
-
except json.JSONDecodeError as exc:
|
|
229
|
-
raise ValueError(f"Failed to parse manifest JSON: {manifest_path}") from exc
|
|
230
|
-
|
|
231
|
-
lookup = ModelCostLookup()
|
|
232
|
-
|
|
233
|
-
for entry in manifest_entries:
|
|
234
|
-
provider: Optional[str] = entry.get("provider")
|
|
235
|
-
|
|
236
|
-
try:
|
|
237
|
-
pattern = re.compile(entry["regex"])
|
|
238
|
-
except re.error as exc:
|
|
239
|
-
raise ValueError(
|
|
240
|
-
f"Invalid regex in manifest for model {entry.get('model')}: {entry['regex']}"
|
|
241
|
-
) from exc
|
|
242
|
-
|
|
243
|
-
cost = ModelTokenCost(
|
|
244
|
-
input=entry.get("input"),
|
|
245
|
-
output=entry.get("output"),
|
|
246
|
-
cache_write=entry.get("cache_write"),
|
|
247
|
-
cache_read=entry.get("cache_read"),
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
lookup.add_pattern(provider, pattern, cost)
|
|
251
|
-
|
|
252
|
-
return lookup
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
COST_TABLE = create_cost_table()
|