arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
  2. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
  3. phoenix/config.py +5 -2
  4. phoenix/datetime_utils.py +8 -1
  5. phoenix/db/bulk_inserter.py +40 -1
  6. phoenix/db/facilitator.py +263 -4
  7. phoenix/db/insertion/helpers.py +15 -0
  8. phoenix/db/insertion/span.py +3 -1
  9. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  10. phoenix/db/models.py +267 -9
  11. phoenix/db/types/model_provider.py +1 -0
  12. phoenix/db/types/token_price_customization.py +29 -0
  13. phoenix/server/api/context.py +38 -4
  14. phoenix/server/api/dataloaders/__init__.py +41 -5
  15. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  16. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  17. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  18. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  19. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  20. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  21. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  22. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
  23. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  24. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  25. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
  26. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  27. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  28. phoenix/server/api/dataloaders/span_costs.py +35 -0
  29. phoenix/server/api/dataloaders/types.py +29 -0
  30. phoenix/server/api/helpers/playground_clients.py +562 -12
  31. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  32. phoenix/server/api/helpers/prompts/models.py +67 -0
  33. phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
  34. phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
  35. phoenix/server/api/input_types/SpanSort.py +17 -0
  36. phoenix/server/api/mutations/__init__.py +2 -0
  37. phoenix/server/api/mutations/chat_mutations.py +17 -0
  38. phoenix/server/api/mutations/model_mutations.py +208 -0
  39. phoenix/server/api/queries.py +82 -41
  40. phoenix/server/api/routers/v1/traces.py +11 -4
  41. phoenix/server/api/subscriptions.py +36 -2
  42. phoenix/server/api/types/CostBreakdown.py +15 -0
  43. phoenix/server/api/types/Experiment.py +59 -1
  44. phoenix/server/api/types/ExperimentRun.py +58 -4
  45. phoenix/server/api/types/GenerativeModel.py +143 -2
  46. phoenix/server/api/types/GenerativeProvider.py +33 -20
  47. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  48. phoenix/server/api/types/ModelInterface.py +11 -0
  49. phoenix/server/api/types/PlaygroundModel.py +10 -0
  50. phoenix/server/api/types/Project.py +42 -0
  51. phoenix/server/api/types/ProjectSession.py +44 -0
  52. phoenix/server/api/types/Span.py +137 -0
  53. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  54. phoenix/server/api/types/SpanCostSummary.py +10 -0
  55. phoenix/server/api/types/TokenPrice.py +16 -0
  56. phoenix/server/api/types/TokenUsage.py +3 -3
  57. phoenix/server/api/types/Trace.py +41 -0
  58. phoenix/server/app.py +59 -0
  59. phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
  60. phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
  61. phoenix/server/cost_tracking/helpers.py +68 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
  63. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  64. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  65. phoenix/server/daemons/__init__.py +0 -0
  66. phoenix/server/daemons/generative_model_store.py +51 -0
  67. phoenix/server/daemons/span_cost_calculator.py +103 -0
  68. phoenix/server/dml_event_handler.py +1 -0
  69. phoenix/server/static/.vite/manifest.json +36 -36
  70. phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
  71. phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
  72. phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
  73. phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
  74. phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
  75. phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
  76. phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
  77. phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  80. phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
  81. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from collections import defaultdict
2
3
  from datetime import datetime
3
4
  from typing import Iterable, Iterator, Optional, Union, cast
@@ -21,12 +22,6 @@ from phoenix.config import (
21
22
  from phoenix.db import models
22
23
  from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
23
24
  from phoenix.db.helpers import SupportedSQLDialect, exclude_experiment_projects
24
- from phoenix.db.models import DatasetExample as OrmExample
25
- from phoenix.db.models import DatasetExampleRevision as OrmRevision
26
- from phoenix.db.models import DatasetVersion as OrmVersion
27
- from phoenix.db.models import Experiment as OrmExperiment
28
- from phoenix.db.models import ExperimentRun as OrmExperimentRun
29
- from phoenix.db.models import Trace as OrmTrace
30
25
  from phoenix.pointcloud.clustering import Hdbscan
31
26
  from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
32
27
  from phoenix.server.api.context import Context
@@ -62,12 +57,13 @@ from phoenix.server.api.types.Experiment import Experiment
62
57
  from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
63
58
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
64
59
  from phoenix.server.api.types.Functionality import Functionality
65
- from phoenix.server.api.types.GenerativeModel import GenerativeModel
60
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
66
61
  from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
62
+ from phoenix.server.api.types.InferenceModel import InferenceModel
67
63
  from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
68
- from phoenix.server.api.types.Model import Model
69
64
  from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
70
65
  from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
66
+ from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
71
67
  from phoenix.server.api.types.Project import Project
72
68
  from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
73
69
  from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
@@ -115,20 +111,39 @@ class Query:
115
111
  ]
116
112
 
117
113
  @strawberry.field
118
- async def models(self, input: Optional[ModelsInput] = None) -> list[GenerativeModel]:
114
+ async def generative_models(
115
+ self,
116
+ info: Info[Context, None],
117
+ ) -> list[GenerativeModel]:
118
+ async with info.context.db() as session:
119
+ result = await session.scalars(
120
+ select(models.GenerativeModel)
121
+ .where(models.GenerativeModel.deleted_at.is_(None))
122
+ .order_by(
123
+ models.GenerativeModel.is_built_in.asc(), # display custom models first
124
+ models.GenerativeModel.provider.nullslast(),
125
+ models.GenerativeModel.name,
126
+ )
127
+ .options(joinedload(models.GenerativeModel.token_prices))
128
+ )
129
+
130
+ return [to_gql_generative_model(model) for model in result.unique()]
131
+
132
+ @strawberry.field
133
+ async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
119
134
  if input is not None and input.provider_key is not None:
120
135
  supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
121
136
  supported_models = [
122
- GenerativeModel(name=model_name, provider_key=input.provider_key)
137
+ PlaygroundModel(name=model_name, provider_key=input.provider_key)
123
138
  for model_name in supported_model_names
124
139
  ]
125
140
  return supported_models
126
141
 
127
142
  registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
128
- all_models: list[GenerativeModel] = []
143
+ all_models: list[PlaygroundModel] = []
129
144
  for provider_key, model_name in registered_models:
130
145
  if model_name is not None and provider_key is not None:
131
- all_models.append(GenerativeModel(name=model_name, provider_key=provider_key))
146
+ all_models.append(PlaygroundModel(name=model_name, provider_key=provider_key))
132
147
  return all_models
133
148
 
134
149
  @strawberry.field
@@ -330,7 +345,7 @@ class Query:
330
345
  )
331
346
 
332
347
  experiment_ids_ = [
333
- from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
348
+ from_global_id_with_expected_type(experiment_id, models.Experiment.__name__)
334
349
  for experiment_id in experiment_ids
335
350
  ]
336
351
  if len(set(experiment_ids_)) != len(experiment_ids_):
@@ -340,18 +355,18 @@ class Query:
340
355
  validation_result = (
341
356
  await session.execute(
342
357
  select(
343
- func.count(distinct(OrmVersion.dataset_id)),
344
- func.max(OrmVersion.dataset_id),
345
- func.max(OrmVersion.id),
346
- func.count(OrmExperiment.id),
358
+ func.count(distinct(models.DatasetVersion.dataset_id)),
359
+ func.max(models.DatasetVersion.dataset_id),
360
+ func.max(models.DatasetVersion.id),
361
+ func.count(models.Experiment.id),
347
362
  )
348
- .select_from(OrmVersion)
363
+ .select_from(models.DatasetVersion)
349
364
  .join(
350
- OrmExperiment,
351
- OrmExperiment.dataset_version_id == OrmVersion.id,
365
+ models.Experiment,
366
+ models.Experiment.dataset_version_id == models.DatasetVersion.id,
352
367
  )
353
368
  .where(
354
- OrmExperiment.id.in_(experiment_ids_),
369
+ models.Experiment.id.in_(experiment_ids_),
355
370
  )
356
371
  )
357
372
  ).first()
@@ -365,29 +380,33 @@ class Query:
365
380
  raise ValueError("Unable to resolve one or more experiment IDs.")
366
381
 
367
382
  revision_ids = (
368
- select(func.max(OrmRevision.id))
369
- .join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
383
+ select(func.max(models.DatasetExampleRevision.id))
384
+ .join(
385
+ models.DatasetExample,
386
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
387
+ )
370
388
  .where(
371
389
  and_(
372
- OrmRevision.dataset_version_id <= version_id,
373
- OrmExample.dataset_id == dataset_id,
390
+ models.DatasetExampleRevision.dataset_version_id <= version_id,
391
+ models.DatasetExample.dataset_id == dataset_id,
374
392
  )
375
393
  )
376
- .group_by(OrmRevision.dataset_example_id)
394
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
377
395
  .scalar_subquery()
378
396
  )
379
397
  examples_query = (
380
- select(OrmExample)
381
- .distinct(OrmExample.id)
398
+ select(models.DatasetExample)
399
+ .distinct(models.DatasetExample.id)
382
400
  .join(
383
- OrmRevision,
401
+ models.DatasetExampleRevision,
384
402
  onclause=and_(
385
- OrmExample.id == OrmRevision.dataset_example_id,
386
- OrmRevision.id.in_(revision_ids),
387
- OrmRevision.revision_kind != "DELETE",
403
+ models.DatasetExample.id
404
+ == models.DatasetExampleRevision.dataset_example_id,
405
+ models.DatasetExampleRevision.id.in_(revision_ids),
406
+ models.DatasetExampleRevision.revision_kind != "DELETE",
388
407
  ),
389
408
  )
390
- .order_by(OrmExample.id.desc())
409
+ .order_by(models.DatasetExample.id.desc())
391
410
  )
392
411
 
393
412
  if filter_condition:
@@ -401,18 +420,20 @@ class Query:
401
420
 
402
421
  ExampleID: TypeAlias = int
403
422
  ExperimentID: TypeAlias = int
404
- runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[OrmExperimentRun]]] = (
423
+ runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[models.ExperimentRun]]] = (
405
424
  defaultdict(lambda: defaultdict(list))
406
425
  )
407
426
  async for run in await session.stream_scalars(
408
- select(OrmExperimentRun)
427
+ select(models.ExperimentRun)
409
428
  .where(
410
429
  and_(
411
- OrmExperimentRun.dataset_example_id.in_(example.id for example in examples),
412
- OrmExperimentRun.experiment_id.in_(experiment_ids_),
430
+ models.ExperimentRun.dataset_example_id.in_(
431
+ example.id for example in examples
432
+ ),
433
+ models.ExperimentRun.experiment_id.in_(experiment_ids_),
413
434
  )
414
435
  )
415
- .options(joinedload(OrmExperimentRun.trace).load_only(OrmTrace.trace_id))
436
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
416
437
  ):
417
438
  runs[run.dataset_example_id][run.experiment_id].append(run)
418
439
 
@@ -460,7 +481,7 @@ class Query:
460
481
  compile_sqlalchemy_filter_condition(
461
482
  filter_condition=condition,
462
483
  experiment_ids=[
463
- from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
484
+ from_global_id_with_expected_type(experiment_id, models.Experiment.__name__)
464
485
  for experiment_id in experiment_ids
465
486
  ],
466
487
  )
@@ -482,8 +503,8 @@ class Query:
482
503
  )
483
504
 
484
505
  @strawberry.field
485
- def model(self) -> Model:
486
- return Model()
506
+ def model(self) -> InferenceModel:
507
+ return InferenceModel()
487
508
 
488
509
  @strawberry.field
489
510
  async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
@@ -658,6 +679,18 @@ class Query:
658
679
  if not trace_annotation:
659
680
  raise NotFound(f"Unknown trace annotation: {id}")
660
681
  return to_gql_trace_annotation(trace_annotation)
682
+ elif type_name == GenerativeModel.__name__:
683
+ async with info.context.db() as session:
684
+ stmt = (
685
+ select(models.GenerativeModel)
686
+ .where(models.GenerativeModel.deleted_at.is_(None))
687
+ .where(models.GenerativeModel.id == node_id)
688
+ .options(joinedload(models.GenerativeModel.token_prices))
689
+ )
690
+ model = await session.scalar(stmt)
691
+ if not model:
692
+ raise NotFound(f"Unknown model: {id}")
693
+ return to_gql_generative_model(model)
661
694
  raise NotFound(f"Unknown node type: {type_name}")
662
695
 
663
696
  @strawberry.field
@@ -964,6 +997,14 @@ class Query:
964
997
  for table_name, num_bytes in stats
965
998
  ]
966
999
 
1000
+ @strawberry.field
1001
+ def validate_regular_expression(self, regex: str) -> ValidationResult:
1002
+ try:
1003
+ re.compile(regex)
1004
+ return ValidationResult(is_valid=True, error_message=None)
1005
+ except re.error as error:
1006
+ return ValidationResult(is_valid=False, error_message=str(error))
1007
+
967
1008
 
968
1009
  def _consolidate_sqlite_db_table_stats(
969
1010
  stats: Iterable[tuple[str, int]],
@@ -3,7 +3,6 @@ import zlib
3
3
  from typing import Any, Literal, Optional
4
4
 
5
5
  from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Query
6
- from google.protobuf.json_format import MessageToJson
7
6
  from google.protobuf.message import DecodeError
8
7
  from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
9
8
  ExportTraceServiceRequest,
@@ -14,7 +13,7 @@ from sqlalchemy import insert, select
14
13
  from starlette.concurrency import run_in_threadpool
15
14
  from starlette.datastructures import State
16
15
  from starlette.requests import Request
17
- from starlette.responses import JSONResponse
16
+ from starlette.responses import Response
18
17
  from starlette.status import (
19
18
  HTTP_404_NOT_FOUND,
20
19
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
@@ -66,7 +65,7 @@ async def post_traces(
66
65
  background_tasks: BackgroundTasks,
67
66
  content_type: Optional[str] = Header(default=None),
68
67
  content_encoding: Optional[str] = Header(default=None),
69
- ) -> JSONResponse:
68
+ ) -> Response:
70
69
  if content_type != "application/x-protobuf":
71
70
  raise HTTPException(
72
71
  detail=f"Unsupported content type: {content_type}",
@@ -91,7 +90,15 @@ async def post_traces(
91
90
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
92
91
  )
93
92
  background_tasks.add_task(_add_spans, req, request.state)
94
- return JSONResponse(MessageToJson(ExportTraceServiceResponse()))
93
+
94
+ # "The server MUST use the same Content-Type in the response as it received in the request"
95
+ response_message = ExportTraceServiceResponse()
96
+ response_bytes = response_message.SerializeToString()
97
+ return Response(
98
+ content=response_bytes,
99
+ media_type="application/x-protobuf",
100
+ status_code=200,
101
+ )
95
102
 
96
103
 
97
104
  class TraceAnnotationResult(V1RoutesBaseModel):
@@ -62,6 +62,7 @@ from phoenix.server.api.types.Experiment import to_gql_experiment
62
62
  from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
63
63
  from phoenix.server.api.types.node import from_global_id_with_expected_type
64
64
  from phoenix.server.api.types.Span import Span
65
+ from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
65
66
  from phoenix.server.dml_event import SpanInsertEvent
66
67
  from phoenix.server.types import DbSessionFactory
67
68
  from phoenix.utilities.template_formatters import (
@@ -173,6 +174,19 @@ class Subscription:
173
174
  db_span = get_db_span(span, db_trace)
174
175
  session.add(db_span)
175
176
  await session.flush()
177
+ try:
178
+ span_cost = info.context.span_cost_calculator.calculate_cost(
179
+ start_time=db_span.start_time,
180
+ attributes=span.attributes,
181
+ )
182
+ except Exception as e:
183
+ logger.exception(f"Failed to calculate cost for span {db_span.id}: {e}")
184
+ span_cost = None
185
+ if span_cost:
186
+ span_cost.span_rowid = db_span.id
187
+ span_cost.trace_rowid = db_span.trace_rowid
188
+ session.add(span_cost)
189
+
176
190
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
177
191
  yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
178
192
 
@@ -372,14 +386,18 @@ class Subscription:
372
386
  and not write_already_in_progress
373
387
  ):
374
388
  result_payloads_stream = _chat_completion_result_payloads(
375
- db=info.context.db, results=_drain_no_wait(results)
389
+ db=info.context.db,
390
+ results=_drain_no_wait(results),
391
+ span_cost_calculator=info.context.span_cost_calculator,
376
392
  )
377
393
  task = _create_task_with_timeout(result_payloads_stream)
378
394
  in_progress.append((None, result_payloads_stream, task))
379
395
  last_write_time = datetime.now()
380
396
  if remaining_results := await _drain(results):
381
397
  async for result_payload in _chat_completion_result_payloads(
382
- db=info.context.db, results=remaining_results
398
+ db=info.context.db,
399
+ results=remaining_results,
400
+ span_cost_calculator=info.context.span_cost_calculator,
383
401
  ):
384
402
  yield result_payload
385
403
 
@@ -463,6 +481,7 @@ async def _chat_completion_result_payloads(
463
481
  *,
464
482
  db: DbSessionFactory,
465
483
  results: Sequence[ChatCompletionResult],
484
+ span_cost_calculator: SpanCostCalculator,
466
485
  ) -> ChatStream:
467
486
  if not results:
468
487
  return
@@ -470,6 +489,19 @@ async def _chat_completion_result_payloads(
470
489
  for _, span, run in results:
471
490
  if span:
472
491
  session.add(span)
492
+ await session.flush()
493
+ try:
494
+ span_cost = span_cost_calculator.calculate_cost(
495
+ start_time=span.start_time,
496
+ attributes=span.attributes,
497
+ )
498
+ except Exception as e:
499
+ logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
500
+ span_cost = None
501
+ if span_cost:
502
+ span_cost.span_rowid = span.id
503
+ span_cost.trace_rowid = span.trace_rowid
504
+ session.add(span_cost)
473
505
  session.add(run)
474
506
  await session.flush()
475
507
  for example_id, span, run in results:
@@ -594,3 +626,5 @@ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
594
626
  LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
595
627
  LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
596
628
  PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
629
+ LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
630
+ LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
@@ -0,0 +1,15 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class CostBreakdown:
8
+ tokens: Optional[float] = None
9
+ cost: Optional[float] = None
10
+
11
+ @strawberry.field
12
+ def cost_per_token(self) -> Optional[float]:
13
+ if self.tokens and self.cost:
14
+ return self.cost / self.tokens
15
+ return None
@@ -2,8 +2,9 @@ from datetime import datetime
2
2
  from typing import ClassVar, Optional
3
3
 
4
4
  import strawberry
5
- from sqlalchemy import select
5
+ from sqlalchemy import func, select
6
6
  from sqlalchemy.orm import joinedload
7
+ from sqlalchemy.sql.functions import coalesce
7
8
  from strawberry import UNSET, Private
8
9
  from strawberry.relay import Connection, Node, NodeID
9
10
  from strawberry.scalars import JSON
@@ -11,6 +12,7 @@ from strawberry.types import Info
11
12
 
12
13
  from phoenix.db import models
13
14
  from phoenix.server.api.context import Context
15
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
14
16
  from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
15
17
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
18
  from phoenix.server.api.types.pagination import (
@@ -19,6 +21,8 @@ from phoenix.server.api.types.pagination import (
19
21
  connection_from_list,
20
22
  )
21
23
  from phoenix.server.api.types.Project import Project
24
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
25
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
22
26
 
23
27
 
24
28
  @strawberry.type
@@ -130,6 +134,60 @@ class Experiment(Node):
130
134
  def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
131
135
  return info.context.last_updated_at.get(self._table, self.id_attr)
132
136
 
137
+ @strawberry.field
138
+ async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
139
+ experiment_id = self.id_attr
140
+ summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(
141
+ experiment_id
142
+ )
143
+ return SpanCostSummary(
144
+ prompt=CostBreakdown(
145
+ tokens=summary.prompt.tokens,
146
+ cost=summary.prompt.cost,
147
+ ),
148
+ completion=CostBreakdown(
149
+ tokens=summary.completion.tokens,
150
+ cost=summary.completion.cost,
151
+ ),
152
+ total=CostBreakdown(
153
+ tokens=summary.total.tokens,
154
+ cost=summary.total.cost,
155
+ ),
156
+ )
157
+
158
+ @strawberry.field
159
+ async def cost_detail_summary_entries(
160
+ self, info: Info[Context, None]
161
+ ) -> list[SpanCostDetailSummaryEntry]:
162
+ experiment_id = self.id_attr
163
+
164
+ stmt = (
165
+ select(
166
+ models.SpanCostDetail.token_type,
167
+ models.SpanCostDetail.is_prompt,
168
+ coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
169
+ coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
170
+ )
171
+ .select_from(models.SpanCostDetail)
172
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
173
+ .join(models.Span, models.SpanCost.span_rowid == models.Span.id)
174
+ .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
175
+ .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
176
+ .where(models.ExperimentRun.experiment_id == experiment_id)
177
+ .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
178
+ )
179
+
180
+ async with info.context.db() as session:
181
+ data = await session.stream(stmt)
182
+ return [
183
+ SpanCostDetailSummaryEntry(
184
+ token_type=token_type,
185
+ is_prompt=is_prompt,
186
+ value=CostBreakdown(tokens=tokens, cost=cost),
187
+ )
188
+ async for token_type, is_prompt, cost, tokens in data
189
+ ]
190
+
133
191
 
134
192
  def to_gql_experiment(
135
193
  experiment: models.Experiment,
@@ -2,8 +2,9 @@ from datetime import datetime
2
2
  from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
- from sqlalchemy import select
5
+ from sqlalchemy import func, select
6
6
  from sqlalchemy.orm import load_only
7
+ from sqlalchemy.sql.functions import coalesce
7
8
  from strawberry import UNSET
8
9
  from strawberry.relay import Connection, GlobalID, Node, NodeID
9
10
  from strawberry.scalars import JSON
@@ -11,6 +12,7 @@ from strawberry.types import Info
11
12
 
12
13
  from phoenix.db import models
13
14
  from phoenix.server.api.context import Context
15
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
14
16
  from phoenix.server.api.types.ExperimentRunAnnotation import (
15
17
  ExperimentRunAnnotation,
16
18
  to_gql_experiment_run_annotation,
@@ -20,6 +22,8 @@ from phoenix.server.api.types.pagination import (
20
22
  CursorString,
21
23
  connection_from_list,
22
24
  )
25
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
26
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
23
27
  from phoenix.server.api.types.Trace import Trace
24
28
 
25
29
  if TYPE_CHECKING:
@@ -98,6 +102,58 @@ class ExperimentRun(Node):
98
102
  version_id=version_id,
99
103
  )
100
104
 
105
+ @strawberry.field
106
+ async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
107
+ run_id = self.id_attr
108
+ summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(run_id)
109
+ return SpanCostSummary(
110
+ prompt=CostBreakdown(
111
+ tokens=summary.prompt.tokens,
112
+ cost=summary.prompt.cost,
113
+ ),
114
+ completion=CostBreakdown(
115
+ tokens=summary.completion.tokens,
116
+ cost=summary.completion.cost,
117
+ ),
118
+ total=CostBreakdown(
119
+ tokens=summary.total.tokens,
120
+ cost=summary.total.cost,
121
+ ),
122
+ )
123
+
124
+ @strawberry.field
125
+ async def cost_detail_summary_entries(
126
+ self, info: Info[Context, None]
127
+ ) -> list[SpanCostDetailSummaryEntry]:
128
+ run_id = self.id_attr
129
+
130
+ stmt = (
131
+ select(
132
+ models.SpanCostDetail.token_type,
133
+ models.SpanCostDetail.is_prompt,
134
+ coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
135
+ coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
136
+ )
137
+ .select_from(models.SpanCostDetail)
138
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
139
+ .join(models.Span, models.SpanCost.span_rowid == models.Span.id)
140
+ .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
141
+ .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
142
+ .where(models.ExperimentRun.id == run_id)
143
+ .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
144
+ )
145
+
146
+ async with info.context.db() as session:
147
+ data = await session.stream(stmt)
148
+ return [
149
+ SpanCostDetailSummaryEntry(
150
+ token_type=token_type,
151
+ is_prompt=is_prompt,
152
+ value=CostBreakdown(tokens=tokens, cost=cost),
153
+ )
154
+ async for token_type, is_prompt, cost, tokens in data
155
+ ]
156
+
101
157
 
102
158
  def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
103
159
  """
@@ -109,9 +165,7 @@ def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
109
165
  return ExperimentRun(
110
166
  id_attr=run.id,
111
167
  experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
112
- trace_id=trace_id
113
- if (trace := run.trace) and (trace_id := trace.trace_id) is not None
114
- else None,
168
+ trace_id=run.trace.trace_id if run.trace else None,
115
169
  output=run.output.get("task_output"),
116
170
  start_time=run.start_time,
117
171
  end_time=run.end_time,