arize-phoenix 4.14.1__py3-none-any.whl → 4.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (85) hide show
  1. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/METADATA +5 -3
  2. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/RECORD +81 -71
  3. phoenix/db/bulk_inserter.py +131 -5
  4. phoenix/db/engines.py +2 -1
  5. phoenix/db/helpers.py +23 -1
  6. phoenix/db/insertion/constants.py +2 -0
  7. phoenix/db/insertion/document_annotation.py +157 -0
  8. phoenix/db/insertion/helpers.py +13 -0
  9. phoenix/db/insertion/span_annotation.py +144 -0
  10. phoenix/db/insertion/trace_annotation.py +144 -0
  11. phoenix/db/insertion/types.py +261 -0
  12. phoenix/experiments/functions.py +3 -2
  13. phoenix/experiments/types.py +3 -3
  14. phoenix/server/api/context.py +7 -9
  15. phoenix/server/api/dataloaders/__init__.py +2 -0
  16. phoenix/server/api/dataloaders/average_experiment_run_latency.py +3 -3
  17. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  18. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  19. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  21. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  22. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  23. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  24. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -4
  25. phoenix/server/api/dataloaders/experiment_run_counts.py +2 -4
  26. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  27. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  28. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  29. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  30. phoenix/server/api/dataloaders/record_counts.py +2 -4
  31. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  32. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  34. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  35. phoenix/server/api/dataloaders/span_projects.py +3 -3
  36. phoenix/server/api/dataloaders/token_counts.py +2 -4
  37. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  38. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  39. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  40. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  41. phoenix/server/api/mutations/span_annotations_mutations.py +8 -3
  42. phoenix/server/api/mutations/trace_annotations_mutations.py +8 -3
  43. phoenix/server/api/openapi/main.py +18 -2
  44. phoenix/server/api/openapi/schema.py +12 -12
  45. phoenix/server/api/routers/v1/__init__.py +36 -83
  46. phoenix/server/api/routers/v1/datasets.py +515 -509
  47. phoenix/server/api/routers/v1/evaluations.py +164 -73
  48. phoenix/server/api/routers/v1/experiment_evaluations.py +68 -91
  49. phoenix/server/api/routers/v1/experiment_runs.py +98 -155
  50. phoenix/server/api/routers/v1/experiments.py +132 -181
  51. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  52. phoenix/server/api/routers/v1/spans.py +164 -203
  53. phoenix/server/api/routers/v1/traces.py +134 -159
  54. phoenix/server/api/routers/v1/utils.py +95 -0
  55. phoenix/server/api/types/Span.py +27 -3
  56. phoenix/server/api/types/Trace.py +21 -4
  57. phoenix/server/api/utils.py +4 -4
  58. phoenix/server/app.py +172 -192
  59. phoenix/server/grpc_server.py +2 -2
  60. phoenix/server/main.py +5 -9
  61. phoenix/server/static/.vite/manifest.json +31 -31
  62. phoenix/server/static/assets/components-Ci5kMOk5.js +1175 -0
  63. phoenix/server/static/assets/{index-CQgXRwU0.js → index-BQG5WVX7.js} +2 -2
  64. phoenix/server/static/assets/{pages-hdjlFZhO.js → pages-BrevprVW.js} +451 -275
  65. phoenix/server/static/assets/{vendor-DPvSDRn3.js → vendor-CP0b0YG0.js} +2 -2
  66. phoenix/server/static/assets/{vendor-arizeai-CkvPT67c.js → vendor-arizeai-DTbiPGp6.js} +27 -27
  67. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  68. phoenix/server/static/assets/{vendor-recharts-5jlNaZuF.js → vendor-recharts-A0DA1O99.js} +1 -1
  69. phoenix/server/thread_server.py +2 -2
  70. phoenix/server/types.py +18 -0
  71. phoenix/session/client.py +5 -3
  72. phoenix/session/session.py +2 -2
  73. phoenix/trace/dsl/filter.py +2 -6
  74. phoenix/trace/fixtures.py +17 -23
  75. phoenix/trace/utils.py +23 -0
  76. phoenix/utilities/client.py +116 -0
  77. phoenix/utilities/project.py +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  80. phoenix/server/openapi/docs.py +0 -221
  81. phoenix/server/static/assets/components-DeS0YEmv.js +0 -1142
  82. phoenix/server/static/assets/vendor-codemirror-Cqwpwlua.js +0 -12
  83. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/WHEEL +0 -0
  84. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/IP_NOTICE +0 -0
  85. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,19 @@
1
1
  import gzip
2
2
  import zlib
3
- from typing import Any, Dict, List
3
+ from typing import Any, Dict, List, Literal, Optional
4
4
 
5
+ from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Query
5
6
  from google.protobuf.message import DecodeError
6
7
  from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
7
8
  ExportTraceServiceRequest,
8
9
  )
10
+ from pydantic import Field
9
11
  from sqlalchemy import select
10
- from starlette.background import BackgroundTask
11
12
  from starlette.concurrency import run_in_threadpool
12
13
  from starlette.datastructures import State
13
14
  from starlette.requests import Request
14
- from starlette.responses import JSONResponse, Response
15
15
  from starlette.status import (
16
+ HTTP_204_NO_CONTENT,
16
17
  HTTP_404_NOT_FOUND,
17
18
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
18
19
  HTTP_422_UNPROCESSABLE_ENTITY,
@@ -21,45 +22,56 @@ from strawberry.relay import GlobalID
21
22
 
22
23
  from phoenix.db import models
23
24
  from phoenix.db.helpers import SupportedSQLDialect
24
- from phoenix.db.insertion.helpers import insert_on_conflict
25
- from phoenix.server.api.types.node import from_global_id_with_expected_type
25
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
26
+ from phoenix.db.insertion.types import Precursors
26
27
  from phoenix.trace.otel import decode_otlp_span
27
28
  from phoenix.utilities.project import get_project_name
28
29
 
29
-
30
- async def post_traces(request: Request) -> Response:
31
- """
32
- summary: Send traces to Phoenix
33
- operationId: addTraces
34
- tags:
35
- - private
36
- requestBody:
37
- required: true
38
- content:
39
- application/x-protobuf:
40
- schema:
41
- type: string
42
- format: binary
43
- responses:
44
- 200:
45
- description: Success
46
- 403:
47
- description: Forbidden
48
- 415:
49
- description: Unsupported content type, only gzipped protobuf
50
- 422:
51
- description: Request body is invalid
52
- """
53
- content_type = request.headers.get("content-type")
30
+ from .pydantic_compat import V1RoutesBaseModel
31
+ from .utils import RequestBody, ResponseBody, add_errors_to_responses
32
+
33
+ router = APIRouter(tags=["traces"], include_in_schema=False)
34
+
35
+
36
+ @router.post(
37
+ "/traces",
38
+ operation_id="addTraces",
39
+ summary="Send traces",
40
+ status_code=HTTP_204_NO_CONTENT,
41
+ responses=add_errors_to_responses(
42
+ [
43
+ {
44
+ "status_code": HTTP_415_UNSUPPORTED_MEDIA_TYPE,
45
+ "description": (
46
+ "Unsupported content type (only `application/x-protobuf` is supported)"
47
+ ),
48
+ },
49
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
50
+ ]
51
+ ),
52
+ openapi_extra={
53
+ "requestBody": {
54
+ "required": True,
55
+ "content": {
56
+ "application/x-protobuf": {"schema": {"type": "string", "format": "binary"}}
57
+ },
58
+ }
59
+ },
60
+ )
61
+ async def post_traces(
62
+ request: Request,
63
+ background_tasks: BackgroundTasks,
64
+ content_type: Optional[str] = Header(default=None),
65
+ content_encoding: Optional[str] = Header(default=None),
66
+ ) -> None:
54
67
  if content_type != "application/x-protobuf":
55
- return Response(
56
- content=f"Unsupported content type: {content_type}",
68
+ raise HTTPException(
69
+ detail=f"Unsupported content type: {content_type}",
57
70
  status_code=HTTP_415_UNSUPPORTED_MEDIA_TYPE,
58
71
  )
59
- content_encoding = request.headers.get("content-encoding")
60
72
  if content_encoding and content_encoding not in ("gzip", "deflate"):
61
- return Response(
62
- content=f"Unsupported content encoding: {content_encoding}",
73
+ raise HTTPException(
74
+ detail=f"Unsupported content encoding: {content_encoding}",
63
75
  status_code=HTTP_415_UNSUPPORTED_MEDIA_TYPE,
64
76
  )
65
77
  body = await request.body()
@@ -71,139 +83,100 @@ async def post_traces(request: Request) -> Response:
71
83
  try:
72
84
  await run_in_threadpool(req.ParseFromString, body)
73
85
  except DecodeError:
74
- return Response(
75
- content="Request body is invalid ExportTraceServiceRequest",
86
+ raise HTTPException(
87
+ detail="Request body is invalid ExportTraceServiceRequest",
76
88
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
77
89
  )
78
- return Response(background=BackgroundTask(_add_spans, req, request.state))
79
-
80
-
81
- async def annotate_traces(request: Request) -> Response:
82
- """
83
- summary: Upsert annotations for traces
84
- operationId: annotateTraces
85
- tags:
86
- - private
87
- requestBody:
88
- description: List of trace annotations to be inserted
89
- required: true
90
- content:
91
- application/json:
92
- schema:
93
- type: object
94
- properties:
95
- data:
96
- type: array
97
- items:
98
- type: object
99
- properties:
100
- trace_id:
101
- type: string
102
- description: The ID of the trace being annotated
103
- name:
104
- type: string
105
- description: The name of the annotation
106
- annotator_kind:
107
- type: string
108
- description: The kind of annotator used for the annotation ("LLM" or "HUMAN")
109
- result:
110
- type: object
111
- description: The result of the annotation
112
- properties:
113
- label:
114
- type: string
115
- description: The label assigned by the annotation
116
- score:
117
- type: number
118
- format: float
119
- description: The score assigned by the annotation
120
- explanation:
121
- type: string
122
- description: Explanation of the annotation result
123
- error:
124
- type: string
125
- description: Optional error message if the annotation encountered an error
126
- metadata:
127
- type: object
128
- description: Metadata for the annotation
129
- additionalProperties:
130
- type: string
131
- required:
132
- - trace_id
133
- - name
134
- - annotator_kind
135
- responses:
136
- 200:
137
- description: Trace annotations inserted successfully
138
- content:
139
- application/json:
140
- schema:
141
- type: object
142
- properties:
143
- data:
144
- type: array
145
- items:
146
- type: object
147
- properties:
148
- id:
149
- type: string
150
- description: The ID of the inserted trace annotation
151
- 404:
152
- description: Trace not found
153
- """
154
- payload: List[Dict[str, Any]] = (await request.json()).get("data", [])
155
- trace_gids = [GlobalID.from_id(annotation["trace_id"]) for annotation in payload]
156
-
157
- resolved_trace_ids = []
158
- for trace_gid in trace_gids:
159
- try:
160
- resolved_trace_ids.append(from_global_id_with_expected_type(trace_gid, "Trace"))
161
- except ValueError:
162
- return Response(
163
- content="Trace with ID {trace_gid} does not exist",
164
- status_code=HTTP_404_NOT_FOUND,
165
- )
90
+ background_tasks.add_task(_add_spans, req, request.state)
91
+ return None
92
+
93
+
94
+ class TraceAnnotationResult(V1RoutesBaseModel):
95
+ label: Optional[str] = Field(default=None, description="The label assigned by the annotation")
96
+ score: Optional[float] = Field(default=None, description="The score assigned by the annotation")
97
+ explanation: Optional[str] = Field(
98
+ default=None, description="Explanation of the annotation result"
99
+ )
100
+
101
+
102
+ class TraceAnnotation(V1RoutesBaseModel):
103
+ trace_id: str = Field(description="OpenTelemetry Trace ID (hex format w/o 0x prefix)")
104
+ name: str = Field(description="The name of the annotation")
105
+ annotator_kind: Literal["LLM", "HUMAN"] = Field(
106
+ description="The kind of annotator used for the annotation"
107
+ )
108
+ result: Optional[TraceAnnotationResult] = Field(
109
+ default=None, description="The result of the annotation"
110
+ )
111
+ metadata: Optional[Dict[str, Any]] = Field(
112
+ default=None, description="Metadata for the annotation"
113
+ )
114
+
115
+ def as_precursor(self) -> Precursors.TraceAnnotation:
116
+ return Precursors.TraceAnnotation(
117
+ self.trace_id,
118
+ models.TraceAnnotation(
119
+ name=self.name,
120
+ annotator_kind=self.annotator_kind,
121
+ score=self.result.score if self.result else None,
122
+ label=self.result.label if self.result else None,
123
+ explanation=self.result.explanation if self.result else None,
124
+ metadata_=self.metadata or {},
125
+ ),
126
+ )
127
+
128
+
129
+ class AnnotateTracesRequestBody(RequestBody[List[TraceAnnotation]]):
130
+ data: List[TraceAnnotation] = Field(description="The trace annotations to be upserted")
131
+
166
132
 
133
+ class InsertedTraceAnnotation(V1RoutesBaseModel):
134
+ id: str = Field(description="The ID of the inserted trace annotation")
135
+
136
+
137
+ class AnnotateTracesResponseBody(ResponseBody[List[InsertedTraceAnnotation]]):
138
+ pass
139
+
140
+
141
+ @router.post(
142
+ "/trace_annotations",
143
+ operation_id="annotateTraces",
144
+ summary="Create or update trace annotations",
145
+ responses=add_errors_to_responses(
146
+ [{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
147
+ ),
148
+ )
149
+ async def annotate_traces(
150
+ request: Request,
151
+ request_body: AnnotateTracesRequestBody,
152
+ sync: bool = Query(default=True, description="If true, fulfill request synchronously."),
153
+ ) -> AnnotateTracesResponseBody:
154
+ precursors = [d.as_precursor() for d in request_body.data]
155
+ if not sync:
156
+ await request.state.enqueue(*precursors)
157
+ return AnnotateTracesResponseBody(data=[])
158
+
159
+ trace_ids = {p.trace_id for p in precursors}
167
160
  async with request.app.state.db() as session:
168
- traces = await session.execute(
169
- select(models.Trace).filter(models.Trace.id.in_(resolved_trace_ids))
170
- )
171
- existing_trace_ids = {trace.id for trace in traces.scalars()}
161
+ existing_traces = {
162
+ trace.trace_id: trace.id
163
+ async for trace in await session.stream_scalars(
164
+ select(models.Trace).filter(models.Trace.trace_id.in_(trace_ids))
165
+ )
166
+ }
172
167
 
173
- missing_trace_ids = set(resolved_trace_ids) - existing_trace_ids
168
+ missing_trace_ids = trace_ids - set(existing_traces.keys())
174
169
  if missing_trace_ids:
175
- missing_trace_gids = [
176
- str(GlobalID("Trace", str(trace_gid))) for trace_gid in missing_trace_ids
177
- ]
178
- return Response(
179
- content=f"Traces with IDs {', '.join(missing_trace_gids)} do not exist.",
170
+ raise HTTPException(
171
+ detail=f"Traces with IDs {', '.join(missing_trace_ids)} do not exist.",
180
172
  status_code=HTTP_404_NOT_FOUND,
181
173
  )
182
174
 
183
175
  inserted_annotations = []
184
176
 
185
- for annotation in payload:
186
- trace_gid = GlobalID.from_id(annotation["trace_id"])
187
- trace_id = from_global_id_with_expected_type(trace_gid, "Trace")
188
-
189
- name = annotation["name"]
190
- annotator_kind = annotation["annotator_kind"]
191
- result = annotation.get("result")
192
- label = result.get("label") if result else None
193
- score = result.get("score") if result else None
194
- explanation = result.get("explanation") if result else None
195
- metadata = annotation.get("metadata") or {}
196
-
197
- values = dict(
198
- trace_rowid=trace_id,
199
- name=name,
200
- label=label,
201
- score=score,
202
- explanation=explanation,
203
- annotator_kind=annotator_kind,
204
- metadata_=metadata,
205
- )
206
- dialect = SupportedSQLDialect(session.bind.dialect.name)
177
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
178
+ for p in precursors:
179
+ values = dict(as_kv(p.as_insertable(existing_traces[p.trace_id]).row))
207
180
  trace_annotation_id = await session.scalar(
208
181
  insert_on_conflict(
209
182
  values,
@@ -213,10 +186,12 @@ async def annotate_traces(request: Request) -> Response:
213
186
  ).returning(models.TraceAnnotation.id)
214
187
  )
215
188
  inserted_annotations.append(
216
- {"id": str(GlobalID("TraceAnnotation", str(trace_annotation_id)))}
189
+ InsertedTraceAnnotation(
190
+ id=str(GlobalID("TraceAnnotation", str(trace_annotation_id)))
191
+ )
217
192
  )
218
193
 
219
- return JSONResponse(content={"data": inserted_annotations})
194
+ return AnnotateTracesResponseBody(data=inserted_annotations)
220
195
 
221
196
 
222
197
  async def _add_spans(req: ExportTraceServiceRequest, state: State) -> None:
@@ -0,0 +1,95 @@
1
+ from typing import Any, Dict, Generic, List, Optional, TypedDict, Union
2
+
3
+ from typing_extensions import TypeAlias, TypeVar, assert_never
4
+
5
+ from .pydantic_compat import V1RoutesBaseModel
6
+
7
+ StatusCode: TypeAlias = int
8
+ DataType = TypeVar("DataType")
9
+ Responses: TypeAlias = Dict[
10
+ Union[int, str], Dict[str, Any]
11
+ ] # input type for the `responses` parameter of a fastapi route
12
+
13
+
14
+ class StatusCodeWithDescription(TypedDict):
15
+ """
16
+ A duck type for a status code with a description detailing under what
17
+ conditions the status code is raised.
18
+ """
19
+
20
+ status_code: StatusCode
21
+ description: str
22
+
23
+
24
+ class RequestBody(V1RoutesBaseModel, Generic[DataType]):
25
+ # A generic request type accepted by V1 routes.
26
+ #
27
+ # Don't use """ for this docstring or it will be included as a description
28
+ # in the generated OpenAPI schema.
29
+ data: DataType
30
+
31
+
32
+ class ResponseBody(V1RoutesBaseModel, Generic[DataType]):
33
+ # A generic response type returned by V1 routes.
34
+ #
35
+ # Don't use """ for this docstring or it will be included as a description
36
+ # in the generated OpenAPI schema.
37
+
38
+ data: DataType
39
+
40
+
41
+ class PaginatedResponseBody(V1RoutesBaseModel, Generic[DataType]):
42
+ # A generic paginated response type returned by V1 routes.
43
+ #
44
+ # Don't use """ for this docstring or it will be included as a description
45
+ # in the generated OpenAPI schema.
46
+
47
+ data: List[DataType]
48
+ next_cursor: Optional[str]
49
+
50
+
51
+ def add_errors_to_responses(
52
+ errors: List[Union[StatusCode, StatusCodeWithDescription]],
53
+ /,
54
+ *,
55
+ responses: Optional[Responses] = None,
56
+ ) -> Responses:
57
+ """
58
+ Creates or updates a patch for an OpenAPI schema's `responses` section to
59
+ include status codes in the generated OpenAPI schema.
60
+ """
61
+ output_responses: Responses = responses or {}
62
+ for error in errors:
63
+ status_code: int
64
+ description: Optional[str] = None
65
+ if isinstance(error, StatusCode):
66
+ status_code = error
67
+ elif isinstance(error, dict):
68
+ status_code = error["status_code"]
69
+ description = error["description"]
70
+ else:
71
+ assert_never(error)
72
+ if status_code not in output_responses:
73
+ output_responses[status_code] = {
74
+ "content": {"text/plain": {"schema": {"type": "string"}}}
75
+ }
76
+ if description:
77
+ output_responses[status_code]["description"] = description
78
+ return output_responses
79
+
80
+
81
+ def add_text_csv_content_to_responses(
82
+ status_code: StatusCode, /, *, responses: Optional[Responses] = None
83
+ ) -> Responses:
84
+ """
85
+ Creates or updates a patch for an OpenAPI schema's `responses` section to
86
+ ensure that the response for the given status code is marked as text/csv in
87
+ the generated OpenAPI schema.
88
+ """
89
+ output_responses: Responses = responses or {}
90
+ if status_code not in output_responses:
91
+ output_responses[status_code] = {}
92
+ output_responses[status_code]["content"] = {
93
+ "text/csv": {"schema": {"type": "string", "contentMediaType": "text/csv"}}
94
+ }
95
+ return output_responses
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sized, cast
7
7
  import numpy as np
8
8
  import strawberry
9
9
  from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
10
+ from sqlalchemy import select
10
11
  from strawberry import ID, UNSET
11
12
  from strawberry.relay import Node, NodeID
12
13
  from strawberry.types import Info
@@ -19,6 +20,9 @@ from phoenix.server.api.helpers.dataset_helpers import (
19
20
  get_dataset_example_input,
20
21
  get_dataset_example_output,
21
22
  )
23
+ from phoenix.server.api.input_types.SpanAnnotationSort import SpanAnnotationSort
24
+ from phoenix.server.api.types.SortDir import SortDir
25
+ from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation
22
26
  from phoenix.trace.attributes import get_attribute_value
23
27
 
24
28
  from .DocumentRetrievalMetrics import DocumentRetrievalMetrics
@@ -177,12 +181,27 @@ class Span(Node):
177
181
 
178
182
  @strawberry.field(
179
183
  description=(
180
- "Annotations of the span's parent span. This encompasses both "
184
+ "Annotations associated with the span. This encompasses both "
181
185
  "LLM and human annotations."
182
186
  )
183
187
  ) # type: ignore
184
- async def span_annotations(self, info: Info[Context, None]) -> List[SpanAnnotation]:
185
- return await info.context.data_loaders.span_annotations.load(self.id_attr)
188
+ async def span_annotations(
189
+ self,
190
+ info: Info[Context, None],
191
+ sort: Optional[SpanAnnotationSort] = UNSET,
192
+ ) -> List[SpanAnnotation]:
193
+ async with info.context.db() as session:
194
+ stmt = select(models.SpanAnnotation).filter_by(span_rowid=self.id_attr)
195
+ if sort:
196
+ sort_col = getattr(models.SpanAnnotation, sort.col.value)
197
+ if sort.dir is SortDir.desc:
198
+ stmt = stmt.order_by(sort_col.desc(), models.SpanAnnotation.id.desc())
199
+ else:
200
+ stmt = stmt.order_by(sort_col.asc(), models.SpanAnnotation.id.asc())
201
+ else:
202
+ stmt = stmt.order_by(models.SpanAnnotation.created_at.desc())
203
+ annotations = await session.scalars(stmt)
204
+ return [to_gql_span_annotation(annotation) for annotation in annotations]
186
205
 
187
206
  @strawberry.field(
188
207
  description="Evaluations of the documents associated with the span, e.g. "
@@ -258,6 +277,11 @@ class Span(Node):
258
277
  project = await info.context.data_loaders.span_projects.load(span_id)
259
278
  return to_gql_project(project)
260
279
 
280
+ @strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
281
+ async def contained_in_dataset(self, info: Info[Context, None]) -> bool:
282
+ examples = await info.context.data_loaders.span_dataset_examples.load(self.id_attr)
283
+ return bool(examples)
284
+
261
285
 
262
286
  def to_gql_span(span: models.Span) -> Span:
263
287
  events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
@@ -11,13 +11,15 @@ from strawberry.types import Info
11
11
 
12
12
  from phoenix.db import models
13
13
  from phoenix.server.api.context import Context
14
- from phoenix.server.api.types.Evaluation import TraceEvaluation
14
+ from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
15
15
  from phoenix.server.api.types.pagination import (
16
16
  ConnectionArgs,
17
17
  CursorString,
18
18
  connection_from_list,
19
19
  )
20
+ from phoenix.server.api.types.SortDir import SortDir
20
21
  from phoenix.server.api.types.Span import Span, to_gql_span
22
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
21
23
 
22
24
 
23
25
  @strawberry.type
@@ -62,6 +64,21 @@ class Trace(Node):
62
64
  data = [to_gql_span(span) async for span in spans]
63
65
  return connection_from_list(data=data, args=args)
64
66
 
65
- @strawberry.field(description="Evaluations associated with the trace") # type: ignore
66
- async def trace_evaluations(self, info: Info[Context, None]) -> List[TraceEvaluation]:
67
- return await info.context.data_loaders.trace_evaluations.load(self.id_attr)
67
+ @strawberry.field(description="Annotations associated with the trace.") # type: ignore
68
+ async def span_annotations(
69
+ self,
70
+ info: Info[Context, None],
71
+ sort: Optional[TraceAnnotationSort] = None,
72
+ ) -> List[TraceAnnotation]:
73
+ async with info.context.db() as session:
74
+ stmt = select(models.TraceAnnotation).filter_by(span_rowid=self.id_attr)
75
+ if sort:
76
+ sort_col = getattr(models.TraceAnnotation, sort.col.value)
77
+ if sort.dir is SortDir.desc:
78
+ stmt = stmt.order_by(sort_col.desc(), models.TraceAnnotation.id.desc())
79
+ else:
80
+ stmt = stmt.order_by(sort_col.asc(), models.TraceAnnotation.id.asc())
81
+ else:
82
+ stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
83
+ annotations = await session.scalars(stmt)
84
+ return [to_gql_trace_annotation(annotation) for annotation in annotations]
@@ -1,13 +1,13 @@
1
- from typing import AsyncContextManager, Callable, List
1
+ from typing import List
2
2
 
3
3
  from sqlalchemy import delete
4
- from sqlalchemy.ext.asyncio import AsyncSession
5
4
 
6
5
  from phoenix.db import models
6
+ from phoenix.server.types import DbSessionFactory
7
7
 
8
8
 
9
9
  async def delete_projects(
10
- db: Callable[[], AsyncContextManager[AsyncSession]],
10
+ db: DbSessionFactory,
11
11
  *project_names: str,
12
12
  ) -> List[int]:
13
13
  if not project_names:
@@ -22,7 +22,7 @@ async def delete_projects(
22
22
 
23
23
 
24
24
  async def delete_traces(
25
- db: Callable[[], AsyncContextManager[AsyncSession]],
25
+ db: DbSessionFactory,
26
26
  *trace_ids: str,
27
27
  ) -> List[int]:
28
28
  if not trace_ids: