arize-phoenix 11.37.0__py3-none-any.whl → 12.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +74 -53
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/types.py +30 -0
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +285 -46
- phoenix/server/api/context.py +13 -2
- phoenix/server/api/dataloaders/__init__.py +6 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +65 -35
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +2 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +24 -9
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +32 -0
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +25 -7
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/DatasetExample.py +11 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +6 -2
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-CFzdBkk_.js → components-Dl9SUw1U.js} +371 -327
- phoenix/server/static/assets/{index-DayUA9lQ.js → index-CqQS0dTo.js} +2 -2
- phoenix/server/static/assets/{pages-CvUhOO9h.js → pages-DKSjVA_E.js} +771 -518
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
7
|
+
from pydantic import Field
|
|
8
|
+
from sqlalchemy import select
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
from starlette.status import HTTP_404_NOT_FOUND
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
14
|
+
from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
|
|
15
|
+
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
16
|
+
from phoenix.server.authorization import is_not_locked
|
|
17
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
18
|
+
|
|
19
|
+
from .annotations import SessionAnnotationData
|
|
20
|
+
from .utils import RequestBody, ResponseBody, add_errors_to_responses
|
|
21
|
+
|
|
22
|
+
router = APIRouter(tags=["sessions"])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InsertedSessionAnnotation(V1RoutesBaseModel):
|
|
26
|
+
id: str = Field(description="The ID of the inserted session annotation")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@router.post(
|
|
38
|
+
"/session_annotations",
|
|
39
|
+
dependencies=[Depends(is_not_locked)],
|
|
40
|
+
operation_id="annotateSessions",
|
|
41
|
+
summary="Create session annotations",
|
|
42
|
+
responses=add_errors_to_responses(
|
|
43
|
+
[{"status_code": HTTP_404_NOT_FOUND, "description": "Session not found"}]
|
|
44
|
+
),
|
|
45
|
+
response_description="Session annotations inserted successfully",
|
|
46
|
+
include_in_schema=True,
|
|
47
|
+
)
|
|
48
|
+
async def annotate_sessions(
|
|
49
|
+
request: Request,
|
|
50
|
+
request_body: AnnotateSessionsRequestBody,
|
|
51
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
52
|
+
) -> AnnotateSessionsResponseBody:
|
|
53
|
+
if not request_body.data:
|
|
54
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
55
|
+
|
|
56
|
+
user_id: Optional[int] = None
|
|
57
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
58
|
+
user_id = int(request.user.identity)
|
|
59
|
+
|
|
60
|
+
session_annotations = request_body.data
|
|
61
|
+
filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
|
|
62
|
+
if len(filtered_session_annotations) != len(session_annotations):
|
|
63
|
+
warnings.warn(
|
|
64
|
+
(
|
|
65
|
+
"Session annotations with the name 'note' are not supported in this endpoint. "
|
|
66
|
+
"They will be ignored."
|
|
67
|
+
),
|
|
68
|
+
UserWarning,
|
|
69
|
+
)
|
|
70
|
+
precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
|
|
71
|
+
if not sync:
|
|
72
|
+
await request.state.enqueue_annotations(*precursors)
|
|
73
|
+
return AnnotateSessionsResponseBody(data=[])
|
|
74
|
+
|
|
75
|
+
session_ids = {p.session_id for p in precursors}
|
|
76
|
+
async with request.app.state.db() as session:
|
|
77
|
+
existing_sessions = {
|
|
78
|
+
session_id: rowid
|
|
79
|
+
async for session_id, rowid in await session.stream(
|
|
80
|
+
select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
|
|
81
|
+
models.ProjectSession.session_id.in_(session_ids)
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
missing_session_ids = session_ids - set(existing_sessions.keys())
|
|
87
|
+
# We prefer to fail the entire operation if there are missing sessions in sync mode
|
|
88
|
+
if missing_session_ids:
|
|
89
|
+
raise HTTPException(
|
|
90
|
+
detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
|
|
91
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
async with request.app.state.db() as session:
|
|
95
|
+
inserted_ids = []
|
|
96
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
97
|
+
for p in precursors:
|
|
98
|
+
values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
|
|
99
|
+
session_annotation_id = await session.scalar(
|
|
100
|
+
insert_on_conflict(
|
|
101
|
+
values,
|
|
102
|
+
dialect=dialect,
|
|
103
|
+
table=models.ProjectSessionAnnotation,
|
|
104
|
+
unique_by=("name", "project_session_id", "identifier"),
|
|
105
|
+
).returning(models.ProjectSessionAnnotation.id)
|
|
106
|
+
)
|
|
107
|
+
inserted_ids.append(session_annotation_id)
|
|
108
|
+
|
|
109
|
+
return AnnotateSessionsResponseBody(
|
|
110
|
+
data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
|
|
111
|
+
)
|
|
@@ -144,12 +144,11 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
|
|
|
144
144
|
responses=add_errors_to_responses(
|
|
145
145
|
[{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
|
|
146
146
|
),
|
|
147
|
-
include_in_schema=False,
|
|
148
147
|
)
|
|
149
148
|
async def annotate_traces(
|
|
150
149
|
request: Request,
|
|
151
150
|
request_body: AnnotateTracesRequestBody,
|
|
152
|
-
sync: bool = Query(default=
|
|
151
|
+
sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
|
|
153
152
|
) -> AnnotateTracesResponseBody:
|
|
154
153
|
if not request_body.data:
|
|
155
154
|
return AnnotateTracesResponseBody(data=[])
|
|
@@ -217,6 +217,13 @@ async def create_user(
|
|
|
217
217
|
detail="Cannot create users with SYSTEM role",
|
|
218
218
|
)
|
|
219
219
|
|
|
220
|
+
# TODO: Implement VIEWER role
|
|
221
|
+
if role == "VIEWER":
|
|
222
|
+
raise HTTPException(
|
|
223
|
+
status_code=HTTP_400_BAD_REQUEST,
|
|
224
|
+
detail="VIEWER role not yet implemented",
|
|
225
|
+
)
|
|
226
|
+
|
|
220
227
|
user: models.User
|
|
221
228
|
if isinstance(user_data, LocalUserData):
|
|
222
229
|
password = (user_data.password or secrets.token_hex()).strip()
|
|
@@ -26,6 +26,7 @@ from typing_extensions import TypeAlias, assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
+
from phoenix.db.helpers import insert_experiment_with_examples_snapshot
|
|
29
30
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
30
31
|
from phoenix.server.api.context import Context
|
|
31
32
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
@@ -43,6 +44,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
43
44
|
get_db_trace,
|
|
44
45
|
streaming_llm_span,
|
|
45
46
|
)
|
|
47
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
46
48
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
47
49
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
48
50
|
ChatCompletionInput,
|
|
@@ -302,18 +304,19 @@ class Subscription:
|
|
|
302
304
|
description="Traces from prompt playground",
|
|
303
305
|
)
|
|
304
306
|
)
|
|
307
|
+
user_id = get_user(info)
|
|
305
308
|
experiment = models.Experiment(
|
|
306
309
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
307
310
|
dataset_version_id=resolved_version_id,
|
|
308
311
|
name=input.experiment_name
|
|
309
312
|
or _default_playground_experiment_name(input.prompt_name),
|
|
310
313
|
description=input.experiment_description,
|
|
311
|
-
repetitions=
|
|
314
|
+
repetitions=input.repetitions,
|
|
312
315
|
metadata_=input.experiment_metadata or dict(),
|
|
313
316
|
project_name=project_name,
|
|
317
|
+
user_id=user_id,
|
|
314
318
|
)
|
|
315
|
-
session
|
|
316
|
-
await session.flush()
|
|
319
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
317
320
|
yield ChatCompletionSubscriptionExperiment(
|
|
318
321
|
experiment=to_gql_experiment(experiment)
|
|
319
322
|
) # eagerly yields experiment so it can be linked by consumers of the subscription
|
|
@@ -327,11 +330,13 @@ class Subscription:
|
|
|
327
330
|
llm_client=llm_client,
|
|
328
331
|
revision=revision,
|
|
329
332
|
results=results,
|
|
333
|
+
repetition_number=repetition_number,
|
|
330
334
|
experiment_id=experiment.id,
|
|
331
335
|
project_id=playground_project_id,
|
|
332
336
|
),
|
|
333
337
|
)
|
|
334
338
|
for revision in revisions
|
|
339
|
+
for repetition_number in range(1, input.repetitions + 1)
|
|
335
340
|
]
|
|
336
341
|
in_progress: list[
|
|
337
342
|
tuple[
|
|
@@ -409,6 +414,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
409
414
|
input: ChatCompletionOverDatasetInput,
|
|
410
415
|
llm_client: PlaygroundStreamingClient,
|
|
411
416
|
revision: models.DatasetExampleRevision,
|
|
417
|
+
repetition_number: int,
|
|
412
418
|
results: asyncio.Queue[ChatCompletionResult],
|
|
413
419
|
experiment_id: int,
|
|
414
420
|
project_id: int,
|
|
@@ -435,7 +441,11 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
435
441
|
)
|
|
436
442
|
except TemplateFormatterError as error:
|
|
437
443
|
format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
|
|
438
|
-
yield ChatCompletionSubscriptionError(
|
|
444
|
+
yield ChatCompletionSubscriptionError(
|
|
445
|
+
message=str(error),
|
|
446
|
+
dataset_example_id=example_id,
|
|
447
|
+
repetition_number=repetition_number,
|
|
448
|
+
)
|
|
439
449
|
await results.put(
|
|
440
450
|
(
|
|
441
451
|
example_id,
|
|
@@ -445,7 +455,7 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
445
455
|
dataset_example_id=revision.dataset_example_id,
|
|
446
456
|
trace_id=None,
|
|
447
457
|
output={},
|
|
448
|
-
repetition_number=
|
|
458
|
+
repetition_number=repetition_number,
|
|
449
459
|
start_time=format_start_time,
|
|
450
460
|
end_time=format_end_time,
|
|
451
461
|
error=str(error),
|
|
@@ -465,17 +475,24 @@ async def _stream_chat_completion_over_dataset_example(
|
|
|
465
475
|
):
|
|
466
476
|
span.add_response_chunk(chunk)
|
|
467
477
|
chunk.dataset_example_id = example_id
|
|
478
|
+
chunk.repetition_number = repetition_number
|
|
468
479
|
yield chunk
|
|
469
480
|
span.set_attributes(llm_client.attributes)
|
|
470
481
|
db_trace = get_db_trace(span, project_id)
|
|
471
482
|
db_span = get_db_span(span, db_trace)
|
|
472
483
|
db_run = get_db_experiment_run(
|
|
473
|
-
db_span,
|
|
484
|
+
db_span,
|
|
485
|
+
db_trace,
|
|
486
|
+
experiment_id=experiment_id,
|
|
487
|
+
example_id=revision.dataset_example_id,
|
|
488
|
+
repetition_number=repetition_number,
|
|
474
489
|
)
|
|
475
490
|
await results.put((example_id, db_span, db_run))
|
|
476
491
|
if span.status_message is not None:
|
|
477
492
|
yield ChatCompletionSubscriptionError(
|
|
478
|
-
message=span.status_message,
|
|
493
|
+
message=span.status_message,
|
|
494
|
+
dataset_example_id=example_id,
|
|
495
|
+
repetition_number=repetition_number,
|
|
479
496
|
)
|
|
480
497
|
|
|
481
498
|
|
|
@@ -511,6 +528,7 @@ async def _chat_completion_result_payloads(
|
|
|
511
528
|
span=Span(span_rowid=span.id, db_span=span) if span else None,
|
|
512
529
|
experiment_run=to_gql_experiment_run(run),
|
|
513
530
|
dataset_example_id=example_id,
|
|
531
|
+
repetition_number=run.repetition_number,
|
|
514
532
|
)
|
|
515
533
|
|
|
516
534
|
|
|
@@ -12,6 +12,7 @@ from phoenix.db import models
|
|
|
12
12
|
from phoenix.server.api.context import Context
|
|
13
13
|
from phoenix.server.api.exceptions import BadRequest
|
|
14
14
|
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
15
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
15
16
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
17
|
from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
|
|
17
18
|
ExperimentRepeatedRunGroup,
|
|
@@ -131,3 +132,13 @@ class DatasetExample(Node):
|
|
|
131
132
|
)
|
|
132
133
|
for group in repeated_run_groups
|
|
133
134
|
]
|
|
135
|
+
|
|
136
|
+
@strawberry.field
|
|
137
|
+
async def dataset_splits(
|
|
138
|
+
self,
|
|
139
|
+
info: Info[Context, None],
|
|
140
|
+
) -> list[DatasetSplit]:
|
|
141
|
+
return [
|
|
142
|
+
to_gql_dataset_split(split)
|
|
143
|
+
for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
|
|
144
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import ClassVar, Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@strawberry.type
|
|
12
|
+
class DatasetSplit(Node):
|
|
13
|
+
_table: ClassVar[type[models.Base]] = models.DatasetSplit
|
|
14
|
+
id_attr: NodeID[int]
|
|
15
|
+
name: str
|
|
16
|
+
description: Optional[str]
|
|
17
|
+
metadata: JSON
|
|
18
|
+
color: str
|
|
19
|
+
created_at: datetime
|
|
20
|
+
updated_at: datetime
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
|
|
24
|
+
return DatasetSplit(
|
|
25
|
+
id_attr=dataset_split.id,
|
|
26
|
+
name=dataset_split.name,
|
|
27
|
+
description=dataset_split.description,
|
|
28
|
+
color=dataset_split.color or "#ffffff",
|
|
29
|
+
metadata=dataset_split.metadata_,
|
|
30
|
+
created_at=dataset_split.created_at,
|
|
31
|
+
updated_at=dataset_split.updated_at,
|
|
32
|
+
)
|
|
@@ -193,10 +193,6 @@ class Experiment(Node):
|
|
|
193
193
|
async for token_type, is_prompt, cost, tokens in data
|
|
194
194
|
]
|
|
195
195
|
|
|
196
|
-
@strawberry.field
|
|
197
|
-
async def repetition_count(self, info: Info[Context, None]) -> int:
|
|
198
|
-
return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
|
|
199
|
-
|
|
200
196
|
|
|
201
197
|
def to_gql_experiment(
|
|
202
198
|
experiment: models.Experiment,
|
|
@@ -588,6 +588,22 @@ class Project(Node):
|
|
|
588
588
|
async with info.context.db() as session:
|
|
589
589
|
return list(await session.scalars(stmt))
|
|
590
590
|
|
|
591
|
+
@strawberry.field(
|
|
592
|
+
description="Names of all available annotations for sessions. "
|
|
593
|
+
"(The list contains no duplicates.)"
|
|
594
|
+
) # type: ignore
|
|
595
|
+
async def session_annotation_names(
|
|
596
|
+
self,
|
|
597
|
+
info: Info[Context, None],
|
|
598
|
+
) -> list[str]:
|
|
599
|
+
stmt = (
|
|
600
|
+
select(distinct(models.ProjectSessionAnnotation.name))
|
|
601
|
+
.join(models.ProjectSession)
|
|
602
|
+
.where(models.ProjectSession.project_id == self.project_rowid)
|
|
603
|
+
)
|
|
604
|
+
async with info.context.db() as session:
|
|
605
|
+
return list(await session.scalars(stmt))
|
|
606
|
+
|
|
591
607
|
@strawberry.field(
|
|
592
608
|
description="Names of available document evaluations.",
|
|
593
609
|
) # type: ignore
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from dataclasses import asdict, dataclass
|
|
1
3
|
from datetime import datetime
|
|
2
4
|
from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
|
|
3
5
|
|
|
6
|
+
import pandas as pd
|
|
4
7
|
import strawberry
|
|
5
8
|
from openinference.semconv.trace import SpanAttributes
|
|
6
9
|
from sqlalchemy import select
|
|
7
10
|
from strawberry import UNSET, Info, Private, lazy
|
|
8
|
-
from strawberry.relay import Connection,
|
|
11
|
+
from strawberry.relay import Connection, Node, NodeID
|
|
9
12
|
|
|
10
13
|
from phoenix.db import models
|
|
11
14
|
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
|
|
16
|
+
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
12
17
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
13
18
|
from phoenix.server.api.types.MimeType import MimeType
|
|
14
19
|
from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
|
|
@@ -18,6 +23,8 @@ from phoenix.server.api.types.SpanIOValue import SpanIOValue
|
|
|
18
23
|
from phoenix.server.api.types.TokenUsage import TokenUsage
|
|
19
24
|
|
|
20
25
|
if TYPE_CHECKING:
|
|
26
|
+
from phoenix.server.api.types.Project import Project
|
|
27
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
|
|
21
28
|
from phoenix.server.api.types.Trace import Trace
|
|
22
29
|
|
|
23
30
|
|
|
@@ -31,10 +38,13 @@ class ProjectSession(Node):
|
|
|
31
38
|
end_time: datetime
|
|
32
39
|
|
|
33
40
|
@strawberry.field
|
|
34
|
-
async def
|
|
41
|
+
async def project(
|
|
42
|
+
self,
|
|
43
|
+
info: Info[Context, None],
|
|
44
|
+
) -> Annotated["Project", lazy(".Project")]:
|
|
35
45
|
from phoenix.server.api.types.Project import Project
|
|
36
46
|
|
|
37
|
-
return
|
|
47
|
+
return Project(project_rowid=self.project_rowid)
|
|
38
48
|
|
|
39
49
|
@strawberry.field
|
|
40
50
|
async def num_traces(
|
|
@@ -165,6 +175,81 @@ class ProjectSession(Node):
|
|
|
165
175
|
for entry in summary
|
|
166
176
|
]
|
|
167
177
|
|
|
178
|
+
@strawberry.field
|
|
179
|
+
async def session_annotations(
|
|
180
|
+
self,
|
|
181
|
+
info: Info[Context, None],
|
|
182
|
+
) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
|
|
183
|
+
"""Get all annotations for this session."""
|
|
184
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
185
|
+
to_gql_project_session_annotation,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id_attr)
|
|
189
|
+
async with info.context.db() as session:
|
|
190
|
+
annotations = await session.stream_scalars(stmt)
|
|
191
|
+
return [
|
|
192
|
+
to_gql_project_session_annotation(annotation) async for annotation in annotations
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
@strawberry.field(
|
|
196
|
+
description="Summarizes each annotation (by name) associated with the session"
|
|
197
|
+
) # type: ignore
|
|
198
|
+
async def session_annotation_summaries(
|
|
199
|
+
self,
|
|
200
|
+
info: Info[Context, None],
|
|
201
|
+
filter: Optional[AnnotationFilter] = None,
|
|
202
|
+
) -> list[AnnotationSummary]:
|
|
203
|
+
"""
|
|
204
|
+
Retrieves and summarizes annotations associated with this span.
|
|
205
|
+
|
|
206
|
+
This method aggregates annotation data by name and label, calculating metrics
|
|
207
|
+
such as count of occurrences and sum of scores. The results are organized
|
|
208
|
+
into a structured format that can be easily converted to a DataFrame.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
info: GraphQL context information
|
|
212
|
+
filter: Optional filter to apply to annotations before processing
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
A list of AnnotationSummary objects, each containing:
|
|
216
|
+
- name: The name of the annotation
|
|
217
|
+
- data: A list of dictionaries with label statistics
|
|
218
|
+
"""
|
|
219
|
+
# Load all annotations for this span from the data loader
|
|
220
|
+
annotations = await info.context.data_loaders.session_annotations_by_session.load(
|
|
221
|
+
self.id_attr
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Apply filter if provided to narrow down the annotations
|
|
225
|
+
if filter:
|
|
226
|
+
annotations = [
|
|
227
|
+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class Metrics:
|
|
232
|
+
record_count: int = 0
|
|
233
|
+
label_count: int = 0
|
|
234
|
+
score_sum: float = 0
|
|
235
|
+
score_count: int = 0
|
|
236
|
+
|
|
237
|
+
summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
|
|
238
|
+
lambda: defaultdict(Metrics)
|
|
239
|
+
)
|
|
240
|
+
for annotation in annotations:
|
|
241
|
+
metrics = summaries[annotation.name][annotation.label]
|
|
242
|
+
metrics.record_count += 1
|
|
243
|
+
metrics.label_count += int(annotation.label is not None)
|
|
244
|
+
metrics.score_sum += annotation.score or 0
|
|
245
|
+
metrics.score_count += int(annotation.score is not None)
|
|
246
|
+
|
|
247
|
+
result: list[AnnotationSummary] = []
|
|
248
|
+
for name, label_metrics in summaries.items():
|
|
249
|
+
rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
|
|
250
|
+
result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
|
|
251
|
+
return result
|
|
252
|
+
|
|
168
253
|
|
|
169
254
|
def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
|
|
170
255
|
return ProjectSession(
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry import Private
|
|
5
|
+
from strawberry.relay import GlobalID, Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
from strawberry.types import Info
|
|
8
|
+
|
|
9
|
+
from phoenix.db import models
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
11
|
+
from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
|
|
12
|
+
|
|
13
|
+
from .AnnotationSource import AnnotationSource
|
|
14
|
+
from .User import User, to_gql_user
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class ProjectSessionAnnotation(Node):
|
|
19
|
+
id_attr: NodeID[int]
|
|
20
|
+
user_id: Private[Optional[int]]
|
|
21
|
+
name: str
|
|
22
|
+
annotator_kind: AnnotatorKind
|
|
23
|
+
label: Optional[str]
|
|
24
|
+
score: Optional[float]
|
|
25
|
+
explanation: Optional[str]
|
|
26
|
+
metadata: JSON
|
|
27
|
+
_project_session_id: Private[Optional[int]]
|
|
28
|
+
identifier: str
|
|
29
|
+
source: AnnotationSource
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def project_session_id(self) -> GlobalID:
|
|
33
|
+
from phoenix.server.api.types.ProjectSession import ProjectSession
|
|
34
|
+
|
|
35
|
+
return GlobalID(type_name=ProjectSession.__name__, node_id=str(self._project_session_id))
|
|
36
|
+
|
|
37
|
+
@strawberry.field
|
|
38
|
+
async def user(
|
|
39
|
+
self,
|
|
40
|
+
info: Info[Context, None],
|
|
41
|
+
) -> Optional[User]:
|
|
42
|
+
if self.user_id is None:
|
|
43
|
+
return None
|
|
44
|
+
user = await info.context.data_loaders.users.load(self.user_id)
|
|
45
|
+
if user is None:
|
|
46
|
+
return None
|
|
47
|
+
return to_gql_user(user)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def to_gql_project_session_annotation(
|
|
51
|
+
annotation: models.ProjectSessionAnnotation,
|
|
52
|
+
) -> ProjectSessionAnnotation:
|
|
53
|
+
"""
|
|
54
|
+
Converts an ORM projectSession annotation to a GraphQL ProjectSessionAnnotation.
|
|
55
|
+
"""
|
|
56
|
+
return ProjectSessionAnnotation(
|
|
57
|
+
id_attr=annotation.id,
|
|
58
|
+
user_id=annotation.user_id,
|
|
59
|
+
_project_session_id=annotation.project_session_id,
|
|
60
|
+
name=annotation.name,
|
|
61
|
+
annotator_kind=AnnotatorKind(annotation.annotator_kind),
|
|
62
|
+
label=annotation.label,
|
|
63
|
+
score=annotation.score,
|
|
64
|
+
explanation=annotation.explanation,
|
|
65
|
+
metadata=JSON(annotation.metadata_),
|
|
66
|
+
identifier=annotation.identifier,
|
|
67
|
+
source=AnnotationSource(annotation.source),
|
|
68
|
+
)
|
phoenix/server/api/types/Span.py
CHANGED
|
@@ -23,11 +23,11 @@ from phoenix.server.api.helpers.dataset_helpers import (
|
|
|
23
23
|
get_dataset_example_input,
|
|
24
24
|
get_dataset_example_output,
|
|
25
25
|
)
|
|
26
|
-
from phoenix.server.api.input_types.
|
|
27
|
-
|
|
28
|
-
SpanAnnotationFilter,
|
|
26
|
+
from phoenix.server.api.input_types.AnnotationFilter import (
|
|
27
|
+
AnnotationFilter,
|
|
29
28
|
satisfies_filter,
|
|
30
29
|
)
|
|
30
|
+
from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
|
|
31
31
|
from phoenix.server.api.input_types.SpanAnnotationSort import (
|
|
32
32
|
SpanAnnotationColumn,
|
|
33
33
|
SpanAnnotationSort,
|
|
@@ -547,7 +547,7 @@ class Span(Node):
|
|
|
547
547
|
self,
|
|
548
548
|
info: Info[Context, None],
|
|
549
549
|
sort: Optional[SpanAnnotationSort] = UNSET,
|
|
550
|
-
filter: Optional[
|
|
550
|
+
filter: Optional[AnnotationFilter] = None,
|
|
551
551
|
) -> list[SpanAnnotation]:
|
|
552
552
|
span_id = self.span_rowid
|
|
553
553
|
annotations = await info.context.data_loaders.span_annotations.load(span_id)
|
|
@@ -580,7 +580,7 @@ class Span(Node):
|
|
|
580
580
|
async def span_annotation_summaries(
|
|
581
581
|
self,
|
|
582
582
|
info: Info[Context, None],
|
|
583
|
-
filter: Optional[
|
|
583
|
+
filter: Optional[AnnotationFilter] = None,
|
|
584
584
|
) -> list[AnnotationSummary]:
|
|
585
585
|
"""
|
|
586
586
|
Retrieves and summarizes annotations associated with this span.
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
3
5
|
from datetime import datetime
|
|
4
6
|
from typing import TYPE_CHECKING, Annotated, Optional, Union
|
|
5
7
|
|
|
8
|
+
import pandas as pd
|
|
6
9
|
import strawberry
|
|
7
10
|
from openinference.semconv.trace import SpanAttributes
|
|
8
11
|
from sqlalchemy import desc, select
|
|
@@ -13,7 +16,9 @@ from typing_extensions import TypeAlias
|
|
|
13
16
|
|
|
14
17
|
from phoenix.db import models
|
|
15
18
|
from phoenix.server.api.context import Context
|
|
19
|
+
from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
|
|
16
20
|
from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
|
|
21
|
+
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
|
|
17
22
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
18
23
|
from phoenix.server.api.types.pagination import (
|
|
19
24
|
ConnectionArgs,
|
|
@@ -229,6 +234,62 @@ class Trace(Node):
|
|
|
229
234
|
annotations = await session.scalars(stmt)
|
|
230
235
|
return [to_gql_trace_annotation(annotation) for annotation in annotations]
|
|
231
236
|
|
|
237
|
+
@strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
|
|
238
|
+
async def trace_annotation_summaries(
|
|
239
|
+
self,
|
|
240
|
+
info: Info[Context, None],
|
|
241
|
+
filter: Optional[AnnotationFilter] = None,
|
|
242
|
+
) -> list[AnnotationSummary]:
|
|
243
|
+
"""
|
|
244
|
+
Retrieves and summarizes annotations associated with this span.
|
|
245
|
+
|
|
246
|
+
This method aggregates annotation data by name and label, calculating metrics
|
|
247
|
+
such as count of occurrences and sum of scores. The results are organized
|
|
248
|
+
into a structured format that can be easily converted to a DataFrame.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
info: GraphQL context information
|
|
252
|
+
filter: Optional filter to apply to annotations before processing
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
A list of AnnotationSummary objects, each containing:
|
|
256
|
+
- name: The name of the annotation
|
|
257
|
+
- data: A list of dictionaries with label statistics
|
|
258
|
+
"""
|
|
259
|
+
# Load all annotations for this span from the data loader
|
|
260
|
+
annotations = await info.context.data_loaders.trace_annotations_by_trace.load(
|
|
261
|
+
self.trace_rowid
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Apply filter if provided to narrow down the annotations
|
|
265
|
+
if filter:
|
|
266
|
+
annotations = [
|
|
267
|
+
annotation for annotation in annotations if satisfies_filter(annotation, filter)
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class Metrics:
|
|
272
|
+
record_count: int = 0
|
|
273
|
+
label_count: int = 0
|
|
274
|
+
score_sum: float = 0
|
|
275
|
+
score_count: int = 0
|
|
276
|
+
|
|
277
|
+
summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
|
|
278
|
+
lambda: defaultdict(Metrics)
|
|
279
|
+
)
|
|
280
|
+
for annotation in annotations:
|
|
281
|
+
metrics = summaries[annotation.name][annotation.label]
|
|
282
|
+
metrics.record_count += 1
|
|
283
|
+
metrics.label_count += int(annotation.label is not None)
|
|
284
|
+
metrics.score_sum += annotation.score or 0
|
|
285
|
+
metrics.score_count += int(annotation.score is not None)
|
|
286
|
+
|
|
287
|
+
result: list[AnnotationSummary] = []
|
|
288
|
+
for name, label_metrics in summaries.items():
|
|
289
|
+
rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
|
|
290
|
+
result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
|
|
291
|
+
return result
|
|
292
|
+
|
|
232
293
|
@strawberry.field
|
|
233
294
|
async def cost_summary(
|
|
234
295
|
self,
|