arize-phoenix 11.32.1__py3-none-any.whl → 11.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/METADATA +1 -1
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/RECORD +57 -50
- phoenix/config.py +44 -0
- phoenix/db/bulk_inserter.py +111 -116
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/context.py +20 -0
- phoenix/server/api/dataloaders/__init__.py +20 -0
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +59 -0
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +39 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/helpers/playground_clients.py +4 -0
- phoenix/server/api/mutations/prompt_label_mutations.py +67 -58
- phoenix/server/api/queries.py +52 -37
- phoenix/server/api/routers/v1/documents.py +1 -1
- phoenix/server/api/routers/v1/evaluations.py +4 -4
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/spans.py +2 -2
- phoenix/server/api/routers/v1/traces.py +18 -3
- phoenix/server/api/types/DatasetExample.py +49 -1
- phoenix/server/api/types/Experiment.py +12 -2
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +146 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +12 -19
- phoenix/server/api/types/Prompt.py +11 -0
- phoenix/server/api/types/PromptLabel.py +2 -19
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/app.py +78 -20
- phoenix/server/cost_tracking/model_cost_manifest.json +1 -1
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/prometheus.py +30 -6
- phoenix/server/static/.vite/manifest.json +43 -43
- phoenix/server/static/assets/components-CdQiQTvs.js +5778 -0
- phoenix/server/static/assets/{index-D1FDMBMV.js → index-B1VuXYRI.js} +12 -21
- phoenix/server/static/assets/pages-CnfZ3RhB.js +9163 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-Cfrr9FCF.js +903 -0
- phoenix/server/static/assets/{vendor-arizeai-DsYDNOqt.js → vendor-arizeai-Dz0kN-lQ.js} +4 -4
- phoenix/server/static/assets/vendor-codemirror-ClqtONZQ.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-BTHn5Y2R.js → vendor-recharts-D6kvOpmb.js} +2 -2
- phoenix/server/static/assets/{vendor-shiki-BAcocHFl.js → vendor-shiki-xSOiKxt0.js} +1 -1
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +13 -0
- phoenix/trace/projects.py +1 -2
- phoenix/version.py +1 -1
- phoenix/server/static/assets/components-Cs9c4Nxp.js +0 -5698
- phoenix/server/static/assets/pages-Cbj9SjBx.js +0 -8928
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-RdRDaQiR.js +0 -905
- phoenix/server/static/assets/vendor-codemirror-BzJDUbEx.js +0 -25
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.32.1.dist-info → arize_phoenix-11.34.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -102,7 +102,7 @@ async def post_evaluations(
|
|
|
102
102
|
detail="Evaluation name must not be blank/empty",
|
|
103
103
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
104
104
|
)
|
|
105
|
-
await request.state.
|
|
105
|
+
await request.state.enqueue_evaluation(evaluation)
|
|
106
106
|
return Response()
|
|
107
107
|
|
|
108
108
|
|
|
@@ -221,7 +221,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
|
|
|
221
221
|
explanation=explanation,
|
|
222
222
|
metadata_={},
|
|
223
223
|
)
|
|
224
|
-
await state.
|
|
224
|
+
await state.enqueue_annotations(document_annotation)
|
|
225
225
|
elif len(names) == 1 and names[0] in ("context.span_id", "span_id"):
|
|
226
226
|
for index, row in dataframe.iterrows():
|
|
227
227
|
score, label, explanation = _get_annotation_result(row)
|
|
@@ -235,7 +235,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
|
|
|
235
235
|
explanation=explanation,
|
|
236
236
|
metadata_={},
|
|
237
237
|
)
|
|
238
|
-
await state.
|
|
238
|
+
await state.enqueue_annotations(span_annotation)
|
|
239
239
|
elif len(names) == 1 and names[0] in ("context.trace_id", "trace_id"):
|
|
240
240
|
for index, row in dataframe.iterrows():
|
|
241
241
|
score, label, explanation = _get_annotation_result(row)
|
|
@@ -249,7 +249,7 @@ async def _add_evaluations(state: State, evaluations: Evaluations) -> None:
|
|
|
249
249
|
explanation=explanation,
|
|
250
250
|
metadata_={},
|
|
251
251
|
)
|
|
252
|
-
await state.
|
|
252
|
+
await state.enqueue_annotations(trace_annotation)
|
|
253
253
|
|
|
254
254
|
|
|
255
255
|
def _get_annotation_result(
|
|
@@ -27,7 +27,7 @@ class ExperimentRun(V1RoutesBaseModel):
|
|
|
27
27
|
description="The ID of the dataset example used in the experiment run"
|
|
28
28
|
)
|
|
29
29
|
output: Any = Field(description="The output of the experiment task")
|
|
30
|
-
repetition_number: int = Field(description="The repetition number of the experiment run")
|
|
30
|
+
repetition_number: int = Field(description="The repetition number of the experiment run", gt=0)
|
|
31
31
|
start_time: datetime = Field(description="The start time of the experiment run")
|
|
32
32
|
end_time: datetime = Field(description="The end time of the experiment run")
|
|
33
33
|
trace_id: Optional[str] = Field(
|
|
@@ -46,7 +46,7 @@ class Experiment(V1RoutesBaseModel):
|
|
|
46
46
|
dataset_version_id: str = Field(
|
|
47
47
|
description="The ID of the dataset version associated with the experiment"
|
|
48
48
|
)
|
|
49
|
-
repetitions: int = Field(description="Number of times the experiment is repeated")
|
|
49
|
+
repetitions: int = Field(description="Number of times the experiment is repeated", gt=0)
|
|
50
50
|
metadata: dict[str, Any] = Field(description="Metadata of the experiment")
|
|
51
51
|
project_name: Optional[str] = Field(
|
|
52
52
|
description="The name of the project associated with the experiment"
|
|
@@ -897,7 +897,7 @@ async def annotate_spans(
|
|
|
897
897
|
)
|
|
898
898
|
precursors = [d.as_precursor(user_id=user_id) for d in filtered_span_annotations]
|
|
899
899
|
if not sync:
|
|
900
|
-
await request.state.
|
|
900
|
+
await request.state.enqueue_annotations(*precursors)
|
|
901
901
|
return AnnotateSpansResponseBody(data=[])
|
|
902
902
|
|
|
903
903
|
span_ids = {p.span_id for p in precursors}
|
|
@@ -1072,7 +1072,7 @@ async def create_spans(
|
|
|
1072
1072
|
|
|
1073
1073
|
# All spans are valid, queue them all
|
|
1074
1074
|
for span_for_insertion, project_name in spans_to_queue:
|
|
1075
|
-
await request.state.
|
|
1075
|
+
await request.state.enqueue_span(span_for_insertion, project_name)
|
|
1076
1076
|
|
|
1077
1077
|
return CreateSpansResponseBody(
|
|
1078
1078
|
total_received=total_received,
|
|
@@ -18,6 +18,7 @@ from starlette.status import (
|
|
|
18
18
|
HTTP_404_NOT_FOUND,
|
|
19
19
|
HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
|
20
20
|
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
21
|
+
HTTP_503_SERVICE_UNAVAILABLE,
|
|
21
22
|
)
|
|
22
23
|
from strawberry.relay import GlobalID
|
|
23
24
|
|
|
@@ -29,6 +30,7 @@ from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
|
29
30
|
from phoenix.server.authorization import is_not_locked
|
|
30
31
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
31
32
|
from phoenix.server.dml_event import SpanDeleteEvent, TraceAnnotationInsertEvent
|
|
33
|
+
from phoenix.server.prometheus import SPAN_QUEUE_REJECTIONS
|
|
32
34
|
from phoenix.trace.otel import decode_otlp_span
|
|
33
35
|
from phoenix.utilities.project import get_project_name
|
|
34
36
|
|
|
@@ -42,9 +44,18 @@ from .utils import (
|
|
|
42
44
|
router = APIRouter(tags=["traces"])
|
|
43
45
|
|
|
44
46
|
|
|
47
|
+
def is_not_at_capacity(request: Request) -> None:
|
|
48
|
+
if request.app.state.span_queue_is_full():
|
|
49
|
+
SPAN_QUEUE_REJECTIONS.inc()
|
|
50
|
+
raise HTTPException(
|
|
51
|
+
detail="Server is at capacity and cannot process more requests",
|
|
52
|
+
status_code=HTTP_503_SERVICE_UNAVAILABLE,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
45
56
|
@router.post(
|
|
46
57
|
"/traces",
|
|
47
|
-
dependencies=[Depends(is_not_locked)],
|
|
58
|
+
dependencies=[Depends(is_not_locked), Depends(is_not_at_capacity)],
|
|
48
59
|
operation_id="addTraces",
|
|
49
60
|
summary="Send traces",
|
|
50
61
|
responses=add_errors_to_responses(
|
|
@@ -56,6 +67,10 @@ router = APIRouter(tags=["traces"])
|
|
|
56
67
|
),
|
|
57
68
|
},
|
|
58
69
|
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
|
|
70
|
+
{
|
|
71
|
+
"status_code": HTTP_503_SERVICE_UNAVAILABLE,
|
|
72
|
+
"description": "Server is at capacity and cannot process more requests",
|
|
73
|
+
},
|
|
59
74
|
]
|
|
60
75
|
),
|
|
61
76
|
openapi_extra={
|
|
@@ -145,7 +160,7 @@ async def annotate_traces(
|
|
|
145
160
|
|
|
146
161
|
precursors = [d.as_precursor(user_id=user_id) for d in request_body.data]
|
|
147
162
|
if not sync:
|
|
148
|
-
await request.state.
|
|
163
|
+
await request.state.enqueue_annotations(*precursors)
|
|
149
164
|
return AnnotateTracesResponseBody(data=[])
|
|
150
165
|
|
|
151
166
|
trace_ids = {p.trace_id for p in precursors}
|
|
@@ -193,7 +208,7 @@ async def _add_spans(req: ExportTraceServiceRequest, state: State) -> None:
|
|
|
193
208
|
for scope_span in resource_spans.scope_spans:
|
|
194
209
|
for otlp_span in scope_span.spans:
|
|
195
210
|
span = await run_in_threadpool(decode_otlp_span, otlp_span)
|
|
196
|
-
await state.
|
|
211
|
+
await state.enqueue_span(span, project_name)
|
|
197
212
|
|
|
198
213
|
|
|
199
214
|
@router.delete(
|
|
@@ -10,8 +10,12 @@ from strawberry.types import Info
|
|
|
10
10
|
|
|
11
11
|
from phoenix.db import models
|
|
12
12
|
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
13
14
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
14
15
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
17
|
+
ExperimentRepeatedRunGroup,
|
|
18
|
+
)
|
|
15
19
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
16
20
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
17
21
|
from phoenix.server.api.types.pagination import (
|
|
@@ -65,6 +69,7 @@ class DatasetExample(Node):
|
|
|
65
69
|
last: Optional[int] = UNSET,
|
|
66
70
|
after: Optional[CursorString] = UNSET,
|
|
67
71
|
before: Optional[CursorString] = UNSET,
|
|
72
|
+
experiment_ids: Optional[list[GlobalID]] = UNSET,
|
|
68
73
|
) -> Connection[ExperimentRun]:
|
|
69
74
|
args = ConnectionArgs(
|
|
70
75
|
first=first,
|
|
@@ -78,8 +83,51 @@ class DatasetExample(Node):
|
|
|
78
83
|
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
79
84
|
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
80
85
|
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
81
|
-
.order_by(
|
|
86
|
+
.order_by(
|
|
87
|
+
models.ExperimentRun.experiment_id.asc(),
|
|
88
|
+
models.ExperimentRun.repetition_number.asc(),
|
|
89
|
+
)
|
|
82
90
|
)
|
|
91
|
+
if experiment_ids:
|
|
92
|
+
experiment_db_ids = [
|
|
93
|
+
from_global_id_with_expected_type(
|
|
94
|
+
global_id=experiment_id,
|
|
95
|
+
expected_type_name=models.Experiment.__name__,
|
|
96
|
+
)
|
|
97
|
+
for experiment_id in experiment_ids or []
|
|
98
|
+
]
|
|
99
|
+
query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
|
|
83
100
|
async with info.context.db() as session:
|
|
84
101
|
runs = (await session.scalars(query)).all()
|
|
85
102
|
return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
|
|
103
|
+
|
|
104
|
+
@strawberry.field
|
|
105
|
+
async def experiment_repeated_run_groups(
|
|
106
|
+
self,
|
|
107
|
+
info: Info[Context, None],
|
|
108
|
+
experiment_ids: list[GlobalID],
|
|
109
|
+
) -> list[ExperimentRepeatedRunGroup]:
|
|
110
|
+
example_rowid = self.id_attr
|
|
111
|
+
experiment_rowids = []
|
|
112
|
+
for experiment_id in experiment_ids:
|
|
113
|
+
try:
|
|
114
|
+
experiment_rowid = from_global_id_with_expected_type(
|
|
115
|
+
global_id=experiment_id,
|
|
116
|
+
expected_type_name=models.Experiment.__name__,
|
|
117
|
+
)
|
|
118
|
+
except Exception:
|
|
119
|
+
raise BadRequest(f"Invalid experiment ID: {experiment_id}")
|
|
120
|
+
experiment_rowids.append(experiment_rowid)
|
|
121
|
+
repeated_run_groups = (
|
|
122
|
+
await info.context.data_loaders.experiment_repeated_run_groups.load_many(
|
|
123
|
+
[(experiment_rowid, example_rowid) for experiment_rowid in experiment_rowids]
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
return [
|
|
127
|
+
ExperimentRepeatedRunGroup(
|
|
128
|
+
experiment_rowid=group.experiment_rowid,
|
|
129
|
+
dataset_example_rowid=group.dataset_example_rowid,
|
|
130
|
+
runs=[to_gql_experiment_run(run) for run in group.runs],
|
|
131
|
+
)
|
|
132
|
+
for group in repeated_run_groups
|
|
133
|
+
]
|
|
@@ -5,13 +5,14 @@ import strawberry
|
|
|
5
5
|
from sqlalchemy import func, select
|
|
6
6
|
from sqlalchemy.orm import joinedload
|
|
7
7
|
from strawberry import UNSET, Private
|
|
8
|
-
from strawberry.relay import Connection, Node, NodeID
|
|
8
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
9
|
from strawberry.scalars import JSON
|
|
10
10
|
from strawberry.types import Info
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
13
|
from phoenix.server.api.context import Context
|
|
14
14
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
15
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
15
16
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
16
17
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
17
18
|
from phoenix.server.api.types.pagination import (
|
|
@@ -32,6 +33,7 @@ class Experiment(Node):
|
|
|
32
33
|
name: str
|
|
33
34
|
project_name: Optional[str]
|
|
34
35
|
description: Optional[str]
|
|
36
|
+
dataset_version_id: GlobalID
|
|
35
37
|
metadata: JSON
|
|
36
38
|
created_at: datetime
|
|
37
39
|
updated_at: datetime
|
|
@@ -71,7 +73,10 @@ class Experiment(Node):
|
|
|
71
73
|
await session.scalars(
|
|
72
74
|
select(models.ExperimentRun)
|
|
73
75
|
.where(models.ExperimentRun.experiment_id == experiment_id)
|
|
74
|
-
.order_by(
|
|
76
|
+
.order_by(
|
|
77
|
+
models.ExperimentRun.dataset_example_id.asc(),
|
|
78
|
+
models.ExperimentRun.repetition_number.asc(),
|
|
79
|
+
)
|
|
75
80
|
.options(
|
|
76
81
|
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
77
82
|
)
|
|
@@ -187,6 +192,10 @@ class Experiment(Node):
|
|
|
187
192
|
async for token_type, is_prompt, cost, tokens in data
|
|
188
193
|
]
|
|
189
194
|
|
|
195
|
+
@strawberry.field
|
|
196
|
+
async def repetition_count(self, info: Info[Context, None]) -> int:
|
|
197
|
+
return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
|
|
198
|
+
|
|
190
199
|
|
|
191
200
|
def to_gql_experiment(
|
|
192
201
|
experiment: models.Experiment,
|
|
@@ -201,6 +210,7 @@ def to_gql_experiment(
|
|
|
201
210
|
name=experiment.name,
|
|
202
211
|
project_name=experiment.project_name,
|
|
203
212
|
description=experiment.description,
|
|
213
|
+
dataset_version_id=GlobalID(DatasetVersion.__name__, str(experiment.dataset_version_id)),
|
|
204
214
|
metadata=experiment.metadata_,
|
|
205
215
|
created_at=experiment.created_at,
|
|
206
216
|
updated_at=experiment.updated_at,
|
|
@@ -1,18 +1,12 @@
|
|
|
1
1
|
import strawberry
|
|
2
|
-
from strawberry.relay import
|
|
2
|
+
from strawberry.relay import Node, NodeID
|
|
3
3
|
|
|
4
4
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
5
|
-
from phoenix.server.api.types.
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
@strawberry.type
|
|
9
|
-
class RunComparisonItem:
|
|
10
|
-
experiment_id: GlobalID
|
|
11
|
-
runs: list[ExperimentRun]
|
|
5
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroup import ExperimentRepeatedRunGroup
|
|
12
6
|
|
|
13
7
|
|
|
14
8
|
@strawberry.type
|
|
15
9
|
class ExperimentComparison(Node):
|
|
16
10
|
id_attr: NodeID[int]
|
|
17
11
|
example: DatasetExample
|
|
18
|
-
|
|
12
|
+
repeated_run_groups: list[ExperimentRepeatedRunGroup]
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from base64 import b64decode
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import func, select
|
|
7
|
+
from strawberry.relay import GlobalID, Node
|
|
8
|
+
from strawberry.types import Info
|
|
9
|
+
from typing_extensions import Self, TypeAlias
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
14
|
+
from phoenix.server.api.types.ExperimentRepeatedRunGroupAnnotationSummary import (
|
|
15
|
+
ExperimentRepeatedRunGroupAnnotationSummary,
|
|
16
|
+
)
|
|
17
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun
|
|
18
|
+
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
19
|
+
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
20
|
+
|
|
21
|
+
ExperimentRowId: TypeAlias = int
|
|
22
|
+
DatasetExampleRowId: TypeAlias = int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class ExperimentRepeatedRunGroup(Node):
|
|
27
|
+
experiment_rowid: strawberry.Private[ExperimentRowId]
|
|
28
|
+
dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
|
|
29
|
+
runs: list[ExperimentRun]
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def resolve_id(
|
|
33
|
+
cls,
|
|
34
|
+
root: Self,
|
|
35
|
+
*,
|
|
36
|
+
info: Info,
|
|
37
|
+
) -> str:
|
|
38
|
+
return (
|
|
39
|
+
f"experiment_id={root.experiment_rowid}:dataset_example_id={root.dataset_example_rowid}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@strawberry.field
|
|
43
|
+
def experiment_id(self) -> strawberry.ID:
|
|
44
|
+
from phoenix.server.api.types.Experiment import Experiment
|
|
45
|
+
|
|
46
|
+
return strawberry.ID(str(GlobalID(Experiment.__name__, str(self.experiment_rowid))))
|
|
47
|
+
|
|
48
|
+
@strawberry.field
|
|
49
|
+
async def average_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
50
|
+
return await info.context.data_loaders.average_experiment_repeated_run_group_latency.load(
|
|
51
|
+
(self.experiment_rowid, self.dataset_example_rowid)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@strawberry.field
|
|
55
|
+
async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
|
|
56
|
+
experiment_id = self.experiment_rowid
|
|
57
|
+
example_id = self.dataset_example_rowid
|
|
58
|
+
summary = (
|
|
59
|
+
await info.context.data_loaders.span_cost_summary_by_experiment_repeated_run_group.load(
|
|
60
|
+
(experiment_id, example_id)
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
return SpanCostSummary(
|
|
64
|
+
prompt=CostBreakdown(
|
|
65
|
+
tokens=summary.prompt.tokens,
|
|
66
|
+
cost=summary.prompt.cost,
|
|
67
|
+
),
|
|
68
|
+
completion=CostBreakdown(
|
|
69
|
+
tokens=summary.completion.tokens,
|
|
70
|
+
cost=summary.completion.cost,
|
|
71
|
+
),
|
|
72
|
+
total=CostBreakdown(
|
|
73
|
+
tokens=summary.total.tokens,
|
|
74
|
+
cost=summary.total.cost,
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
@strawberry.field
|
|
79
|
+
async def cost_detail_summary_entries(
|
|
80
|
+
self, info: Info[Context, None]
|
|
81
|
+
) -> list[SpanCostDetailSummaryEntry]:
|
|
82
|
+
experiment_id = self.experiment_rowid
|
|
83
|
+
example_id = self.dataset_example_rowid
|
|
84
|
+
stmt = (
|
|
85
|
+
select(
|
|
86
|
+
models.SpanCostDetail.token_type,
|
|
87
|
+
models.SpanCostDetail.is_prompt,
|
|
88
|
+
func.sum(models.SpanCostDetail.cost).label("cost"),
|
|
89
|
+
func.sum(models.SpanCostDetail.tokens).label("tokens"),
|
|
90
|
+
)
|
|
91
|
+
.select_from(models.SpanCostDetail)
|
|
92
|
+
.join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
|
|
93
|
+
.join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
|
|
94
|
+
.join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
|
|
95
|
+
.where(models.ExperimentRun.experiment_id == experiment_id)
|
|
96
|
+
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
97
|
+
.group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async with info.context.db() as session:
|
|
101
|
+
data = await session.stream(stmt)
|
|
102
|
+
return [
|
|
103
|
+
SpanCostDetailSummaryEntry(
|
|
104
|
+
token_type=token_type,
|
|
105
|
+
is_prompt=is_prompt,
|
|
106
|
+
value=CostBreakdown(tokens=tokens, cost=cost),
|
|
107
|
+
)
|
|
108
|
+
async for token_type, is_prompt, cost, tokens in data
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
@strawberry.field
|
|
112
|
+
async def annotation_summaries(
|
|
113
|
+
self,
|
|
114
|
+
info: Info[Context, None],
|
|
115
|
+
) -> list[ExperimentRepeatedRunGroupAnnotationSummary]:
|
|
116
|
+
loader = info.context.data_loaders.experiment_repeated_run_group_annotation_summaries
|
|
117
|
+
summaries = await loader.load((self.experiment_rowid, self.dataset_example_rowid))
|
|
118
|
+
return [
|
|
119
|
+
ExperimentRepeatedRunGroupAnnotationSummary(
|
|
120
|
+
annotation_name=summary.annotation_name,
|
|
121
|
+
mean_score=summary.mean_score,
|
|
122
|
+
)
|
|
123
|
+
for summary in summaries
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN = re.compile(
|
|
128
|
+
r"ExperimentRepeatedRunGroup:experiment_id=(\d+):dataset_example_id=(\d+)"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def parse_experiment_repeated_run_group_node_id(
|
|
133
|
+
node_id: str,
|
|
134
|
+
) -> tuple[ExperimentRowId, DatasetExampleRowId]:
|
|
135
|
+
decoded_node_id = _base64_decode(node_id)
|
|
136
|
+
match = re.match(_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN, decoded_node_id)
|
|
137
|
+
if not match:
|
|
138
|
+
raise ValueError(f"Invalid node ID format: {node_id}")
|
|
139
|
+
|
|
140
|
+
experiment_id = int(match.group(1))
|
|
141
|
+
dataset_example_id = int(match.group(2))
|
|
142
|
+
return experiment_id, dataset_example_id
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _base64_decode(string: str) -> str:
|
|
146
|
+
return b64decode(string.encode()).decode()
|
|
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Annotated, Optional
|
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import func, select
|
|
6
|
-
from sqlalchemy.orm import load_only
|
|
7
6
|
from sqlalchemy.sql.functions import coalesce
|
|
8
7
|
from strawberry import UNSET
|
|
9
8
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
@@ -34,12 +33,17 @@ if TYPE_CHECKING:
|
|
|
34
33
|
class ExperimentRun(Node):
|
|
35
34
|
id_attr: NodeID[int]
|
|
36
35
|
experiment_id: GlobalID
|
|
36
|
+
repetition_number: int
|
|
37
37
|
trace_id: Optional[str]
|
|
38
38
|
output: Optional[JSON]
|
|
39
39
|
start_time: datetime
|
|
40
40
|
end_time: datetime
|
|
41
41
|
error: Optional[str]
|
|
42
42
|
|
|
43
|
+
@strawberry.field
|
|
44
|
+
def latency_ms(self) -> float:
|
|
45
|
+
return (self.end_time - self.start_time).total_seconds() * 1000
|
|
46
|
+
|
|
43
47
|
@strawberry.field
|
|
44
48
|
async def annotations(
|
|
45
49
|
self,
|
|
@@ -78,24 +82,12 @@ class ExperimentRun(Node):
|
|
|
78
82
|
]: # use lazy types to avoid circular import: https://strawberry.rocks/docs/types/lazy
|
|
79
83
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
80
84
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
models.DatasetExample,
|
|
88
|
-
models.DatasetExample.id == models.ExperimentRun.dataset_example_id,
|
|
89
|
-
)
|
|
90
|
-
.join(
|
|
91
|
-
models.Experiment,
|
|
92
|
-
models.Experiment.id == models.ExperimentRun.experiment_id,
|
|
93
|
-
)
|
|
94
|
-
.where(models.ExperimentRun.id == self.id_attr)
|
|
95
|
-
.options(load_only(models.DatasetExample.id, models.DatasetExample.created_at))
|
|
96
|
-
)
|
|
97
|
-
) is not None
|
|
98
|
-
example, version_id = result.first()
|
|
85
|
+
(
|
|
86
|
+
example,
|
|
87
|
+
version_id,
|
|
88
|
+
) = await info.context.data_loaders.dataset_examples_and_versions_by_experiment_run.load(
|
|
89
|
+
self.id_attr
|
|
90
|
+
)
|
|
99
91
|
return DatasetExample(
|
|
100
92
|
id_attr=example.id,
|
|
101
93
|
created_at=example.created_at,
|
|
@@ -165,6 +157,7 @@ def to_gql_experiment_run(run: models.ExperimentRun) -> ExperimentRun:
|
|
|
165
157
|
return ExperimentRun(
|
|
166
158
|
id_attr=run.id,
|
|
167
159
|
experiment_id=GlobalID(Experiment.__name__, str(run.experiment_id)),
|
|
160
|
+
repetition_number=run.repetition_number,
|
|
168
161
|
trace_id=run.trace.trace_id if run.trace else None,
|
|
169
162
|
output=run.output.get("task_output"),
|
|
170
163
|
start_time=run.start_time,
|
|
@@ -19,6 +19,7 @@ from phoenix.server.api.types.pagination import (
|
|
|
19
19
|
connection_from_list,
|
|
20
20
|
)
|
|
21
21
|
|
|
22
|
+
from .PromptLabel import PromptLabel, to_gql_prompt_label
|
|
22
23
|
from .PromptVersion import (
|
|
23
24
|
PromptVersion,
|
|
24
25
|
to_gql_prompt_version,
|
|
@@ -116,6 +117,16 @@ class Prompt(Node):
|
|
|
116
117
|
raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
|
|
117
118
|
return to_gql_prompt_from_orm(source_prompt)
|
|
118
119
|
|
|
120
|
+
@strawberry.field
|
|
121
|
+
async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
|
|
122
|
+
async with info.context.db() as session:
|
|
123
|
+
labels = await session.scalars(
|
|
124
|
+
select(models.PromptLabel)
|
|
125
|
+
.join(models.PromptPromptLabel)
|
|
126
|
+
.where(models.PromptPromptLabel.prompt_id == self.id_attr)
|
|
127
|
+
)
|
|
128
|
+
return [to_gql_prompt_label(label) for label in labels]
|
|
129
|
+
|
|
119
130
|
|
|
120
131
|
def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
|
|
121
132
|
if not orm_model.source_prompt_id:
|
|
@@ -1,14 +1,10 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
|
-
from sqlalchemy import select
|
|
5
4
|
from strawberry.relay import Node, NodeID
|
|
6
|
-
from strawberry.types import Info
|
|
7
5
|
|
|
8
6
|
from phoenix.db import models
|
|
9
|
-
from phoenix.server.api.context import Context
|
|
10
7
|
from phoenix.server.api.types.Identifier import Identifier
|
|
11
|
-
from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
|
|
12
8
|
|
|
13
9
|
|
|
14
10
|
@strawberry.type
|
|
@@ -16,21 +12,7 @@ class PromptLabel(Node):
|
|
|
16
12
|
id_attr: NodeID[int]
|
|
17
13
|
name: Identifier
|
|
18
14
|
description: Optional[str] = None
|
|
19
|
-
|
|
20
|
-
@strawberry.field
|
|
21
|
-
async def prompts(self, info: Info[Context, None]) -> list[Prompt]:
|
|
22
|
-
async with info.context.db() as session:
|
|
23
|
-
statement = (
|
|
24
|
-
select(models.Prompt)
|
|
25
|
-
.join(
|
|
26
|
-
models.PromptPromptLabel, models.Prompt.id == models.PromptPromptLabel.prompt_id
|
|
27
|
-
)
|
|
28
|
-
.where(models.PromptPromptLabel.prompt_label_id == self.id_attr)
|
|
29
|
-
)
|
|
30
|
-
return [
|
|
31
|
-
to_gql_prompt_from_orm(prompt_orm)
|
|
32
|
-
async for prompt_orm in await session.stream_scalars(statement)
|
|
33
|
-
]
|
|
15
|
+
color: str
|
|
34
16
|
|
|
35
17
|
|
|
36
18
|
def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
|
|
@@ -38,4 +20,5 @@ def to_gql_prompt_label(label_orm: models.PromptLabel) -> PromptLabel:
|
|
|
38
20
|
id_attr=label_orm.id,
|
|
39
21
|
name=Identifier(label_orm.name),
|
|
40
22
|
description=label_orm.description,
|
|
23
|
+
color=label_orm.color,
|
|
41
24
|
)
|
phoenix/server/api/types/node.py
CHANGED
|
@@ -1,5 +1,15 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from base64 import b64decode
|
|
3
|
+
|
|
1
4
|
from strawberry.relay import GlobalID
|
|
2
5
|
|
|
6
|
+
_GLOBAL_ID_PATTERN = re.compile(r"[a-zA-Z]+:[0-9]+")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_global_id(node_id: str) -> bool:
|
|
10
|
+
decoded_node_id = b64decode(node_id).decode()
|
|
11
|
+
return _GLOBAL_ID_PATTERN.match(decoded_node_id) is not None
|
|
12
|
+
|
|
3
13
|
|
|
4
14
|
def from_global_id(global_id: GlobalID) -> tuple[str, int]:
|
|
5
15
|
"""
|