arize-phoenix 4.8.1__py3-none-any.whl → 4.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/METADATA +1 -1
- {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/RECORD +32 -24
- phoenix/db/insertion/evaluation.py +7 -25
- phoenix/db/insertion/helpers.py +54 -13
- phoenix/db/insertion/span.py +12 -16
- phoenix/inferences/inferences.py +16 -4
- phoenix/pointcloud/umap_parameters.py +52 -52
- phoenix/server/api/input_types/CreateSpanAnnotationsInput.py +16 -0
- phoenix/server/api/input_types/CreateTraceAnnotationsInput.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +9 -0
- phoenix/server/api/input_types/PatchAnnotationsInput.py +17 -0
- phoenix/server/api/mutations/__init__.py +8 -1
- phoenix/server/api/mutations/project_mutations.py +1 -1
- phoenix/server/api/mutations/span_annotations_mutations.py +108 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +108 -0
- phoenix/server/api/queries.py +1 -0
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +3 -10
- phoenix/server/api/routers/v1/experiments.py +4 -5
- phoenix/server/api/routers/v1/spans.py +147 -2
- phoenix/server/api/routers/v1/traces.py +150 -1
- phoenix/server/api/types/AnnotatorKind.py +7 -1
- phoenix/server/api/types/ExperimentRunAnnotation.py +3 -3
- phoenix/server/api/types/SpanAnnotation.py +45 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/static/index.js +532 -524
- phoenix/session/client.py +10 -8
- phoenix/trace/trace_dataset.py +23 -15
- phoenix/version.py +1 -1
- {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import List, Sequence
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import delete, insert, update
|
|
5
|
+
from strawberry import UNSET
|
|
6
|
+
from strawberry.types import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.context import Context
|
|
10
|
+
from phoenix.server.api.input_types.CreateSpanAnnotationsInput import CreateSpanAnnotationsInput
|
|
11
|
+
from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
|
|
12
|
+
from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
|
|
13
|
+
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
14
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
15
|
+
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@strawberry.type
|
|
19
|
+
class SpanAnnotationMutationPayload:
|
|
20
|
+
span_annotations: List[SpanAnnotation]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@strawberry.type
|
|
24
|
+
class SpanAnnotationMutationMixin:
|
|
25
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
26
|
+
async def create_span_annotations(
|
|
27
|
+
self, info: Info[Context, None], input: List[CreateSpanAnnotationsInput]
|
|
28
|
+
) -> SpanAnnotationMutationPayload:
|
|
29
|
+
inserted_annotations: Sequence[models.SpanAnnotation] = []
|
|
30
|
+
async with info.context.db() as session:
|
|
31
|
+
values_list = [
|
|
32
|
+
dict(
|
|
33
|
+
span_rowid=from_global_id_with_expected_type(annotation.span_id, "Span"),
|
|
34
|
+
name=annotation.name,
|
|
35
|
+
label=annotation.label,
|
|
36
|
+
score=annotation.score,
|
|
37
|
+
explanation=annotation.explanation,
|
|
38
|
+
annotator_kind=annotation.annotator_kind,
|
|
39
|
+
metadata_=annotation.metadata,
|
|
40
|
+
)
|
|
41
|
+
for annotation in input
|
|
42
|
+
]
|
|
43
|
+
stmt = (
|
|
44
|
+
insert(models.SpanAnnotation).values(values_list).returning(models.SpanAnnotation)
|
|
45
|
+
)
|
|
46
|
+
result = await session.scalars(stmt)
|
|
47
|
+
inserted_annotations = result.all()
|
|
48
|
+
|
|
49
|
+
return SpanAnnotationMutationPayload(
|
|
50
|
+
span_annotations=[
|
|
51
|
+
to_gql_span_annotation(annotation) for annotation in inserted_annotations
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
56
|
+
async def patch_span_annotations(
|
|
57
|
+
self, info: Info[Context, None], input: List[PatchAnnotationsInput]
|
|
58
|
+
) -> SpanAnnotationMutationPayload:
|
|
59
|
+
patched_annotations = []
|
|
60
|
+
async with info.context.db() as session:
|
|
61
|
+
for annotation in input:
|
|
62
|
+
span_annotation_id = from_global_id_with_expected_type(
|
|
63
|
+
annotation.annotation_id, "SpanAnnotation"
|
|
64
|
+
)
|
|
65
|
+
patch = {
|
|
66
|
+
column.key: patch_value
|
|
67
|
+
for column, patch_value, column_is_nullable in (
|
|
68
|
+
(models.SpanAnnotation.name, annotation.name, False),
|
|
69
|
+
(models.SpanAnnotation.annotator_kind, annotation.annotator_kind, False),
|
|
70
|
+
(models.SpanAnnotation.label, annotation.label, True),
|
|
71
|
+
(models.SpanAnnotation.score, annotation.score, True),
|
|
72
|
+
(models.SpanAnnotation.explanation, annotation.explanation, True),
|
|
73
|
+
(models.SpanAnnotation.metadata_, annotation.metadata, False),
|
|
74
|
+
)
|
|
75
|
+
if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
|
|
76
|
+
}
|
|
77
|
+
span_annotation = await session.scalar(
|
|
78
|
+
update(models.SpanAnnotation)
|
|
79
|
+
.where(models.SpanAnnotation.id == span_annotation_id)
|
|
80
|
+
.values(**patch)
|
|
81
|
+
.returning(models.SpanAnnotation)
|
|
82
|
+
)
|
|
83
|
+
if span_annotation is not None:
|
|
84
|
+
patched_annotations.append(to_gql_span_annotation(span_annotation))
|
|
85
|
+
|
|
86
|
+
return SpanAnnotationMutationPayload(span_annotations=patched_annotations)
|
|
87
|
+
|
|
88
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
89
|
+
async def delete_span_annotations(
|
|
90
|
+
self, info: Info[Context, None], input: DeleteAnnotationsInput
|
|
91
|
+
) -> SpanAnnotationMutationPayload:
|
|
92
|
+
span_annotation_ids = [
|
|
93
|
+
from_global_id_with_expected_type(global_id, "SpanAnnotation")
|
|
94
|
+
for global_id in input.annotation_ids
|
|
95
|
+
]
|
|
96
|
+
async with info.context.db() as session:
|
|
97
|
+
stmt = (
|
|
98
|
+
delete(models.SpanAnnotation)
|
|
99
|
+
.where(models.SpanAnnotation.id.in_(span_annotation_ids))
|
|
100
|
+
.returning(models.SpanAnnotation)
|
|
101
|
+
)
|
|
102
|
+
result = await session.scalars(stmt)
|
|
103
|
+
deleted_annotations = result.all()
|
|
104
|
+
|
|
105
|
+
deleted_annotations_gql = [
|
|
106
|
+
to_gql_span_annotation(annotation) for annotation in deleted_annotations
|
|
107
|
+
]
|
|
108
|
+
return SpanAnnotationMutationPayload(span_annotations=deleted_annotations_gql)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import List, Sequence
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import delete, insert, update
|
|
5
|
+
from strawberry import UNSET
|
|
6
|
+
from strawberry.types import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.context import Context
|
|
10
|
+
from phoenix.server.api.input_types.CreateTraceAnnotationsInput import CreateTraceAnnotationsInput
|
|
11
|
+
from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
|
|
12
|
+
from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
|
|
13
|
+
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
14
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
15
|
+
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@strawberry.type
|
|
19
|
+
class TraceAnnotationMutationPayload:
|
|
20
|
+
trace_annotations: List[TraceAnnotation]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@strawberry.type
|
|
24
|
+
class TraceAnnotationMutationMixin:
|
|
25
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
26
|
+
async def create_trace_annotations(
|
|
27
|
+
self, info: Info[Context, None], input: List[CreateTraceAnnotationsInput]
|
|
28
|
+
) -> TraceAnnotationMutationPayload:
|
|
29
|
+
inserted_annotations: Sequence[models.TraceAnnotation] = []
|
|
30
|
+
async with info.context.db() as session:
|
|
31
|
+
values_list = [
|
|
32
|
+
dict(
|
|
33
|
+
trace_rowid=from_global_id_with_expected_type(annotation.trace_id, "Trace"),
|
|
34
|
+
name=annotation.name,
|
|
35
|
+
label=annotation.label,
|
|
36
|
+
score=annotation.score,
|
|
37
|
+
explanation=annotation.explanation,
|
|
38
|
+
annotator_kind=annotation.annotator_kind,
|
|
39
|
+
metadata_=annotation.metadata,
|
|
40
|
+
)
|
|
41
|
+
for annotation in input
|
|
42
|
+
]
|
|
43
|
+
stmt = (
|
|
44
|
+
insert(models.TraceAnnotation).values(values_list).returning(models.TraceAnnotation)
|
|
45
|
+
)
|
|
46
|
+
result = await session.scalars(stmt)
|
|
47
|
+
inserted_annotations = result.all()
|
|
48
|
+
|
|
49
|
+
return TraceAnnotationMutationPayload(
|
|
50
|
+
trace_annotations=[
|
|
51
|
+
to_gql_trace_annotation(annotation) for annotation in inserted_annotations
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
56
|
+
async def patch_trace_annotations(
|
|
57
|
+
self, info: Info[Context, None], input: List[PatchAnnotationsInput]
|
|
58
|
+
) -> TraceAnnotationMutationPayload:
|
|
59
|
+
patched_annotations = []
|
|
60
|
+
async with info.context.db() as session:
|
|
61
|
+
for annotation in input:
|
|
62
|
+
trace_annotation_id = from_global_id_with_expected_type(
|
|
63
|
+
annotation.annotation_id, "TraceAnnotation"
|
|
64
|
+
)
|
|
65
|
+
patch = {
|
|
66
|
+
column.key: patch_value
|
|
67
|
+
for column, patch_value, column_is_nullable in (
|
|
68
|
+
(models.TraceAnnotation.name, annotation.name, False),
|
|
69
|
+
(models.TraceAnnotation.annotator_kind, annotation.annotator_kind, False),
|
|
70
|
+
(models.TraceAnnotation.label, annotation.label, True),
|
|
71
|
+
(models.TraceAnnotation.score, annotation.score, True),
|
|
72
|
+
(models.TraceAnnotation.explanation, annotation.explanation, True),
|
|
73
|
+
(models.TraceAnnotation.metadata_, annotation.metadata, False),
|
|
74
|
+
)
|
|
75
|
+
if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
|
|
76
|
+
}
|
|
77
|
+
trace_annotation = await session.scalar(
|
|
78
|
+
update(models.TraceAnnotation)
|
|
79
|
+
.where(models.TraceAnnotation.id == trace_annotation_id)
|
|
80
|
+
.values(**patch)
|
|
81
|
+
.returning(models.TraceAnnotation)
|
|
82
|
+
)
|
|
83
|
+
if trace_annotation:
|
|
84
|
+
patched_annotations.append(to_gql_trace_annotation(trace_annotation))
|
|
85
|
+
|
|
86
|
+
return TraceAnnotationMutationPayload(trace_annotations=patched_annotations)
|
|
87
|
+
|
|
88
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
89
|
+
async def delete_trace_annotations(
|
|
90
|
+
self, info: Info[Context, None], input: DeleteAnnotationsInput
|
|
91
|
+
) -> TraceAnnotationMutationPayload:
|
|
92
|
+
trace_annotation_ids = [
|
|
93
|
+
from_global_id_with_expected_type(global_id, "TraceAnnotation")
|
|
94
|
+
for global_id in input.annotation_ids
|
|
95
|
+
]
|
|
96
|
+
async with info.context.db() as session:
|
|
97
|
+
stmt = (
|
|
98
|
+
delete(models.TraceAnnotation)
|
|
99
|
+
.where(models.TraceAnnotation.id.in_(trace_annotation_ids))
|
|
100
|
+
.returning(models.TraceAnnotation)
|
|
101
|
+
)
|
|
102
|
+
result = await session.scalars(stmt)
|
|
103
|
+
deleted_annotations = result.all()
|
|
104
|
+
|
|
105
|
+
deleted_annotations_gql = [
|
|
106
|
+
to_gql_trace_annotation(annotation) for annotation in deleted_annotations
|
|
107
|
+
]
|
|
108
|
+
return TraceAnnotationMutationPayload(trace_annotations=deleted_annotations_gql)
|
phoenix/server/api/queries.py
CHANGED
|
@@ -40,8 +40,10 @@ V1_ROUTES = [
|
|
|
40
40
|
Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
|
|
41
41
|
Route("/v1/evaluations", evaluations.get_evaluations, methods=["GET"]),
|
|
42
42
|
Route("/v1/traces", traces.post_traces, methods=["POST"]),
|
|
43
|
+
Route("/v1/trace_annotations", traces.annotate_traces, methods=["POST"]),
|
|
43
44
|
Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
|
|
44
45
|
Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
|
|
46
|
+
Route("/v1/span_annotations", spans.annotate_spans, methods=["POST"]),
|
|
45
47
|
Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
|
|
46
48
|
Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
|
|
47
49
|
Route("/v1/datasets/{id:str}", datasets.delete_dataset_by_id, methods=["DELETE"]),
|
|
@@ -7,7 +7,7 @@ from strawberry.relay import GlobalID
|
|
|
7
7
|
|
|
8
8
|
from phoenix.db import models
|
|
9
9
|
from phoenix.db.helpers import SupportedSQLDialect
|
|
10
|
-
from phoenix.db.insertion.helpers import
|
|
10
|
+
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
11
11
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
12
12
|
|
|
13
13
|
|
|
@@ -123,20 +123,13 @@ async def upsert_experiment_evaluation(request: Request) -> Response:
|
|
|
123
123
|
end_time=datetime.fromisoformat(end_time),
|
|
124
124
|
trace_id=payload.get("trace_id"),
|
|
125
125
|
)
|
|
126
|
-
set_ = {
|
|
127
|
-
**{k: v for k, v in values.items() if k != "metadata_"},
|
|
128
|
-
"metadata": values["metadata_"], # `metadata` must match database
|
|
129
|
-
}
|
|
130
126
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
131
127
|
exp_eval_run = await session.scalar(
|
|
132
128
|
insert_on_conflict(
|
|
129
|
+
values,
|
|
133
130
|
dialect=dialect,
|
|
134
131
|
table=models.ExperimentRunAnnotation,
|
|
135
|
-
|
|
136
|
-
constraint="uq_experiment_run_annotations_experiment_run_id_name",
|
|
137
|
-
column_names=("experiment_run_id", "name"),
|
|
138
|
-
on_conflict=OnConflict.DO_UPDATE,
|
|
139
|
-
set_=set_,
|
|
132
|
+
unique_by=("experiment_run_id", "name"),
|
|
140
133
|
).returning(models.ExperimentRunAnnotation)
|
|
141
134
|
)
|
|
142
135
|
evaluation_gid = GlobalID("ExperimentEvaluation", str(exp_eval_run.id))
|
|
@@ -180,16 +180,15 @@ async def create_experiment(request: Request) -> Response:
|
|
|
180
180
|
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
181
181
|
project_rowid = await session.scalar(
|
|
182
182
|
insert_on_conflict(
|
|
183
|
-
|
|
184
|
-
table=models.Project,
|
|
185
|
-
constraint="uq_projects_name",
|
|
186
|
-
column_names=("name",),
|
|
187
|
-
values=dict(
|
|
183
|
+
dict(
|
|
188
184
|
name=project_name,
|
|
189
185
|
description=project_description,
|
|
190
186
|
created_at=experiment.created_at,
|
|
191
187
|
updated_at=experiment.updated_at,
|
|
192
188
|
),
|
|
189
|
+
dialect=dialect,
|
|
190
|
+
table=models.Project,
|
|
191
|
+
unique_by=("name",),
|
|
193
192
|
).returning(models.Project.id)
|
|
194
193
|
)
|
|
195
194
|
assert project_rowid is not None
|
|
@@ -1,13 +1,19 @@
|
|
|
1
1
|
from datetime import timezone
|
|
2
|
-
from typing import AsyncIterator
|
|
2
|
+
from typing import Any, AsyncIterator, Dict, List
|
|
3
3
|
|
|
4
|
+
from sqlalchemy import select
|
|
4
5
|
from starlette.requests import Request
|
|
5
|
-
from starlette.responses import Response, StreamingResponse
|
|
6
|
+
from starlette.responses import JSONResponse, Response, StreamingResponse
|
|
6
7
|
from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
7
9
|
|
|
8
10
|
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
9
11
|
from phoenix.datetime_utils import normalize_datetime
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
14
|
+
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
10
15
|
from phoenix.server.api.routers.utils import df_to_bytes, from_iso_format
|
|
16
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
11
17
|
from phoenix.trace.dsl import SpanQuery
|
|
12
18
|
|
|
13
19
|
DEFAULT_SPAN_LIMIT = 1000
|
|
@@ -128,3 +134,142 @@ async def query_spans_handler(request: Request) -> Response:
|
|
|
128
134
|
|
|
129
135
|
async def get_spans_handler(request: Request) -> Response:
|
|
130
136
|
return await query_spans_handler(request)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def annotate_spans(request: Request) -> Response:
|
|
140
|
+
"""
|
|
141
|
+
summary: Upsert annotations for spans
|
|
142
|
+
operationId: annotateSpans
|
|
143
|
+
tags:
|
|
144
|
+
- private
|
|
145
|
+
requestBody:
|
|
146
|
+
description: List of span annotations to be inserted
|
|
147
|
+
required: true
|
|
148
|
+
content:
|
|
149
|
+
application/json:
|
|
150
|
+
schema:
|
|
151
|
+
type: object
|
|
152
|
+
properties:
|
|
153
|
+
data:
|
|
154
|
+
type: array
|
|
155
|
+
items:
|
|
156
|
+
type: object
|
|
157
|
+
properties:
|
|
158
|
+
span_id:
|
|
159
|
+
type: string
|
|
160
|
+
description: The ID of the span being annotated
|
|
161
|
+
name:
|
|
162
|
+
type: string
|
|
163
|
+
description: The name of the annotation
|
|
164
|
+
annotator_kind:
|
|
165
|
+
type: string
|
|
166
|
+
description: The kind of annotator used for the annotation ("LLM" or "HUMAN")
|
|
167
|
+
result:
|
|
168
|
+
type: object
|
|
169
|
+
description: The result of the annotation
|
|
170
|
+
properties:
|
|
171
|
+
label:
|
|
172
|
+
type: string
|
|
173
|
+
description: The label assigned by the annotation
|
|
174
|
+
score:
|
|
175
|
+
type: number
|
|
176
|
+
format: float
|
|
177
|
+
description: The score assigned by the annotation
|
|
178
|
+
explanation:
|
|
179
|
+
type: string
|
|
180
|
+
description: Explanation of the annotation result
|
|
181
|
+
error:
|
|
182
|
+
type: string
|
|
183
|
+
description: Optional error message if the annotation encountered an error
|
|
184
|
+
metadata:
|
|
185
|
+
type: object
|
|
186
|
+
description: Metadata for the annotation
|
|
187
|
+
additionalProperties:
|
|
188
|
+
type: string
|
|
189
|
+
required:
|
|
190
|
+
- span_id
|
|
191
|
+
- name
|
|
192
|
+
- annotator_kind
|
|
193
|
+
responses:
|
|
194
|
+
200:
|
|
195
|
+
description: Span annotations inserted successfully
|
|
196
|
+
content:
|
|
197
|
+
application/json:
|
|
198
|
+
schema:
|
|
199
|
+
type: object
|
|
200
|
+
properties:
|
|
201
|
+
data:
|
|
202
|
+
type: array
|
|
203
|
+
items:
|
|
204
|
+
type: object
|
|
205
|
+
properties:
|
|
206
|
+
id:
|
|
207
|
+
type: string
|
|
208
|
+
description: The ID of the inserted span annotation
|
|
209
|
+
404:
|
|
210
|
+
description: Span not found
|
|
211
|
+
"""
|
|
212
|
+
payload: List[Dict[str, Any]] = (await request.json()).get("data", [])
|
|
213
|
+
span_gids = [GlobalID.from_id(annotation["span_id"]) for annotation in payload]
|
|
214
|
+
|
|
215
|
+
resolved_span_ids = []
|
|
216
|
+
for span_gid in span_gids:
|
|
217
|
+
try:
|
|
218
|
+
resolved_span_ids.append(from_global_id_with_expected_type(span_gid, "Span"))
|
|
219
|
+
except ValueError:
|
|
220
|
+
return Response(
|
|
221
|
+
content="Span with ID {span_gid} does not exist",
|
|
222
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
async with request.app.state.db() as session:
|
|
226
|
+
spans = await session.execute(
|
|
227
|
+
select(models.Span).filter(models.Span.id.in_(resolved_span_ids))
|
|
228
|
+
)
|
|
229
|
+
existing_span_ids = {span.id for span in spans.scalars()}
|
|
230
|
+
|
|
231
|
+
missing_span_ids = set(resolved_span_ids) - existing_span_ids
|
|
232
|
+
if missing_span_ids:
|
|
233
|
+
missing_span_gids = [
|
|
234
|
+
str(GlobalID("Span", str(span_gid))) for span_gid in missing_span_ids
|
|
235
|
+
]
|
|
236
|
+
return Response(
|
|
237
|
+
content=f"Spans with IDs {', '.join(missing_span_gids)} do not exist.",
|
|
238
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
inserted_annotations = []
|
|
242
|
+
for annotation in payload:
|
|
243
|
+
span_gid = GlobalID.from_id(annotation["span_id"])
|
|
244
|
+
span_id = from_global_id_with_expected_type(span_gid, "Span")
|
|
245
|
+
name = annotation["name"]
|
|
246
|
+
annotator_kind = annotation["annotator_kind"]
|
|
247
|
+
result = annotation.get("result")
|
|
248
|
+
label = result.get("label") if result else None
|
|
249
|
+
score = result.get("score") if result else None
|
|
250
|
+
explanation = result.get("explanation") if result else None
|
|
251
|
+
metadata = annotation.get("metadata") or {}
|
|
252
|
+
|
|
253
|
+
values = dict(
|
|
254
|
+
span_rowid=span_id,
|
|
255
|
+
name=name,
|
|
256
|
+
label=label,
|
|
257
|
+
score=score,
|
|
258
|
+
explanation=explanation,
|
|
259
|
+
annotator_kind=annotator_kind,
|
|
260
|
+
metadata_=metadata,
|
|
261
|
+
)
|
|
262
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
263
|
+
span_annotation_id = await session.scalar(
|
|
264
|
+
insert_on_conflict(
|
|
265
|
+
values,
|
|
266
|
+
dialect=dialect,
|
|
267
|
+
table=models.SpanAnnotation,
|
|
268
|
+
unique_by=("name", "span_rowid"),
|
|
269
|
+
).returning(models.SpanAnnotation.id)
|
|
270
|
+
)
|
|
271
|
+
inserted_annotations.append(
|
|
272
|
+
{"id": str(GlobalID("SpanAnnotation", str(span_annotation_id)))}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return JSONResponse(content={"data": inserted_annotations})
|
|
@@ -1,20 +1,28 @@
|
|
|
1
1
|
import gzip
|
|
2
2
|
import zlib
|
|
3
|
+
from typing import Any, Dict, List
|
|
3
4
|
|
|
4
5
|
from google.protobuf.message import DecodeError
|
|
5
6
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
6
7
|
ExportTraceServiceRequest,
|
|
7
8
|
)
|
|
9
|
+
from sqlalchemy import select
|
|
8
10
|
from starlette.background import BackgroundTask
|
|
9
11
|
from starlette.concurrency import run_in_threadpool
|
|
10
12
|
from starlette.datastructures import State
|
|
11
13
|
from starlette.requests import Request
|
|
12
|
-
from starlette.responses import Response
|
|
14
|
+
from starlette.responses import JSONResponse, Response
|
|
13
15
|
from starlette.status import (
|
|
16
|
+
HTTP_404_NOT_FOUND,
|
|
14
17
|
HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
|
15
18
|
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
16
19
|
)
|
|
20
|
+
from strawberry.relay import GlobalID
|
|
17
21
|
|
|
22
|
+
from phoenix.db import models
|
|
23
|
+
from phoenix.db.helpers import SupportedSQLDialect
|
|
24
|
+
from phoenix.db.insertion.helpers import insert_on_conflict
|
|
25
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
18
26
|
from phoenix.trace.otel import decode_otlp_span
|
|
19
27
|
from phoenix.utilities.project import get_project_name
|
|
20
28
|
|
|
@@ -70,6 +78,147 @@ async def post_traces(request: Request) -> Response:
|
|
|
70
78
|
return Response(background=BackgroundTask(_add_spans, req, request.state))
|
|
71
79
|
|
|
72
80
|
|
|
81
|
+
async def annotate_traces(request: Request) -> Response:
|
|
82
|
+
"""
|
|
83
|
+
summary: Upsert annotations for traces
|
|
84
|
+
operationId: annotateTraces
|
|
85
|
+
tags:
|
|
86
|
+
- private
|
|
87
|
+
requestBody:
|
|
88
|
+
description: List of trace annotations to be inserted
|
|
89
|
+
required: true
|
|
90
|
+
content:
|
|
91
|
+
application/json:
|
|
92
|
+
schema:
|
|
93
|
+
type: object
|
|
94
|
+
properties:
|
|
95
|
+
data:
|
|
96
|
+
type: array
|
|
97
|
+
items:
|
|
98
|
+
type: object
|
|
99
|
+
properties:
|
|
100
|
+
trace_id:
|
|
101
|
+
type: string
|
|
102
|
+
description: The ID of the trace being annotated
|
|
103
|
+
name:
|
|
104
|
+
type: string
|
|
105
|
+
description: The name of the annotation
|
|
106
|
+
annotator_kind:
|
|
107
|
+
type: string
|
|
108
|
+
description: The kind of annotator used for the annotation ("LLM" or "HUMAN")
|
|
109
|
+
result:
|
|
110
|
+
type: object
|
|
111
|
+
description: The result of the annotation
|
|
112
|
+
properties:
|
|
113
|
+
label:
|
|
114
|
+
type: string
|
|
115
|
+
description: The label assigned by the annotation
|
|
116
|
+
score:
|
|
117
|
+
type: number
|
|
118
|
+
format: float
|
|
119
|
+
description: The score assigned by the annotation
|
|
120
|
+
explanation:
|
|
121
|
+
type: string
|
|
122
|
+
description: Explanation of the annotation result
|
|
123
|
+
error:
|
|
124
|
+
type: string
|
|
125
|
+
description: Optional error message if the annotation encountered an error
|
|
126
|
+
metadata:
|
|
127
|
+
type: object
|
|
128
|
+
description: Metadata for the annotation
|
|
129
|
+
additionalProperties:
|
|
130
|
+
type: string
|
|
131
|
+
required:
|
|
132
|
+
- trace_id
|
|
133
|
+
- name
|
|
134
|
+
- annotator_kind
|
|
135
|
+
responses:
|
|
136
|
+
200:
|
|
137
|
+
description: Trace annotations inserted successfully
|
|
138
|
+
content:
|
|
139
|
+
application/json:
|
|
140
|
+
schema:
|
|
141
|
+
type: object
|
|
142
|
+
properties:
|
|
143
|
+
data:
|
|
144
|
+
type: array
|
|
145
|
+
items:
|
|
146
|
+
type: object
|
|
147
|
+
properties:
|
|
148
|
+
id:
|
|
149
|
+
type: string
|
|
150
|
+
description: The ID of the inserted trace annotation
|
|
151
|
+
404:
|
|
152
|
+
description: Trace not found
|
|
153
|
+
"""
|
|
154
|
+
payload: List[Dict[str, Any]] = (await request.json()).get("data", [])
|
|
155
|
+
trace_gids = [GlobalID.from_id(annotation["trace_id"]) for annotation in payload]
|
|
156
|
+
|
|
157
|
+
resolved_trace_ids = []
|
|
158
|
+
for trace_gid in trace_gids:
|
|
159
|
+
try:
|
|
160
|
+
resolved_trace_ids.append(from_global_id_with_expected_type(trace_gid, "Trace"))
|
|
161
|
+
except ValueError:
|
|
162
|
+
return Response(
|
|
163
|
+
content="Trace with ID {trace_gid} does not exist",
|
|
164
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async with request.app.state.db() as session:
|
|
168
|
+
traces = await session.execute(
|
|
169
|
+
select(models.Trace).filter(models.Trace.id.in_(resolved_trace_ids))
|
|
170
|
+
)
|
|
171
|
+
existing_trace_ids = {trace.id for trace in traces.scalars()}
|
|
172
|
+
|
|
173
|
+
missing_trace_ids = set(resolved_trace_ids) - existing_trace_ids
|
|
174
|
+
if missing_trace_ids:
|
|
175
|
+
missing_trace_gids = [
|
|
176
|
+
str(GlobalID("Trace", str(trace_gid))) for trace_gid in missing_trace_ids
|
|
177
|
+
]
|
|
178
|
+
return Response(
|
|
179
|
+
content=f"Traces with IDs {', '.join(missing_trace_gids)} do not exist.",
|
|
180
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
inserted_annotations = []
|
|
184
|
+
|
|
185
|
+
for annotation in payload:
|
|
186
|
+
trace_gid = GlobalID.from_id(annotation["trace_id"])
|
|
187
|
+
trace_id = from_global_id_with_expected_type(trace_gid, "Trace")
|
|
188
|
+
|
|
189
|
+
name = annotation["name"]
|
|
190
|
+
annotator_kind = annotation["annotator_kind"]
|
|
191
|
+
result = annotation.get("result")
|
|
192
|
+
label = result.get("label") if result else None
|
|
193
|
+
score = result.get("score") if result else None
|
|
194
|
+
explanation = result.get("explanation") if result else None
|
|
195
|
+
metadata = annotation.get("metadata") or {}
|
|
196
|
+
|
|
197
|
+
values = dict(
|
|
198
|
+
trace_rowid=trace_id,
|
|
199
|
+
name=name,
|
|
200
|
+
label=label,
|
|
201
|
+
score=score,
|
|
202
|
+
explanation=explanation,
|
|
203
|
+
annotator_kind=annotator_kind,
|
|
204
|
+
metadata_=metadata,
|
|
205
|
+
)
|
|
206
|
+
dialect = SupportedSQLDialect(session.bind.dialect.name)
|
|
207
|
+
trace_annotation_id = await session.scalar(
|
|
208
|
+
insert_on_conflict(
|
|
209
|
+
values,
|
|
210
|
+
dialect=dialect,
|
|
211
|
+
table=models.TraceAnnotation,
|
|
212
|
+
unique_by=("name", "trace_rowid"),
|
|
213
|
+
).returning(models.TraceAnnotation.id)
|
|
214
|
+
)
|
|
215
|
+
inserted_annotations.append(
|
|
216
|
+
{"id": str(GlobalID("TraceAnnotation", str(trace_annotation_id)))}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return JSONResponse(content={"data": inserted_annotations})
|
|
220
|
+
|
|
221
|
+
|
|
73
222
|
async def _add_spans(req: ExportTraceServiceRequest, state: State) -> None:
|
|
74
223
|
for resource_spans in req.resource_spans:
|
|
75
224
|
project_name = get_project_name(resource_spans.resource.attributes)
|
|
@@ -7,7 +7,7 @@ from strawberry.relay import Node, NodeID
|
|
|
7
7
|
from strawberry.scalars import JSON
|
|
8
8
|
|
|
9
9
|
from phoenix.db import models
|
|
10
|
-
from phoenix.server.api.types.AnnotatorKind import
|
|
10
|
+
from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
|
|
11
11
|
from phoenix.server.api.types.Trace import Trace
|
|
12
12
|
|
|
13
13
|
|
|
@@ -15,7 +15,7 @@ from phoenix.server.api.types.Trace import Trace
|
|
|
15
15
|
class ExperimentRunAnnotation(Node):
|
|
16
16
|
id_attr: NodeID[int]
|
|
17
17
|
name: str
|
|
18
|
-
annotator_kind:
|
|
18
|
+
annotator_kind: ExperimentRunAnnotatorKind
|
|
19
19
|
label: Optional[str]
|
|
20
20
|
score: Optional[float]
|
|
21
21
|
explanation: Optional[str]
|
|
@@ -45,7 +45,7 @@ def to_gql_experiment_run_annotation(
|
|
|
45
45
|
return ExperimentRunAnnotation(
|
|
46
46
|
id_attr=annotation.id,
|
|
47
47
|
name=annotation.name,
|
|
48
|
-
annotator_kind=
|
|
48
|
+
annotator_kind=ExperimentRunAnnotatorKind(annotation.annotator_kind),
|
|
49
49
|
label=annotation.label,
|
|
50
50
|
score=annotation.score,
|
|
51
51
|
explanation=annotation.explanation,
|