arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
  2. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/model_provider.py +1 -0
  12. phoenix/db/types/token_price_customization.py +29 -0
  13. phoenix/server/api/context.py +38 -4
  14. phoenix/server/api/dataloaders/__init__.py +41 -5
  15. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  16. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  20. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  21. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  27. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  28. phoenix/server/api/dataloaders/span_costs.py +35 -0
  29. phoenix/server/api/dataloaders/types.py +29 -0
  30. phoenix/server/api/helpers/playground_clients.py +562 -12
  31. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  32. phoenix/server/api/helpers/prompts/models.py +67 -0
  33. phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
  34. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  35. phoenix/server/api/input_types/SpanSort.py +17 -0
  36. phoenix/server/api/mutations/__init__.py +2 -0
  37. phoenix/server/api/mutations/chat_mutations.py +17 -0
  38. phoenix/server/api/mutations/model_mutations.py +208 -0
  39. phoenix/server/api/queries.py +82 -41
  40. phoenix/server/api/routers/v1/traces.py +11 -4
  41. phoenix/server/api/subscriptions.py +36 -2
  42. phoenix/server/api/types/CostBreakdown.py +15 -0
  43. phoenix/server/api/types/Experiment.py +59 -1
  44. phoenix/server/api/types/ExperimentRun.py +58 -4
  45. phoenix/server/api/types/GenerativeModel.py +143 -2
  46. phoenix/server/api/types/GenerativeProvider.py +33 -20
  47. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  48. phoenix/server/api/types/ModelInterface.py +11 -0
  49. phoenix/server/api/types/PlaygroundModel.py +10 -0
  50. phoenix/server/api/types/Project.py +42 -0
  51. phoenix/server/api/types/ProjectSession.py +44 -0
  52. phoenix/server/api/types/Span.py +137 -0
  53. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  54. phoenix/server/api/types/SpanCostSummary.py +10 -0
  55. phoenix/server/api/types/TokenPrice.py +16 -0
  56. phoenix/server/api/types/TokenUsage.py +3 -3
  57. phoenix/server/api/types/Trace.py +41 -0
  58. phoenix/server/app.py +59 -0
  59. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  60. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  61. phoenix/server/cost_tracking/helpers.py +68 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  63. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  64. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  65. phoenix/server/daemons/__init__.py +0 -0
  66. phoenix/server/daemons/generative_model_store.py +51 -0
  67. phoenix/server/daemons/span_cost_calculator.py +103 -0
  68. phoenix/server/dml_event_handler.py +1 -0
  69. phoenix/server/static/.vite/manifest.json +36 -36
  70. phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
  71. phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
  72. phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
  73. phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
  74. phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
  75. phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  76. phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
  77. phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  80. phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
  81. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@ from typing_extensions import TypeAlias
14
14
  from phoenix.db import models
15
15
  from phoenix.server.api.context import Context
16
16
  from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
17
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
17
18
  from phoenix.server.api.types.pagination import (
18
19
  ConnectionArgs,
19
20
  CursorString,
@@ -21,6 +22,8 @@ from phoenix.server.api.types.pagination import (
21
22
  )
22
23
  from phoenix.server.api.types.SortDir import SortDir
23
24
  from phoenix.server.api.types.Span import Span
25
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
26
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
24
27
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
25
28
 
26
29
  if TYPE_CHECKING:
@@ -226,6 +229,44 @@ class Trace(Node):
226
229
  annotations = await session.scalars(stmt)
227
230
  return [to_gql_trace_annotation(annotation) for annotation in annotations]
228
231
 
232
+ @strawberry.field
233
+ async def cost_summary(
234
+ self,
235
+ info: Info[Context, None],
236
+ ) -> SpanCostSummary:
237
+ loader = info.context.data_loaders.span_cost_summary_by_trace
238
+ summary = await loader.load(self.trace_rowid)
239
+ return SpanCostSummary(
240
+ prompt=CostBreakdown(
241
+ tokens=summary.prompt.tokens,
242
+ cost=summary.prompt.cost,
243
+ ),
244
+ completion=CostBreakdown(
245
+ tokens=summary.completion.tokens,
246
+ cost=summary.completion.cost,
247
+ ),
248
+ total=CostBreakdown(
249
+ tokens=summary.total.tokens,
250
+ cost=summary.total.cost,
251
+ ),
252
+ )
253
+
254
+ @strawberry.field
255
+ async def cost_detail_summary_entries(
256
+ self,
257
+ info: Info[Context, None],
258
+ ) -> list[SpanCostDetailSummaryEntry]:
259
+ loader = info.context.data_loaders.span_cost_detail_summary_entries_by_trace
260
+ entries = await loader.load(self.trace_rowid)
261
+ return [
262
+ SpanCostDetailSummaryEntry(
263
+ token_type=entry.token_type,
264
+ is_prompt=entry.is_prompt,
265
+ value=CostBreakdown(tokens=entry.value.tokens, cost=entry.value.cost),
266
+ )
267
+ for entry in entries
268
+ ]
269
+
229
270
 
230
271
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
231
272
  OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
phoenix/server/app.py CHANGED
@@ -85,6 +85,7 @@ from phoenix.server.api.dataloaders import (
85
85
  ExperimentRunAnnotations,
86
86
  ExperimentRunCountsDataLoader,
87
87
  ExperimentSequenceNumberDataLoader,
88
+ LastUsedTimesByGenerativeModelIdDataLoader,
88
89
  LatencyMsQuantileDataLoader,
89
90
  MinStartOrMaxEndTimeDataLoader,
90
91
  NumChildSpansDataLoader,
@@ -100,6 +101,18 @@ from phoenix.server.api.dataloaders import (
100
101
  SessionTraceLatencyMsQuantileDataLoader,
101
102
  SpanAnnotationsDataLoader,
102
103
  SpanByIdDataLoader,
104
+ SpanCostBySpanDataLoader,
105
+ SpanCostDetailsBySpanCostDataLoader,
106
+ SpanCostDetailSummaryEntriesByGenerativeModelDataLoader,
107
+ SpanCostDetailSummaryEntriesByProjectSessionDataLoader,
108
+ SpanCostDetailSummaryEntriesBySpanDataLoader,
109
+ SpanCostDetailSummaryEntriesByTraceDataLoader,
110
+ SpanCostSummaryByExperimentDataLoader,
111
+ SpanCostSummaryByExperimentRunDataLoader,
112
+ SpanCostSummaryByGenerativeModelDataLoader,
113
+ SpanCostSummaryByProjectDataLoader,
114
+ SpanCostSummaryByProjectSessionDataLoader,
115
+ SpanCostSummaryByTraceDataLoader,
103
116
  SpanDatasetExamplesDataLoader,
104
117
  SpanDescendantsDataLoader,
105
118
  SpanProjectsDataLoader,
@@ -120,6 +133,8 @@ from phoenix.server.api.routers import (
120
133
  from phoenix.server.api.routers.v1 import REST_API_VERSION
121
134
  from phoenix.server.api.schema import build_graphql_schema
122
135
  from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
136
+ from phoenix.server.daemons.generative_model_store import GenerativeModelStore
137
+ from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
123
138
  from phoenix.server.dml_event import DmlEvent
124
139
  from phoenix.server.dml_event_handler import DmlEventHandler
125
140
  from phoenix.server.email.types import EmailSender
@@ -502,6 +517,8 @@ def _lifespan(
502
517
  bulk_inserter: BulkInserter,
503
518
  dml_event_handler: DmlEventHandler,
504
519
  trace_data_sweeper: Optional[TraceDataSweeper],
520
+ span_cost_calculator: SpanCostCalculator,
521
+ generative_model_store: GenerativeModelStore,
505
522
  token_store: Optional[TokenStore] = None,
506
523
  tracer_provider: Optional["TracerProvider"] = None,
507
524
  enable_prometheus: bool = False,
@@ -536,6 +553,8 @@ def _lifespan(
536
553
  await stack.enter_async_context(dml_event_handler)
537
554
  if trace_data_sweeper:
538
555
  await stack.enter_async_context(trace_data_sweeper)
556
+ await stack.enter_async_context(span_cost_calculator)
557
+ await stack.enter_async_context(generative_model_store)
539
558
  if scaffolder_config:
540
559
  scaffolder = Scaffolder(
541
560
  config=scaffolder_config,
@@ -583,6 +602,7 @@ def create_graphql_router(
583
602
  export_path: Path,
584
603
  last_updated_at: CanGetLastUpdatedAt,
585
604
  authentication_enabled: bool,
605
+ span_cost_calculator: SpanCostCalculator,
586
606
  corpus: Optional[Model] = None,
587
607
  cache_for_dataloaders: Optional[CacheForDataLoaders] = None,
588
608
  event_queue: CanPutItem[DmlEvent],
@@ -600,6 +620,7 @@ def create_graphql_router(
600
620
  export_path (Path): the file path to export data to for download (legacy)
601
621
  last_updated_at (CanGetLastUpdatedAt): How to get the last updated timestamp for updates.
602
622
  authentication_enabled (bool): Whether authentication is enabled.
623
+ span_cost_calculator (SpanCostCalculator): The span cost calculator for calculating costs.
603
624
  event_queue (CanPutItem[DmlEvent]): The event queue for DML events.
604
625
  corpus (Optional[Model], optional): the corpus for UMAP projection. Defaults to None.
605
626
  cache_for_dataloaders (Optional[CacheForDataLoaders], optional): GraphQL data loaders.
@@ -645,6 +666,9 @@ def create_graphql_router(
645
666
  experiment_run_annotations=ExperimentRunAnnotations(db),
646
667
  experiment_run_counts=ExperimentRunCountsDataLoader(db),
647
668
  experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
669
+ last_used_times_by_generative_model_id=LastUsedTimesByGenerativeModelIdDataLoader(
670
+ db
671
+ ),
648
672
  latency_ms_quantile=LatencyMsQuantileDataLoader(
649
673
  db,
650
674
  cache_map=(
@@ -679,6 +703,31 @@ def create_graphql_router(
679
703
  span_annotations=SpanAnnotationsDataLoader(db),
680
704
  span_fields=TableFieldsDataLoader(db, models.Span),
681
705
  span_by_id=SpanByIdDataLoader(db),
706
+ span_cost_by_span=SpanCostBySpanDataLoader(db),
707
+ span_cost_detail_summary_entries_by_generative_model=SpanCostDetailSummaryEntriesByGenerativeModelDataLoader(
708
+ db
709
+ ),
710
+ span_cost_detail_summary_entries_by_project_session=SpanCostDetailSummaryEntriesByProjectSessionDataLoader(
711
+ db
712
+ ),
713
+ span_cost_detail_summary_entries_by_span=SpanCostDetailSummaryEntriesBySpanDataLoader(
714
+ db
715
+ ),
716
+ span_cost_detail_summary_entries_by_trace=SpanCostDetailSummaryEntriesByTraceDataLoader(
717
+ db
718
+ ),
719
+ span_cost_details_by_span_cost=SpanCostDetailsBySpanCostDataLoader(db),
720
+ span_cost_detail_fields=TableFieldsDataLoader(db, models.SpanCostDetail),
721
+ span_cost_fields=TableFieldsDataLoader(db, models.SpanCost),
722
+ span_cost_summary_by_generative_model=SpanCostSummaryByGenerativeModelDataLoader(
723
+ db
724
+ ),
725
+ span_cost_summary_by_project=SpanCostSummaryByProjectDataLoader(
726
+ db,
727
+ cache_map=cache_for_dataloaders.token_cost if cache_for_dataloaders else None,
728
+ ),
729
+ span_cost_summary_by_project_session=SpanCostSummaryByProjectSessionDataLoader(db),
730
+ span_cost_summary_by_trace=SpanCostSummaryByTraceDataLoader(db),
682
731
  span_dataset_examples=SpanDatasetExamplesDataLoader(db),
683
732
  span_descendants=SpanDescendantsDataLoader(db),
684
733
  span_projects=SpanProjectsDataLoader(db),
@@ -698,6 +747,8 @@ def create_graphql_router(
698
747
  project_by_name=ProjectByNameDataLoader(db),
699
748
  users=UsersDataLoader(db),
700
749
  user_roles=UserRolesDataLoader(db),
750
+ span_cost_summary_by_experiment=SpanCostSummaryByExperimentDataLoader(db),
751
+ span_cost_summary_by_experiment_run=SpanCostSummaryByExperimentRunDataLoader(db),
701
752
  ),
702
753
  cache_for_dataloaders=cache_for_dataloaders,
703
754
  read_only=read_only,
@@ -705,6 +756,7 @@ def create_graphql_router(
705
756
  secret=secret,
706
757
  token_store=token_store,
707
758
  email_sender=email_sender,
759
+ span_cost_calculator=span_cost_calculator,
708
760
  )
709
761
 
710
762
  return GraphQLRouter(
@@ -860,9 +912,12 @@ def create_app(
860
912
  db=db,
861
913
  dml_event_handler=dml_event_handler,
862
914
  )
915
+ generative_model_store = GenerativeModelStore(db)
916
+ span_cost_calculator = SpanCostCalculator(db, generative_model_store)
863
917
  bulk_inserter = bulk_inserter_factory(
864
918
  db,
865
919
  enable_prometheus=enable_prometheus,
920
+ span_cost_calculator=span_cost_calculator,
866
921
  event_queue=dml_event_handler,
867
922
  initial_batch_of_spans=initial_batch_of_spans,
868
923
  initial_batch_of_evaluations=initial_batch_of_evaluations,
@@ -904,6 +959,7 @@ def create_app(
904
959
  secret=secret,
905
960
  token_store=token_store,
906
961
  email_sender=email_sender,
962
+ span_cost_calculator=span_cost_calculator,
907
963
  )
908
964
  if enable_prometheus:
909
965
  from phoenix.server.prometheus import PrometheusMiddleware
@@ -918,6 +974,8 @@ def create_app(
918
974
  bulk_inserter=bulk_inserter,
919
975
  dml_event_handler=dml_event_handler,
920
976
  trace_data_sweeper=trace_data_sweeper,
977
+ span_cost_calculator=span_cost_calculator,
978
+ generative_model_store=generative_model_store,
921
979
  token_store=token_store,
922
980
  tracer_provider=tracer_provider,
923
981
  enable_prometheus=enable_prometheus,
@@ -981,6 +1039,7 @@ def create_app(
981
1039
  app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or [])
982
1040
  app.state.db = db
983
1041
  app.state.email_sender = email_sender
1042
+ app.state.span_cost_calculator = span_cost_calculator
984
1043
  app = _add_get_secret_method(app=app, secret=secret)
985
1044
  app = _add_get_token_store_method(app=app, token_store=token_store)
986
1045
  if tracer_provider:
@@ -0,0 +1,190 @@
1
+ from itertools import chain
2
+ from typing import Any, Iterable, Mapping
3
+
4
+ from typing_extensions import TypeAlias
5
+
6
+ from phoenix.db import models
7
+ from phoenix.server.cost_tracking.helpers import get_aggregated_tokens
8
+ from phoenix.server.cost_tracking.token_cost_calculator import (
9
+ TokenCostCalculator,
10
+ create_token_cost_calculator,
11
+ )
12
+ from phoenix.trace.attributes import get_attribute_value
13
+
14
+ _TokenType: TypeAlias = str
15
+
16
+
17
+ class SpanCostDetailsCalculator:
18
+ """
19
+ Calculates detailed cost breakdowns for LLM spans based on token usage and pricing.
20
+
21
+ This calculator processes both detailed token counts (from span attributes) and
22
+ aggregated token totals to provide comprehensive cost analysis for prompt and
23
+ completion tokens. It handles multiple token types (e.g., "input", "output",
24
+ "image", "audio", "video", "document", "reasoning", etc.) and calculates costs
25
+ using configured pricing models with fallback behavior.
26
+
27
+ **Fallback Behavior:**
28
+ - If a specific token type has a configured calculator, it uses that calculator
29
+ - If no specific calculator exists, it falls back to the default calculator:
30
+ - Prompt tokens (is_prompt=True) fall back to "input" calculator
31
+ - Completion tokens (is_prompt=False) fall back to "output" calculator
32
+
33
+ This ensures all token types get cost calculations even if not explicitly configured.
34
+
35
+ The calculator expects token prices to include at least:
36
+ - An "input" token type for prompt tokens (used as fallback for unconfigured prompt token types)
37
+ - An "output" token type for completion tokens (used as fallback for unconfigured completion token types)
38
+
39
+ Additional token types can be configured for more granular cost tracking.
40
+ """ # noqa: E501
41
+
42
+ def __init__(
43
+ self,
44
+ prices: Iterable[models.TokenPrice],
45
+ ) -> None:
46
+ """
47
+ Initialize the cost calculator with token pricing configuration.
48
+
49
+ Args:
50
+ prices: Collection of token price configurations defining rates for
51
+ different token types and whether they're prompt or completion tokens.
52
+
53
+ Raises:
54
+ ValueError: If required "input" (prompt) or "output" (completion)
55
+ token types are missing from the pricing configuration.
56
+ """
57
+ # Create calculators for prompt token types (is_prompt=True)
58
+ self._prompt: Mapping[_TokenType, TokenCostCalculator] = {
59
+ p.token_type: create_token_cost_calculator(p.base_rate, p.customization)
60
+ for p in prices
61
+ if p.is_prompt
62
+ }
63
+ if "input" not in self._prompt:
64
+ raise ValueError("Token prices for prompt must include an 'input' token type")
65
+
66
+ # Create calculators for completion token types (is_prompt=False)
67
+ self._completion: Mapping[_TokenType, TokenCostCalculator] = {
68
+ p.token_type: create_token_cost_calculator(p.base_rate, p.customization)
69
+ for p in prices
70
+ if not p.is_prompt
71
+ }
72
+ if "output" not in self._completion:
73
+ raise ValueError("Token prices for completion must include an 'output' token type")
74
+
75
+ def calculate_details(
76
+ self,
77
+ attributes: Mapping[str, Any],
78
+ ) -> list[models.SpanCostDetail]:
79
+ """
80
+ Calculate detailed cost breakdown for a given span.
81
+
82
+ This method processes token usage in two phases:
83
+ 1. **Detailed token processing**: Extracts specific token counts from span attributes
84
+ (e.g., "llm.token_count.prompt_details", "llm.token_count.completion_details")
85
+ and calculates costs for each token type found. Uses fallback behavior for
86
+ token types without specific calculators.
87
+
88
+ 2. **Aggregated token processing**: For default token types ("input"/"output") that
89
+ weren't found in detailed processing, calculates remaining tokens by subtracting
90
+ detailed counts from total aggregated tokens.
91
+
92
+ **Fallback Calculation Logic:**
93
+ - For each token type in detailed processing:
94
+ - If a specific calculator exists for the token type, use it
95
+ - Otherwise, fall back to the default calculator ("input" for prompt tokens,
96
+ "output" for completion tokens)
97
+ - This ensures all token types receive cost calculations regardless of
98
+ specific calculator configuration
99
+
100
+ Args:
101
+ attributes: Dictionary containing span attributes with token usage data.
102
+
103
+ Returns:
104
+ List of SpanCostDetail objects containing token counts, costs, and cost-per-token
105
+ for each token type found in the span.
106
+
107
+ Note:
108
+ - Token counts are validated and converted to non-negative integers
109
+ - All token types receive cost calculations via fallback mechanism
110
+ - Cost-per-token is calculated only when both cost and token count are positive
111
+ - If cost is 0.0, cost-per-token will be None (not 0.0) due to falsy evaluation
112
+ """
113
+ prompt_details: dict[_TokenType, models.SpanCostDetail] = {}
114
+ completion_details: dict[_TokenType, models.SpanCostDetail] = {}
115
+
116
+ # Phase 1: Process detailed token counts from span attributes
117
+ for is_prompt, prefix, calculators, results in (
118
+ (True, "prompt", self._prompt, prompt_details),
119
+ (False, "completion", self._completion, completion_details),
120
+ ):
121
+ # Extract detailed token counts from span attributes
122
+ details = get_attribute_value(attributes, f"llm.token_count.{prefix}_details")
123
+ if isinstance(details, dict) and details:
124
+ for token_type, token_count in details.items():
125
+ # Validate token count is numeric
126
+ if not isinstance(token_count, (int, float)):
127
+ continue
128
+ tokens = max(0, int(token_count))
129
+
130
+ # Calculate cost using specific calculator or fallback to default
131
+ if token_type in calculators:
132
+ # Use specific calculator for this token type
133
+ calculator = calculators[token_type]
134
+ else:
135
+ # Fallback to default calculator: "input" for prompts,
136
+ # "output" for completions
137
+ key = "input" if is_prompt else "output"
138
+ calculator = calculators[key]
139
+ cost = calculator.calculate_cost(attributes, tokens)
140
+
141
+ # Calculate cost per token (avoid division by zero)
142
+ cost_per_token = cost / tokens if tokens else None
143
+
144
+ detail = models.SpanCostDetail(
145
+ token_type=token_type,
146
+ is_prompt=is_prompt,
147
+ tokens=tokens,
148
+ cost=cost,
149
+ cost_per_token=cost_per_token,
150
+ )
151
+ results[token_type] = detail
152
+
153
+ # Get aggregated token totals for fallback calculations
154
+ prompt_tokens, completion_tokens, _ = get_aggregated_tokens(attributes)
155
+
156
+ # Phase 2: Process remaining tokens for default token types
157
+ for is_prompt, token_type, total, calculators, results in (
158
+ (True, "input", prompt_tokens, self._prompt, prompt_details),
159
+ (False, "output", completion_tokens, self._completion, completion_details),
160
+ ):
161
+ # Skip if this token type was already processed in detailed phase
162
+ if token_type in results:
163
+ continue
164
+
165
+ # Calculate remaining tokens by subtracting detailed counts from total
166
+ tokens = total - sum(
167
+ int(d.tokens or 0) for d in results.values() if d.is_prompt == is_prompt
168
+ )
169
+
170
+ # Skip if no remaining tokens or negative (shouldn't happen with valid data)
171
+ if tokens <= 0:
172
+ continue
173
+
174
+ # Calculate cost using guaranteed default calculator (input/output are required)
175
+ cost = calculators[token_type].calculate_cost(attributes, tokens)
176
+
177
+ # Calculate cost per token (avoid division by zero)
178
+ cost_per_token = cost / tokens if cost and tokens else None
179
+
180
+ detail = models.SpanCostDetail(
181
+ token_type=token_type,
182
+ is_prompt=is_prompt,
183
+ tokens=tokens,
184
+ cost=cost,
185
+ cost_per_token=cost_per_token,
186
+ )
187
+ results[token_type] = detail
188
+
189
+ # Return combined results from both prompt and completion processing
190
+ return list(chain(prompt_details.values(), completion_details.values()))
@@ -0,0 +1,151 @@
1
+ import re
2
+ from datetime import datetime
3
+ from typing import Any, Iterable, Mapping, Optional
4
+
5
+ from openinference.semconv.trace import SpanAttributes
6
+ from typing_extensions import TypeAlias
7
+
8
+ from phoenix.datetime_utils import is_timezone_aware
9
+ from phoenix.db import models
10
+ from phoenix.server.cost_tracking import regex_specificity
11
+ from phoenix.trace.attributes import get_attribute_value
12
+
13
+ _RegexPatternStr: TypeAlias = str
14
+ _RegexSpecificityScore: TypeAlias = int
15
+ _TieBreakerId: TypeAlias = int
16
+
17
+
18
+ class CostModelLookup:
19
+ def __init__(
20
+ self,
21
+ generative_models: Iterable[models.GenerativeModel] = (),
22
+ ) -> None:
23
+ self._models = tuple(generative_models)
24
+ self._model_priority: dict[
25
+ int, tuple[_RegexSpecificityScore, float, _TieBreakerId]
26
+ ] = {} # higher is better
27
+ self._regex_specificity_score: dict[re.Pattern[str], _RegexSpecificityScore] = {}
28
+
29
+ for m in self._models:
30
+ self._regex_specificity_score[m.name_pattern] = regex_specificity.score(m.name_pattern)
31
+
32
+ # For built-in models, use negative ID so that earlier IDs win
33
+ # For user-defined models, use positive ID so later IDs win
34
+ tie_breaker = -m.id if m.is_built_in else m.id
35
+
36
+ self._model_priority[m.id] = (
37
+ self._regex_specificity_score[m.name_pattern],
38
+ m.start_time.timestamp() if m.start_time else 0.0,
39
+ tie_breaker,
40
+ )
41
+
42
+ def find_model(
43
+ self,
44
+ start_time: datetime,
45
+ attributes: Mapping[str, Any],
46
+ ) -> Optional[models.GenerativeModel]:
47
+ """
48
+ Find the most appropriate generative model for cost tracking based on attributes and time.
49
+
50
+ This method implements a sophisticated model lookup system that filters and prioritizes
51
+ generative models based on the provided attributes and timestamp. The lookup follows
52
+ a specific priority hierarchy to ensure consistent and predictable model selection.
53
+
54
+ Args:
55
+ start_time: The timestamp for which to find a model. Must be timezone-aware.
56
+ Models with start_time greater than this value will be excluded.
57
+ attributes: A mapping containing span attributes. Must include:
58
+ - SpanAttributes.LLM_MODEL_NAME: The name of the LLM model to match
59
+ - SpanAttributes.LLM_PROVIDER: (Optional) The provider of the LLM model
60
+
61
+ Raises:
62
+ TypeError: If start_time is not timezone-aware (tzinfo is None)
63
+
64
+ Returns:
65
+ The most appropriate GenerativeModel that matches the criteria, or None if no
66
+ suitable model is found.
67
+
68
+ Model Selection Logic:
69
+ 1. **Input Validation**: Returns None if model name is empty or whitespace-only
70
+ 2. **Time and Regex Filtering**: Only models that satisfy both conditions:
71
+ - start_time <= start_time or start_time=None (active models)
72
+ - name_pattern regex matches the model name from attributes
73
+ 3. **Early Return Optimization**: If only one candidate remains, return it immediately
74
+ 4. **Two-Tier Priority System**: Models are processed in tiers:
75
+ - User-defined models (is_built_in=False) are processed first
76
+ - Built-in models (is_built_in=True) are processed second
77
+ - If a tier has only one model, return it immediately
78
+ 5. **Provider Filtering**: Within each tier, if provider is specified:
79
+ - Prefer models with matching provider
80
+ - Fall back to provider-agnostic models if no provider-specific matches exist
81
+ 6. **Priority Selection**: Select the model with the highest priority tuple:
82
+ (regex_specificity_score, start_time.timestamp, tie_breaker)
83
+
84
+ Priority Tuple Components:
85
+ - regex_specificity_score: More specific regex patterns have higher priority
86
+ - start_time.timestamp: Models with later start times have higher priority
87
+ - tie_breaker: For built-in models, uses negative ID (lower IDs win);
88
+ for user-defined models, uses positive ID (higher IDs win)
89
+
90
+ Examples:
91
+ >>> lookup = CostModelLookup([model1, model2, model3])
92
+ >>> model = lookup.find_model(
93
+ ... start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
94
+ ... attributes={"llm": {"model_name": "gpt-3.5-turbo", "provider": "openai"}}
95
+ ... )
96
+ """ # noqa: E501
97
+ # 1. extract and validate inputs
98
+ if not is_timezone_aware(start_time):
99
+ raise TypeError("start_time must be timezone-aware")
100
+
101
+ model_name = str(
102
+ get_attribute_value(attributes, SpanAttributes.LLM_MODEL_NAME) or ""
103
+ ).strip()
104
+ if not model_name:
105
+ return None
106
+
107
+ # 2. only include models that are active and match the regex pattern
108
+ candidates = [
109
+ model
110
+ for model in self._models
111
+ if (not model.start_time or model.start_time <= start_time)
112
+ and model.name_pattern.match(model_name)
113
+ ]
114
+ if not candidates:
115
+ return None
116
+
117
+ # 3. early return: if only one candidate remains, return it
118
+ if len(candidates) == 1:
119
+ return candidates[0]
120
+
121
+ provider = str(get_attribute_value(attributes, SpanAttributes.LLM_PROVIDER) or "").strip()
122
+
123
+ # 4. priority-based selection: user-defined models first, then built-in models
124
+ for is_built_in in (False, True): # False = user-defined, True = built-in
125
+ # get candidates for current tier (user-defined or built-in)
126
+ tier_candidates = [model for model in candidates if model.is_built_in == is_built_in]
127
+
128
+ if not tier_candidates:
129
+ continue # try next tier
130
+
131
+ # early return: if only one candidate in this tier, return it
132
+ if len(tier_candidates) == 1:
133
+ return tier_candidates[0]
134
+
135
+ # 5. provider filtering: if provider specified, prefer provider-specific models
136
+ if provider:
137
+ provider_specific_models = [
138
+ model
139
+ for model in tier_candidates
140
+ if model.provider and model.provider == provider
141
+ ]
142
+ # only use provider-specific models if any exist
143
+ # this allows fallback to provider-agnostic models when no match
144
+ if provider_specific_models:
145
+ tier_candidates = provider_specific_models
146
+
147
+ # 6. select best model in this tier
148
+ return max(tier_candidates, key=lambda model: self._model_priority[model.id])
149
+
150
+ # 7. no suitable model found
151
+ return None
@@ -0,0 +1,68 @@
1
+ import logging
2
+ from typing import Any, Mapping
3
+
4
+ from openinference.semconv.trace import SpanAttributes
5
+ from typing_extensions import TypeAlias
6
+
7
+ from phoenix.trace.attributes import get_attribute_value
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _PromptTokens: TypeAlias = int
12
+ _CompletionTokens: TypeAlias = int
13
+ _TotalTokens: TypeAlias = int
14
+
15
+
16
+ def get_aggregated_tokens(
17
+ attributes: Mapping[str, Any],
18
+ ) -> tuple[_PromptTokens, _CompletionTokens, _TotalTokens]:
19
+ """Return the total, prompt, and completion token counts from the span attributes."""
20
+ try:
21
+ prompt_tokens_value = get_attribute_value(
22
+ attributes,
23
+ SpanAttributes.LLM_TOKEN_COUNT_PROMPT,
24
+ )
25
+ prompt_tokens: int = (
26
+ 0
27
+ if not isinstance(prompt_tokens_value, (int, float))
28
+ else max(0, int(prompt_tokens_value))
29
+ )
30
+
31
+ completion_tokens_value = get_attribute_value(
32
+ attributes,
33
+ SpanAttributes.LLM_TOKEN_COUNT_COMPLETION,
34
+ )
35
+ completion_tokens: int = (
36
+ 0
37
+ if not isinstance(completion_tokens_value, (int, float))
38
+ else max(0, int(completion_tokens_value))
39
+ )
40
+
41
+ total_tokens_value = get_attribute_value(
42
+ attributes,
43
+ SpanAttributes.LLM_TOKEN_COUNT_TOTAL,
44
+ )
45
+ total_tokens: int = (
46
+ 0
47
+ if not isinstance(total_tokens_value, (int, float))
48
+ else max(0, int(total_tokens_value))
49
+ )
50
+
51
+ assert prompt_tokens >= 0
52
+ assert completion_tokens >= 0
53
+ assert total_tokens >= 0
54
+
55
+ calculated_total = prompt_tokens + completion_tokens
56
+
57
+ if total_tokens > calculated_total:
58
+ if not prompt_tokens:
59
+ prompt_tokens = total_tokens - completion_tokens
60
+ else:
61
+ completion_tokens = total_tokens - prompt_tokens
62
+ else:
63
+ total_tokens = calculated_total
64
+
65
+ return prompt_tokens, completion_tokens, total_tokens
66
+ except Exception as e:
67
+ logger.error(f"Error getting aggregated tokens: {e}")
68
+ return 0, 0, 0