arize-phoenix 4.8.1__py3-none-any.whl → 4.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (32) hide show
  1. {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/METADATA +1 -1
  2. {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/RECORD +32 -24
  3. phoenix/db/insertion/evaluation.py +7 -25
  4. phoenix/db/insertion/helpers.py +54 -13
  5. phoenix/db/insertion/span.py +12 -16
  6. phoenix/inferences/inferences.py +16 -4
  7. phoenix/pointcloud/umap_parameters.py +52 -52
  8. phoenix/server/api/input_types/CreateSpanAnnotationsInput.py +16 -0
  9. phoenix/server/api/input_types/CreateTraceAnnotationsInput.py +16 -0
  10. phoenix/server/api/input_types/DeleteAnnotationsInput.py +9 -0
  11. phoenix/server/api/input_types/PatchAnnotationsInput.py +17 -0
  12. phoenix/server/api/mutations/__init__.py +8 -1
  13. phoenix/server/api/mutations/project_mutations.py +1 -1
  14. phoenix/server/api/mutations/span_annotations_mutations.py +108 -0
  15. phoenix/server/api/mutations/trace_annotations_mutations.py +108 -0
  16. phoenix/server/api/queries.py +1 -0
  17. phoenix/server/api/routers/v1/__init__.py +2 -0
  18. phoenix/server/api/routers/v1/experiment_evaluations.py +3 -10
  19. phoenix/server/api/routers/v1/experiments.py +4 -5
  20. phoenix/server/api/routers/v1/spans.py +147 -2
  21. phoenix/server/api/routers/v1/traces.py +150 -1
  22. phoenix/server/api/types/AnnotatorKind.py +7 -1
  23. phoenix/server/api/types/ExperimentRunAnnotation.py +3 -3
  24. phoenix/server/api/types/SpanAnnotation.py +45 -0
  25. phoenix/server/api/types/TraceAnnotation.py +45 -0
  26. phoenix/server/static/index.js +532 -524
  27. phoenix/session/client.py +10 -8
  28. phoenix/trace/trace_dataset.py +23 -15
  29. phoenix/version.py +1 -1
  30. {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/WHEEL +0 -0
  31. {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/licenses/IP_NOTICE +0 -0
  32. {arize_phoenix-4.8.1.dist-info → arize_phoenix-4.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,108 @@
1
+ from typing import List, Sequence
2
+
3
+ import strawberry
4
+ from sqlalchemy import delete, insert, update
5
+ from strawberry import UNSET
6
+ from strawberry.types import Info
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.context import Context
10
+ from phoenix.server.api.input_types.CreateSpanAnnotationsInput import CreateSpanAnnotationsInput
11
+ from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
+ from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
13
+ from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
15
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
16
+
17
+
18
+ @strawberry.type
19
+ class SpanAnnotationMutationPayload:
20
+ span_annotations: List[SpanAnnotation]
21
+
22
+
23
+ @strawberry.type
24
+ class SpanAnnotationMutationMixin:
25
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
26
+ async def create_span_annotations(
27
+ self, info: Info[Context, None], input: List[CreateSpanAnnotationsInput]
28
+ ) -> SpanAnnotationMutationPayload:
29
+ inserted_annotations: Sequence[models.SpanAnnotation] = []
30
+ async with info.context.db() as session:
31
+ values_list = [
32
+ dict(
33
+ span_rowid=from_global_id_with_expected_type(annotation.span_id, "Span"),
34
+ name=annotation.name,
35
+ label=annotation.label,
36
+ score=annotation.score,
37
+ explanation=annotation.explanation,
38
+ annotator_kind=annotation.annotator_kind,
39
+ metadata_=annotation.metadata,
40
+ )
41
+ for annotation in input
42
+ ]
43
+ stmt = (
44
+ insert(models.SpanAnnotation).values(values_list).returning(models.SpanAnnotation)
45
+ )
46
+ result = await session.scalars(stmt)
47
+ inserted_annotations = result.all()
48
+
49
+ return SpanAnnotationMutationPayload(
50
+ span_annotations=[
51
+ to_gql_span_annotation(annotation) for annotation in inserted_annotations
52
+ ]
53
+ )
54
+
55
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
56
+ async def patch_span_annotations(
57
+ self, info: Info[Context, None], input: List[PatchAnnotationsInput]
58
+ ) -> SpanAnnotationMutationPayload:
59
+ patched_annotations = []
60
+ async with info.context.db() as session:
61
+ for annotation in input:
62
+ span_annotation_id = from_global_id_with_expected_type(
63
+ annotation.annotation_id, "SpanAnnotation"
64
+ )
65
+ patch = {
66
+ column.key: patch_value
67
+ for column, patch_value, column_is_nullable in (
68
+ (models.SpanAnnotation.name, annotation.name, False),
69
+ (models.SpanAnnotation.annotator_kind, annotation.annotator_kind, False),
70
+ (models.SpanAnnotation.label, annotation.label, True),
71
+ (models.SpanAnnotation.score, annotation.score, True),
72
+ (models.SpanAnnotation.explanation, annotation.explanation, True),
73
+ (models.SpanAnnotation.metadata_, annotation.metadata, False),
74
+ )
75
+ if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
76
+ }
77
+ span_annotation = await session.scalar(
78
+ update(models.SpanAnnotation)
79
+ .where(models.SpanAnnotation.id == span_annotation_id)
80
+ .values(**patch)
81
+ .returning(models.SpanAnnotation)
82
+ )
83
+ if span_annotation is not None:
84
+ patched_annotations.append(to_gql_span_annotation(span_annotation))
85
+
86
+ return SpanAnnotationMutationPayload(span_annotations=patched_annotations)
87
+
88
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
89
+ async def delete_span_annotations(
90
+ self, info: Info[Context, None], input: DeleteAnnotationsInput
91
+ ) -> SpanAnnotationMutationPayload:
92
+ span_annotation_ids = [
93
+ from_global_id_with_expected_type(global_id, "SpanAnnotation")
94
+ for global_id in input.annotation_ids
95
+ ]
96
+ async with info.context.db() as session:
97
+ stmt = (
98
+ delete(models.SpanAnnotation)
99
+ .where(models.SpanAnnotation.id.in_(span_annotation_ids))
100
+ .returning(models.SpanAnnotation)
101
+ )
102
+ result = await session.scalars(stmt)
103
+ deleted_annotations = result.all()
104
+
105
+ deleted_annotations_gql = [
106
+ to_gql_span_annotation(annotation) for annotation in deleted_annotations
107
+ ]
108
+ return SpanAnnotationMutationPayload(span_annotations=deleted_annotations_gql)
@@ -0,0 +1,108 @@
1
+ from typing import List, Sequence
2
+
3
+ import strawberry
4
+ from sqlalchemy import delete, insert, update
5
+ from strawberry import UNSET
6
+ from strawberry.types import Info
7
+
8
+ from phoenix.db import models
9
+ from phoenix.server.api.context import Context
10
+ from phoenix.server.api.input_types.CreateTraceAnnotationsInput import CreateTraceAnnotationsInput
11
+ from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
12
+ from phoenix.server.api.input_types.PatchAnnotationsInput import PatchAnnotationsInput
13
+ from phoenix.server.api.mutations.auth import IsAuthenticated
14
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
15
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
16
+
17
+
18
+ @strawberry.type
19
+ class TraceAnnotationMutationPayload:
20
+ trace_annotations: List[TraceAnnotation]
21
+
22
+
23
+ @strawberry.type
24
+ class TraceAnnotationMutationMixin:
25
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
26
+ async def create_trace_annotations(
27
+ self, info: Info[Context, None], input: List[CreateTraceAnnotationsInput]
28
+ ) -> TraceAnnotationMutationPayload:
29
+ inserted_annotations: Sequence[models.TraceAnnotation] = []
30
+ async with info.context.db() as session:
31
+ values_list = [
32
+ dict(
33
+ trace_rowid=from_global_id_with_expected_type(annotation.trace_id, "Trace"),
34
+ name=annotation.name,
35
+ label=annotation.label,
36
+ score=annotation.score,
37
+ explanation=annotation.explanation,
38
+ annotator_kind=annotation.annotator_kind,
39
+ metadata_=annotation.metadata,
40
+ )
41
+ for annotation in input
42
+ ]
43
+ stmt = (
44
+ insert(models.TraceAnnotation).values(values_list).returning(models.TraceAnnotation)
45
+ )
46
+ result = await session.scalars(stmt)
47
+ inserted_annotations = result.all()
48
+
49
+ return TraceAnnotationMutationPayload(
50
+ trace_annotations=[
51
+ to_gql_trace_annotation(annotation) for annotation in inserted_annotations
52
+ ]
53
+ )
54
+
55
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
56
+ async def patch_trace_annotations(
57
+ self, info: Info[Context, None], input: List[PatchAnnotationsInput]
58
+ ) -> TraceAnnotationMutationPayload:
59
+ patched_annotations = []
60
+ async with info.context.db() as session:
61
+ for annotation in input:
62
+ trace_annotation_id = from_global_id_with_expected_type(
63
+ annotation.annotation_id, "TraceAnnotation"
64
+ )
65
+ patch = {
66
+ column.key: patch_value
67
+ for column, patch_value, column_is_nullable in (
68
+ (models.TraceAnnotation.name, annotation.name, False),
69
+ (models.TraceAnnotation.annotator_kind, annotation.annotator_kind, False),
70
+ (models.TraceAnnotation.label, annotation.label, True),
71
+ (models.TraceAnnotation.score, annotation.score, True),
72
+ (models.TraceAnnotation.explanation, annotation.explanation, True),
73
+ (models.TraceAnnotation.metadata_, annotation.metadata, False),
74
+ )
75
+ if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
76
+ }
77
+ trace_annotation = await session.scalar(
78
+ update(models.TraceAnnotation)
79
+ .where(models.TraceAnnotation.id == trace_annotation_id)
80
+ .values(**patch)
81
+ .returning(models.TraceAnnotation)
82
+ )
83
+ if trace_annotation:
84
+ patched_annotations.append(to_gql_trace_annotation(trace_annotation))
85
+
86
+ return TraceAnnotationMutationPayload(trace_annotations=patched_annotations)
87
+
88
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
89
+ async def delete_trace_annotations(
90
+ self, info: Info[Context, None], input: DeleteAnnotationsInput
91
+ ) -> TraceAnnotationMutationPayload:
92
+ trace_annotation_ids = [
93
+ from_global_id_with_expected_type(global_id, "TraceAnnotation")
94
+ for global_id in input.annotation_ids
95
+ ]
96
+ async with info.context.db() as session:
97
+ stmt = (
98
+ delete(models.TraceAnnotation)
99
+ .where(models.TraceAnnotation.id.in_(trace_annotation_ids))
100
+ .returning(models.TraceAnnotation)
101
+ )
102
+ result = await session.scalars(stmt)
103
+ deleted_annotations = result.all()
104
+
105
+ deleted_annotations_gql = [
106
+ to_gql_trace_annotation(annotation) for annotation in deleted_annotations
107
+ ]
108
+ return TraceAnnotationMutationPayload(trace_annotations=deleted_annotations_gql)
@@ -92,6 +92,7 @@ class Query:
92
92
  models.Project.name == models.Experiment.project_name,
93
93
  )
94
94
  .where(models.Experiment.project_name.is_(None))
95
+ .order_by(models.Project.id)
95
96
  )
96
97
  async with info.context.db() as session:
97
98
  projects = await session.stream_scalars(stmt)
@@ -40,8 +40,10 @@ V1_ROUTES = [
40
40
  Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
41
41
  Route("/v1/evaluations", evaluations.get_evaluations, methods=["GET"]),
42
42
  Route("/v1/traces", traces.post_traces, methods=["POST"]),
43
+ Route("/v1/trace_annotations", traces.annotate_traces, methods=["POST"]),
43
44
  Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
44
45
  Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
46
+ Route("/v1/span_annotations", spans.annotate_spans, methods=["POST"]),
45
47
  Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
46
48
  Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
47
49
  Route("/v1/datasets/{id:str}", datasets.delete_dataset_by_id, methods=["DELETE"]),
@@ -7,7 +7,7 @@ from strawberry.relay import GlobalID
7
7
 
8
8
  from phoenix.db import models
9
9
  from phoenix.db.helpers import SupportedSQLDialect
10
- from phoenix.db.insertion.helpers import OnConflict, insert_on_conflict
10
+ from phoenix.db.insertion.helpers import insert_on_conflict
11
11
  from phoenix.server.api.types.node import from_global_id_with_expected_type
12
12
 
13
13
 
@@ -123,20 +123,13 @@ async def upsert_experiment_evaluation(request: Request) -> Response:
123
123
  end_time=datetime.fromisoformat(end_time),
124
124
  trace_id=payload.get("trace_id"),
125
125
  )
126
- set_ = {
127
- **{k: v for k, v in values.items() if k != "metadata_"},
128
- "metadata": values["metadata_"], # `metadata` must match database
129
- }
130
126
  dialect = SupportedSQLDialect(session.bind.dialect.name)
131
127
  exp_eval_run = await session.scalar(
132
128
  insert_on_conflict(
129
+ values,
133
130
  dialect=dialect,
134
131
  table=models.ExperimentRunAnnotation,
135
- values=values,
136
- constraint="uq_experiment_run_annotations_experiment_run_id_name",
137
- column_names=("experiment_run_id", "name"),
138
- on_conflict=OnConflict.DO_UPDATE,
139
- set_=set_,
132
+ unique_by=("experiment_run_id", "name"),
140
133
  ).returning(models.ExperimentRunAnnotation)
141
134
  )
142
135
  evaluation_gid = GlobalID("ExperimentEvaluation", str(exp_eval_run.id))
@@ -180,16 +180,15 @@ async def create_experiment(request: Request) -> Response:
180
180
  dialect = SupportedSQLDialect(session.bind.dialect.name)
181
181
  project_rowid = await session.scalar(
182
182
  insert_on_conflict(
183
- dialect=dialect,
184
- table=models.Project,
185
- constraint="uq_projects_name",
186
- column_names=("name",),
187
- values=dict(
183
+ dict(
188
184
  name=project_name,
189
185
  description=project_description,
190
186
  created_at=experiment.created_at,
191
187
  updated_at=experiment.updated_at,
192
188
  ),
189
+ dialect=dialect,
190
+ table=models.Project,
191
+ unique_by=("name",),
193
192
  ).returning(models.Project.id)
194
193
  )
195
194
  assert project_rowid is not None
@@ -1,13 +1,19 @@
1
1
  from datetime import timezone
2
- from typing import AsyncIterator
2
+ from typing import Any, AsyncIterator, Dict, List
3
3
 
4
+ from sqlalchemy import select
4
5
  from starlette.requests import Request
5
- from starlette.responses import Response, StreamingResponse
6
+ from starlette.responses import JSONResponse, Response, StreamingResponse
6
7
  from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY
8
+ from strawberry.relay import GlobalID
7
9
 
8
10
  from phoenix.config import DEFAULT_PROJECT_NAME
9
11
  from phoenix.datetime_utils import normalize_datetime
12
+ from phoenix.db import models
13
+ from phoenix.db.helpers import SupportedSQLDialect
14
+ from phoenix.db.insertion.helpers import insert_on_conflict
10
15
  from phoenix.server.api.routers.utils import df_to_bytes, from_iso_format
16
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
11
17
  from phoenix.trace.dsl import SpanQuery
12
18
 
13
19
  DEFAULT_SPAN_LIMIT = 1000
@@ -128,3 +134,142 @@ async def query_spans_handler(request: Request) -> Response:
128
134
 
129
135
  async def get_spans_handler(request: Request) -> Response:
130
136
  return await query_spans_handler(request)
137
+
138
+
139
+ async def annotate_spans(request: Request) -> Response:
140
+ """
141
+ summary: Upsert annotations for spans
142
+ operationId: annotateSpans
143
+ tags:
144
+ - private
145
+ requestBody:
146
+ description: List of span annotations to be inserted
147
+ required: true
148
+ content:
149
+ application/json:
150
+ schema:
151
+ type: object
152
+ properties:
153
+ data:
154
+ type: array
155
+ items:
156
+ type: object
157
+ properties:
158
+ span_id:
159
+ type: string
160
+ description: The ID of the span being annotated
161
+ name:
162
+ type: string
163
+ description: The name of the annotation
164
+ annotator_kind:
165
+ type: string
166
+ description: The kind of annotator used for the annotation ("LLM" or "HUMAN")
167
+ result:
168
+ type: object
169
+ description: The result of the annotation
170
+ properties:
171
+ label:
172
+ type: string
173
+ description: The label assigned by the annotation
174
+ score:
175
+ type: number
176
+ format: float
177
+ description: The score assigned by the annotation
178
+ explanation:
179
+ type: string
180
+ description: Explanation of the annotation result
181
+ error:
182
+ type: string
183
+ description: Optional error message if the annotation encountered an error
184
+ metadata:
185
+ type: object
186
+ description: Metadata for the annotation
187
+ additionalProperties:
188
+ type: string
189
+ required:
190
+ - span_id
191
+ - name
192
+ - annotator_kind
193
+ responses:
194
+ 200:
195
+ description: Span annotations inserted successfully
196
+ content:
197
+ application/json:
198
+ schema:
199
+ type: object
200
+ properties:
201
+ data:
202
+ type: array
203
+ items:
204
+ type: object
205
+ properties:
206
+ id:
207
+ type: string
208
+ description: The ID of the inserted span annotation
209
+ 404:
210
+ description: Span not found
211
+ """
212
+ payload: List[Dict[str, Any]] = (await request.json()).get("data", [])
213
+ span_gids = [GlobalID.from_id(annotation["span_id"]) for annotation in payload]
214
+
215
+ resolved_span_ids = []
216
+ for span_gid in span_gids:
217
+ try:
218
+ resolved_span_ids.append(from_global_id_with_expected_type(span_gid, "Span"))
219
+ except ValueError:
220
+ return Response(
221
+ content="Span with ID {span_gid} does not exist",
222
+ status_code=HTTP_404_NOT_FOUND,
223
+ )
224
+
225
+ async with request.app.state.db() as session:
226
+ spans = await session.execute(
227
+ select(models.Span).filter(models.Span.id.in_(resolved_span_ids))
228
+ )
229
+ existing_span_ids = {span.id for span in spans.scalars()}
230
+
231
+ missing_span_ids = set(resolved_span_ids) - existing_span_ids
232
+ if missing_span_ids:
233
+ missing_span_gids = [
234
+ str(GlobalID("Span", str(span_gid))) for span_gid in missing_span_ids
235
+ ]
236
+ return Response(
237
+ content=f"Spans with IDs {', '.join(missing_span_gids)} do not exist.",
238
+ status_code=HTTP_404_NOT_FOUND,
239
+ )
240
+
241
+ inserted_annotations = []
242
+ for annotation in payload:
243
+ span_gid = GlobalID.from_id(annotation["span_id"])
244
+ span_id = from_global_id_with_expected_type(span_gid, "Span")
245
+ name = annotation["name"]
246
+ annotator_kind = annotation["annotator_kind"]
247
+ result = annotation.get("result")
248
+ label = result.get("label") if result else None
249
+ score = result.get("score") if result else None
250
+ explanation = result.get("explanation") if result else None
251
+ metadata = annotation.get("metadata") or {}
252
+
253
+ values = dict(
254
+ span_rowid=span_id,
255
+ name=name,
256
+ label=label,
257
+ score=score,
258
+ explanation=explanation,
259
+ annotator_kind=annotator_kind,
260
+ metadata_=metadata,
261
+ )
262
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
263
+ span_annotation_id = await session.scalar(
264
+ insert_on_conflict(
265
+ values,
266
+ dialect=dialect,
267
+ table=models.SpanAnnotation,
268
+ unique_by=("name", "span_rowid"),
269
+ ).returning(models.SpanAnnotation.id)
270
+ )
271
+ inserted_annotations.append(
272
+ {"id": str(GlobalID("SpanAnnotation", str(span_annotation_id)))}
273
+ )
274
+
275
+ return JSONResponse(content={"data": inserted_annotations})
@@ -1,20 +1,28 @@
1
1
  import gzip
2
2
  import zlib
3
+ from typing import Any, Dict, List
3
4
 
4
5
  from google.protobuf.message import DecodeError
5
6
  from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
6
7
  ExportTraceServiceRequest,
7
8
  )
9
+ from sqlalchemy import select
8
10
  from starlette.background import BackgroundTask
9
11
  from starlette.concurrency import run_in_threadpool
10
12
  from starlette.datastructures import State
11
13
  from starlette.requests import Request
12
- from starlette.responses import Response
14
+ from starlette.responses import JSONResponse, Response
13
15
  from starlette.status import (
16
+ HTTP_404_NOT_FOUND,
14
17
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
15
18
  HTTP_422_UNPROCESSABLE_ENTITY,
16
19
  )
20
+ from strawberry.relay import GlobalID
17
21
 
22
+ from phoenix.db import models
23
+ from phoenix.db.helpers import SupportedSQLDialect
24
+ from phoenix.db.insertion.helpers import insert_on_conflict
25
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
18
26
  from phoenix.trace.otel import decode_otlp_span
19
27
  from phoenix.utilities.project import get_project_name
20
28
 
@@ -70,6 +78,147 @@ async def post_traces(request: Request) -> Response:
70
78
  return Response(background=BackgroundTask(_add_spans, req, request.state))
71
79
 
72
80
 
81
+ async def annotate_traces(request: Request) -> Response:
82
+ """
83
+ summary: Upsert annotations for traces
84
+ operationId: annotateTraces
85
+ tags:
86
+ - private
87
+ requestBody:
88
+ description: List of trace annotations to be inserted
89
+ required: true
90
+ content:
91
+ application/json:
92
+ schema:
93
+ type: object
94
+ properties:
95
+ data:
96
+ type: array
97
+ items:
98
+ type: object
99
+ properties:
100
+ trace_id:
101
+ type: string
102
+ description: The ID of the trace being annotated
103
+ name:
104
+ type: string
105
+ description: The name of the annotation
106
+ annotator_kind:
107
+ type: string
108
+ description: The kind of annotator used for the annotation ("LLM" or "HUMAN")
109
+ result:
110
+ type: object
111
+ description: The result of the annotation
112
+ properties:
113
+ label:
114
+ type: string
115
+ description: The label assigned by the annotation
116
+ score:
117
+ type: number
118
+ format: float
119
+ description: The score assigned by the annotation
120
+ explanation:
121
+ type: string
122
+ description: Explanation of the annotation result
123
+ error:
124
+ type: string
125
+ description: Optional error message if the annotation encountered an error
126
+ metadata:
127
+ type: object
128
+ description: Metadata for the annotation
129
+ additionalProperties:
130
+ type: string
131
+ required:
132
+ - trace_id
133
+ - name
134
+ - annotator_kind
135
+ responses:
136
+ 200:
137
+ description: Trace annotations inserted successfully
138
+ content:
139
+ application/json:
140
+ schema:
141
+ type: object
142
+ properties:
143
+ data:
144
+ type: array
145
+ items:
146
+ type: object
147
+ properties:
148
+ id:
149
+ type: string
150
+ description: The ID of the inserted trace annotation
151
+ 404:
152
+ description: Trace not found
153
+ """
154
+ payload: List[Dict[str, Any]] = (await request.json()).get("data", [])
155
+ trace_gids = [GlobalID.from_id(annotation["trace_id"]) for annotation in payload]
156
+
157
+ resolved_trace_ids = []
158
+ for trace_gid in trace_gids:
159
+ try:
160
+ resolved_trace_ids.append(from_global_id_with_expected_type(trace_gid, "Trace"))
161
+ except ValueError:
162
+ return Response(
163
+ content="Trace with ID {trace_gid} does not exist",
164
+ status_code=HTTP_404_NOT_FOUND,
165
+ )
166
+
167
+ async with request.app.state.db() as session:
168
+ traces = await session.execute(
169
+ select(models.Trace).filter(models.Trace.id.in_(resolved_trace_ids))
170
+ )
171
+ existing_trace_ids = {trace.id for trace in traces.scalars()}
172
+
173
+ missing_trace_ids = set(resolved_trace_ids) - existing_trace_ids
174
+ if missing_trace_ids:
175
+ missing_trace_gids = [
176
+ str(GlobalID("Trace", str(trace_gid))) for trace_gid in missing_trace_ids
177
+ ]
178
+ return Response(
179
+ content=f"Traces with IDs {', '.join(missing_trace_gids)} do not exist.",
180
+ status_code=HTTP_404_NOT_FOUND,
181
+ )
182
+
183
+ inserted_annotations = []
184
+
185
+ for annotation in payload:
186
+ trace_gid = GlobalID.from_id(annotation["trace_id"])
187
+ trace_id = from_global_id_with_expected_type(trace_gid, "Trace")
188
+
189
+ name = annotation["name"]
190
+ annotator_kind = annotation["annotator_kind"]
191
+ result = annotation.get("result")
192
+ label = result.get("label") if result else None
193
+ score = result.get("score") if result else None
194
+ explanation = result.get("explanation") if result else None
195
+ metadata = annotation.get("metadata") or {}
196
+
197
+ values = dict(
198
+ trace_rowid=trace_id,
199
+ name=name,
200
+ label=label,
201
+ score=score,
202
+ explanation=explanation,
203
+ annotator_kind=annotator_kind,
204
+ metadata_=metadata,
205
+ )
206
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
207
+ trace_annotation_id = await session.scalar(
208
+ insert_on_conflict(
209
+ values,
210
+ dialect=dialect,
211
+ table=models.TraceAnnotation,
212
+ unique_by=("name", "trace_rowid"),
213
+ ).returning(models.TraceAnnotation.id)
214
+ )
215
+ inserted_annotations.append(
216
+ {"id": str(GlobalID("TraceAnnotation", str(trace_annotation_id)))}
217
+ )
218
+
219
+ return JSONResponse(content={"data": inserted_annotations})
220
+
221
+
73
222
  async def _add_spans(req: ExportTraceServiceRequest, state: State) -> None:
74
223
  for resource_spans in req.resource_spans:
75
224
  project_name = get_project_name(resource_spans.resource.attributes)
@@ -4,7 +4,13 @@ import strawberry
4
4
 
5
5
 
6
6
  @strawberry.enum
7
- class AnnotatorKind(Enum):
7
+ class ExperimentRunAnnotatorKind(Enum):
8
8
  LLM = "LLM"
9
9
  HUMAN = "HUMAN"
10
10
  CODE = "CODE"
11
+
12
+
13
+ @strawberry.enum
14
+ class AnnotatorKind(Enum):
15
+ LLM = "LLM"
16
+ HUMAN = "HUMAN"
@@ -7,7 +7,7 @@ from strawberry.relay import Node, NodeID
7
7
  from strawberry.scalars import JSON
8
8
 
9
9
  from phoenix.db import models
10
- from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
10
+ from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind
11
11
  from phoenix.server.api.types.Trace import Trace
12
12
 
13
13
 
@@ -15,7 +15,7 @@ from phoenix.server.api.types.Trace import Trace
15
15
  class ExperimentRunAnnotation(Node):
16
16
  id_attr: NodeID[int]
17
17
  name: str
18
- annotator_kind: AnnotatorKind
18
+ annotator_kind: ExperimentRunAnnotatorKind
19
19
  label: Optional[str]
20
20
  score: Optional[float]
21
21
  explanation: Optional[str]
@@ -45,7 +45,7 @@ def to_gql_experiment_run_annotation(
45
45
  return ExperimentRunAnnotation(
46
46
  id_attr=annotation.id,
47
47
  name=annotation.name,
48
- annotator_kind=AnnotatorKind(annotation.annotator_kind),
48
+ annotator_kind=ExperimentRunAnnotatorKind(annotation.annotator_kind),
49
49
  label=annotation.label,
50
50
  score=annotation.score,
51
51
  explanation=annotation.explanation,