arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (118) hide show
  1. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
  2. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
  3. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -21
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/db/bulk_inserter.py +14 -54
  10. phoenix/db/insertion/evaluation.py +6 -6
  11. phoenix/db/insertion/helpers.py +2 -13
  12. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  13. phoenix/db/models.py +4 -236
  14. phoenix/inferences/fixtures.py +23 -23
  15. phoenix/inferences/inferences.py +7 -7
  16. phoenix/inferences/validation.py +1 -1
  17. phoenix/server/api/context.py +0 -18
  18. phoenix/server/api/dataloaders/__init__.py +0 -18
  19. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  20. phoenix/server/api/routers/v1/__init__.py +2 -77
  21. phoenix/server/api/routers/v1/evaluations.py +2 -4
  22. phoenix/server/api/routers/v1/spans.py +1 -3
  23. phoenix/server/api/routers/v1/traces.py +4 -1
  24. phoenix/server/api/schema.py +303 -2
  25. phoenix/server/api/types/Cluster.py +19 -19
  26. phoenix/server/api/types/Dataset.py +63 -282
  27. phoenix/server/api/types/DatasetRole.py +23 -0
  28. phoenix/server/api/types/Dimension.py +29 -30
  29. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  30. phoenix/server/api/types/Event.py +16 -16
  31. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  32. phoenix/server/api/types/Model.py +42 -43
  33. phoenix/server/api/types/Project.py +12 -26
  34. phoenix/server/api/types/Span.py +2 -79
  35. phoenix/server/api/types/TimeSeries.py +6 -6
  36. phoenix/server/api/types/Trace.py +4 -15
  37. phoenix/server/api/types/UMAPPoints.py +1 -1
  38. phoenix/server/api/types/node.py +111 -5
  39. phoenix/server/api/types/pagination.py +52 -10
  40. phoenix/server/app.py +49 -101
  41. phoenix/server/main.py +27 -49
  42. phoenix/server/openapi/docs.py +0 -3
  43. phoenix/server/static/index.js +2595 -3523
  44. phoenix/server/templates/index.html +0 -1
  45. phoenix/services.py +15 -15
  46. phoenix/session/client.py +21 -438
  47. phoenix/session/session.py +37 -47
  48. phoenix/trace/exporter.py +9 -14
  49. phoenix/trace/fixtures.py +7 -133
  50. phoenix/trace/schemas.py +2 -1
  51. phoenix/trace/span_evaluations.py +3 -3
  52. phoenix/trace/trace_dataset.py +6 -6
  53. phoenix/version.py +1 -1
  54. phoenix/datasets/__init__.py +0 -0
  55. phoenix/datasets/evaluators/__init__.py +0 -18
  56. phoenix/datasets/evaluators/code_evaluators.py +0 -99
  57. phoenix/datasets/evaluators/llm_evaluators.py +0 -244
  58. phoenix/datasets/evaluators/utils.py +0 -292
  59. phoenix/datasets/experiments.py +0 -550
  60. phoenix/datasets/tracing.py +0 -85
  61. phoenix/datasets/types.py +0 -178
  62. phoenix/db/insertion/dataset.py +0 -237
  63. phoenix/db/migrations/types.py +0 -29
  64. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  65. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  66. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  67. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  68. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  69. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  70. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  71. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  72. phoenix/server/api/dataloaders/span_projects.py +0 -33
  73. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  74. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  75. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  76. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  77. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  78. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  79. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  80. phoenix/server/api/input_types/DatasetSort.py +0 -17
  81. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  82. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  83. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  84. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  85. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  86. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  87. phoenix/server/api/mutations/__init__.py +0 -13
  88. phoenix/server/api/mutations/auth.py +0 -11
  89. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  90. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  91. phoenix/server/api/mutations/project_mutations.py +0 -47
  92. phoenix/server/api/openapi/__init__.py +0 -0
  93. phoenix/server/api/openapi/main.py +0 -6
  94. phoenix/server/api/openapi/schema.py +0 -16
  95. phoenix/server/api/queries.py +0 -503
  96. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  97. phoenix/server/api/routers/v1/datasets.py +0 -965
  98. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
  99. phoenix/server/api/routers/v1/experiment_runs.py +0 -108
  100. phoenix/server/api/routers/v1/experiments.py +0 -174
  101. phoenix/server/api/types/AnnotatorKind.py +0 -10
  102. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  103. phoenix/server/api/types/DatasetExample.py +0 -85
  104. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  105. phoenix/server/api/types/DatasetVersion.py +0 -14
  106. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  107. phoenix/server/api/types/Experiment.py +0 -140
  108. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  109. phoenix/server/api/types/ExperimentComparison.py +0 -19
  110. phoenix/server/api/types/ExperimentRun.py +0 -91
  111. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  112. phoenix/server/api/types/Inferences.py +0 -80
  113. phoenix/server/api/types/InferencesRole.py +0 -23
  114. phoenix/utilities/json.py +0 -61
  115. phoenix/utilities/re.py +0 -50
  116. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  117. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  118. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -1,40 +1,6 @@
1
- from typing import Any, Awaitable, Callable, Mapping, Tuple
2
-
3
- import wrapt
4
- from starlette import routing
5
- from starlette.requests import Request
6
- from starlette.responses import Response
7
- from starlette.status import HTTP_403_FORBIDDEN
8
-
9
- from . import (
10
- datasets,
11
- evaluations,
12
- experiment_evaluations,
13
- experiment_runs,
14
- experiments,
15
- spans,
16
- traces,
17
- )
18
- from .dataset_examples import list_dataset_examples
19
-
20
-
21
- @wrapt.decorator # type: ignore
22
- async def forbid_if_readonly(
23
- wrapped: Callable[[Request], Awaitable[Response]],
24
- _: Any,
25
- args: Tuple[Request],
26
- kwargs: Mapping[str, Any],
27
- ) -> Response:
28
- request, *_ = args
29
- if request.app.state.read_only:
30
- return Response(status_code=HTTP_403_FORBIDDEN)
31
- return await wrapped(*args, **kwargs)
32
-
33
-
34
- class Route(routing.Route):
35
- def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None:
36
- super().__init__(path, forbid_if_readonly(endpoint), **kwargs)
1
+ from starlette.routing import Route
37
2
 
3
+ from . import evaluations, spans, traces
38
4
 
39
5
  V1_ROUTES = [
40
6
  Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
@@ -42,45 +8,4 @@ V1_ROUTES = [
42
8
  Route("/v1/traces", traces.post_traces, methods=["POST"]),
43
9
  Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
44
10
  Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
45
- Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
46
- Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
47
- Route("/v1/datasets/{id:str}", datasets.get_dataset_by_id, methods=["GET"]),
48
- Route("/v1/datasets/{id:str}/csv", datasets.get_dataset_csv, methods=["GET"]),
49
- Route(
50
- "/v1/datasets/{id:str}/jsonl/openai_ft",
51
- datasets.get_dataset_jsonl_openai_ft,
52
- methods=["GET"],
53
- ),
54
- Route(
55
- "/v1/datasets/{id:str}/jsonl/openai_evals",
56
- datasets.get_dataset_jsonl_openai_evals,
57
- methods=["GET"],
58
- ),
59
- Route("/v1/datasets/{id:str}/examples", list_dataset_examples, methods=["GET"]),
60
- Route("/v1/datasets/{id:str}/versions", datasets.get_dataset_versions, methods=["GET"]),
61
- Route(
62
- "/v1/datasets/{dataset_id:str}/experiments",
63
- experiments.create_experiment,
64
- methods=["POST"],
65
- ),
66
- Route(
67
- "/v1/experiments/{experiment_id:str}",
68
- experiments.read_experiment,
69
- methods=["GET"],
70
- ),
71
- Route(
72
- "/v1/experiments/{experiment_id:str}/runs",
73
- experiment_runs.create_experiment_run,
74
- methods=["POST"],
75
- ),
76
- Route(
77
- "/v1/experiments/{experiment_id:str}/runs",
78
- experiment_runs.list_experiment_runs,
79
- methods=["GET"],
80
- ),
81
- Route(
82
- "/v1/experiment_evaluations",
83
- experiment_evaluations.create_experiment_evaluation,
84
- methods=["POST"],
85
- ),
86
11
  ]
@@ -44,7 +44,7 @@ async def post_evaluations(request: Request) -> Response:
44
44
  summary: Add evaluations to a span, trace, or document
45
45
  operationId: addEvaluations
46
46
  tags:
47
- - private
47
+ - evaluations
48
48
  parameters:
49
49
  - name: project-name
50
50
  in: query
@@ -105,7 +105,7 @@ async def get_evaluations(request: Request) -> Response:
105
105
  summary: Get evaluations from Phoenix
106
106
  operationId: getEvaluation
107
107
  tags:
108
- - private
108
+ - evaluations
109
109
  parameters:
110
110
  - name: project-name
111
111
  in: query
@@ -116,8 +116,6 @@ async def get_evaluations(request: Request) -> Response:
116
116
  responses:
117
117
  200:
118
118
  description: Success
119
- 403:
120
- description: Forbidden
121
119
  404:
122
120
  description: Not found
123
121
  """
@@ -19,7 +19,7 @@ async def query_spans_handler(request: Request) -> Response:
19
19
  summary: Query spans using query DSL
20
20
  operationId: querySpans
21
21
  tags:
22
- - private
22
+ - spans
23
23
  parameters:
24
24
  - name: project-name
25
25
  in: query
@@ -68,8 +68,6 @@ async def query_spans_handler(request: Request) -> Response:
68
68
  responses:
69
69
  200:
70
70
  description: Success
71
- 403:
72
- description: Forbidden
73
71
  404:
74
72
  description: Not found
75
73
  422:
@@ -11,6 +11,7 @@ from starlette.datastructures import State
11
11
  from starlette.requests import Request
12
12
  from starlette.responses import Response
13
13
  from starlette.status import (
14
+ HTTP_403_FORBIDDEN,
14
15
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
15
16
  HTTP_422_UNPROCESSABLE_ENTITY,
16
17
  )
@@ -24,7 +25,7 @@ async def post_traces(request: Request) -> Response:
24
25
  summary: Send traces to Phoenix
25
26
  operationId: addTraces
26
27
  tags:
27
- - private
28
+ - traces
28
29
  requestBody:
29
30
  required: true
30
31
  content:
@@ -42,6 +43,8 @@ async def post_traces(request: Request) -> Response:
42
43
  422:
43
44
  description: Request body is invalid
44
45
  """
46
+ if request.app.state.read_only:
47
+ return Response(status_code=HTTP_403_FORBIDDEN)
45
48
  content_type = request.headers.get("content-type")
46
49
  if content_type != "application/x-protobuf":
47
50
  return Response(
@@ -1,7 +1,308 @@
1
+ from collections import defaultdict
2
+ from typing import Dict, List, Optional, Set, Union
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
1
6
  import strawberry
7
+ from sqlalchemy import delete, select
8
+ from sqlalchemy.orm import contains_eager, load_only
9
+ from strawberry import ID, UNSET
10
+ from strawberry.types import Info
11
+ from typing_extensions import Annotated
12
+
13
+ from phoenix.config import DEFAULT_PROJECT_NAME
14
+ from phoenix.db import models
15
+ from phoenix.db.insertion.span import ClearProjectSpansEvent
16
+ from phoenix.pointcloud.clustering import Hdbscan
17
+ from phoenix.server.api.context import Context
18
+ from phoenix.server.api.helpers import ensure_list
19
+ from phoenix.server.api.input_types.ClusterInput import ClusterInput
20
+ from phoenix.server.api.input_types.Coordinates import (
21
+ InputCoordinate2D,
22
+ InputCoordinate3D,
23
+ )
24
+ from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
25
+ from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
26
+ from phoenix.server.api.types.Dimension import to_gql_dimension
27
+ from phoenix.server.api.types.EmbeddingDimension import (
28
+ DEFAULT_CLUSTER_SELECTION_EPSILON,
29
+ DEFAULT_MIN_CLUSTER_SIZE,
30
+ DEFAULT_MIN_SAMPLES,
31
+ to_gql_embedding_dimension,
32
+ )
33
+ from phoenix.server.api.types.Event import create_event_id, unpack_event_id
34
+ from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
35
+ from phoenix.server.api.types.Functionality import Functionality
36
+ from phoenix.server.api.types.Model import Model
37
+ from phoenix.server.api.types.node import (
38
+ GlobalID,
39
+ Node,
40
+ from_global_id,
41
+ from_global_id_with_expected_type,
42
+ )
43
+ from phoenix.server.api.types.pagination import (
44
+ Connection,
45
+ ConnectionArgs,
46
+ CursorString,
47
+ connection_from_list,
48
+ )
49
+ from phoenix.server.api.types.Project import Project
50
+ from phoenix.server.api.types.Span import to_gql_span
51
+ from phoenix.server.api.types.Trace import Trace
52
+
53
+
54
+ @strawberry.type
55
+ class Query:
56
+ @strawberry.field
57
+ async def projects(
58
+ self,
59
+ info: Info[Context, None],
60
+ first: Optional[int] = 50,
61
+ last: Optional[int] = UNSET,
62
+ after: Optional[CursorString] = UNSET,
63
+ before: Optional[CursorString] = UNSET,
64
+ ) -> Connection[Project]:
65
+ args = ConnectionArgs(
66
+ first=first,
67
+ after=after if isinstance(after, CursorString) else None,
68
+ last=last,
69
+ before=before if isinstance(before, CursorString) else None,
70
+ )
71
+ async with info.context.db() as session:
72
+ projects = await session.scalars(select(models.Project))
73
+ data = [
74
+ Project(
75
+ id_attr=project.id,
76
+ name=project.name,
77
+ gradient_start_color=project.gradient_start_color,
78
+ gradient_end_color=project.gradient_end_color,
79
+ )
80
+ for project in projects
81
+ ]
82
+ return connection_from_list(data=data, args=args)
83
+
84
+ @strawberry.field
85
+ async def functionality(self, info: Info[Context, None]) -> "Functionality":
86
+ has_model_inferences = not info.context.model.is_empty
87
+ async with info.context.db() as session:
88
+ has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
89
+ return Functionality(
90
+ model_inferences=has_model_inferences,
91
+ tracing=has_traces,
92
+ )
93
+
94
+ @strawberry.field
95
+ def model(self) -> Model:
96
+ return Model()
97
+
98
+ @strawberry.field
99
+ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
100
+ type_name, node_id = from_global_id(str(id))
101
+ if type_name == "Dimension":
102
+ dimension = info.context.model.scalar_dimensions[node_id]
103
+ return to_gql_dimension(node_id, dimension)
104
+ elif type_name == "EmbeddingDimension":
105
+ embedding_dimension = info.context.model.embedding_dimensions[node_id]
106
+ return to_gql_embedding_dimension(node_id, embedding_dimension)
107
+ elif type_name == "Project":
108
+ project_stmt = select(
109
+ models.Project.id,
110
+ models.Project.name,
111
+ models.Project.gradient_start_color,
112
+ models.Project.gradient_end_color,
113
+ ).where(models.Project.id == node_id)
114
+ async with info.context.db() as session:
115
+ project = (await session.execute(project_stmt)).first()
116
+ if project is None:
117
+ raise ValueError(f"Unknown project: {id}")
118
+ return Project(
119
+ id_attr=project.id,
120
+ name=project.name,
121
+ gradient_start_color=project.gradient_start_color,
122
+ gradient_end_color=project.gradient_end_color,
123
+ )
124
+ elif type_name == "Trace":
125
+ trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
126
+ async with info.context.db() as session:
127
+ id_attr = await session.scalar(trace_stmt)
128
+ if id_attr is None:
129
+ raise ValueError(f"Unknown trace: {id}")
130
+ return Trace(id_attr=id_attr)
131
+ elif type_name == "Span":
132
+ span_stmt = (
133
+ select(models.Span)
134
+ .join(models.Trace)
135
+ .options(contains_eager(models.Span.trace))
136
+ .where(models.Span.id == node_id)
137
+ )
138
+ async with info.context.db() as session:
139
+ span = await session.scalar(span_stmt)
140
+ if span is None:
141
+ raise ValueError(f"Unknown span: {id}")
142
+ return to_gql_span(span)
143
+ raise Exception(f"Unknown node type: {type_name}")
144
+
145
+ @strawberry.field
146
+ def clusters(
147
+ self,
148
+ clusters: List[ClusterInput],
149
+ ) -> List[Cluster]:
150
+ clustered_events: Dict[str, Set[ID]] = defaultdict(set)
151
+ for i, cluster in enumerate(clusters):
152
+ clustered_events[cluster.id or str(i)].update(cluster.event_ids)
153
+ return to_gql_clusters(
154
+ clustered_events=clustered_events,
155
+ )
156
+
157
+ @strawberry.field
158
+ def hdbscan_clustering(
159
+ self,
160
+ info: Info[Context, None],
161
+ event_ids: Annotated[
162
+ List[ID],
163
+ strawberry.argument(
164
+ description="Event ID of the coordinates",
165
+ ),
166
+ ],
167
+ coordinates_2d: Annotated[
168
+ Optional[List[InputCoordinate2D]],
169
+ strawberry.argument(
170
+ description="Point coordinates. Must be either 2D or 3D.",
171
+ ),
172
+ ] = UNSET,
173
+ coordinates_3d: Annotated[
174
+ Optional[List[InputCoordinate3D]],
175
+ strawberry.argument(
176
+ description="Point coordinates. Must be either 2D or 3D.",
177
+ ),
178
+ ] = UNSET,
179
+ min_cluster_size: Annotated[
180
+ int,
181
+ strawberry.argument(
182
+ description="HDBSCAN minimum cluster size",
183
+ ),
184
+ ] = DEFAULT_MIN_CLUSTER_SIZE,
185
+ cluster_min_samples: Annotated[
186
+ int,
187
+ strawberry.argument(
188
+ description="HDBSCAN minimum samples",
189
+ ),
190
+ ] = DEFAULT_MIN_SAMPLES,
191
+ cluster_selection_epsilon: Annotated[
192
+ float,
193
+ strawberry.argument(
194
+ description="HDBSCAN cluster selection epsilon",
195
+ ),
196
+ ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
197
+ ) -> List[Cluster]:
198
+ coordinates_3d = ensure_list(coordinates_3d)
199
+ coordinates_2d = ensure_list(coordinates_2d)
200
+
201
+ if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
202
+ raise ValueError("must specify only one of 2D or 3D coordinates")
203
+
204
+ if len(coordinates_3d) > 0:
205
+ coordinates = list(
206
+ map(
207
+ lambda coord: np.array(
208
+ [coord.x, coord.y, coord.z],
209
+ ),
210
+ coordinates_3d,
211
+ )
212
+ )
213
+ else:
214
+ coordinates = list(
215
+ map(
216
+ lambda coord: np.array(
217
+ [coord.x, coord.y],
218
+ ),
219
+ coordinates_2d,
220
+ )
221
+ )
222
+
223
+ if len(event_ids) != len(coordinates):
224
+ raise ValueError(
225
+ f"length mismatch between "
226
+ f"event_ids ({len(event_ids)}) "
227
+ f"and coordinates ({len(coordinates)})"
228
+ )
229
+
230
+ if len(event_ids) == 0:
231
+ return []
232
+
233
+ grouped_event_ids: Dict[
234
+ Union[DatasetRole, AncillaryDatasetRole],
235
+ List[ID],
236
+ ] = defaultdict(list)
237
+ grouped_coordinates: Dict[
238
+ Union[DatasetRole, AncillaryDatasetRole],
239
+ List[npt.NDArray[np.float64]],
240
+ ] = defaultdict(list)
241
+
242
+ for event_id, coordinate in zip(event_ids, coordinates):
243
+ row_id, dataset_role = unpack_event_id(event_id)
244
+ grouped_coordinates[dataset_role].append(coordinate)
245
+ grouped_event_ids[dataset_role].append(create_event_id(row_id, dataset_role))
246
+
247
+ stacked_event_ids = (
248
+ grouped_event_ids[DatasetRole.primary]
249
+ + grouped_event_ids[DatasetRole.reference]
250
+ + grouped_event_ids[AncillaryDatasetRole.corpus]
251
+ )
252
+ stacked_coordinates = np.stack(
253
+ grouped_coordinates[DatasetRole.primary]
254
+ + grouped_coordinates[DatasetRole.reference]
255
+ + grouped_coordinates[AncillaryDatasetRole.corpus]
256
+ )
257
+
258
+ clusters = Hdbscan(
259
+ min_cluster_size=min_cluster_size,
260
+ min_samples=cluster_min_samples,
261
+ cluster_selection_epsilon=cluster_selection_epsilon,
262
+ ).find_clusters(stacked_coordinates)
263
+
264
+ clustered_events = {
265
+ str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
266
+ for i, cluster in enumerate(clusters)
267
+ }
268
+
269
+ return to_gql_clusters(
270
+ clustered_events=clustered_events,
271
+ )
272
+
273
+
274
+ @strawberry.type
275
+ class Mutation(ExportEventsMutation):
276
+ @strawberry.mutation
277
+ async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
278
+ if info.context.read_only:
279
+ return Query()
280
+ node_id = from_global_id_with_expected_type(str(id), "Project")
281
+ async with info.context.db() as session:
282
+ project = await session.scalar(
283
+ select(models.Project)
284
+ .where(models.Project.id == node_id)
285
+ .options(load_only(models.Project.name))
286
+ )
287
+ if project is None:
288
+ raise ValueError(f"Unknown project: {id}")
289
+ if project.name == DEFAULT_PROJECT_NAME:
290
+ raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
291
+ await session.delete(project)
292
+ return Query()
293
+
294
+ @strawberry.mutation
295
+ async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
296
+ if info.context.read_only:
297
+ return Query()
298
+ project_id = from_global_id_with_expected_type(str(id), "Project")
299
+ delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
300
+ async with info.context.db() as session:
301
+ await session.execute(delete_statement)
302
+ if cache := info.context.cache_for_dataloaders:
303
+ cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
304
+ return Query()
2
305
 
3
- from phoenix.server.api.mutations import Mutation
4
- from phoenix.server.api.queries import Query
5
306
 
6
307
  # This is the schema for generating `schema.graphql`.
7
308
  # See https://strawberry.rocks/docs/guides/schema-export
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
9
9
  from phoenix.server.api.context import Context
10
10
  from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
11
11
  from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
12
+ from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
12
13
  from phoenix.server.api.types.DatasetValues import DatasetValues
13
14
  from phoenix.server.api.types.Event import unpack_event_id
14
- from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
15
15
 
16
16
 
17
17
  @strawberry.type
@@ -36,8 +36,8 @@ class Cluster:
36
36
  """
37
37
  Calculates the drift score of the cluster. The score will be a value
38
38
  representing the balance of points between the primary and the reference
39
- inferences, and will be on a scale between 1 (all primary) and -1 (all
40
- reference), with 0 being an even balance between the two inference sets.
39
+ datasets, and will be on a scale between 1 (all primary) and -1 (all
40
+ reference), with 0 being an even balance between the two datasets.
41
41
 
42
42
  Returns
43
43
  -------
@@ -47,8 +47,8 @@ class Cluster:
47
47
  if model[REFERENCE].empty:
48
48
  return None
49
49
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
50
- primary_count = count_by_role[InferencesRole.primary]
51
- reference_count = count_by_role[InferencesRole.reference]
50
+ primary_count = count_by_role[DatasetRole.primary]
51
+ reference_count = count_by_role[DatasetRole.reference]
52
52
  return (
53
53
  None
54
54
  if not (denominator := (primary_count + reference_count))
@@ -76,8 +76,8 @@ class Cluster:
76
76
  if corpus is None or corpus[PRIMARY].empty:
77
77
  return None
78
78
  count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
79
- primary_count = count_by_role[InferencesRole.primary]
80
- corpus_count = count_by_role[AncillaryInferencesRole.corpus]
79
+ primary_count = count_by_role[DatasetRole.primary]
80
+ corpus_count = count_by_role[AncillaryDatasetRole.corpus]
81
81
  return (
82
82
  None
83
83
  if not (denominator := (primary_count + corpus_count))
@@ -94,19 +94,19 @@ class Cluster:
94
94
  metric: DataQualityMetricInput,
95
95
  ) -> DatasetValues:
96
96
  model = info.context.model
97
- row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
98
- for row_id, inferences_role in map(unpack_event_id, self.event_ids):
99
- if not isinstance(inferences_role, InferencesRole):
97
+ row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
98
+ for row_id, dataset_role in map(unpack_event_id, self.event_ids):
99
+ if not isinstance(dataset_role, DatasetRole):
100
100
  continue
101
- row_ids[inferences_role].append(row_id)
101
+ row_ids[dataset_role].append(row_id)
102
102
  return DatasetValues(
103
103
  primary_value=metric.metric_instance(
104
104
  model[PRIMARY],
105
- subset_rows=row_ids[InferencesRole.primary],
105
+ subset_rows=row_ids[DatasetRole.primary],
106
106
  ),
107
107
  reference_value=metric.metric_instance(
108
108
  model[REFERENCE],
109
- subset_rows=row_ids[InferencesRole.reference],
109
+ subset_rows=row_ids[DatasetRole.reference],
110
110
  ),
111
111
  )
112
112
 
@@ -120,20 +120,20 @@ class Cluster:
120
120
  metric: PerformanceMetricInput,
121
121
  ) -> DatasetValues:
122
122
  model = info.context.model
123
- row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
124
- for row_id, inferences_role in map(unpack_event_id, self.event_ids):
125
- if not isinstance(inferences_role, InferencesRole):
123
+ row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
124
+ for row_id, dataset_role in map(unpack_event_id, self.event_ids):
125
+ if not isinstance(dataset_role, DatasetRole):
126
126
  continue
127
- row_ids[inferences_role].append(row_id)
127
+ row_ids[dataset_role].append(row_id)
128
128
  metric_instance = metric.metric_instance(model)
129
129
  return DatasetValues(
130
130
  primary_value=metric_instance(
131
131
  model[PRIMARY],
132
- subset_rows=row_ids[InferencesRole.primary],
132
+ subset_rows=row_ids[DatasetRole.primary],
133
133
  ),
134
134
  reference_value=metric_instance(
135
135
  model[REFERENCE],
136
- subset_rows=row_ids[InferencesRole.reference],
136
+ subset_rows=row_ids[DatasetRole.reference],
137
137
  ),
138
138
  )
139
139