arize-phoenix 4.5.0__py3-none-any.whl → 4.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (123) hide show
  1. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/METADATA +16 -8
  2. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/RECORD +122 -58
  3. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +42 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datetime_utils.py +4 -0
  10. phoenix/db/bulk_inserter.py +54 -14
  11. phoenix/db/insertion/dataset.py +237 -0
  12. phoenix/db/insertion/evaluation.py +10 -10
  13. phoenix/db/insertion/helpers.py +17 -14
  14. phoenix/db/insertion/span.py +3 -3
  15. phoenix/db/migrations/types.py +29 -0
  16. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  17. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  18. phoenix/db/models.py +236 -4
  19. phoenix/experiments/__init__.py +6 -0
  20. phoenix/experiments/evaluators/__init__.py +29 -0
  21. phoenix/experiments/evaluators/base.py +153 -0
  22. phoenix/experiments/evaluators/code_evaluators.py +99 -0
  23. phoenix/experiments/evaluators/llm_evaluators.py +244 -0
  24. phoenix/experiments/evaluators/utils.py +186 -0
  25. phoenix/experiments/functions.py +757 -0
  26. phoenix/experiments/tracing.py +85 -0
  27. phoenix/experiments/types.py +753 -0
  28. phoenix/experiments/utils.py +24 -0
  29. phoenix/inferences/fixtures.py +23 -23
  30. phoenix/inferences/inferences.py +7 -7
  31. phoenix/inferences/validation.py +1 -1
  32. phoenix/server/api/context.py +20 -0
  33. phoenix/server/api/dataloaders/__init__.py +20 -0
  34. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  35. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  36. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  37. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  38. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  39. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  40. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  41. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  42. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  43. phoenix/server/api/dataloaders/span_projects.py +33 -0
  44. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  45. phoenix/server/api/helpers/dataset_helpers.py +179 -0
  46. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  47. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  48. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  49. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  50. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  51. phoenix/server/api/input_types/DatasetSort.py +17 -0
  52. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  53. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  54. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  55. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  56. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  57. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  58. phoenix/server/api/mutations/__init__.py +13 -0
  59. phoenix/server/api/mutations/auth.py +11 -0
  60. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  61. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  62. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  63. phoenix/server/api/mutations/project_mutations.py +47 -0
  64. phoenix/server/api/openapi/__init__.py +0 -0
  65. phoenix/server/api/openapi/main.py +6 -0
  66. phoenix/server/api/openapi/schema.py +16 -0
  67. phoenix/server/api/queries.py +503 -0
  68. phoenix/server/api/routers/v1/__init__.py +77 -2
  69. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  70. phoenix/server/api/routers/v1/datasets.py +965 -0
  71. phoenix/server/api/routers/v1/evaluations.py +8 -13
  72. phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
  73. phoenix/server/api/routers/v1/experiment_runs.py +220 -0
  74. phoenix/server/api/routers/v1/experiments.py +302 -0
  75. phoenix/server/api/routers/v1/spans.py +9 -5
  76. phoenix/server/api/routers/v1/traces.py +1 -4
  77. phoenix/server/api/schema.py +2 -303
  78. phoenix/server/api/types/AnnotatorKind.py +10 -0
  79. phoenix/server/api/types/Cluster.py +19 -19
  80. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  81. phoenix/server/api/types/Dataset.py +282 -63
  82. phoenix/server/api/types/DatasetExample.py +85 -0
  83. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  84. phoenix/server/api/types/DatasetVersion.py +14 -0
  85. phoenix/server/api/types/Dimension.py +30 -29
  86. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  87. phoenix/server/api/types/Event.py +16 -16
  88. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  89. phoenix/server/api/types/Experiment.py +147 -0
  90. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  91. phoenix/server/api/types/ExperimentComparison.py +19 -0
  92. phoenix/server/api/types/ExperimentRun.py +91 -0
  93. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  94. phoenix/server/api/types/Inferences.py +80 -0
  95. phoenix/server/api/types/InferencesRole.py +23 -0
  96. phoenix/server/api/types/Model.py +43 -42
  97. phoenix/server/api/types/Project.py +26 -12
  98. phoenix/server/api/types/Span.py +79 -2
  99. phoenix/server/api/types/TimeSeries.py +6 -6
  100. phoenix/server/api/types/Trace.py +15 -4
  101. phoenix/server/api/types/UMAPPoints.py +1 -1
  102. phoenix/server/api/types/node.py +5 -111
  103. phoenix/server/api/types/pagination.py +10 -52
  104. phoenix/server/app.py +103 -49
  105. phoenix/server/main.py +49 -27
  106. phoenix/server/openapi/docs.py +3 -0
  107. phoenix/server/static/index.js +2300 -1294
  108. phoenix/server/templates/index.html +1 -0
  109. phoenix/services.py +15 -15
  110. phoenix/session/client.py +581 -22
  111. phoenix/session/session.py +47 -37
  112. phoenix/trace/exporter.py +14 -9
  113. phoenix/trace/fixtures.py +133 -7
  114. phoenix/trace/schemas.py +1 -2
  115. phoenix/trace/span_evaluations.py +3 -3
  116. phoenix/trace/trace_dataset.py +6 -6
  117. phoenix/utilities/json.py +61 -0
  118. phoenix/utilities/re.py +50 -0
  119. phoenix/version.py +1 -1
  120. phoenix/server/api/types/DatasetRole.py +0 -23
  121. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/IP_NOTICE +0 -0
  122. {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/LICENSE +0 -0
  123. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,302 @@
1
+ from random import getrandbits
2
+
3
+ from sqlalchemy import select
4
+ from starlette.requests import Request
5
+ from starlette.responses import JSONResponse, Response
6
+ from starlette.status import HTTP_404_NOT_FOUND
7
+ from strawberry.relay import GlobalID
8
+
9
+ from phoenix.db import models
10
+ from phoenix.db.helpers import SupportedSQLDialect
11
+ from phoenix.db.insertion.helpers import insert_on_conflict
12
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
13
+
14
+
15
+ def _short_uuid() -> str:
16
+ return str(getrandbits(32).to_bytes(4, "big").hex())
17
+
18
+
19
+ def _generate_experiment_name(dataset_name: str) -> str:
20
+ """
21
+ Generate a semi-unique name for the experiment.
22
+ """
23
+ short_ds_name = dataset_name[:8].replace(" ", "-")
24
+ return f"{short_ds_name}-{_short_uuid()}"
25
+
26
+
27
+ async def create_experiment(request: Request) -> Response:
28
+ """
29
+ summary: Create an experiment using a dataset
30
+ operationId: createExperiment
31
+ tags:
32
+ - private
33
+ parameters:
34
+ - in: path
35
+ name: dataset_id
36
+ required: true
37
+ description: The ID of the dataset to create an experiment for
38
+ schema:
39
+ type: string
40
+ requestBody:
41
+ description: Details of the experiment to be created
42
+ required: true
43
+ content:
44
+ application/json:
45
+ schema:
46
+ type: object
47
+ properties:
48
+ repetitions:
49
+ type: integer
50
+ description: Number of times the experiment should be repeated for each example
51
+ default: 1
52
+ metadata:
53
+ type: object
54
+ description: Metadata for the experiment
55
+ additionalProperties:
56
+ type: string
57
+ version_id:
58
+ type: string
59
+ description: ID of the dataset version to use
60
+ responses:
61
+ 200:
62
+ description: Experiment retrieved successfully
63
+ content:
64
+ application/json:
65
+ schema:
66
+ type: object
67
+ properties:
68
+ data:
69
+ type: object
70
+ properties:
71
+ id:
72
+ type: string
73
+ description: The ID of the experiment
74
+ dataset_id:
75
+ type: string
76
+ description: The ID of the dataset associated with the experiment
77
+ dataset_version_id:
78
+ type: string
79
+ description: The ID of the dataset version associated with the experiment
80
+ repetitions:
81
+ type: integer
82
+ description: Number of times the experiment is repeated
83
+ metadata:
84
+ type: object
85
+ description: Metadata of the experiment
86
+ additionalProperties:
87
+ type: string
88
+ project_name:
89
+ type: string
90
+ description: The name of the project associated with the experiment
91
+ created_at:
92
+ type: string
93
+ format: date-time
94
+ description: The creation timestamp of the experiment
95
+ updated_at:
96
+ type: string
97
+ format: date-time
98
+ description: The last update timestamp of the experiment
99
+ 404:
100
+ description: Dataset or DatasetVersion not found
101
+ """
102
+ dataset_globalid = GlobalID.from_id(request.path_params["dataset_id"])
103
+ try:
104
+ dataset_id = from_global_id_with_expected_type(dataset_globalid, "Dataset")
105
+ except ValueError:
106
+ return Response(
107
+ content="Dataset with ID {dataset_globalid} does not exist",
108
+ status_code=HTTP_404_NOT_FOUND,
109
+ )
110
+
111
+ payload = await request.json()
112
+ repetitions = payload.get("repetitions", 1)
113
+ metadata = payload.get("metadata") or {}
114
+ dataset_version_globalid_str = payload.get("version_id")
115
+ if dataset_version_globalid_str is not None:
116
+ try:
117
+ dataset_version_globalid = GlobalID.from_id(dataset_version_globalid_str)
118
+ dataset_version_id = from_global_id_with_expected_type(
119
+ dataset_version_globalid, "DatasetVersion"
120
+ )
121
+ except ValueError:
122
+ return Response(
123
+ content="DatasetVersion with ID {dataset_version_globalid} does not exist",
124
+ status_code=HTTP_404_NOT_FOUND,
125
+ )
126
+
127
+ async with request.app.state.db() as session:
128
+ result = (
129
+ await session.execute(select(models.Dataset).where(models.Dataset.id == dataset_id))
130
+ ).scalar()
131
+ if result is None:
132
+ return Response(
133
+ content=f"Dataset with ID {dataset_globalid} does not exist",
134
+ status_code=HTTP_404_NOT_FOUND,
135
+ )
136
+ dataset_name = result.name
137
+ if dataset_version_globalid_str is None:
138
+ dataset_version_result = await session.execute(
139
+ select(models.DatasetVersion)
140
+ .where(models.DatasetVersion.dataset_id == dataset_id)
141
+ .order_by(models.DatasetVersion.id.desc())
142
+ )
143
+ dataset_version = dataset_version_result.scalar()
144
+ if not dataset_version:
145
+ return Response(
146
+ content=f"Dataset {dataset_globalid} does not have any versions",
147
+ status_code=HTTP_404_NOT_FOUND,
148
+ )
149
+ dataset_version_id = dataset_version.id
150
+ dataset_version_globalid = GlobalID("DatasetVersion", str(dataset_version_id))
151
+ else:
152
+ dataset_version = await session.execute(
153
+ select(models.DatasetVersion).where(models.DatasetVersion.id == dataset_version_id)
154
+ )
155
+ dataset_version = dataset_version.scalar()
156
+ if not dataset_version:
157
+ return Response(
158
+ content=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
159
+ status_code=HTTP_404_NOT_FOUND,
160
+ )
161
+
162
+ # generate a semi-unique name for the experiment
163
+ experiment_name = payload.get("name") or _generate_experiment_name(dataset_name)
164
+ project_name = f"Experiment-{getrandbits(96).to_bytes(12, 'big').hex()}"
165
+ project_description = (
166
+ f"dataset_id: {dataset_globalid}\ndataset_version_id: {dataset_version_globalid}"
167
+ )
168
+ experiment = models.Experiment(
169
+ dataset_id=int(dataset_id),
170
+ dataset_version_id=int(dataset_version_id),
171
+ name=experiment_name,
172
+ description=payload.get("description"),
173
+ repetitions=repetitions,
174
+ metadata_=metadata,
175
+ project_name=project_name,
176
+ )
177
+ session.add(experiment)
178
+ await session.flush()
179
+
180
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
181
+ project_rowid = await session.scalar(
182
+ insert_on_conflict(
183
+ dialect=dialect,
184
+ table=models.Project,
185
+ constraint="uq_projects_name",
186
+ column_names=("name",),
187
+ values=dict(
188
+ name=project_name,
189
+ description=project_description,
190
+ created_at=experiment.created_at,
191
+ updated_at=experiment.updated_at,
192
+ ),
193
+ ).returning(models.Project.id)
194
+ )
195
+ assert project_rowid is not None
196
+
197
+ experiment_globalid = GlobalID("Experiment", str(experiment.id))
198
+ if dataset_version_globalid_str is None:
199
+ dataset_version_globalid = GlobalID(
200
+ "DatasetVersion", str(experiment.dataset_version_id)
201
+ )
202
+ experiment_payload = {
203
+ "id": str(experiment_globalid),
204
+ "dataset_id": str(dataset_globalid),
205
+ "dataset_version_id": str(dataset_version_globalid),
206
+ "repetitions": experiment.repetitions,
207
+ "metadata": experiment.metadata_,
208
+ "project_name": experiment.project_name,
209
+ "created_at": experiment.created_at.isoformat(),
210
+ "updated_at": experiment.updated_at.isoformat(),
211
+ }
212
+ return JSONResponse(content={"data": experiment_payload})
213
+
214
+
215
+ async def read_experiment(request: Request) -> Response:
216
+ """
217
+ summary: Get details of a specific experiment
218
+ operationId: getExperiment
219
+ tags:
220
+ - private
221
+ parameters:
222
+ - in: path
223
+ name: experiment_id
224
+ required: true
225
+ description: The ID of the experiment to retrieve
226
+ schema:
227
+ type: string
228
+ responses:
229
+ 200:
230
+ description: Experiment retrieved successfully
231
+ content:
232
+ application/json:
233
+ schema:
234
+ type: object
235
+ properties:
236
+ data:
237
+ type: object
238
+ properties:
239
+ id:
240
+ type: string
241
+ description: The ID of the experiment
242
+ dataset_id:
243
+ type: string
244
+ description: The ID of the dataset associated with the experiment
245
+ dataset_version_id:
246
+ type: string
247
+ description: The ID of the dataset version associated with the experiment
248
+ repetitions:
249
+ type: integer
250
+ description: Number of times the experiment is repeated
251
+ metadata:
252
+ type: object
253
+ description: Metadata of the experiment
254
+ additionalProperties:
255
+ type: string
256
+ project_name:
257
+ type: string
258
+ description: The name of the project associated with the experiment
259
+ created_at:
260
+ type: string
261
+ format: date-time
262
+ description: The creation timestamp of the experiment
263
+ updated_at:
264
+ type: string
265
+ format: date-time
266
+ description: The last update timestamp of the experiment
267
+ 404:
268
+ description: Experiment not found
269
+ """
270
+ experiment_globalid = GlobalID.from_id(request.path_params["experiment_id"])
271
+ try:
272
+ experiment_id = from_global_id_with_expected_type(experiment_globalid, "Experiment")
273
+ except ValueError:
274
+ return Response(
275
+ content="Experiment with ID {experiment_globalid} does not exist",
276
+ status_code=HTTP_404_NOT_FOUND,
277
+ )
278
+
279
+ async with request.app.state.db() as session:
280
+ experiment = await session.execute(
281
+ select(models.Experiment).where(models.Experiment.id == experiment_id)
282
+ )
283
+ experiment = experiment.scalar()
284
+ if not experiment:
285
+ return Response(
286
+ content=f"Experiment with ID {experiment_globalid} does not exist",
287
+ status_code=HTTP_404_NOT_FOUND,
288
+ )
289
+
290
+ dataset_globalid = GlobalID("Dataset", str(experiment.dataset_id))
291
+ dataset_version_globalid = GlobalID("DatasetVersion", str(experiment.dataset_version_id))
292
+ experiment_payload = {
293
+ "id": str(experiment_globalid),
294
+ "dataset_id": str(dataset_globalid),
295
+ "dataset_version_id": str(dataset_version_globalid),
296
+ "repetitions": experiment.repetitions,
297
+ "metadata": experiment.metadata_,
298
+ "project_name": experiment.project_name,
299
+ "created_at": experiment.created_at.isoformat(),
300
+ "updated_at": experiment.updated_at.isoformat(),
301
+ }
302
+ return JSONResponse(content={"data": experiment_payload})
@@ -19,9 +19,9 @@ async def query_spans_handler(request: Request) -> Response:
19
19
  summary: Query spans using query DSL
20
20
  operationId: querySpans
21
21
  tags:
22
- - spans
22
+ - private
23
23
  parameters:
24
- - name: project-name
24
+ - name: project_name
25
25
  in: query
26
26
  schema:
27
27
  type: string
@@ -68,6 +68,8 @@ async def query_spans_handler(request: Request) -> Response:
68
68
  responses:
69
69
  200:
70
70
  description: Success
71
+ 403:
72
+ description: Forbidden
71
73
  404:
72
74
  description: Not found
73
75
  422:
@@ -76,9 +78,11 @@ async def query_spans_handler(request: Request) -> Response:
76
78
  payload = await request.json()
77
79
  queries = payload.pop("queries", [])
78
80
  project_name = (
79
- request.query_params.get("project-name")
80
- # read from headers/payload for backward-compatibility
81
- or request.headers.get("project-name")
81
+ request.query_params.get("project_name")
82
+ or request.query_params.get("project-name") # for backward compatibility
83
+ or request.headers.get(
84
+ "project-name"
85
+ ) # read from headers/payload for backward-compatibility
82
86
  or payload.get("project_name")
83
87
  or DEFAULT_PROJECT_NAME
84
88
  )
@@ -11,7 +11,6 @@ from starlette.datastructures import State
11
11
  from starlette.requests import Request
12
12
  from starlette.responses import Response
13
13
  from starlette.status import (
14
- HTTP_403_FORBIDDEN,
15
14
  HTTP_415_UNSUPPORTED_MEDIA_TYPE,
16
15
  HTTP_422_UNPROCESSABLE_ENTITY,
17
16
  )
@@ -25,7 +24,7 @@ async def post_traces(request: Request) -> Response:
25
24
  summary: Send traces to Phoenix
26
25
  operationId: addTraces
27
26
  tags:
28
- - traces
27
+ - private
29
28
  requestBody:
30
29
  required: true
31
30
  content:
@@ -43,8 +42,6 @@ async def post_traces(request: Request) -> Response:
43
42
  422:
44
43
  description: Request body is invalid
45
44
  """
46
- if request.app.state.read_only:
47
- return Response(status_code=HTTP_403_FORBIDDEN)
48
45
  content_type = request.headers.get("content-type")
49
46
  if content_type != "application/x-protobuf":
50
47
  return Response(
@@ -1,308 +1,7 @@
1
- from collections import defaultdict
2
- from typing import Dict, List, Optional, Set, Union
3
-
4
- import numpy as np
5
- import numpy.typing as npt
6
1
  import strawberry
7
- from sqlalchemy import delete, select
8
- from sqlalchemy.orm import contains_eager, load_only
9
- from strawberry import ID, UNSET
10
- from strawberry.types import Info
11
- from typing_extensions import Annotated
12
-
13
- from phoenix.config import DEFAULT_PROJECT_NAME
14
- from phoenix.db import models
15
- from phoenix.db.insertion.span import ClearProjectSpansEvent
16
- from phoenix.pointcloud.clustering import Hdbscan
17
- from phoenix.server.api.context import Context
18
- from phoenix.server.api.helpers import ensure_list
19
- from phoenix.server.api.input_types.ClusterInput import ClusterInput
20
- from phoenix.server.api.input_types.Coordinates import (
21
- InputCoordinate2D,
22
- InputCoordinate3D,
23
- )
24
- from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
25
- from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
26
- from phoenix.server.api.types.Dimension import to_gql_dimension
27
- from phoenix.server.api.types.EmbeddingDimension import (
28
- DEFAULT_CLUSTER_SELECTION_EPSILON,
29
- DEFAULT_MIN_CLUSTER_SIZE,
30
- DEFAULT_MIN_SAMPLES,
31
- to_gql_embedding_dimension,
32
- )
33
- from phoenix.server.api.types.Event import create_event_id, unpack_event_id
34
- from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
35
- from phoenix.server.api.types.Functionality import Functionality
36
- from phoenix.server.api.types.Model import Model
37
- from phoenix.server.api.types.node import (
38
- GlobalID,
39
- Node,
40
- from_global_id,
41
- from_global_id_with_expected_type,
42
- )
43
- from phoenix.server.api.types.pagination import (
44
- Connection,
45
- ConnectionArgs,
46
- CursorString,
47
- connection_from_list,
48
- )
49
- from phoenix.server.api.types.Project import Project
50
- from phoenix.server.api.types.Span import to_gql_span
51
- from phoenix.server.api.types.Trace import Trace
52
-
53
-
54
- @strawberry.type
55
- class Query:
56
- @strawberry.field
57
- async def projects(
58
- self,
59
- info: Info[Context, None],
60
- first: Optional[int] = 50,
61
- last: Optional[int] = UNSET,
62
- after: Optional[CursorString] = UNSET,
63
- before: Optional[CursorString] = UNSET,
64
- ) -> Connection[Project]:
65
- args = ConnectionArgs(
66
- first=first,
67
- after=after if isinstance(after, CursorString) else None,
68
- last=last,
69
- before=before if isinstance(before, CursorString) else None,
70
- )
71
- async with info.context.db() as session:
72
- projects = await session.scalars(select(models.Project))
73
- data = [
74
- Project(
75
- id_attr=project.id,
76
- name=project.name,
77
- gradient_start_color=project.gradient_start_color,
78
- gradient_end_color=project.gradient_end_color,
79
- )
80
- for project in projects
81
- ]
82
- return connection_from_list(data=data, args=args)
83
-
84
- @strawberry.field
85
- async def functionality(self, info: Info[Context, None]) -> "Functionality":
86
- has_model_inferences = not info.context.model.is_empty
87
- async with info.context.db() as session:
88
- has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
89
- return Functionality(
90
- model_inferences=has_model_inferences,
91
- tracing=has_traces,
92
- )
93
-
94
- @strawberry.field
95
- def model(self) -> Model:
96
- return Model()
97
-
98
- @strawberry.field
99
- async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
100
- type_name, node_id = from_global_id(str(id))
101
- if type_name == "Dimension":
102
- dimension = info.context.model.scalar_dimensions[node_id]
103
- return to_gql_dimension(node_id, dimension)
104
- elif type_name == "EmbeddingDimension":
105
- embedding_dimension = info.context.model.embedding_dimensions[node_id]
106
- return to_gql_embedding_dimension(node_id, embedding_dimension)
107
- elif type_name == "Project":
108
- project_stmt = select(
109
- models.Project.id,
110
- models.Project.name,
111
- models.Project.gradient_start_color,
112
- models.Project.gradient_end_color,
113
- ).where(models.Project.id == node_id)
114
- async with info.context.db() as session:
115
- project = (await session.execute(project_stmt)).first()
116
- if project is None:
117
- raise ValueError(f"Unknown project: {id}")
118
- return Project(
119
- id_attr=project.id,
120
- name=project.name,
121
- gradient_start_color=project.gradient_start_color,
122
- gradient_end_color=project.gradient_end_color,
123
- )
124
- elif type_name == "Trace":
125
- trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
126
- async with info.context.db() as session:
127
- id_attr = await session.scalar(trace_stmt)
128
- if id_attr is None:
129
- raise ValueError(f"Unknown trace: {id}")
130
- return Trace(id_attr=id_attr)
131
- elif type_name == "Span":
132
- span_stmt = (
133
- select(models.Span)
134
- .join(models.Trace)
135
- .options(contains_eager(models.Span.trace))
136
- .where(models.Span.id == node_id)
137
- )
138
- async with info.context.db() as session:
139
- span = await session.scalar(span_stmt)
140
- if span is None:
141
- raise ValueError(f"Unknown span: {id}")
142
- return to_gql_span(span)
143
- raise Exception(f"Unknown node type: {type_name}")
144
-
145
- @strawberry.field
146
- def clusters(
147
- self,
148
- clusters: List[ClusterInput],
149
- ) -> List[Cluster]:
150
- clustered_events: Dict[str, Set[ID]] = defaultdict(set)
151
- for i, cluster in enumerate(clusters):
152
- clustered_events[cluster.id or str(i)].update(cluster.event_ids)
153
- return to_gql_clusters(
154
- clustered_events=clustered_events,
155
- )
156
-
157
- @strawberry.field
158
- def hdbscan_clustering(
159
- self,
160
- info: Info[Context, None],
161
- event_ids: Annotated[
162
- List[ID],
163
- strawberry.argument(
164
- description="Event ID of the coordinates",
165
- ),
166
- ],
167
- coordinates_2d: Annotated[
168
- Optional[List[InputCoordinate2D]],
169
- strawberry.argument(
170
- description="Point coordinates. Must be either 2D or 3D.",
171
- ),
172
- ] = UNSET,
173
- coordinates_3d: Annotated[
174
- Optional[List[InputCoordinate3D]],
175
- strawberry.argument(
176
- description="Point coordinates. Must be either 2D or 3D.",
177
- ),
178
- ] = UNSET,
179
- min_cluster_size: Annotated[
180
- int,
181
- strawberry.argument(
182
- description="HDBSCAN minimum cluster size",
183
- ),
184
- ] = DEFAULT_MIN_CLUSTER_SIZE,
185
- cluster_min_samples: Annotated[
186
- int,
187
- strawberry.argument(
188
- description="HDBSCAN minimum samples",
189
- ),
190
- ] = DEFAULT_MIN_SAMPLES,
191
- cluster_selection_epsilon: Annotated[
192
- float,
193
- strawberry.argument(
194
- description="HDBSCAN cluster selection epsilon",
195
- ),
196
- ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
197
- ) -> List[Cluster]:
198
- coordinates_3d = ensure_list(coordinates_3d)
199
- coordinates_2d = ensure_list(coordinates_2d)
200
-
201
- if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
202
- raise ValueError("must specify only one of 2D or 3D coordinates")
203
-
204
- if len(coordinates_3d) > 0:
205
- coordinates = list(
206
- map(
207
- lambda coord: np.array(
208
- [coord.x, coord.y, coord.z],
209
- ),
210
- coordinates_3d,
211
- )
212
- )
213
- else:
214
- coordinates = list(
215
- map(
216
- lambda coord: np.array(
217
- [coord.x, coord.y],
218
- ),
219
- coordinates_2d,
220
- )
221
- )
222
-
223
- if len(event_ids) != len(coordinates):
224
- raise ValueError(
225
- f"length mismatch between "
226
- f"event_ids ({len(event_ids)}) "
227
- f"and coordinates ({len(coordinates)})"
228
- )
229
-
230
- if len(event_ids) == 0:
231
- return []
232
-
233
- grouped_event_ids: Dict[
234
- Union[DatasetRole, AncillaryDatasetRole],
235
- List[ID],
236
- ] = defaultdict(list)
237
- grouped_coordinates: Dict[
238
- Union[DatasetRole, AncillaryDatasetRole],
239
- List[npt.NDArray[np.float64]],
240
- ] = defaultdict(list)
241
-
242
- for event_id, coordinate in zip(event_ids, coordinates):
243
- row_id, dataset_role = unpack_event_id(event_id)
244
- grouped_coordinates[dataset_role].append(coordinate)
245
- grouped_event_ids[dataset_role].append(create_event_id(row_id, dataset_role))
246
-
247
- stacked_event_ids = (
248
- grouped_event_ids[DatasetRole.primary]
249
- + grouped_event_ids[DatasetRole.reference]
250
- + grouped_event_ids[AncillaryDatasetRole.corpus]
251
- )
252
- stacked_coordinates = np.stack(
253
- grouped_coordinates[DatasetRole.primary]
254
- + grouped_coordinates[DatasetRole.reference]
255
- + grouped_coordinates[AncillaryDatasetRole.corpus]
256
- )
257
-
258
- clusters = Hdbscan(
259
- min_cluster_size=min_cluster_size,
260
- min_samples=cluster_min_samples,
261
- cluster_selection_epsilon=cluster_selection_epsilon,
262
- ).find_clusters(stacked_coordinates)
263
-
264
- clustered_events = {
265
- str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
266
- for i, cluster in enumerate(clusters)
267
- }
268
-
269
- return to_gql_clusters(
270
- clustered_events=clustered_events,
271
- )
272
-
273
-
274
- @strawberry.type
275
- class Mutation(ExportEventsMutation):
276
- @strawberry.mutation
277
- async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
278
- if info.context.read_only:
279
- return Query()
280
- node_id = from_global_id_with_expected_type(str(id), "Project")
281
- async with info.context.db() as session:
282
- project = await session.scalar(
283
- select(models.Project)
284
- .where(models.Project.id == node_id)
285
- .options(load_only(models.Project.name))
286
- )
287
- if project is None:
288
- raise ValueError(f"Unknown project: {id}")
289
- if project.name == DEFAULT_PROJECT_NAME:
290
- raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
291
- await session.delete(project)
292
- return Query()
293
-
294
- @strawberry.mutation
295
- async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
296
- if info.context.read_only:
297
- return Query()
298
- project_id = from_global_id_with_expected_type(str(id), "Project")
299
- delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
300
- async with info.context.db() as session:
301
- await session.execute(delete_statement)
302
- if cache := info.context.cache_for_dataloaders:
303
- cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
304
- return Query()
305
2
 
3
+ from phoenix.server.api.mutations import Mutation
4
+ from phoenix.server.api.queries import Query
306
5
 
307
6
  # This is the schema for generating `schema.graphql`.
308
7
  # See https://strawberry.rocks/docs/guides/schema-export
@@ -0,0 +1,10 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.enum
7
+ class AnnotatorKind(Enum):
8
+ LLM = "LLM"
9
+ HUMAN = "HUMAN"
10
+ CODE = "CODE"