arize-phoenix 11.30.0__py3-none-any.whl → 11.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (27) hide show
  1. {arize_phoenix-11.30.0.dist-info → arize_phoenix-11.32.0.dist-info}/METADATA +17 -17
  2. {arize_phoenix-11.30.0.dist-info → arize_phoenix-11.32.0.dist-info}/RECORD +26 -25
  3. phoenix/db/types/trace_retention.py +1 -1
  4. phoenix/experiments/functions.py +69 -19
  5. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  6. phoenix/server/api/routers/v1/annotations.py +128 -5
  7. phoenix/server/api/routers/v1/documents.py +47 -79
  8. phoenix/server/api/routers/v1/experiment_runs.py +71 -31
  9. phoenix/server/api/routers/v1/spans.py +2 -48
  10. phoenix/server/api/routers/v1/traces.py +19 -55
  11. phoenix/server/api/types/Dataset.py +8 -66
  12. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  13. phoenix/server/api/types/DocumentAnnotation.py +92 -0
  14. phoenix/server/api/types/Span.py +8 -2
  15. phoenix/server/api/types/TraceAnnotation.py +8 -5
  16. phoenix/server/cost_tracking/model_cost_manifest.json +91 -0
  17. phoenix/server/static/.vite/manifest.json +9 -9
  18. phoenix/server/static/assets/{components-BBwXqJXQ.js → components-Cs9c4Nxp.js} +1 -1
  19. phoenix/server/static/assets/{index-C_gU3x10.js → index-D1FDMBMV.js} +1 -1
  20. phoenix/server/static/assets/{pages-YmQb55Uo.js → pages-Cbj9SjBx.js} +488 -463
  21. phoenix/trace/projects.py +6 -0
  22. phoenix/version.py +1 -1
  23. phoenix/server/api/types/Evaluation.py +0 -40
  24. {arize_phoenix-11.30.0.dist-info → arize_phoenix-11.32.0.dist-info}/WHEEL +0 -0
  25. {arize_phoenix-11.30.0.dist-info → arize_phoenix-11.32.0.dist-info}/entry_points.txt +0 -0
  26. {arize_phoenix-11.30.0.dist-info → arize_phoenix-11.32.0.dist-info}/licenses/IP_NOTICE +0 -0
  27. {arize_phoenix-11.30.0.dist-info → arize_phoenix-11.32.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,4 @@
1
- from datetime import datetime, timezone
2
- from typing import Any, Literal, Optional
1
+ from typing import Optional
3
2
 
4
3
  from fastapi import APIRouter, Depends, HTTPException, Query
5
4
  from pydantic import Field
@@ -11,64 +10,19 @@ from strawberry.relay import GlobalID
11
10
  from phoenix.db import models
12
11
  from phoenix.db.helpers import SupportedSQLDialect
13
12
  from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
14
- from phoenix.db.insertion.types import Precursors
15
- from phoenix.server.api.types.Evaluation import DocumentAnnotation
13
+ from phoenix.server.api.routers.v1.annotations import SpanDocumentAnnotationData
14
+ from phoenix.server.api.types.DocumentAnnotation import DocumentAnnotation
16
15
  from phoenix.server.authorization import is_not_locked
17
16
  from phoenix.server.bearer_auth import PhoenixUser
18
17
  from phoenix.server.dml_event import DocumentAnnotationInsertEvent
19
18
 
20
19
  from .models import V1RoutesBaseModel
21
- from .spans import SpanAnnotationResult
22
20
  from .utils import RequestBody, ResponseBody, add_errors_to_responses
23
21
 
24
22
  # Since the document annotations are spans related, we place it under spans
25
23
  router = APIRouter(tags=["spans"])
26
24
 
27
25
 
28
- class SpanDocumentAnnotationData(V1RoutesBaseModel):
29
- span_id: str = Field(description="OpenTelemetry Span ID (hex format w/o 0x prefix)")
30
- name: str = Field(description="The name of the document annotation. E.x. relevance")
31
- annotator_kind: Literal["LLM", "CODE", "HUMAN"] = Field(
32
- description="The kind of annotator. E.g. llm judge, a heuristic piece of code, or a human"
33
- )
34
- document_position: int = Field(
35
- description="A 0 based index of the document. E.x. the first document during retrieval is 0"
36
- )
37
- result: Optional[SpanAnnotationResult] = Field(
38
- default=None, description="The score and or label of the annotation"
39
- )
40
- metadata: Optional[dict[str, Any]] = Field(
41
- default=None, description="Metadata for custom values of the annotation"
42
- )
43
- identifier: str = Field(
44
- default="",
45
- description=(
46
- "An custom ID for the annotation. "
47
- "If provided, the annotation will be updated if it already exists."
48
- ),
49
- )
50
-
51
- # Precursor here means a value to add to a queue for processing async
52
- def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.DocumentAnnotation:
53
- return Precursors.DocumentAnnotation(
54
- datetime.now(timezone.utc),
55
- self.span_id,
56
- self.document_position,
57
- models.DocumentAnnotation(
58
- name=self.name,
59
- annotator_kind=self.annotator_kind,
60
- document_position=self.document_position,
61
- score=self.result.score if self.result else None,
62
- label=self.result.label if self.result else None,
63
- explanation=self.result.explanation if self.result else None,
64
- metadata_=self.metadata or {},
65
- identifier=self.identifier,
66
- source="API",
67
- user_id=user_id,
68
- ),
69
- )
70
-
71
-
72
26
  class AnnotateSpanDocumentsRequestBody(RequestBody[list[SpanDocumentAnnotationData]]):
73
27
  pass
74
28
 
@@ -90,7 +44,11 @@ class AnnotateSpanDocumentsResponseBody(ResponseBody[list[InsertedSpanDocumentAn
90
44
  {
91
45
  "status_code": HTTP_404_NOT_FOUND,
92
46
  "description": "Span not found",
93
- }
47
+ },
48
+ {
49
+ "status_code": 422,
50
+ "description": "Invalid request - non-empty identifier not supported",
51
+ },
94
52
  ]
95
53
  ),
96
54
  response_description="Span document annotation inserted successfully",
@@ -106,6 +64,14 @@ async def annotate_span_documents(
106
64
  if not request_body.data:
107
65
  return AnnotateSpanDocumentsResponseBody(data=[])
108
66
 
67
+ # Validate that identifiers are empty or only whitespace
68
+ for annotation in request_body.data:
69
+ if annotation.identifier.strip():
70
+ raise HTTPException(
71
+ detail=f"Non-empty identifier '{annotation.identifier}' is not supported",
72
+ status_code=422, # Unprocessable Entity
73
+ )
74
+
109
75
  user_id: Optional[int] = None
110
76
  if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
111
77
  user_id = int(request.user.identity)
@@ -117,6 +83,7 @@ async def annotate_span_documents(
117
83
  ]
118
84
  if not sync:
119
85
  await request.state.enqueue(*precursors)
86
+ return AnnotateSpanDocumentsResponseBody(data=[])
120
87
 
121
88
  span_ids = {p.span_id for p in precursors}
122
89
  # Account for the fact that the spans could arrive after the annotation
@@ -130,38 +97,39 @@ async def annotate_span_documents(
130
97
  )
131
98
  }
132
99
 
133
- missing_span_ids = span_ids - set(existing_spans.keys())
134
- # We prefer to fail the entire operation if there are missing spans in sync mode
135
- if missing_span_ids:
136
- raise HTTPException(
137
- detail=f"Spans with IDs {', '.join(missing_span_ids)} do not exist.",
138
- status_code=HTTP_404_NOT_FOUND,
139
- )
140
-
141
- # Validate that document positions are within bounds
142
- for annotation in span_document_annotations:
143
- _, num_docs = existing_spans[annotation.span_id]
144
- if annotation.document_position not in range(num_docs):
100
+ missing_span_ids = span_ids - set(existing_spans.keys())
101
+ # We prefer to fail the entire operation if there are missing spans in sync mode
102
+ if missing_span_ids:
145
103
  raise HTTPException(
146
- detail=f"Document position {annotation.document_position} is out of bounds for "
147
- f"span {annotation.span_id} (max: {num_docs - 1})",
148
- status_code=422, # Unprocessable Entity
104
+ detail=f"Spans with IDs {', '.join(missing_span_ids)} do not exist.",
105
+ status_code=HTTP_404_NOT_FOUND,
149
106
  )
150
107
 
151
- inserted_document_annotation_ids = []
152
- dialect = SupportedSQLDialect(session.bind.dialect.name)
153
- for anno in precursors:
154
- span_rowid, _ = existing_spans[anno.span_id]
155
- values = dict(as_kv(anno.as_insertable(span_rowid).row))
156
- span_document_annotation_id = await session.scalar(
157
- insert_on_conflict(
158
- values,
159
- dialect=dialect,
160
- table=models.DocumentAnnotation,
161
- unique_by=("name", "span_rowid", "identifier", "document_position"),
162
- ).returning(models.DocumentAnnotation.id)
163
- )
164
- inserted_document_annotation_ids.append(span_document_annotation_id)
108
+ # Validate that document positions are within bounds
109
+ for annotation in span_document_annotations:
110
+ _, num_docs = existing_spans[annotation.span_id]
111
+ if annotation.document_position not in range(num_docs):
112
+ raise HTTPException(
113
+ detail=f"Document position {annotation.document_position} is out of bounds for "
114
+ f"span {annotation.span_id} (max: {num_docs - 1})",
115
+ status_code=422, # Unprocessable Entity
116
+ )
117
+
118
+ inserted_document_annotation_ids = []
119
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
120
+ for anno in precursors:
121
+ span_rowid, _ = existing_spans[anno.span_id]
122
+ values = dict(as_kv(anno.as_insertable(span_rowid).row))
123
+ span_document_annotation_id = await session.scalar(
124
+ insert_on_conflict(
125
+ values,
126
+ dialect=dialect,
127
+ table=models.DocumentAnnotation,
128
+ unique_by=("name", "span_rowid", "identifier", "document_position"),
129
+ constraint_name="uq_document_annotations_name_span_rowid_document_pos_identifier",
130
+ ).returning(models.DocumentAnnotation.id)
131
+ )
132
+ inserted_document_annotation_ids.append(span_document_annotation_id)
165
133
 
166
134
  # We queue an event to let the application know that annotations have changed
167
135
  request.state.event_queue.put(
@@ -1,13 +1,13 @@
1
1
  from datetime import datetime
2
2
  from typing import Any, Optional
3
3
 
4
- from fastapi import APIRouter, Depends, HTTPException
4
+ from fastapi import APIRouter, Depends, HTTPException, Query
5
5
  from pydantic import Field
6
6
  from sqlalchemy import select
7
7
  from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
8
8
  from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
9
9
  from starlette.requests import Request
10
- from starlette.status import HTTP_404_NOT_FOUND, HTTP_409_CONFLICT
10
+ from starlette.status import HTTP_404_NOT_FOUND, HTTP_409_CONFLICT, HTTP_422_UNPROCESSABLE_ENTITY
11
11
  from strawberry.relay import GlobalID
12
12
 
13
13
  from phoenix.db import models
@@ -17,7 +17,7 @@ from phoenix.server.authorization import is_not_locked
17
17
  from phoenix.server.dml_event import ExperimentRunInsertEvent
18
18
 
19
19
  from .models import V1RoutesBaseModel
20
- from .utils import ResponseBody, add_errors_to_responses
20
+ from .utils import PaginatedResponseBody, ResponseBody, add_errors_to_responses
21
21
 
22
22
  router = APIRouter(tags=["experiments"], include_in_schema=True)
23
23
 
@@ -129,7 +129,7 @@ class ExperimentRunResponse(ExperimentRun):
129
129
  experiment_id: str = Field(description="The ID of the experiment")
130
130
 
131
131
 
132
- class ListExperimentRunsResponseBody(ResponseBody[list[ExperimentRunResponse]]):
132
+ class ListExperimentRunsResponseBody(PaginatedResponseBody[ExperimentRunResponse]):
133
133
  pass
134
134
 
135
135
 
@@ -137,13 +137,28 @@ class ListExperimentRunsResponseBody(ResponseBody[list[ExperimentRunResponse]]):
137
137
  "/experiments/{experiment_id}/runs",
138
138
  operation_id="listExperimentRuns",
139
139
  summary="List runs for an experiment",
140
+ description="Retrieve a paginated list of runs for an experiment",
140
141
  response_description="Experiment runs retrieved successfully",
141
142
  responses=add_errors_to_responses(
142
- [{"status_code": HTTP_404_NOT_FOUND, "description": "Experiment not found"}]
143
+ [
144
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Experiment not found"},
145
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid cursor format"},
146
+ ]
143
147
  ),
144
148
  )
145
149
  async def list_experiment_runs(
146
- request: Request, experiment_id: str
150
+ request: Request,
151
+ experiment_id: str,
152
+ cursor: Optional[str] = Query(
153
+ default=None,
154
+ description="Cursor for pagination (base64-encoded experiment run ID)",
155
+ ),
156
+ limit: Optional[int] = Query(
157
+ default=None,
158
+ description="The max number of experiment runs to return at a time. "
159
+ "If not specified, returns all results.",
160
+ gt=0,
161
+ ),
147
162
  ) -> ListExperimentRunsResponseBody:
148
163
  experiment_gid = GlobalID.from_id(experiment_id)
149
164
  try:
@@ -154,30 +169,55 @@ async def list_experiment_runs(
154
169
  status_code=HTTP_404_NOT_FOUND,
155
170
  )
156
171
 
172
+ stmt = (
173
+ select(models.ExperimentRun)
174
+ .filter_by(experiment_id=experiment_rowid)
175
+ .order_by(models.ExperimentRun.id.desc())
176
+ )
177
+
178
+ if cursor:
179
+ try:
180
+ cursor_id = GlobalID.from_id(cursor).node_id
181
+ stmt = stmt.where(models.ExperimentRun.id <= int(cursor_id))
182
+ except ValueError:
183
+ raise HTTPException(
184
+ detail=f"Invalid cursor format: {cursor}",
185
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
186
+ )
187
+
188
+ # Apply limit only if specified for pagination
189
+ if limit is not None:
190
+ stmt = stmt.limit(limit + 1)
191
+
157
192
  async with request.app.state.db() as session:
158
- experiment_runs = await session.execute(
159
- select(models.ExperimentRun)
160
- .where(models.ExperimentRun.experiment_id == experiment_rowid)
161
- # order by dataset_example_id to be consistent with `list_dataset_examples`
162
- .order_by(models.ExperimentRun.dataset_example_id.asc())
163
- )
164
- experiment_runs = experiment_runs.scalars().all()
165
- runs = []
166
- for exp_run in experiment_runs:
167
- run_gid = GlobalID("ExperimentRun", str(exp_run.id))
168
- experiment_gid = GlobalID("Experiment", str(exp_run.experiment_id))
169
- example_gid = GlobalID("DatasetExample", str(exp_run.dataset_example_id))
170
- runs.append(
171
- ExperimentRunResponse(
172
- start_time=exp_run.start_time,
173
- end_time=exp_run.end_time,
174
- experiment_id=str(experiment_gid),
175
- dataset_example_id=str(example_gid),
176
- repetition_number=exp_run.repetition_number,
177
- output=exp_run.output.get("task_output"),
178
- error=exp_run.error,
179
- id=str(run_gid),
180
- trace_id=exp_run.trace_id,
181
- )
193
+ experiment_runs = (await session.scalars(stmt)).all()
194
+
195
+ if not experiment_runs:
196
+ return ListExperimentRunsResponseBody(next_cursor=None, data=[])
197
+
198
+ next_cursor = None
199
+ # Only check for next cursor if limit was specified
200
+ if limit is not None and len(experiment_runs) == limit + 1:
201
+ last_run = experiment_runs[-1]
202
+ next_cursor = str(GlobalID("ExperimentRun", str(last_run.id)))
203
+ experiment_runs = experiment_runs[:-1]
204
+
205
+ runs = []
206
+ for exp_run in experiment_runs:
207
+ run_gid = GlobalID("ExperimentRun", str(exp_run.id))
208
+ experiment_gid = GlobalID("Experiment", str(exp_run.experiment_id))
209
+ example_gid = GlobalID("DatasetExample", str(exp_run.dataset_example_id))
210
+ runs.append(
211
+ ExperimentRunResponse(
212
+ start_time=exp_run.start_time,
213
+ end_time=exp_run.end_time,
214
+ experiment_id=str(experiment_gid),
215
+ dataset_example_id=str(example_gid),
216
+ repetition_number=exp_run.repetition_number,
217
+ output=exp_run.output.get("task_output"),
218
+ error=exp_run.error,
219
+ id=str(run_gid),
220
+ trace_id=exp_run.trace_id,
182
221
  )
183
- return ListExperimentRunsResponseBody(data=runs)
222
+ )
223
+ return ListExperimentRunsResponseBody(data=runs, next_cursor=next_cursor)
@@ -5,7 +5,7 @@ from collections.abc import AsyncIterator
5
5
  from datetime import datetime, timezone
6
6
  from enum import Enum
7
7
  from secrets import token_urlsafe
8
- from typing import Annotated, Any, Literal, Optional, Union
8
+ from typing import Annotated, Any, Optional, Union
9
9
 
10
10
  import pandas as pd
11
11
  import sqlalchemy as sa
@@ -27,8 +27,8 @@ from phoenix.datetime_utils import normalize_datetime
27
27
  from phoenix.db import models
28
28
  from phoenix.db.helpers import SupportedSQLDialect, get_ancestor_span_rowids
29
29
  from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
30
- from phoenix.db.insertion.types import Precursors
31
30
  from phoenix.server.api.routers.utils import df_to_bytes
31
+ from phoenix.server.api.routers.v1.annotations import SpanAnnotationData
32
32
  from phoenix.server.api.types.node import from_global_id_with_expected_type
33
33
  from phoenix.server.authorization import is_not_locked
34
34
  from phoenix.server.bearer_auth import PhoenixUser
@@ -850,52 +850,6 @@ async def get_spans_handler(
850
850
  return await query_spans_handler(request, request_body, project_name)
851
851
 
852
852
 
853
- class SpanAnnotationResult(V1RoutesBaseModel):
854
- label: Optional[str] = Field(default=None, description="The label assigned by the annotation")
855
- score: Optional[float] = Field(default=None, description="The score assigned by the annotation")
856
- explanation: Optional[str] = Field(
857
- default=None, description="Explanation of the annotation result"
858
- )
859
-
860
-
861
- class SpanAnnotationData(V1RoutesBaseModel):
862
- span_id: str = Field(description="OpenTelemetry Span ID (hex format w/o 0x prefix)")
863
- name: str = Field(description="The name of the annotation")
864
- annotator_kind: Literal["LLM", "CODE", "HUMAN"] = Field(
865
- description="The kind of annotator used for the annotation"
866
- )
867
- result: Optional[SpanAnnotationResult] = Field(
868
- default=None, description="The result of the annotation"
869
- )
870
- metadata: Optional[dict[str, Any]] = Field(
871
- default=None, description="Metadata for the annotation"
872
- )
873
- identifier: str = Field(
874
- default="",
875
- description=(
876
- "The identifier of the annotation. "
877
- "If provided, the annotation will be updated if it already exists."
878
- ),
879
- )
880
-
881
- def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SpanAnnotation:
882
- return Precursors.SpanAnnotation(
883
- datetime.now(timezone.utc),
884
- self.span_id,
885
- models.SpanAnnotation(
886
- name=self.name,
887
- annotator_kind=self.annotator_kind,
888
- score=self.result.score if self.result else None,
889
- label=self.result.label if self.result else None,
890
- explanation=self.result.explanation if self.result else None,
891
- metadata_=self.metadata or {},
892
- identifier=self.identifier,
893
- source="API",
894
- user_id=user_id,
895
- ),
896
- )
897
-
898
-
899
853
  class AnnotateSpansRequestBody(RequestBody[list[SpanAnnotationData]]):
900
854
  data: list[SpanAnnotationData]
901
855
 
@@ -1,7 +1,6 @@
1
1
  import gzip
2
2
  import zlib
3
- from datetime import datetime, timezone
4
- from typing import Any, Literal, Optional
3
+ from typing import Optional
5
4
 
6
5
  from fastapi import APIRouter, BackgroundTasks, Depends, Header, HTTPException, Path, Query
7
6
  from google.protobuf.message import DecodeError
@@ -10,7 +9,7 @@ from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
10
9
  ExportTraceServiceResponse,
11
10
  )
12
11
  from pydantic import Field
13
- from sqlalchemy import delete, insert, select
12
+ from sqlalchemy import delete, select
14
13
  from starlette.concurrency import run_in_threadpool
15
14
  from starlette.datastructures import State
16
15
  from starlette.requests import Request
@@ -23,8 +22,9 @@ from starlette.status import (
23
22
  from strawberry.relay import GlobalID
24
23
 
25
24
  from phoenix.db import models
26
- from phoenix.db.insertion.helpers import as_kv
27
- from phoenix.db.insertion.types import Precursors
25
+ from phoenix.db.helpers import SupportedSQLDialect
26
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
27
+ from phoenix.server.api.routers.v1.annotations import TraceAnnotationData
28
28
  from phoenix.server.api.types.node import from_global_id_with_expected_type
29
29
  from phoenix.server.authorization import is_not_locked
30
30
  from phoenix.server.bearer_auth import PhoenixUser
@@ -33,7 +33,11 @@ from phoenix.trace.otel import decode_otlp_span
33
33
  from phoenix.utilities.project import get_project_name
34
34
 
35
35
  from .models import V1RoutesBaseModel
36
- from .utils import RequestBody, ResponseBody, add_errors_to_responses
36
+ from .utils import (
37
+ RequestBody,
38
+ ResponseBody,
39
+ add_errors_to_responses,
40
+ )
37
41
 
38
42
  router = APIRouter(tags=["traces"])
39
43
 
@@ -105,54 +109,8 @@ async def post_traces(
105
109
  )
106
110
 
107
111
 
108
- class TraceAnnotationResult(V1RoutesBaseModel):
109
- label: Optional[str] = Field(default=None, description="The label assigned by the annotation")
110
- score: Optional[float] = Field(default=None, description="The score assigned by the annotation")
111
- explanation: Optional[str] = Field(
112
- default=None, description="Explanation of the annotation result"
113
- )
114
-
115
-
116
- class TraceAnnotation(V1RoutesBaseModel):
117
- trace_id: str = Field(description="OpenTelemetry Trace ID (hex format w/o 0x prefix)")
118
- name: str = Field(description="The name of the annotation")
119
- annotator_kind: Literal["LLM", "HUMAN"] = Field(
120
- description="The kind of annotator used for the annotation"
121
- )
122
- result: Optional[TraceAnnotationResult] = Field(
123
- default=None, description="The result of the annotation"
124
- )
125
- metadata: Optional[dict[str, Any]] = Field(
126
- default=None, description="Metadata for the annotation"
127
- )
128
- identifier: str = Field(
129
- default="",
130
- description=(
131
- "The identifier of the annotation. "
132
- "If provided, the annotation will be updated if it already exists."
133
- ),
134
- )
135
-
136
- def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.TraceAnnotation:
137
- return Precursors.TraceAnnotation(
138
- datetime.now(timezone.utc),
139
- self.trace_id,
140
- models.TraceAnnotation(
141
- name=self.name,
142
- annotator_kind=self.annotator_kind,
143
- score=self.result.score if self.result else None,
144
- label=self.result.label if self.result else None,
145
- explanation=self.result.explanation if self.result else None,
146
- metadata_=self.metadata or {},
147
- identifier=self.identifier,
148
- source="APP",
149
- user_id=user_id,
150
- ),
151
- )
152
-
153
-
154
- class AnnotateTracesRequestBody(RequestBody[list[TraceAnnotation]]):
155
- data: list[TraceAnnotation] = Field(description="The trace annotations to be upserted")
112
+ class AnnotateTracesRequestBody(RequestBody[list[TraceAnnotationData]]):
113
+ data: list[TraceAnnotationData] = Field(description="The trace annotations to be upserted")
156
114
 
157
115
 
158
116
  class InsertedTraceAnnotation(V1RoutesBaseModel):
@@ -208,10 +166,16 @@ async def annotate_traces(
208
166
  status_code=HTTP_404_NOT_FOUND,
209
167
  )
210
168
  inserted_ids = []
169
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
211
170
  for p in precursors:
212
171
  values = dict(as_kv(p.as_insertable(existing_traces[p.trace_id]).row))
213
172
  trace_annotation_id = await session.scalar(
214
- insert(models.TraceAnnotation).values(**values).returning(models.TraceAnnotation.id)
173
+ insert_on_conflict(
174
+ values,
175
+ dialect=dialect,
176
+ table=models.TraceAnnotation,
177
+ unique_by=("name", "trace_rowid", "identifier"),
178
+ ).returning(models.TraceAnnotation.id)
215
179
  )
216
180
  inserted_ids.append(trace_annotation_id)
217
181
  request.state.event_queue.put(TraceAnnotationInsertEvent(tuple(inserted_ids)))
@@ -15,9 +15,11 @@ from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest
16
16
  from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
17
17
  from phoenix.server.api.types.DatasetExample import DatasetExample
18
+ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
+ DatasetExperimentAnnotationSummary,
20
+ )
18
21
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
19
22
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
20
- from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
21
23
  from phoenix.server.api.types.node import from_global_id_with_expected_type
22
24
  from phoenix.server.api.types.pagination import (
23
25
  ConnectionArgs,
@@ -270,53 +272,13 @@ class Dataset(Node):
270
272
  @strawberry.field
271
273
  async def experiment_annotation_summaries(
272
274
  self, info: Info[Context, None]
273
- ) -> list[ExperimentAnnotationSummary]:
275
+ ) -> list[DatasetExperimentAnnotationSummary]:
274
276
  dataset_id = self.id_attr
275
- repetition_mean_scores_by_example_subquery = (
276
- select(
277
- models.ExperimentRunAnnotation.name.label("annotation_name"),
278
- func.avg(models.ExperimentRunAnnotation.score).label("mean_repetition_score"),
279
- )
280
- .select_from(models.ExperimentRunAnnotation)
281
- .join(
282
- models.ExperimentRun,
283
- models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
284
- )
285
- .join(
286
- models.Experiment,
287
- models.ExperimentRun.experiment_id == models.Experiment.id,
288
- )
289
- .where(models.Experiment.dataset_id == dataset_id)
290
- .group_by(
291
- models.ExperimentRun.dataset_example_id,
292
- models.ExperimentRunAnnotation.name,
293
- )
294
- .subquery()
295
- .alias("repetition_mean_scores_by_example")
296
- )
297
- repetition_mean_scores_subquery = (
298
- select(
299
- repetition_mean_scores_by_example_subquery.c.annotation_name.label(
300
- "annotation_name"
301
- ),
302
- func.avg(repetition_mean_scores_by_example_subquery.c.mean_repetition_score).label(
303
- "mean_score"
304
- ),
305
- )
306
- .select_from(repetition_mean_scores_by_example_subquery)
307
- .group_by(
308
- repetition_mean_scores_by_example_subquery.c.annotation_name,
309
- )
310
- .subquery()
311
- .alias("repetition_mean_scores")
312
- )
313
- repetitions_subquery = (
277
+ query = (
314
278
  select(
315
279
  models.ExperimentRunAnnotation.name.label("annotation_name"),
316
280
  func.min(models.ExperimentRunAnnotation.score).label("min_score"),
317
281
  func.max(models.ExperimentRunAnnotation.score).label("max_score"),
318
- func.count().label("count"),
319
- func.count(models.ExperimentRunAnnotation.error).label("error_count"),
320
282
  )
321
283
  .select_from(models.ExperimentRunAnnotation)
322
284
  .join(
@@ -329,36 +291,16 @@ class Dataset(Node):
329
291
  )
330
292
  .where(models.Experiment.dataset_id == dataset_id)
331
293
  .group_by(models.ExperimentRunAnnotation.name)
332
- .subquery()
333
- )
334
- run_scores_query = (
335
- select(
336
- repetition_mean_scores_subquery.c.annotation_name.label("annotation_name"),
337
- repetition_mean_scores_subquery.c.mean_score.label("mean_score"),
338
- repetitions_subquery.c.min_score.label("min_score"),
339
- repetitions_subquery.c.max_score.label("max_score"),
340
- repetitions_subquery.c.count.label("count_"),
341
- repetitions_subquery.c.error_count.label("error_count"),
342
- )
343
- .select_from(repetition_mean_scores_subquery)
344
- .join(
345
- repetitions_subquery,
346
- repetitions_subquery.c.annotation_name
347
- == repetition_mean_scores_subquery.c.annotation_name,
348
- )
349
- .order_by(repetition_mean_scores_subquery.c.annotation_name)
294
+ .order_by(models.ExperimentRunAnnotation.name)
350
295
  )
351
296
  async with info.context.db() as session:
352
297
  return [
353
- ExperimentAnnotationSummary(
298
+ DatasetExperimentAnnotationSummary(
354
299
  annotation_name=scores_tuple.annotation_name,
355
300
  min_score=scores_tuple.min_score,
356
301
  max_score=scores_tuple.max_score,
357
- mean_score=scores_tuple.mean_score,
358
- count=scores_tuple.count_,
359
- error_count=scores_tuple.error_count,
360
302
  )
361
- async for scores_tuple in await session.stream(run_scores_query)
303
+ async for scores_tuple in await session.stream(query)
362
304
  ]
363
305
 
364
306
  @strawberry.field
@@ -0,0 +1,10 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class DatasetExperimentAnnotationSummary:
8
+ annotation_name: str
9
+ min_score: Optional[float]
10
+ max_score: Optional[float]