arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (112) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/RECORD +111 -55
  3. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +0 -27
  5. phoenix/config.py +21 -7
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +64 -62
  8. phoenix/core/model_schema_adapter.py +27 -25
  9. phoenix/datasets/__init__.py +0 -0
  10. phoenix/datasets/evaluators.py +275 -0
  11. phoenix/datasets/experiments.py +469 -0
  12. phoenix/datasets/tracing.py +66 -0
  13. phoenix/datasets/types.py +212 -0
  14. phoenix/db/bulk_inserter.py +54 -14
  15. phoenix/db/insertion/dataset.py +234 -0
  16. phoenix/db/insertion/evaluation.py +6 -6
  17. phoenix/db/insertion/helpers.py +13 -2
  18. phoenix/db/migrations/types.py +29 -0
  19. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  20. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  21. phoenix/db/models.py +230 -3
  22. phoenix/inferences/fixtures.py +23 -23
  23. phoenix/inferences/inferences.py +7 -7
  24. phoenix/inferences/validation.py +1 -1
  25. phoenix/server/api/context.py +16 -0
  26. phoenix/server/api/dataloaders/__init__.py +16 -0
  27. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  28. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  29. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  30. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  31. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  32. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  34. phoenix/server/api/dataloaders/span_projects.py +33 -0
  35. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  36. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  37. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  38. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  39. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  40. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  41. phoenix/server/api/input_types/DatasetSort.py +17 -0
  42. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  43. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  44. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  45. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  46. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  47. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  48. phoenix/server/api/mutations/__init__.py +13 -0
  49. phoenix/server/api/mutations/auth.py +11 -0
  50. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  51. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  52. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  53. phoenix/server/api/mutations/project_mutations.py +42 -0
  54. phoenix/server/api/openapi/__init__.py +0 -0
  55. phoenix/server/api/openapi/main.py +6 -0
  56. phoenix/server/api/openapi/schema.py +15 -0
  57. phoenix/server/api/queries.py +503 -0
  58. phoenix/server/api/routers/v1/__init__.py +77 -2
  59. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  60. phoenix/server/api/routers/v1/datasets.py +861 -0
  61. phoenix/server/api/routers/v1/evaluations.py +4 -2
  62. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  63. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  64. phoenix/server/api/routers/v1/experiments.py +174 -0
  65. phoenix/server/api/routers/v1/spans.py +3 -1
  66. phoenix/server/api/routers/v1/traces.py +1 -4
  67. phoenix/server/api/schema.py +2 -303
  68. phoenix/server/api/types/AnnotatorKind.py +10 -0
  69. phoenix/server/api/types/Cluster.py +19 -19
  70. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  71. phoenix/server/api/types/Dataset.py +282 -63
  72. phoenix/server/api/types/DatasetExample.py +85 -0
  73. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  74. phoenix/server/api/types/DatasetVersion.py +14 -0
  75. phoenix/server/api/types/Dimension.py +30 -29
  76. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  77. phoenix/server/api/types/Event.py +16 -16
  78. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  79. phoenix/server/api/types/Experiment.py +135 -0
  80. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  81. phoenix/server/api/types/ExperimentComparison.py +19 -0
  82. phoenix/server/api/types/ExperimentRun.py +91 -0
  83. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  84. phoenix/server/api/types/Inferences.py +80 -0
  85. phoenix/server/api/types/InferencesRole.py +23 -0
  86. phoenix/server/api/types/Model.py +43 -42
  87. phoenix/server/api/types/Project.py +26 -12
  88. phoenix/server/api/types/Span.py +78 -2
  89. phoenix/server/api/types/TimeSeries.py +6 -6
  90. phoenix/server/api/types/Trace.py +15 -4
  91. phoenix/server/api/types/UMAPPoints.py +1 -1
  92. phoenix/server/api/types/node.py +5 -111
  93. phoenix/server/api/types/pagination.py +10 -52
  94. phoenix/server/app.py +99 -49
  95. phoenix/server/main.py +49 -27
  96. phoenix/server/openapi/docs.py +3 -0
  97. phoenix/server/static/index.js +2246 -1368
  98. phoenix/server/templates/index.html +1 -0
  99. phoenix/services.py +15 -15
  100. phoenix/session/client.py +316 -21
  101. phoenix/session/session.py +47 -37
  102. phoenix/trace/exporter.py +14 -9
  103. phoenix/trace/fixtures.py +133 -7
  104. phoenix/trace/span_evaluations.py +3 -3
  105. phoenix/trace/trace_dataset.py +6 -6
  106. phoenix/utilities/json.py +61 -0
  107. phoenix/utilities/re.py +50 -0
  108. phoenix/version.py +1 -1
  109. phoenix/server/api/types/DatasetRole.py +0 -23
  110. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/IP_NOTICE +0 -0
  111. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/LICENSE +0 -0
  112. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,42 @@
1
+ import strawberry
2
+ from sqlalchemy import delete, select
3
+ from sqlalchemy.orm import load_only
4
+ from strawberry.relay import GlobalID
5
+ from strawberry.types import Info
6
+
7
+ from phoenix.config import DEFAULT_PROJECT_NAME
8
+ from phoenix.db import models
9
+ from phoenix.db.insertion.span import ClearProjectSpansEvent
10
+ from phoenix.server.api.context import Context
11
+ from phoenix.server.api.mutations.auth import IsAuthenticated
12
+ from phoenix.server.api.queries import Query
13
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
14
+
15
+
16
+ @strawberry.type
17
+ class ProjectMutationMixin:
18
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
19
+ async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
20
+ node_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
21
+ async with info.context.db() as session:
22
+ project = await session.scalar(
23
+ select(models.Project)
24
+ .where(models.Project.id == node_id)
25
+ .options(load_only(models.Project.name))
26
+ )
27
+ if project is None:
28
+ raise ValueError(f"Unknown project: {id}")
29
+ if project.name == DEFAULT_PROJECT_NAME:
30
+ raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
31
+ await session.delete(project)
32
+ return Query()
33
+
34
+ @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
35
+ async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
36
+ project_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
37
+ delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
38
+ async with info.context.db() as session:
39
+ await session.execute(delete_statement)
40
+ if cache := info.context.cache_for_dataloaders:
41
+ cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
42
+ return Query()
File without changes
@@ -0,0 +1,6 @@
1
+ from .schema import get_openapi_schema
2
+
3
+ if __name__ == "__main__":
4
+ import yaml # type: ignore
5
+
6
+ print(yaml.dump(get_openapi_schema(), indent=2))
@@ -0,0 +1,15 @@
1
+ from typing import Any
2
+
3
+ from phoenix.server.api.routers.v1 import V1_ROUTES
4
+ from starlette.schemas import SchemaGenerator
5
+
6
+ OPENAPI_SCHEMA_GENERATOR = SchemaGenerator(
7
+ {"openapi": "3.0.0", "info": {"title": "Arize-Phoenix API", "version": "1.0"}}
8
+ )
9
+
10
+
11
+ def get_openapi_schema() -> Any:
12
+ """
13
+ Exports an OpenAPI schema for the Phoenix REST API as a JSON object.
14
+ """
15
+ return OPENAPI_SCHEMA_GENERATOR.get_schema(V1_ROUTES) # type: ignore
@@ -0,0 +1,503 @@
1
+ from collections import defaultdict
2
+ from typing import DefaultDict, Dict, List, Optional, Set, Union
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import strawberry
7
+ from sqlalchemy import and_, distinct, func, select
8
+ from sqlalchemy.orm import joinedload
9
+ from strawberry import ID, UNSET
10
+ from strawberry.relay import Connection, GlobalID, Node
11
+ from strawberry.types import Info
12
+ from typing_extensions import Annotated, TypeAlias
13
+
14
+ from phoenix.db import models
15
+ from phoenix.db.models import (
16
+ DatasetExample as OrmExample,
17
+ )
18
+ from phoenix.db.models import (
19
+ DatasetExampleRevision as OrmRevision,
20
+ )
21
+ from phoenix.db.models import (
22
+ DatasetVersion as OrmVersion,
23
+ )
24
+ from phoenix.db.models import (
25
+ Experiment as OrmExperiment,
26
+ )
27
+ from phoenix.db.models import (
28
+ ExperimentRun as OrmRun,
29
+ )
30
+ from phoenix.db.models import (
31
+ Trace as OrmTrace,
32
+ )
33
+ from phoenix.pointcloud.clustering import Hdbscan
34
+ from phoenix.server.api.context import Context
35
+ from phoenix.server.api.helpers import ensure_list
36
+ from phoenix.server.api.input_types.ClusterInput import ClusterInput
37
+ from phoenix.server.api.input_types.Coordinates import (
38
+ InputCoordinate2D,
39
+ InputCoordinate3D,
40
+ )
41
+ from phoenix.server.api.input_types.DatasetSort import DatasetSort
42
+ from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
43
+ from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
44
+ from phoenix.server.api.types.DatasetExample import DatasetExample
45
+ from phoenix.server.api.types.Dimension import to_gql_dimension
46
+ from phoenix.server.api.types.EmbeddingDimension import (
47
+ DEFAULT_CLUSTER_SELECTION_EPSILON,
48
+ DEFAULT_MIN_CLUSTER_SIZE,
49
+ DEFAULT_MIN_SAMPLES,
50
+ to_gql_embedding_dimension,
51
+ )
52
+ from phoenix.server.api.types.Event import create_event_id, unpack_event_id
53
+ from phoenix.server.api.types.Experiment import Experiment
54
+ from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
55
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
56
+ from phoenix.server.api.types.Functionality import Functionality
57
+ from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
58
+ from phoenix.server.api.types.Model import Model
59
+ from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
60
+ from phoenix.server.api.types.pagination import (
61
+ ConnectionArgs,
62
+ CursorString,
63
+ connection_from_list,
64
+ )
65
+ from phoenix.server.api.types.Project import Project
66
+ from phoenix.server.api.types.SortDir import SortDir
67
+ from phoenix.server.api.types.Span import Span, to_gql_span
68
+ from phoenix.server.api.types.Trace import Trace
69
+
70
+
71
+ @strawberry.type
72
+ class Query:
73
+ @strawberry.field
74
+ async def projects(
75
+ self,
76
+ info: Info[Context, None],
77
+ first: Optional[int] = 50,
78
+ last: Optional[int] = UNSET,
79
+ after: Optional[CursorString] = UNSET,
80
+ before: Optional[CursorString] = UNSET,
81
+ ) -> Connection[Project]:
82
+ args = ConnectionArgs(
83
+ first=first,
84
+ after=after if isinstance(after, CursorString) else None,
85
+ last=last,
86
+ before=before if isinstance(before, CursorString) else None,
87
+ )
88
+ stmt = (
89
+ select(models.Project)
90
+ .outerjoin(
91
+ models.Experiment,
92
+ models.Project.name == models.Experiment.project_name,
93
+ )
94
+ .where(models.Experiment.project_name.is_(None))
95
+ )
96
+ async with info.context.db() as session:
97
+ projects = await session.stream_scalars(stmt)
98
+ data = [
99
+ Project(
100
+ id_attr=project.id,
101
+ name=project.name,
102
+ gradient_start_color=project.gradient_start_color,
103
+ gradient_end_color=project.gradient_end_color,
104
+ )
105
+ async for project in projects
106
+ ]
107
+ return connection_from_list(data=data, args=args)
108
+
109
+ @strawberry.field
110
+ async def datasets(
111
+ self,
112
+ info: Info[Context, None],
113
+ first: Optional[int] = 50,
114
+ last: Optional[int] = UNSET,
115
+ after: Optional[CursorString] = UNSET,
116
+ before: Optional[CursorString] = UNSET,
117
+ sort: Optional[DatasetSort] = UNSET,
118
+ ) -> Connection[Dataset]:
119
+ args = ConnectionArgs(
120
+ first=first,
121
+ after=after if isinstance(after, CursorString) else None,
122
+ last=last,
123
+ before=before if isinstance(before, CursorString) else None,
124
+ )
125
+ stmt = select(models.Dataset)
126
+ if sort:
127
+ sort_col = getattr(models.Dataset, sort.col.value)
128
+ stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
129
+ async with info.context.db() as session:
130
+ datasets = await session.scalars(stmt)
131
+ return connection_from_list(
132
+ data=[to_gql_dataset(dataset) for dataset in datasets], args=args
133
+ )
134
+
135
+ @strawberry.field
136
+ async def compare_experiments(
137
+ self,
138
+ info: Info[Context, None],
139
+ experiment_ids: List[GlobalID],
140
+ ) -> List[ExperimentComparison]:
141
+ experiment_ids_ = [
142
+ from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
143
+ for experiment_id in experiment_ids
144
+ ]
145
+ if len(set(experiment_ids_)) != len(experiment_ids_):
146
+ raise ValueError("Experiment IDs must be unique.")
147
+
148
+ async with info.context.db() as session:
149
+ validation_result = (
150
+ await session.execute(
151
+ select(
152
+ func.count(distinct(OrmVersion.dataset_id)),
153
+ func.max(OrmVersion.dataset_id),
154
+ func.max(OrmVersion.id),
155
+ func.count(OrmExperiment.id),
156
+ )
157
+ .select_from(OrmVersion)
158
+ .join(
159
+ OrmExperiment,
160
+ OrmExperiment.dataset_version_id == OrmVersion.id,
161
+ )
162
+ .where(
163
+ OrmExperiment.id.in_(experiment_ids_),
164
+ )
165
+ )
166
+ ).first()
167
+ if validation_result is None:
168
+ raise ValueError("No experiments could be found for input IDs.")
169
+
170
+ num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
171
+ if num_datasets != 1:
172
+ raise ValueError("Experiments must belong to the same dataset.")
173
+ if num_resolved_experiment_ids != len(experiment_ids_):
174
+ raise ValueError("Unable to resolve one or more experiment IDs.")
175
+
176
+ revision_ids = (
177
+ select(func.max(OrmRevision.id))
178
+ .join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
179
+ .where(
180
+ and_(
181
+ OrmRevision.dataset_version_id <= version_id,
182
+ OrmExample.dataset_id == dataset_id,
183
+ )
184
+ )
185
+ .group_by(OrmRevision.dataset_example_id)
186
+ .scalar_subquery()
187
+ )
188
+ examples = (
189
+ await session.scalars(
190
+ select(OrmExample)
191
+ .join(OrmRevision, OrmExample.id == OrmRevision.dataset_example_id)
192
+ .where(
193
+ and_(
194
+ OrmRevision.id.in_(revision_ids),
195
+ OrmRevision.revision_kind != "DELETE",
196
+ )
197
+ )
198
+ .order_by(OrmRevision.dataset_example_id.desc())
199
+ )
200
+ ).all()
201
+
202
+ ExampleID: TypeAlias = int
203
+ ExperimentID: TypeAlias = int
204
+ runs: DefaultDict[ExampleID, DefaultDict[ExperimentID, List[OrmRun]]] = defaultdict(
205
+ lambda: defaultdict(list)
206
+ )
207
+ async for run in await session.stream_scalars(
208
+ select(OrmRun)
209
+ .where(
210
+ and_(
211
+ OrmRun.dataset_example_id.in_(example.id for example in examples),
212
+ OrmRun.experiment_id.in_(experiment_ids_),
213
+ )
214
+ )
215
+ .options(joinedload(OrmRun.trace).load_only(OrmTrace.trace_id))
216
+ ):
217
+ runs[run.dataset_example_id][run.experiment_id].append(run)
218
+
219
+ experiment_comparisons = []
220
+ for example in examples:
221
+ run_comparison_items = []
222
+ for experiment_id in experiment_ids_:
223
+ run_comparison_items.append(
224
+ RunComparisonItem(
225
+ experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
226
+ runs=[
227
+ to_gql_experiment_run(run)
228
+ for run in sorted(
229
+ runs[example.id][experiment_id], key=lambda run: run.id
230
+ )
231
+ ],
232
+ )
233
+ )
234
+ experiment_comparisons.append(
235
+ ExperimentComparison(
236
+ example=DatasetExample(
237
+ id_attr=example.id,
238
+ created_at=example.created_at,
239
+ version_id=version_id,
240
+ ),
241
+ run_comparison_items=run_comparison_items,
242
+ )
243
+ )
244
+ return experiment_comparisons
245
+
246
+ @strawberry.field
247
+ async def functionality(self, info: Info[Context, None]) -> "Functionality":
248
+ has_model_inferences = not info.context.model.is_empty
249
+ async with info.context.db() as session:
250
+ has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
251
+ return Functionality(
252
+ model_inferences=has_model_inferences,
253
+ tracing=has_traces,
254
+ )
255
+
256
+ @strawberry.field
257
+ def model(self) -> Model:
258
+ return Model()
259
+
260
+ @strawberry.field
261
+ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
262
+ type_name, node_id = from_global_id(id)
263
+ if type_name == "Dimension":
264
+ dimension = info.context.model.scalar_dimensions[node_id]
265
+ return to_gql_dimension(node_id, dimension)
266
+ elif type_name == "EmbeddingDimension":
267
+ embedding_dimension = info.context.model.embedding_dimensions[node_id]
268
+ return to_gql_embedding_dimension(node_id, embedding_dimension)
269
+ elif type_name == "Project":
270
+ project_stmt = select(
271
+ models.Project.id,
272
+ models.Project.name,
273
+ models.Project.gradient_start_color,
274
+ models.Project.gradient_end_color,
275
+ ).where(models.Project.id == node_id)
276
+ async with info.context.db() as session:
277
+ project = (await session.execute(project_stmt)).first()
278
+ if project is None:
279
+ raise ValueError(f"Unknown project: {id}")
280
+ return Project(
281
+ id_attr=project.id,
282
+ name=project.name,
283
+ gradient_start_color=project.gradient_start_color,
284
+ gradient_end_color=project.gradient_end_color,
285
+ )
286
+ elif type_name == "Trace":
287
+ trace_stmt = select(
288
+ models.Trace.id,
289
+ models.Trace.project_rowid,
290
+ ).where(models.Trace.id == node_id)
291
+ async with info.context.db() as session:
292
+ trace = (await session.execute(trace_stmt)).first()
293
+ if trace is None:
294
+ raise ValueError(f"Unknown trace: {id}")
295
+ return Trace(
296
+ id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
297
+ )
298
+ elif type_name == Span.__name__:
299
+ span_stmt = (
300
+ select(models.Span)
301
+ .options(
302
+ joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
303
+ )
304
+ .where(models.Span.id == node_id)
305
+ )
306
+ async with info.context.db() as session:
307
+ span = await session.scalar(span_stmt)
308
+ if span is None:
309
+ raise ValueError(f"Unknown span: {id}")
310
+ return to_gql_span(span)
311
+ elif type_name == Dataset.__name__:
312
+ dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
313
+ async with info.context.db() as session:
314
+ if (dataset := await session.scalar(dataset_stmt)) is None:
315
+ raise ValueError(f"Unknown dataset: {id}")
316
+ return to_gql_dataset(dataset)
317
+ elif type_name == DatasetExample.__name__:
318
+ example_id = node_id
319
+ latest_revision_id = (
320
+ select(func.max(models.DatasetExampleRevision.id))
321
+ .where(models.DatasetExampleRevision.dataset_example_id == example_id)
322
+ .scalar_subquery()
323
+ )
324
+ async with info.context.db() as session:
325
+ example = await session.scalar(
326
+ select(models.DatasetExample)
327
+ .join(
328
+ models.DatasetExampleRevision,
329
+ onclause=models.DatasetExampleRevision.dataset_example_id
330
+ == models.DatasetExample.id,
331
+ )
332
+ .where(
333
+ and_(
334
+ models.DatasetExample.id == example_id,
335
+ models.DatasetExampleRevision.id == latest_revision_id,
336
+ models.DatasetExampleRevision.revision_kind != "DELETE",
337
+ )
338
+ )
339
+ )
340
+ if not example:
341
+ raise ValueError(f"Unknown dataset example: {id}")
342
+ return DatasetExample(
343
+ id_attr=example.id,
344
+ created_at=example.created_at,
345
+ )
346
+ elif type_name == Experiment.__name__:
347
+ async with info.context.db() as session:
348
+ experiment = await session.scalar(
349
+ select(models.Experiment).where(models.Experiment.id == node_id)
350
+ )
351
+ if not experiment:
352
+ raise ValueError(f"Unknown experiment: {id}")
353
+ return Experiment(
354
+ id_attr=experiment.id,
355
+ name=experiment.name,
356
+ project_name=experiment.project_name,
357
+ description=experiment.description,
358
+ created_at=experiment.created_at,
359
+ updated_at=experiment.updated_at,
360
+ metadata=experiment.metadata_,
361
+ )
362
+ elif type_name == ExperimentRun.__name__:
363
+ async with info.context.db() as session:
364
+ if not (
365
+ run := await session.scalar(
366
+ select(models.ExperimentRun)
367
+ .where(models.ExperimentRun.id == node_id)
368
+ .options(
369
+ joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
370
+ )
371
+ )
372
+ ):
373
+ raise ValueError(f"Unknown experiment run: {id}")
374
+ return to_gql_experiment_run(run)
375
+ raise Exception(f"Unknown node type: {type_name}")
376
+
377
+ @strawberry.field
378
+ def clusters(
379
+ self,
380
+ clusters: List[ClusterInput],
381
+ ) -> List[Cluster]:
382
+ clustered_events: Dict[str, Set[ID]] = defaultdict(set)
383
+ for i, cluster in enumerate(clusters):
384
+ clustered_events[cluster.id or str(i)].update(cluster.event_ids)
385
+ return to_gql_clusters(
386
+ clustered_events=clustered_events,
387
+ )
388
+
389
+ @strawberry.field
390
+ def hdbscan_clustering(
391
+ self,
392
+ info: Info[Context, None],
393
+ event_ids: Annotated[
394
+ List[ID],
395
+ strawberry.argument(
396
+ description="Event ID of the coordinates",
397
+ ),
398
+ ],
399
+ coordinates_2d: Annotated[
400
+ Optional[List[InputCoordinate2D]],
401
+ strawberry.argument(
402
+ description="Point coordinates. Must be either 2D or 3D.",
403
+ ),
404
+ ] = UNSET,
405
+ coordinates_3d: Annotated[
406
+ Optional[List[InputCoordinate3D]],
407
+ strawberry.argument(
408
+ description="Point coordinates. Must be either 2D or 3D.",
409
+ ),
410
+ ] = UNSET,
411
+ min_cluster_size: Annotated[
412
+ int,
413
+ strawberry.argument(
414
+ description="HDBSCAN minimum cluster size",
415
+ ),
416
+ ] = DEFAULT_MIN_CLUSTER_SIZE,
417
+ cluster_min_samples: Annotated[
418
+ int,
419
+ strawberry.argument(
420
+ description="HDBSCAN minimum samples",
421
+ ),
422
+ ] = DEFAULT_MIN_SAMPLES,
423
+ cluster_selection_epsilon: Annotated[
424
+ float,
425
+ strawberry.argument(
426
+ description="HDBSCAN cluster selection epsilon",
427
+ ),
428
+ ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
429
+ ) -> List[Cluster]:
430
+ coordinates_3d = ensure_list(coordinates_3d)
431
+ coordinates_2d = ensure_list(coordinates_2d)
432
+
433
+ if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
434
+ raise ValueError("must specify only one of 2D or 3D coordinates")
435
+
436
+ if len(coordinates_3d) > 0:
437
+ coordinates = list(
438
+ map(
439
+ lambda coord: np.array(
440
+ [coord.x, coord.y, coord.z],
441
+ ),
442
+ coordinates_3d,
443
+ )
444
+ )
445
+ else:
446
+ coordinates = list(
447
+ map(
448
+ lambda coord: np.array(
449
+ [coord.x, coord.y],
450
+ ),
451
+ coordinates_2d,
452
+ )
453
+ )
454
+
455
+ if len(event_ids) != len(coordinates):
456
+ raise ValueError(
457
+ f"length mismatch between "
458
+ f"event_ids ({len(event_ids)}) "
459
+ f"and coordinates ({len(coordinates)})"
460
+ )
461
+
462
+ if len(event_ids) == 0:
463
+ return []
464
+
465
+ grouped_event_ids: Dict[
466
+ Union[InferencesRole, AncillaryInferencesRole],
467
+ List[ID],
468
+ ] = defaultdict(list)
469
+ grouped_coordinates: Dict[
470
+ Union[InferencesRole, AncillaryInferencesRole],
471
+ List[npt.NDArray[np.float64]],
472
+ ] = defaultdict(list)
473
+
474
+ for event_id, coordinate in zip(event_ids, coordinates):
475
+ row_id, inferences_role = unpack_event_id(event_id)
476
+ grouped_coordinates[inferences_role].append(coordinate)
477
+ grouped_event_ids[inferences_role].append(create_event_id(row_id, inferences_role))
478
+
479
+ stacked_event_ids = (
480
+ grouped_event_ids[InferencesRole.primary]
481
+ + grouped_event_ids[InferencesRole.reference]
482
+ + grouped_event_ids[AncillaryInferencesRole.corpus]
483
+ )
484
+ stacked_coordinates = np.stack(
485
+ grouped_coordinates[InferencesRole.primary]
486
+ + grouped_coordinates[InferencesRole.reference]
487
+ + grouped_coordinates[AncillaryInferencesRole.corpus]
488
+ )
489
+
490
+ clusters = Hdbscan(
491
+ min_cluster_size=min_cluster_size,
492
+ min_samples=cluster_min_samples,
493
+ cluster_selection_epsilon=cluster_selection_epsilon,
494
+ ).find_clusters(stacked_coordinates)
495
+
496
+ clustered_events = {
497
+ str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
498
+ for i, cluster in enumerate(clusters)
499
+ }
500
+
501
+ return to_gql_clusters(
502
+ clustered_events=clustered_events,
503
+ )
@@ -1,6 +1,40 @@
1
- from starlette.routing import Route
1
+ from typing import Any, Awaitable, Callable, Mapping, Tuple
2
+
3
+ import wrapt
4
+ from starlette import routing
5
+ from starlette.requests import Request
6
+ from starlette.responses import Response
7
+ from starlette.status import HTTP_403_FORBIDDEN
8
+
9
+ from . import (
10
+ datasets,
11
+ evaluations,
12
+ experiment_evaluations,
13
+ experiment_runs,
14
+ experiments,
15
+ spans,
16
+ traces,
17
+ )
18
+ from .dataset_examples import list_dataset_examples
19
+
20
+
21
+ @wrapt.decorator # type: ignore
22
+ async def forbid_if_readonly(
23
+ wrapped: Callable[[Request], Awaitable[Response]],
24
+ _: Any,
25
+ args: Tuple[Request],
26
+ kwargs: Mapping[str, Any],
27
+ ) -> Response:
28
+ request, *_ = args
29
+ if request.app.state.read_only:
30
+ return Response(status_code=HTTP_403_FORBIDDEN)
31
+ return await wrapped(*args, **kwargs)
32
+
33
+
34
+ class Route(routing.Route):
35
+ def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None:
36
+ super().__init__(path, forbid_if_readonly(endpoint), **kwargs)
2
37
 
3
- from . import evaluations, spans, traces
4
38
 
5
39
  V1_ROUTES = [
6
40
  Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
@@ -8,4 +42,45 @@ V1_ROUTES = [
8
42
  Route("/v1/traces", traces.post_traces, methods=["POST"]),
9
43
  Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
10
44
  Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
45
+ Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
46
+ Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
47
+ Route("/v1/datasets/{id:str}", datasets.get_dataset_by_id, methods=["GET"]),
48
+ Route("/v1/datasets/{id:str}/csv", datasets.get_dataset_csv, methods=["GET"]),
49
+ Route(
50
+ "/v1/datasets/{id:str}/jsonl/openai_ft",
51
+ datasets.get_dataset_jsonl_openai_ft,
52
+ methods=["GET"],
53
+ ),
54
+ Route(
55
+ "/v1/datasets/{id:str}/jsonl/openai_evals",
56
+ datasets.get_dataset_jsonl_openai_evals,
57
+ methods=["GET"],
58
+ ),
59
+ Route("/v1/datasets/{id:str}/examples", list_dataset_examples, methods=["GET"]),
60
+ Route("/v1/datasets/{id:str}/versions", datasets.get_dataset_versions, methods=["GET"]),
61
+ Route(
62
+ "/v1/datasets/{dataset_id:str}/experiments",
63
+ experiments.create_experiment,
64
+ methods=["POST"],
65
+ ),
66
+ Route(
67
+ "/v1/experiments/{experiment_id:str}",
68
+ experiments.read_experiment,
69
+ methods=["GET"],
70
+ ),
71
+ Route(
72
+ "/v1/experiments/{experiment_id:str}/runs",
73
+ experiment_runs.create_experiment_run,
74
+ methods=["POST"],
75
+ ),
76
+ Route(
77
+ "/v1/experiments/{experiment_id:str}/runs",
78
+ experiment_runs.list_experiment_runs,
79
+ methods=["GET"],
80
+ ),
81
+ Route(
82
+ "/v1/experiment_evaluations",
83
+ experiment_evaluations.create_experiment_evaluation,
84
+ methods=["POST"],
85
+ ),
11
86
  ]