arize-phoenix 10.14.0__py3-none-any.whl → 11.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/METADATA +3 -2
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/RECORD +82 -50
- phoenix/config.py +5 -2
- phoenix/datetime_utils.py +8 -1
- phoenix/db/bulk_inserter.py +40 -1
- phoenix/db/facilitator.py +263 -4
- phoenix/db/insertion/helpers.py +15 -0
- phoenix/db/insertion/span.py +3 -1
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/models.py +267 -9
- phoenix/db/types/model_provider.py +1 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/server/api/context.py +38 -4
- phoenix/server/api/dataloaders/__init__.py +41 -5
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +140 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +35 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/helpers/playground_clients.py +562 -12
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/models.py +67 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +2 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +3 -0
- phoenix/server/api/input_types/SpanSort.py +17 -0
- phoenix/server/api/mutations/__init__.py +2 -0
- phoenix/server/api/mutations/chat_mutations.py +17 -0
- phoenix/server/api/mutations/model_mutations.py +208 -0
- phoenix/server/api/queries.py +82 -41
- phoenix/server/api/routers/v1/traces.py +11 -4
- phoenix/server/api/subscriptions.py +36 -2
- phoenix/server/api/types/CostBreakdown.py +15 -0
- phoenix/server/api/types/Experiment.py +59 -1
- phoenix/server/api/types/ExperimentRun.py +58 -4
- phoenix/server/api/types/GenerativeModel.py +143 -2
- phoenix/server/api/types/GenerativeProvider.py +33 -20
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +11 -0
- phoenix/server/api/types/PlaygroundModel.py +10 -0
- phoenix/server/api/types/Project.py +42 -0
- phoenix/server/api/types/ProjectSession.py +44 -0
- phoenix/server/api/types/Span.py +137 -0
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +41 -0
- phoenix/server/app.py +59 -0
- phoenix/server/cost_tracking/cost_details_calculator.py +190 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +151 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +59 -329
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/generative_model_store.py +51 -0
- phoenix/server/daemons/span_cost_calculator.py +103 -0
- phoenix/server/dml_event_handler.py +1 -0
- phoenix/server/static/.vite/manifest.json +36 -36
- phoenix/server/static/assets/components-BnK9kodr.js +5055 -0
- phoenix/server/static/assets/{index-qiubV_74.js → index-S3YKLmbo.js} +13 -13
- phoenix/server/static/assets/{pages-C4V07ozl.js → pages-BW6PBHZb.js} +809 -417
- phoenix/server/static/assets/{vendor-Bfsiga8H.js → vendor-DqQvHbPa.js} +147 -147
- phoenix/server/static/assets/{vendor-arizeai-CQOWsrzm.js → vendor-arizeai-CLX44PFA.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-CrcGVhB2.js → vendor-codemirror-Du3XyJnB.js} +1 -1
- phoenix/server/static/assets/{vendor-recharts-Yyg3G-Rq.js → vendor-recharts-B2PJDrnX.js} +25 -25
- phoenix/server/static/assets/{vendor-shiki-OPjag7Hm.js → vendor-shiki-CNbrFjf9.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-CUUWyAMo.js +0 -4509
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-10.14.0.dist-info → arize_phoenix-11.0.0.dist-info}/licenses/LICENSE +0 -0
phoenix/server/api/queries.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from collections import defaultdict
|
|
2
3
|
from datetime import datetime
|
|
3
4
|
from typing import Iterable, Iterator, Optional, Union, cast
|
|
@@ -21,12 +22,6 @@ from phoenix.config import (
|
|
|
21
22
|
from phoenix.db import models
|
|
22
23
|
from phoenix.db.constants import DEFAULT_PROJECT_TRACE_RETENTION_POLICY_ID
|
|
23
24
|
from phoenix.db.helpers import SupportedSQLDialect, exclude_experiment_projects
|
|
24
|
-
from phoenix.db.models import DatasetExample as OrmExample
|
|
25
|
-
from phoenix.db.models import DatasetExampleRevision as OrmRevision
|
|
26
|
-
from phoenix.db.models import DatasetVersion as OrmVersion
|
|
27
|
-
from phoenix.db.models import Experiment as OrmExperiment
|
|
28
|
-
from phoenix.db.models import ExperimentRun as OrmExperimentRun
|
|
29
|
-
from phoenix.db.models import Trace as OrmTrace
|
|
30
25
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
31
26
|
from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin
|
|
32
27
|
from phoenix.server.api.context import Context
|
|
@@ -62,12 +57,13 @@ from phoenix.server.api.types.Experiment import Experiment
|
|
|
62
57
|
from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
|
|
63
58
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
64
59
|
from phoenix.server.api.types.Functionality import Functionality
|
|
65
|
-
from phoenix.server.api.types.GenerativeModel import GenerativeModel
|
|
60
|
+
from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
|
|
66
61
|
from phoenix.server.api.types.GenerativeProvider import GenerativeProvider, GenerativeProviderKey
|
|
62
|
+
from phoenix.server.api.types.InferenceModel import InferenceModel
|
|
67
63
|
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
68
|
-
from phoenix.server.api.types.Model import Model
|
|
69
64
|
from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
|
|
70
65
|
from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
|
|
66
|
+
from phoenix.server.api.types.PlaygroundModel import PlaygroundModel
|
|
71
67
|
from phoenix.server.api.types.Project import Project
|
|
72
68
|
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
|
|
73
69
|
from phoenix.server.api.types.ProjectTraceRetentionPolicy import ProjectTraceRetentionPolicy
|
|
@@ -115,20 +111,39 @@ class Query:
|
|
|
115
111
|
]
|
|
116
112
|
|
|
117
113
|
@strawberry.field
|
|
118
|
-
async def
|
|
114
|
+
async def generative_models(
|
|
115
|
+
self,
|
|
116
|
+
info: Info[Context, None],
|
|
117
|
+
) -> list[GenerativeModel]:
|
|
118
|
+
async with info.context.db() as session:
|
|
119
|
+
result = await session.scalars(
|
|
120
|
+
select(models.GenerativeModel)
|
|
121
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
122
|
+
.order_by(
|
|
123
|
+
models.GenerativeModel.is_built_in.asc(), # display custom models first
|
|
124
|
+
models.GenerativeModel.provider.nullslast(),
|
|
125
|
+
models.GenerativeModel.name,
|
|
126
|
+
)
|
|
127
|
+
.options(joinedload(models.GenerativeModel.token_prices))
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return [to_gql_generative_model(model) for model in result.unique()]
|
|
131
|
+
|
|
132
|
+
@strawberry.field
|
|
133
|
+
async def playground_models(self, input: Optional[ModelsInput] = None) -> list[PlaygroundModel]:
|
|
119
134
|
if input is not None and input.provider_key is not None:
|
|
120
135
|
supported_model_names = PLAYGROUND_CLIENT_REGISTRY.list_models(input.provider_key)
|
|
121
136
|
supported_models = [
|
|
122
|
-
|
|
137
|
+
PlaygroundModel(name=model_name, provider_key=input.provider_key)
|
|
123
138
|
for model_name in supported_model_names
|
|
124
139
|
]
|
|
125
140
|
return supported_models
|
|
126
141
|
|
|
127
142
|
registered_models = PLAYGROUND_CLIENT_REGISTRY.list_all_models()
|
|
128
|
-
all_models: list[
|
|
143
|
+
all_models: list[PlaygroundModel] = []
|
|
129
144
|
for provider_key, model_name in registered_models:
|
|
130
145
|
if model_name is not None and provider_key is not None:
|
|
131
|
-
all_models.append(
|
|
146
|
+
all_models.append(PlaygroundModel(name=model_name, provider_key=provider_key))
|
|
132
147
|
return all_models
|
|
133
148
|
|
|
134
149
|
@strawberry.field
|
|
@@ -330,7 +345,7 @@ class Query:
|
|
|
330
345
|
)
|
|
331
346
|
|
|
332
347
|
experiment_ids_ = [
|
|
333
|
-
from_global_id_with_expected_type(experiment_id,
|
|
348
|
+
from_global_id_with_expected_type(experiment_id, models.Experiment.__name__)
|
|
334
349
|
for experiment_id in experiment_ids
|
|
335
350
|
]
|
|
336
351
|
if len(set(experiment_ids_)) != len(experiment_ids_):
|
|
@@ -340,18 +355,18 @@ class Query:
|
|
|
340
355
|
validation_result = (
|
|
341
356
|
await session.execute(
|
|
342
357
|
select(
|
|
343
|
-
func.count(distinct(
|
|
344
|
-
func.max(
|
|
345
|
-
func.max(
|
|
346
|
-
func.count(
|
|
358
|
+
func.count(distinct(models.DatasetVersion.dataset_id)),
|
|
359
|
+
func.max(models.DatasetVersion.dataset_id),
|
|
360
|
+
func.max(models.DatasetVersion.id),
|
|
361
|
+
func.count(models.Experiment.id),
|
|
347
362
|
)
|
|
348
|
-
.select_from(
|
|
363
|
+
.select_from(models.DatasetVersion)
|
|
349
364
|
.join(
|
|
350
|
-
|
|
351
|
-
|
|
365
|
+
models.Experiment,
|
|
366
|
+
models.Experiment.dataset_version_id == models.DatasetVersion.id,
|
|
352
367
|
)
|
|
353
368
|
.where(
|
|
354
|
-
|
|
369
|
+
models.Experiment.id.in_(experiment_ids_),
|
|
355
370
|
)
|
|
356
371
|
)
|
|
357
372
|
).first()
|
|
@@ -365,29 +380,33 @@ class Query:
|
|
|
365
380
|
raise ValueError("Unable to resolve one or more experiment IDs.")
|
|
366
381
|
|
|
367
382
|
revision_ids = (
|
|
368
|
-
select(func.max(
|
|
369
|
-
.join(
|
|
383
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
384
|
+
.join(
|
|
385
|
+
models.DatasetExample,
|
|
386
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
|
|
387
|
+
)
|
|
370
388
|
.where(
|
|
371
389
|
and_(
|
|
372
|
-
|
|
373
|
-
|
|
390
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id,
|
|
391
|
+
models.DatasetExample.dataset_id == dataset_id,
|
|
374
392
|
)
|
|
375
393
|
)
|
|
376
|
-
.group_by(
|
|
394
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
377
395
|
.scalar_subquery()
|
|
378
396
|
)
|
|
379
397
|
examples_query = (
|
|
380
|
-
select(
|
|
381
|
-
.distinct(
|
|
398
|
+
select(models.DatasetExample)
|
|
399
|
+
.distinct(models.DatasetExample.id)
|
|
382
400
|
.join(
|
|
383
|
-
|
|
401
|
+
models.DatasetExampleRevision,
|
|
384
402
|
onclause=and_(
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
403
|
+
models.DatasetExample.id
|
|
404
|
+
== models.DatasetExampleRevision.dataset_example_id,
|
|
405
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
406
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
388
407
|
),
|
|
389
408
|
)
|
|
390
|
-
.order_by(
|
|
409
|
+
.order_by(models.DatasetExample.id.desc())
|
|
391
410
|
)
|
|
392
411
|
|
|
393
412
|
if filter_condition:
|
|
@@ -401,18 +420,20 @@ class Query:
|
|
|
401
420
|
|
|
402
421
|
ExampleID: TypeAlias = int
|
|
403
422
|
ExperimentID: TypeAlias = int
|
|
404
|
-
runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[
|
|
423
|
+
runs: defaultdict[ExampleID, defaultdict[ExperimentID, list[models.ExperimentRun]]] = (
|
|
405
424
|
defaultdict(lambda: defaultdict(list))
|
|
406
425
|
)
|
|
407
426
|
async for run in await session.stream_scalars(
|
|
408
|
-
select(
|
|
427
|
+
select(models.ExperimentRun)
|
|
409
428
|
.where(
|
|
410
429
|
and_(
|
|
411
|
-
|
|
412
|
-
|
|
430
|
+
models.ExperimentRun.dataset_example_id.in_(
|
|
431
|
+
example.id for example in examples
|
|
432
|
+
),
|
|
433
|
+
models.ExperimentRun.experiment_id.in_(experiment_ids_),
|
|
413
434
|
)
|
|
414
435
|
)
|
|
415
|
-
.options(joinedload(
|
|
436
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
416
437
|
):
|
|
417
438
|
runs[run.dataset_example_id][run.experiment_id].append(run)
|
|
418
439
|
|
|
@@ -460,7 +481,7 @@ class Query:
|
|
|
460
481
|
compile_sqlalchemy_filter_condition(
|
|
461
482
|
filter_condition=condition,
|
|
462
483
|
experiment_ids=[
|
|
463
|
-
from_global_id_with_expected_type(experiment_id,
|
|
484
|
+
from_global_id_with_expected_type(experiment_id, models.Experiment.__name__)
|
|
464
485
|
for experiment_id in experiment_ids
|
|
465
486
|
],
|
|
466
487
|
)
|
|
@@ -482,8 +503,8 @@ class Query:
|
|
|
482
503
|
)
|
|
483
504
|
|
|
484
505
|
@strawberry.field
|
|
485
|
-
def model(self) ->
|
|
486
|
-
return
|
|
506
|
+
def model(self) -> InferenceModel:
|
|
507
|
+
return InferenceModel()
|
|
487
508
|
|
|
488
509
|
@strawberry.field
|
|
489
510
|
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
|
|
@@ -658,6 +679,18 @@ class Query:
|
|
|
658
679
|
if not trace_annotation:
|
|
659
680
|
raise NotFound(f"Unknown trace annotation: {id}")
|
|
660
681
|
return to_gql_trace_annotation(trace_annotation)
|
|
682
|
+
elif type_name == GenerativeModel.__name__:
|
|
683
|
+
async with info.context.db() as session:
|
|
684
|
+
stmt = (
|
|
685
|
+
select(models.GenerativeModel)
|
|
686
|
+
.where(models.GenerativeModel.deleted_at.is_(None))
|
|
687
|
+
.where(models.GenerativeModel.id == node_id)
|
|
688
|
+
.options(joinedload(models.GenerativeModel.token_prices))
|
|
689
|
+
)
|
|
690
|
+
model = await session.scalar(stmt)
|
|
691
|
+
if not model:
|
|
692
|
+
raise NotFound(f"Unknown model: {id}")
|
|
693
|
+
return to_gql_generative_model(model)
|
|
661
694
|
raise NotFound(f"Unknown node type: {type_name}")
|
|
662
695
|
|
|
663
696
|
@strawberry.field
|
|
@@ -964,6 +997,14 @@ class Query:
|
|
|
964
997
|
for table_name, num_bytes in stats
|
|
965
998
|
]
|
|
966
999
|
|
|
1000
|
+
@strawberry.field
|
|
1001
|
+
def validate_regular_expression(self, regex: str) -> ValidationResult:
|
|
1002
|
+
try:
|
|
1003
|
+
re.compile(regex)
|
|
1004
|
+
return ValidationResult(is_valid=True, error_message=None)
|
|
1005
|
+
except re.error as error:
|
|
1006
|
+
return ValidationResult(is_valid=False, error_message=str(error))
|
|
1007
|
+
|
|
967
1008
|
|
|
968
1009
|
def _consolidate_sqlite_db_table_stats(
|
|
969
1010
|
stats: Iterable[tuple[str, int]],
|
|
@@ -3,7 +3,6 @@ import zlib
|
|
|
3
3
|
from typing import Any, Literal, Optional
|
|
4
4
|
|
|
5
5
|
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Query
|
|
6
|
-
from google.protobuf.json_format import MessageToJson
|
|
7
6
|
from google.protobuf.message import DecodeError
|
|
8
7
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
9
8
|
ExportTraceServiceRequest,
|
|
@@ -14,7 +13,7 @@ from sqlalchemy import insert, select
|
|
|
14
13
|
from starlette.concurrency import run_in_threadpool
|
|
15
14
|
from starlette.datastructures import State
|
|
16
15
|
from starlette.requests import Request
|
|
17
|
-
from starlette.responses import
|
|
16
|
+
from starlette.responses import Response
|
|
18
17
|
from starlette.status import (
|
|
19
18
|
HTTP_404_NOT_FOUND,
|
|
20
19
|
HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
|
@@ -66,7 +65,7 @@ async def post_traces(
|
|
|
66
65
|
background_tasks: BackgroundTasks,
|
|
67
66
|
content_type: Optional[str] = Header(default=None),
|
|
68
67
|
content_encoding: Optional[str] = Header(default=None),
|
|
69
|
-
) ->
|
|
68
|
+
) -> Response:
|
|
70
69
|
if content_type != "application/x-protobuf":
|
|
71
70
|
raise HTTPException(
|
|
72
71
|
detail=f"Unsupported content type: {content_type}",
|
|
@@ -91,7 +90,15 @@ async def post_traces(
|
|
|
91
90
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
92
91
|
)
|
|
93
92
|
background_tasks.add_task(_add_spans, req, request.state)
|
|
94
|
-
|
|
93
|
+
|
|
94
|
+
# "The server MUST use the same Content-Type in the response as it received in the request"
|
|
95
|
+
response_message = ExportTraceServiceResponse()
|
|
96
|
+
response_bytes = response_message.SerializeToString()
|
|
97
|
+
return Response(
|
|
98
|
+
content=response_bytes,
|
|
99
|
+
media_type="application/x-protobuf",
|
|
100
|
+
status_code=200,
|
|
101
|
+
)
|
|
95
102
|
|
|
96
103
|
|
|
97
104
|
class TraceAnnotationResult(V1RoutesBaseModel):
|
|
@@ -62,6 +62,7 @@ from phoenix.server.api.types.Experiment import to_gql_experiment
|
|
|
62
62
|
from phoenix.server.api.types.ExperimentRun import to_gql_experiment_run
|
|
63
63
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
64
64
|
from phoenix.server.api.types.Span import Span
|
|
65
|
+
from phoenix.server.daemons.span_cost_calculator import SpanCostCalculator
|
|
65
66
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
66
67
|
from phoenix.server.types import DbSessionFactory
|
|
67
68
|
from phoenix.utilities.template_formatters import (
|
|
@@ -173,6 +174,19 @@ class Subscription:
|
|
|
173
174
|
db_span = get_db_span(span, db_trace)
|
|
174
175
|
session.add(db_span)
|
|
175
176
|
await session.flush()
|
|
177
|
+
try:
|
|
178
|
+
span_cost = info.context.span_cost_calculator.calculate_cost(
|
|
179
|
+
start_time=db_span.start_time,
|
|
180
|
+
attributes=span.attributes,
|
|
181
|
+
)
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.exception(f"Failed to calculate cost for span {db_span.id}: {e}")
|
|
184
|
+
span_cost = None
|
|
185
|
+
if span_cost:
|
|
186
|
+
span_cost.span_rowid = db_span.id
|
|
187
|
+
span_cost.trace_rowid = db_span.trace_rowid
|
|
188
|
+
session.add(span_cost)
|
|
189
|
+
|
|
176
190
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
177
191
|
yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
|
|
178
192
|
|
|
@@ -372,14 +386,18 @@ class Subscription:
|
|
|
372
386
|
and not write_already_in_progress
|
|
373
387
|
):
|
|
374
388
|
result_payloads_stream = _chat_completion_result_payloads(
|
|
375
|
-
db=info.context.db,
|
|
389
|
+
db=info.context.db,
|
|
390
|
+
results=_drain_no_wait(results),
|
|
391
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
376
392
|
)
|
|
377
393
|
task = _create_task_with_timeout(result_payloads_stream)
|
|
378
394
|
in_progress.append((None, result_payloads_stream, task))
|
|
379
395
|
last_write_time = datetime.now()
|
|
380
396
|
if remaining_results := await _drain(results):
|
|
381
397
|
async for result_payload in _chat_completion_result_payloads(
|
|
382
|
-
db=info.context.db,
|
|
398
|
+
db=info.context.db,
|
|
399
|
+
results=remaining_results,
|
|
400
|
+
span_cost_calculator=info.context.span_cost_calculator,
|
|
383
401
|
):
|
|
384
402
|
yield result_payload
|
|
385
403
|
|
|
@@ -463,6 +481,7 @@ async def _chat_completion_result_payloads(
|
|
|
463
481
|
*,
|
|
464
482
|
db: DbSessionFactory,
|
|
465
483
|
results: Sequence[ChatCompletionResult],
|
|
484
|
+
span_cost_calculator: SpanCostCalculator,
|
|
466
485
|
) -> ChatStream:
|
|
467
486
|
if not results:
|
|
468
487
|
return
|
|
@@ -470,6 +489,19 @@ async def _chat_completion_result_payloads(
|
|
|
470
489
|
for _, span, run in results:
|
|
471
490
|
if span:
|
|
472
491
|
session.add(span)
|
|
492
|
+
await session.flush()
|
|
493
|
+
try:
|
|
494
|
+
span_cost = span_cost_calculator.calculate_cost(
|
|
495
|
+
start_time=span.start_time,
|
|
496
|
+
attributes=span.attributes,
|
|
497
|
+
)
|
|
498
|
+
except Exception as e:
|
|
499
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
500
|
+
span_cost = None
|
|
501
|
+
if span_cost:
|
|
502
|
+
span_cost.span_rowid = span.id
|
|
503
|
+
span_cost.trace_rowid = span.trace_rowid
|
|
504
|
+
session.add(span_cost)
|
|
473
505
|
session.add(run)
|
|
474
506
|
await session.flush()
|
|
475
507
|
for example_id, span, run in results:
|
|
@@ -594,3 +626,5 @@ LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
|
594
626
|
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
|
|
595
627
|
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
|
|
596
628
|
PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
629
|
+
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
|
|
630
|
+
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@strawberry.type
|
|
7
|
+
class CostBreakdown:
|
|
8
|
+
tokens: Optional[float] = None
|
|
9
|
+
cost: Optional[float] = None
|
|
10
|
+
|
|
11
|
+
@strawberry.field
|
|
12
|
+
def cost_per_token(self) -> Optional[float]:
|
|
13
|
+
if self.tokens and self.cost:
|
|
14
|
+
return self.cost / self.tokens
|
|
15
|
+
return None
|
|
@@ -2,8 +2,9 @@ from datetime import datetime
|
|
|
2
2
|
from typing import ClassVar, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
-
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy import func, select
|
|
6
6
|
from sqlalchemy.orm import joinedload
|
|
7
|
+
from sqlalchemy.sql.functions import coalesce
|
|
7
8
|
from strawberry import UNSET, Private
|
|
8
9
|
from strawberry.relay import Connection, Node, NodeID
|
|
9
10
|
from strawberry.scalars import JSON
|
|
@@ -11,6 +12,7 @@ from strawberry.types import Info
|
|
|
11
12
|
|
|
12
13
|
from phoenix.db import models
|
|
13
14
|
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
14
16
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
15
17
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
16
18
|
from phoenix.server.api.types.pagination import (
|
|
@@ -19,6 +21,8 @@ from phoenix.server.api.types.pagination import (
|
|
|
19
21
|
connection_from_list,
|
|
20
22
|
)
|
|
21
23
|
from phoenix.server.api.types.Project import Project
|
|
24
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
25
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
22
26
|
|
|
23
27
|
|
|
24
28
|
@strawberry.type
|
|
@@ -130,6 +134,60 @@ class Experiment(Node):
|
|
|
130
134
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
131
135
|
return info.context.last_updated_at.get(self._table, self.id_attr)
|
|
132
136
|
|
|
137
|
+
@strawberry.field
|
|
138
|
+
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
139
|
+
experiment_id = self.id_attr
|
|
140
|
+
summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(
|
|
141
|
+
experiment_id
|
|
142
|
+
)
|
|
143
|
+
return SpanCostSummary(
|
|
144
|
+
prompt=CostBreakdown(
|
|
145
|
+
tokens=summary.prompt.tokens,
|
|
146
|
+
cost=summary.prompt.cost,
|
|
147
|
+
),
|
|
148
|
+
completion=CostBreakdown(
|
|
149
|
+
tokens=summary.completion.tokens,
|
|
150
|
+
cost=summary.completion.cost,
|
|
151
|
+
),
|
|
152
|
+
total=CostBreakdown(
|
|
153
|
+
tokens=summary.total.tokens,
|
|
154
|
+
cost=summary.total.cost,
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
@strawberry.field
|
|
159
|
+
async def cost_detail_summary_entries(
|
|
160
|
+
self, info: Info[Context, None]
|
|
161
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
162
|
+
experiment_id = self.id_attr
|
|
163
|
+
|
|
164
|
+
stmt = (
|
|
165
|
+
select(
|
|
166
|
+
models.SpanCostDetail.token_type,
|
|
167
|
+
models.SpanCostDetail.is_prompt,
|
|
168
|
+
coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
|
|
169
|
+
coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
|
|
170
|
+
)
|
|
171
|
+
.select_from(models.SpanCostDetail)
|
|
172
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
173
|
+
.join(models.Span, models.SpanCost.span_rowid == models.Span.id)
|
|
174
|
+
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
175
|
+
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
176
|
+
.where(models.ExperimentRun.experiment_id == experiment_id)
|
|
177
|
+
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
async with info.context.db() as session:
|
|
181
|
+
data = await session.stream(stmt)
|
|
182
|
+
return [
|
|
183
|
+
SpanCostDetailSummaryEntry(
|
|
184
|
+
token_type=token_type,
|
|
185
|
+
is_prompt=is_prompt,
|
|
186
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
187
|
+
)
|
|
188
|
+
async for token_type, is_prompt, cost, tokens in data
|
|
189
|
+
]
|
|
190
|
+
|
|
133
191
|
|
|
134
192
|
def to_gql_experiment(
|
|
135
193
|
experiment: models.Experiment,
|
|
@@ -2,8 +2,9 @@ from datetime import datetime
|
|
|
2
2
|
from typing import TYPE_CHECKING, Annotated, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
-
from sqlalchemy import select
|
|
5
|
+
from sqlalchemy import func, select
|
|
6
6
|
from sqlalchemy.orm import load_only
|
|
7
|
+
from sqlalchemy.sql.functions import coalesce
|
|
7
8
|
from strawberry import UNSET
|
|
8
9
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
10
|
from strawberry.scalars import JSON
|
|
@@ -11,6 +12,7 @@ from strawberry.types import Info
|
|
|
11
12
|
|
|
12
13
|
from phoenix.db import models
|
|
13
14
|
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
14
16
|
from phoenix.server.api.types.ExperimentRunAnnotation import (
|
|
15
17
|
ExperimentRunAnnotation,
|
|
16
18
|
to_gql_experiment_run_annotation,
|
|
@@ -20,6 +22,8 @@ from phoenix.server.api.types.pagination import (
|
|
|
20
22
|
CursorString,
|
|
21
23
|
connection_from_list,
|
|
22
24
|
)
|
|
25
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
26
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
23
27
|
from phoenix.server.api.types.Trace import Trace
|
|
24
28
|
|
|
25
29
|
if TYPE_CHECKING:
|
|
@@ -98,6 +102,58 @@ class ExperimentRun(Node):
|
|
|
98
102
|
version_id=version_id,
|
|
99
103
|
)
|
|
100
104
|
|
|
105
|
+
@strawberry.field
|
|
106
|
+
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
107
|
+
run_id = self.id_attr
|
|
108
|
+
summary = await info.context.data_loaders.span_cost_summary_by_experiment_run.load(run_id)
|
|
109
|
+
return SpanCostSummary(
|
|
110
|
+
prompt=CostBreakdown(
|
|
111
|
+
tokens=summary.prompt.tokens,
|
|
112
|
+
cost=summary.prompt.cost,
|
|
113
|
+
),
|
|
114
|
+
completion=CostBreakdown(
|
|
115
|
+
tokens=summary.completion.tokens,
|
|
116
|
+
cost=summary.completion.cost,
|
|
117
|
+
),
|
|
118
|
+
total=CostBreakdown(
|
|
119
|
+
tokens=summary.total.tokens,
|
|
120
|
+
cost=summary.total.cost,
|
|
121
|
+
),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@strawberry.field
|
|
125
|
+
async def cost_detail_summary_entries(
|
|
126
|
+
self, info: Info[Context, None]
|
|
127
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
128
|
+
run_id = self.id_attr
|
|
129
|
+
|
|
130
|
+
stmt = (
|
|
131
|
+
select(
|
|
132
|
+
models.SpanCostDetail.token_type,
|
|
133
|
+
models.SpanCostDetail.is_prompt,
|
|
134
|
+
coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
|
|
135
|
+
coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
|
|
136
|
+
)
|
|
137
|
+
.select_from(models.SpanCostDetail)
|
|
138
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
139
|
+
.join(models.Span, models.SpanCost.span_rowid == models.Span.id)
|
|
140
|
+
.join(models.Trace, models.Span.trace_rowid == models.Trace.id)
|
|
141
|
+
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
142
|
+
.where(models.ExperimentRun.id == run_id)
|
|
143
|
+
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
async with info.context.db() as session:
|
|
147
|
+
data = await session.stream(stmt)
|
|
148
|
+
return [
|
|
149
|
+
SpanCostDetailSummaryEntry(
|
|
150
|
+
token_type=token_type,
|
|
151
|
+
is_prompt=is_prompt,
|
|
152
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
153
|
+
)
|
|
154
|
+
async for token_type, is_prompt, cost, tokens in data
|
|
155
|
+
]
|
|
156
|
+
|
|
101
157
|
|
|
102
158
|
def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
|
|
103
159
|
"""
|
|
@@ -109,9 +165,7 @@ def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
|
|
|
109
165
|
return ExperimentRun(
|
|
110
166
|
id_attr=run.id,
|
|
111
167
|
experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
|
|
112
|
-
trace_id=trace_id
|
|
113
|
-
if (trace := run.trace) and (trace_id := trace.trace_id) is not None
|
|
114
|
-
else None,
|
|
168
|
+
trace_id=run.trace.trace_id if run.trace else None,
|
|
115
169
|
output=run.output.get("task_output"),
|
|
116
170
|
start_time=run.start_time,
|
|
117
171
|
end_time=run.end_time,
|