arize-phoenix 10.15.0__py3-none-any.whl → 11.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +2 -2
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +77 -46
- phoenix/config.py +5 -2
- phoenix/datetime_utils.py +8 -1
- phoenix/db/bulk_inserter.py +40 -1
- phoenix/db/facilitator.py +263 -4
- phoenix/db/insertion/helpers.py +15 -0
- phoenix/db/insertion/span.py +3 -1
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/models.py +267 -9
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/server/api/context.py +38 -4
- phoenix/server/api/dataloaders/__init__.py +41 -5
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +35 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/helpers/playground_clients.py +103 -12
- phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
- phoenix/server/api/input_types/SpanSort.py +17 -0
- phoenix/server/api/mutations/__init__.py +2 -0
- phoenix/server/api/mutations/chat_mutations.py +17 -0
- phoenix/server/api/mutations/model_mutations.py +208 -0
- phoenix/server/api/queries.py +82 -41
- phoenix/server/api/routers/v1/traces.py +11 -4
- phoenix/server/api/subscriptions.py +36 -2
- phoenix/server/api/types/CostBreakdown.py +15 -0
- phoenix/server/api/types/Experiment.py +59 -1
- phoenix/server/api/types/ExperimentRun.py +58 -4
- phoenix/server/api/types/GenerativeModel.py +143 -2
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +11 -0
- phoenix/server/api/types/PlaygroundModel.py +10 -0
- phoenix/server/api/types/Project.py +42 -0
- phoenix/server/api/types/ProjectSession.py +44 -0
- phoenix/server/api/types/Span.py +137 -0
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +41 -0
- phoenix/server/app.py +59 -0
- phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/generative_model_store.py +51 -0
- phoenix/server/daemons/span_cost_calculator.py +103 -0
- phoenix/server/dml_event_handler.py +1 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
- phoenix/server/static/assets/{index-DIlhmbjB.js → index-S3YKLmbo.js} +13 -13
- phoenix/server/static/assets/{pages-YX47cEoQ.js → pages-BW6PBHZb.js} +811 -419
- phoenix/server/static/assets/{vendor-DCZoBorz.js → vendor-DqQvHbPa.js} +147 -147
- phoenix/server/static/assets/{vendor-arizeai-Ckci3irT.js → vendor-arizeai-CLX44PFA.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-BODM513D.js → vendor-codemirror-Du3XyJnB.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-C9O2a-N3.js → vendor-recharts-B2PJDrnX.js} +25 -25
- phoenix/server/static/assets/{vendor-shiki-Dq54rRC7.js → vendor-shiki-CNbrFjf9.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-SpUMF1qV.js +0 -4509
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-10.15.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/app.py
CHANGED
|
@@ -85,6 +85,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
85
85
|
ExperimentRunAnnotations,
|
|
86
86
|
ExperimentRunCountsDataLoader,
|
|
87
87
|
ExperimentSequenceNumberDataLoader,
|
|
88
|
+
LastUsedTimesByGenerativeModelIdDataLoader,
|
|
88
89
|
LatencyMsQuantileDataLoader,
|
|
89
90
|
MinStartOrMaxEndTimeDataLoader,
|
|
90
91
|
NumChildSpansDataLoader,
|
|
@@ -100,6 +101,18 @@ from phoenix.server.api.dataloaders import (
|
|
|
100
101
|
SessionTraceLatencyMsQuantileDataLoader,
|
|
101
102
|
SpanAnnotationsDataLoader,
|
|
102
103
|
SpanByIdDataLoader,
|
|
104
|
+
SpanCostBySpanDataLoader,
|
|
105
|
+
SpanCostDetailsBySpanCostDataLoader,
|
|
106
|
+
SpanCostDetailSummaryEntriesByGenerativeModelDataLoader,
|
|
107
|
+
SpanCostDetailSummaryEntriesByProjectSessionDataLoader,
|
|
108
|
+
SpanCostDetailSummaryEntriesBySpanDataLoader,
|
|
109
|
+
SpanCostDetailSummaryEntriesByTraceDataLoader,
|
|
110
|
+
SpanCostSummaryByExperimentDataLoader,
|
|
111
|
+
SpanCostSummaryByExperimentRunDataLoader,
|
|
112
|
+
SpanCostSummaryByGenerativeModelDataLoader,
|
|
113
|
+
SpanCostSummaryByProjectDataLoader,
|
|
114
|
+
SpanCostSummaryByProjectSessionDataLoader,
|
|
115
|
+
SpanCostSummaryByTraceDataLoader,
|
|
103
116
|
SpanDatasetExamplesDataLoader,
|
|
104
117
|
SpanDescendantsDataLoader,
|
|
105
118
|
SpanProjectsDataLoader,
|
|
@@ -120,6 +133,8 @@ from phoenix.server.api.routers import (
|
|
|
120
133
|
from phoenix.server.api.routers.v1 import REST_API_VERSION
|
|
121
134
|
from phoenix.server.api.schema import build_graphql_schema
|
|
122
135
|
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
|
|
136
|
+
from phoenix.server.daemons.generative_model_store import GenerativeModelStore
|
|
137
|
+
from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
|
|
123
138
|
from phoenix.server.dml_event import DmlEvent
|
|
124
139
|
from phoenix.server.dml_event_handler import DmlEventHandler
|
|
125
140
|
from phoenix.server.email.types import EmailSender
|
|
@@ -502,6 +517,8 @@ def _lifespan(
|
|
|
502
517
|
bulk_inserter: BulkInserter,
|
|
503
518
|
dml_event_handler: DmlEventHandler,
|
|
504
519
|
trace_data_sweeper: Optional[TraceDataSweeper],
|
|
520
|
+
span_cost_calculator: SpanCostCalculator,
|
|
521
|
+
generative_model_store: GenerativeModelStore,
|
|
505
522
|
token_store: Optional[TokenStore] = None,
|
|
506
523
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
507
524
|
enable_prometheus: bool = False,
|
|
@@ -536,6 +553,8 @@ def _lifespan(
|
|
|
536
553
|
await stack.enter_async_context(dml_event_handler)
|
|
537
554
|
if trace_data_sweeper:
|
|
538
555
|
await stack.enter_async_context(trace_data_sweeper)
|
|
556
|
+
await stack.enter_async_context(span_cost_calculator)
|
|
557
|
+
await stack.enter_async_context(generative_model_store)
|
|
539
558
|
if scaffolder_config:
|
|
540
559
|
scaffolder = Scaffolder(
|
|
541
560
|
config=scaffolder_config,
|
|
@@ -583,6 +602,7 @@ def create_graphql_router(
|
|
|
583
602
|
export_path: Path,
|
|
584
603
|
last_updated_at: CanGetLastUpdatedAt,
|
|
585
604
|
authentication_enabled: bool,
|
|
605
|
+
span_cost_calculator: SpanCostCalculator,
|
|
586
606
|
corpus: Optional[Model] = None,
|
|
587
607
|
cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
|
|
588
608
|
event_queue: CanPutItem[DmlEvent],
|
|
@@ -600,6 +620,7 @@ def create_graphql_router(
|
|
|
600
620
|
export_path (Path): the file path to export data to for download (legacy)
|
|
601
621
|
last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
|
|
602
622
|
authentication_enabled (bool): Whether authentication is enabled.
|
|
623
|
+
span_cost_calculator (SpanCostCalculator): The span cost calculator for calculating costs.
|
|
603
624
|
event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
|
|
604
625
|
corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
|
|
605
626
|
cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
|
|
@@ -645,6 +666,9 @@ def create_graphql_router(
|
|
|
645
666
|
experiment_run_annotations=ExperimentRunAnnotations(db),
|
|
646
667
|
experiment_run_counts=ExperimentRunCountsDataLoader(db),
|
|
647
668
|
experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
|
|
669
|
+
last_used_times_by_generative_model_id=LastUsedTimesByGenerativeModelIdDataLoader(
|
|
670
|
+
db
|
|
671
|
+
),
|
|
648
672
|
latency_ms_quantile=LatencyMsQuantileDataLoader(
|
|
649
673
|
db,
|
|
650
674
|
cache_map=(
|
|
@@ -679,6 +703,31 @@ def create_graphql_router(
|
|
|
679
703
|
span_annotations=SpanAnnotationsDataLoader(db),
|
|
680
704
|
span_fields=TableFieldsDataLoader(db, models.Span),
|
|
681
705
|
span_by_id=SpanByIdDataLoader(db),
|
|
706
|
+
span_cost_by_span=SpanCostBySpanDataLoader(db),
|
|
707
|
+
span_cost_detail_summary_entries_by_generative_model=SpanCostDetailSummaryEntriesByGenerativeModelDataLoader(
|
|
708
|
+
db
|
|
709
|
+
),
|
|
710
|
+
span_cost_detail_summary_entries_by_project_session=SpanCostDetailSummaryEntriesByProjectSessionDataLoader(
|
|
711
|
+
db
|
|
712
|
+
),
|
|
713
|
+
span_cost_detail_summary_entries_by_span=SpanCostDetailSummaryEntriesBySpanDataLoader(
|
|
714
|
+
db
|
|
715
|
+
),
|
|
716
|
+
span_cost_detail_summary_entries_by_trace=SpanCostDetailSummaryEntriesByTraceDataLoader(
|
|
717
|
+
db
|
|
718
|
+
),
|
|
719
|
+
span_cost_details_by_span_cost=SpanCostDetailsBySpanCostDataLoader(db),
|
|
720
|
+
span_cost_detail_fields=TableFieldsDataLoader(db, models.SpanCostDetail),
|
|
721
|
+
span_cost_fields=TableFieldsDataLoader(db, models.SpanCost),
|
|
722
|
+
span_cost_summary_by_generative_model=SpanCostSummaryByGenerativeModelDataLoader(
|
|
723
|
+
db
|
|
724
|
+
),
|
|
725
|
+
span_cost_summary_by_project=SpanCostSummaryByProjectDataLoader(
|
|
726
|
+
db,
|
|
727
|
+
cache_map=cache_for_dataloaders.token_cost if cache_for_dataloaders else None,
|
|
728
|
+
),
|
|
729
|
+
span_cost_summary_by_project_session=SpanCostSummaryByProjectSessionDataLoader(db),
|
|
730
|
+
span_cost_summary_by_trace=SpanCostSummaryByTraceDataLoader(db),
|
|
682
731
|
span_dataset_examples=SpanDatasetExamplesDataLoader(db),
|
|
683
732
|
span_descendants=SpanDescendantsDataLoader(db),
|
|
684
733
|
span_projects=SpanProjectsDataLoader(db),
|
|
@@ -698,6 +747,8 @@ def create_graphql_router(
|
|
|
698
747
|
project_by_name=ProjectByNameDataLoader(db),
|
|
699
748
|
users=UsersDataLoader(db),
|
|
700
749
|
user_roles=UserRolesDataLoader(db),
|
|
750
|
+
span_cost_summary_by_experiment=SpanCostSummaryByExperimentDataLoader(db),
|
|
751
|
+
span_cost_summary_by_experiment_run=SpanCostSummaryByExperimentRunDataLoader(db),
|
|
701
752
|
),
|
|
702
753
|
cache_for_dataloaders=cache_for_dataloaders,
|
|
703
754
|
read_only=read_only,
|
|
@@ -705,6 +756,7 @@ def create_graphql_router(
|
|
|
705
756
|
secret=secret,
|
|
706
757
|
token_store=token_store,
|
|
707
758
|
email_sender=email_sender,
|
|
759
|
+
span_cost_calculator=span_cost_calculator,
|
|
708
760
|
)
|
|
709
761
|
|
|
710
762
|
return GraphQLRouter(
|
|
@@ -860,9 +912,12 @@ def create_app(
|
|
|
860
912
|
db=db,
|
|
861
913
|
dml_event_handler=dml_event_handler,
|
|
862
914
|
)
|
|
915
|
+
generative_model_store = GenerativeModelStore(db)
|
|
916
|
+
span_cost_calculator = SpanCostCalculator(db, generative_model_store)
|
|
863
917
|
bulk_inserter = bulk_inserter_factory(
|
|
864
918
|
db,
|
|
865
919
|
enable_prometheus=enable_prometheus,
|
|
920
|
+
span_cost_calculator=span_cost_calculator,
|
|
866
921
|
event_queue=dml_event_handler,
|
|
867
922
|
initial_batch_of_spans=initial_batch_of_spans,
|
|
868
923
|
initial_batch_of_evaluations=initial_batch_of_evaluations,
|
|
@@ -904,6 +959,7 @@ def create_app(
|
|
|
904
959
|
secret=secret,
|
|
905
960
|
token_store=token_store,
|
|
906
961
|
email_sender=email_sender,
|
|
962
|
+
span_cost_calculator=span_cost_calculator,
|
|
907
963
|
)
|
|
908
964
|
if enable_prometheus:
|
|
909
965
|
from phoenix.server.prometheus import PrometheusMiddleware
|
|
@@ -918,6 +974,8 @@ def create_app(
|
|
|
918
974
|
bulk_inserter=bulk_inserter,
|
|
919
975
|
dml_event_handler=dml_event_handler,
|
|
920
976
|
trace_data_sweeper=trace_data_sweeper,
|
|
977
|
+
span_cost_calculator=span_cost_calculator,
|
|
978
|
+
generative_model_store=generative_model_store,
|
|
921
979
|
token_store=token_store,
|
|
922
980
|
tracer_provider=tracer_provider,
|
|
923
981
|
enable_prometheus=enable_prometheus,
|
|
@@ -981,6 +1039,7 @@ def create_app(
|
|
|
981
1039
|
app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
|
|
982
1040
|
app.state.db = db
|
|
983
1041
|
app.state.email_sender = email_sender
|
|
1042
|
+
app.state.span_cost_calculator = span_cost_calculator
|
|
984
1043
|
app = _add_get_secret_method(app=app, secret=secret)
|
|
985
1044
|
app = _add_get_token_store_method(app=app, token_store=token_store)
|
|
986
1045
|
if tracer_provider:
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from itertools import chain
|
|
2
|
+
from typing import Any, Iterable, Mapping
|
|
3
|
+
|
|
4
|
+
from typing_extensions import TypeAlias
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.cost_tracking.helpers import get_aggregated_tokens
|
|
8
|
+
from phoenix.server.cost_tracking.token_cost_calculator import (
|
|
9
|
+
TokenCostCalculator,
|
|
10
|
+
create_token_cost_calculator,
|
|
11
|
+
)
|
|
12
|
+
from phoenix.trace.attributes import get_attribute_value
|
|
13
|
+
|
|
14
|
+
_TokenType: TypeAlias = str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpanCostDetailsCalculator:
|
|
18
|
+
"""
|
|
19
|
+
Calculates detailed cost breakdowns for LLM spans based on token usage and pricing.
|
|
20
|
+
|
|
21
|
+
This calculator processes both detailed token counts (from span attributes) and
|
|
22
|
+
aggregated token totals to provide comprehensive cost analysis for prompt and
|
|
23
|
+
completion tokens. It handles multiple token types (e.g., "input", "output",
|
|
24
|
+
"image", "audio", "video", "document", "reasoning", etc.) and calculates costs
|
|
25
|
+
using configured pricing models with fallback behavior.
|
|
26
|
+
|
|
27
|
+
**Fallback Behavior:**
|
|
28
|
+
- If a specific token type has a configured calculator, it uses that calculator
|
|
29
|
+
- If no specific calculator exists, it falls back to the default calculator:
|
|
30
|
+
- Prompt tokens (is_prompt=True) fall back to "input" calculator
|
|
31
|
+
- Completion tokens (is_prompt=False) fall back to "output" calculator
|
|
32
|
+
|
|
33
|
+
This ensures all token types get cost calculations even if not explicitly configured.
|
|
34
|
+
|
|
35
|
+
The calculator expects token prices to include at least:
|
|
36
|
+
- An "input" token type for prompt tokens (used as fallback for unconfigured prompt token types)
|
|
37
|
+
- An "output" token type for completion tokens (used as fallback for unconfigured completion token types)
|
|
38
|
+
|
|
39
|
+
Additional token types can be configured for more granular cost tracking.
|
|
40
|
+
""" # noqa: E501
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
prices: Iterable[models.TokenPrice],
|
|
45
|
+
) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Initialize the cost calculator with token pricing configuration.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
prices: Collection of token price configurations defining rates for
|
|
51
|
+
different token types and whether they're prompt or completion tokens.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If required "input" (prompt) or "output" (completion)
|
|
55
|
+
token types are missing from the pricing configuration.
|
|
56
|
+
"""
|
|
57
|
+
# Create calculators for prompt token types (is_prompt=True)
|
|
58
|
+
self._prompt: Mapping[_TokenType, TokenCostCalculator] = {
|
|
59
|
+
p.token_type: create_token_cost_calculator(p.base_rate, p.customization)
|
|
60
|
+
for p in prices
|
|
61
|
+
if p.is_prompt
|
|
62
|
+
}
|
|
63
|
+
if "input" not in self._prompt:
|
|
64
|
+
raise ValueError("Token prices for prompt must include an 'input' token type")
|
|
65
|
+
|
|
66
|
+
# Create calculators for completion token types (is_prompt=False)
|
|
67
|
+
self._completion: Mapping[_TokenType, TokenCostCalculator] = {
|
|
68
|
+
p.token_type: create_token_cost_calculator(p.base_rate, p.customization)
|
|
69
|
+
for p in prices
|
|
70
|
+
if not p.is_prompt
|
|
71
|
+
}
|
|
72
|
+
if "output" not in self._completion:
|
|
73
|
+
raise ValueError("Token prices for completion must include an 'output' token type")
|
|
74
|
+
|
|
75
|
+
def calculate_details(
|
|
76
|
+
self,
|
|
77
|
+
attributes: Mapping[str, Any],
|
|
78
|
+
) -> list[models.SpanCostDetail]:
|
|
79
|
+
"""
|
|
80
|
+
Calculate detailed cost breakdown for a given span.
|
|
81
|
+
|
|
82
|
+
This method processes token usage in two phases:
|
|
83
|
+
1. **Detailed token processing**: Extracts specific token counts from span attributes
|
|
84
|
+
(e.g., "llm.token_count.prompt_details", "llm.token_count.completion_details")
|
|
85
|
+
and calculates costs for each token type found. Uses fallback behavior for
|
|
86
|
+
token types without specific calculators.
|
|
87
|
+
|
|
88
|
+
2. **Aggregated token processing**: For default token types ("input"/"output") that
|
|
89
|
+
weren't found in detailed processing, calculates remaining tokens by subtracting
|
|
90
|
+
detailed counts from total aggregated tokens.
|
|
91
|
+
|
|
92
|
+
**Fallback Calculation Logic:**
|
|
93
|
+
- For each token type in detailed processing:
|
|
94
|
+
- If a specific calculator exists for the token type, use it
|
|
95
|
+
- Otherwise, fall back to the default calculator ("input" for prompt tokens,
|
|
96
|
+
"output" for completion tokens)
|
|
97
|
+
- This ensures all token types receive cost calculations regardless of
|
|
98
|
+
specific calculator configuration
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
attributes: Dictionary containing span attributes with token usage data.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
List of SpanCostDetail objects containing token counts, costs, and cost-per-token
|
|
105
|
+
for each token type found in the span.
|
|
106
|
+
|
|
107
|
+
Note:
|
|
108
|
+
- Token counts are validated and converted to non-negative integers
|
|
109
|
+
- All token types receive cost calculations via fallback mechanism
|
|
110
|
+
- Cost-per-token is calculated only when both cost and token count are positive
|
|
111
|
+
- If cost is 0.0, cost-per-token will be None (not 0.0) due to falsy evaluation
|
|
112
|
+
"""
|
|
113
|
+
prompt_details: dict[_TokenType, models.SpanCostDetail] = {}
|
|
114
|
+
completion_details: dict[_TokenType, models.SpanCostDetail] = {}
|
|
115
|
+
|
|
116
|
+
# Phase 1: Process detailed token counts from span attributes
|
|
117
|
+
for is_prompt, prefix, calculators, results in (
|
|
118
|
+
(True, "prompt", self._prompt, prompt_details),
|
|
119
|
+
(False, "completion", self._completion, completion_details),
|
|
120
|
+
):
|
|
121
|
+
# Extract detailed token counts from span attributes
|
|
122
|
+
details = get_attribute_value(attributes, f"llm.token_count.{prefix}_details")
|
|
123
|
+
if isinstance(details, dict) and details:
|
|
124
|
+
for token_type, token_count in details.items():
|
|
125
|
+
# Validate token count is numeric
|
|
126
|
+
if not isinstance(token_count, (int, float)):
|
|
127
|
+
continue
|
|
128
|
+
tokens = max(0, int(token_count))
|
|
129
|
+
|
|
130
|
+
# Calculate cost using specific calculator or fallback to default
|
|
131
|
+
if token_type in calculators:
|
|
132
|
+
# Use specific calculator for this token type
|
|
133
|
+
calculator = calculators[token_type]
|
|
134
|
+
else:
|
|
135
|
+
# Fallback to default calculator: "input" for prompts,
|
|
136
|
+
# "output" for completions
|
|
137
|
+
key = "input" if is_prompt else "output"
|
|
138
|
+
calculator = calculators[key]
|
|
139
|
+
cost = calculator.calculate_cost(attributes, tokens)
|
|
140
|
+
|
|
141
|
+
# Calculate cost per token (avoid division by zero)
|
|
142
|
+
cost_per_token = cost / tokens if tokens else None
|
|
143
|
+
|
|
144
|
+
detail = models.SpanCostDetail(
|
|
145
|
+
token_type=token_type,
|
|
146
|
+
is_prompt=is_prompt,
|
|
147
|
+
tokens=tokens,
|
|
148
|
+
cost=cost,
|
|
149
|
+
cost_per_token=cost_per_token,
|
|
150
|
+
)
|
|
151
|
+
results[token_type] = detail
|
|
152
|
+
|
|
153
|
+
# Get aggregated token totals for fallback calculations
|
|
154
|
+
prompt_tokens, completion_tokens, _ = get_aggregated_tokens(attributes)
|
|
155
|
+
|
|
156
|
+
# Phase 2: Process remaining tokens for default token types
|
|
157
|
+
for is_prompt, token_type, total, calculators, results in (
|
|
158
|
+
(True, "input", prompt_tokens, self._prompt, prompt_details),
|
|
159
|
+
(False, "output", completion_tokens, self._completion, completion_details),
|
|
160
|
+
):
|
|
161
|
+
# Skip if this token type was already processed in detailed phase
|
|
162
|
+
if token_type in results:
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
# Calculate remaining tokens by subtracting detailed counts from total
|
|
166
|
+
tokens = total - sum(
|
|
167
|
+
int(d.tokens or 0) for d in results.values() if d.is_prompt == is_prompt
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Skip if no remaining tokens or negative (shouldn't happen with valid data)
|
|
171
|
+
if tokens <= 0:
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
# Calculate cost using guaranteed default calculator (input/output are required)
|
|
175
|
+
cost = calculators[token_type].calculate_cost(attributes, tokens)
|
|
176
|
+
|
|
177
|
+
# Calculate cost per token (avoid division by zero)
|
|
178
|
+
cost_per_token = cost / tokens if cost and tokens else None
|
|
179
|
+
|
|
180
|
+
detail = models.SpanCostDetail(
|
|
181
|
+
token_type=token_type,
|
|
182
|
+
is_prompt=is_prompt,
|
|
183
|
+
tokens=tokens,
|
|
184
|
+
cost=cost,
|
|
185
|
+
cost_per_token=cost_per_token,
|
|
186
|
+
)
|
|
187
|
+
results[token_type] = detail
|
|
188
|
+
|
|
189
|
+
# Return combined results from both prompt and completion processing
|
|
190
|
+
return list(chain(prompt_details.values(), completion_details.values()))
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Iterable, Mapping, Optional
|
|
4
|
+
|
|
5
|
+
from openinference.semconv.trace import SpanAttributes
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
7
|
+
|
|
8
|
+
from phoenix.datetime_utils import is_timezone_aware
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.cost_tracking import regex_specificity
|
|
11
|
+
from phoenix.trace.attributes import get_attribute_value
|
|
12
|
+
|
|
13
|
+
_RegexPatternStr: TypeAlias = str
|
|
14
|
+
_RegexSpecificityScore: TypeAlias = int
|
|
15
|
+
_TieBreakerId: TypeAlias = int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CostModelLookup:
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
generative_models: Iterable[models.GenerativeModel] = (),
|
|
22
|
+
) -> None:
|
|
23
|
+
self._models = tuple(generative_models)
|
|
24
|
+
self._model_priority: dict[
|
|
25
|
+
int, tuple[_RegexSpecificityScore, float, _TieBreakerId]
|
|
26
|
+
] = {} # higher is better
|
|
27
|
+
self._regex_specificity_score: dict[re.Pattern[str], _RegexSpecificityScore] = {}
|
|
28
|
+
|
|
29
|
+
for m in self._models:
|
|
30
|
+
self._regex_specificity_score[m.name_pattern] = regex_specificity.score(m.name_pattern)
|
|
31
|
+
|
|
32
|
+
# For built-in models, use negative ID so that earlier IDs win
|
|
33
|
+
# For user-defined models, use positive ID so later IDs win
|
|
34
|
+
tie_breaker = -m.id if m.is_built_in else m.id
|
|
35
|
+
|
|
36
|
+
self._model_priority[m.id] = (
|
|
37
|
+
self._regex_specificity_score[m.name_pattern],
|
|
38
|
+
m.start_time.timestamp() if m.start_time else 0.0,
|
|
39
|
+
tie_breaker,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def find_model(
|
|
43
|
+
self,
|
|
44
|
+
start_time: datetime,
|
|
45
|
+
attributes: Mapping[str, Any],
|
|
46
|
+
) -> Optional[models.GenerativeModel]:
|
|
47
|
+
"""
|
|
48
|
+
Find the most appropriate generative model for cost tracking based on attributes and time.
|
|
49
|
+
|
|
50
|
+
This method implements a sophisticated model lookup system that filters and prioritizes
|
|
51
|
+
generative models based on the provided attributes and timestamp. The lookup follows
|
|
52
|
+
a specific priority hierarchy to ensure consistent and predictable model selection.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
start_time: The timestamp for which to find a model. Must be timezone-aware.
|
|
56
|
+
Models with start_time greater than this value will be excluded.
|
|
57
|
+
attributes: A mapping containing span attributes. Must include:
|
|
58
|
+
- SpanAttributes.LLM_MODEL_NAME: The name of the LLM model to match
|
|
59
|
+
- SpanAttributes.LLM_PROVIDER: (Optional) The provider of the LLM model
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
TypeError: If start_time is not timezone-aware (tzinfo is None)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The most appropriate GenerativeModel that matches the criteria, or None if no
|
|
66
|
+
suitable model is found.
|
|
67
|
+
|
|
68
|
+
Model Selection Logic:
|
|
69
|
+
1. **Input Validation**: Returns None if model name is empty or whitespace-only
|
|
70
|
+
2. **Time and Regex Filtering**: Only models that satisfy both conditions:
|
|
71
|
+
- start_time <= start_time or start_time=None (active models)
|
|
72
|
+
- name_pattern regex matches the model name from attributes
|
|
73
|
+
3. **Early Return Optimization**: If only one candidate remains, return it immediately
|
|
74
|
+
4. **Two-Tier Priority System**: Models are processed in tiers:
|
|
75
|
+
- User-defined models (is_built_in=False) are processed first
|
|
76
|
+
- Built-in models (is_built_in=True) are processed second
|
|
77
|
+
- If a tier has only one model, return it immediately
|
|
78
|
+
5. **Provider Filtering**: Within each tier, if provider is specified:
|
|
79
|
+
- Prefer models with matching provider
|
|
80
|
+
- Fall back to provider-agnostic models if no provider-specific matches exist
|
|
81
|
+
6. **Priority Selection**: Select the model with the highest priority tuple:
|
|
82
|
+
(regex_specificity_score, start_time.timestamp, tie_breaker)
|
|
83
|
+
|
|
84
|
+
Priority Tuple Components:
|
|
85
|
+
- regex_specificity_score: More specific regex patterns have higher priority
|
|
86
|
+
- start_time.timestamp: Models with later start times have higher priority
|
|
87
|
+
- tie_breaker: For built-in models, uses negative ID (lower IDs win);
|
|
88
|
+
for user-defined models, uses positive ID (higher IDs win)
|
|
89
|
+
|
|
90
|
+
Examples:
|
|
91
|
+
>>> lookup = CostModelLookup([model1, model2, model3])
|
|
92
|
+
>>> model = lookup.find_model(
|
|
93
|
+
... start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
|
|
94
|
+
... attributes={"llm": {"model_name": "gpt-3.5-turbo", "provider": "openai"}}
|
|
95
|
+
... )
|
|
96
|
+
""" # noqa: E501
|
|
97
|
+
# 1. extract and validate inputs
|
|
98
|
+
if not is_timezone_aware(start_time):
|
|
99
|
+
raise TypeError("start_time must be timezone-aware")
|
|
100
|
+
|
|
101
|
+
model_name = str(
|
|
102
|
+
get_attribute_value(attributes, SpanAttributes.LLM_MODEL_NAME) or ""
|
|
103
|
+
).strip()
|
|
104
|
+
if not model_name:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
# 2. only include models that are active and match the regex pattern
|
|
108
|
+
candidates = [
|
|
109
|
+
model
|
|
110
|
+
for model in self._models
|
|
111
|
+
if (not model.start_time or model.start_time <= start_time)
|
|
112
|
+
and model.name_pattern.match(model_name)
|
|
113
|
+
]
|
|
114
|
+
if not candidates:
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
# 3. early return: if only one candidate remains, return it
|
|
118
|
+
if len(candidates) == 1:
|
|
119
|
+
return candidates[0]
|
|
120
|
+
|
|
121
|
+
provider = str(get_attribute_value(attributes, SpanAttributes.LLM_PROVIDER) or "").strip()
|
|
122
|
+
|
|
123
|
+
# 4. priority-based selection: user-defined models first, then built-in models
|
|
124
|
+
for is_built_in in (False, True): # False = user-defined, True = built-in
|
|
125
|
+
# get candidates for current tier (user-defined or built-in)
|
|
126
|
+
tier_candidates = [model for model in candidates if model.is_built_in == is_built_in]
|
|
127
|
+
|
|
128
|
+
if not tier_candidates:
|
|
129
|
+
continue # try next tier
|
|
130
|
+
|
|
131
|
+
# early return: if only one candidate in this tier, return it
|
|
132
|
+
if len(tier_candidates) == 1:
|
|
133
|
+
return tier_candidates[0]
|
|
134
|
+
|
|
135
|
+
# 5. provider filtering: if provider specified, prefer provider-specific models
|
|
136
|
+
if provider:
|
|
137
|
+
provider_specific_models = [
|
|
138
|
+
model
|
|
139
|
+
for model in tier_candidates
|
|
140
|
+
if model.provider and model.provider == provider
|
|
141
|
+
]
|
|
142
|
+
# only use provider-specific models if any exist
|
|
143
|
+
# this allows fallback to provider-agnostic models when no match
|
|
144
|
+
if provider_specific_models:
|
|
145
|
+
tier_candidates = provider_specific_models
|
|
146
|
+
|
|
147
|
+
# 6. select best model in this tier
|
|
148
|
+
return max(tier_candidates, key=lambda model: self._model_priority[model.id])
|
|
149
|
+
|
|
150
|
+
# 7. no suitable model found
|
|
151
|
+
return None
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Mapping
|
|
3
|
+
|
|
4
|
+
from openinference.semconv.trace import SpanAttributes
|
|
5
|
+
from typing_extensions import TypeAlias
|
|
6
|
+
|
|
7
|
+
from phoenix.trace.attributes import get_attribute_value
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
_PromptTokens: TypeAlias = int
|
|
12
|
+
_CompletionTokens: TypeAlias = int
|
|
13
|
+
_TotalTokens: TypeAlias = int
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_aggregated_tokens(
|
|
17
|
+
attributes: Mapping[str, Any],
|
|
18
|
+
) -> tuple[_PromptTokens, _CompletionTokens, _TotalTokens]:
|
|
19
|
+
"""Return the total, prompt, and completion token counts from the span attributes."""
|
|
20
|
+
try:
|
|
21
|
+
prompt_tokens_value = get_attribute_value(
|
|
22
|
+
attributes,
|
|
23
|
+
SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
|
|
24
|
+
)
|
|
25
|
+
prompt_tokens: int = (
|
|
26
|
+
0
|
|
27
|
+
if not isinstance(prompt_tokens_value, (int, float))
|
|
28
|
+
else max(0, int(prompt_tokens_value))
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
completion_tokens_value = get_attribute_value(
|
|
32
|
+
attributes,
|
|
33
|
+
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
|
|
34
|
+
)
|
|
35
|
+
completion_tokens: int = (
|
|
36
|
+
0
|
|
37
|
+
if not isinstance(completion_tokens_value, (int, float))
|
|
38
|
+
else max(0, int(completion_tokens_value))
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
total_tokens_value = get_attribute_value(
|
|
42
|
+
attributes,
|
|
43
|
+
SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
|
|
44
|
+
)
|
|
45
|
+
total_tokens: int = (
|
|
46
|
+
0
|
|
47
|
+
if not isinstance(total_tokens_value, (int, float))
|
|
48
|
+
else max(0, int(total_tokens_value))
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
assert prompt_tokens >= 0
|
|
52
|
+
assert completion_tokens >= 0
|
|
53
|
+
assert total_tokens >= 0
|
|
54
|
+
|
|
55
|
+
calculated_total = prompt_tokens + completion_tokens
|
|
56
|
+
|
|
57
|
+
if total_tokens > calculated_total:
|
|
58
|
+
if not prompt_tokens:
|
|
59
|
+
prompt_tokens = total_tokens - completion_tokens
|
|
60
|
+
else:
|
|
61
|
+
completion_tokens = total_tokens - prompt_tokens
|
|
62
|
+
else:
|
|
63
|
+
total_tokens = calculated_total
|
|
64
|
+
|
|
65
|
+
return prompt_tokens, completion_tokens, total_tokens
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.error(f"Error getting aggregated tokens: {e}")
|
|
68
|
+
return 0, 0, 0
|