arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (118) hide show
  1. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
  2. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
  3. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -21
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/db/bulk_inserter.py +14 -54
  10. phoenix/db/insertion/evaluation.py +6 -6
  11. phoenix/db/insertion/helpers.py +2 -13
  12. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  13. phoenix/db/models.py +4 -236
  14. phoenix/inferences/fixtures.py +23 -23
  15. phoenix/inferences/inferences.py +7 -7
  16. phoenix/inferences/validation.py +1 -1
  17. phoenix/server/api/context.py +0 -18
  18. phoenix/server/api/dataloaders/__init__.py +0 -18
  19. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  20. phoenix/server/api/routers/v1/__init__.py +2 -77
  21. phoenix/server/api/routers/v1/evaluations.py +2 -4
  22. phoenix/server/api/routers/v1/spans.py +1 -3
  23. phoenix/server/api/routers/v1/traces.py +4 -1
  24. phoenix/server/api/schema.py +303 -2
  25. phoenix/server/api/types/Cluster.py +19 -19
  26. phoenix/server/api/types/Dataset.py +63 -282
  27. phoenix/server/api/types/DatasetRole.py +23 -0
  28. phoenix/server/api/types/Dimension.py +29 -30
  29. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  30. phoenix/server/api/types/Event.py +16 -16
  31. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  32. phoenix/server/api/types/Model.py +42 -43
  33. phoenix/server/api/types/Project.py +12 -26
  34. phoenix/server/api/types/Span.py +2 -79
  35. phoenix/server/api/types/TimeSeries.py +6 -6
  36. phoenix/server/api/types/Trace.py +4 -15
  37. phoenix/server/api/types/UMAPPoints.py +1 -1
  38. phoenix/server/api/types/node.py +111 -5
  39. phoenix/server/api/types/pagination.py +52 -10
  40. phoenix/server/app.py +49 -101
  41. phoenix/server/main.py +27 -49
  42. phoenix/server/openapi/docs.py +0 -3
  43. phoenix/server/static/index.js +2595 -3523
  44. phoenix/server/templates/index.html +0 -1
  45. phoenix/services.py +15 -15
  46. phoenix/session/client.py +21 -438
  47. phoenix/session/session.py +37 -47
  48. phoenix/trace/exporter.py +9 -14
  49. phoenix/trace/fixtures.py +7 -133
  50. phoenix/trace/schemas.py +2 -1
  51. phoenix/trace/span_evaluations.py +3 -3
  52. phoenix/trace/trace_dataset.py +6 -6
  53. phoenix/version.py +1 -1
  54. phoenix/datasets/__init__.py +0 -0
  55. phoenix/datasets/evaluators/__init__.py +0 -18
  56. phoenix/datasets/evaluators/code_evaluators.py +0 -99
  57. phoenix/datasets/evaluators/llm_evaluators.py +0 -244
  58. phoenix/datasets/evaluators/utils.py +0 -292
  59. phoenix/datasets/experiments.py +0 -550
  60. phoenix/datasets/tracing.py +0 -85
  61. phoenix/datasets/types.py +0 -178
  62. phoenix/db/insertion/dataset.py +0 -237
  63. phoenix/db/migrations/types.py +0 -29
  64. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  65. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  66. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  67. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  68. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  69. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  70. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  71. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  72. phoenix/server/api/dataloaders/span_projects.py +0 -33
  73. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  74. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  75. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  76. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  77. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  78. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  79. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  80. phoenix/server/api/input_types/DatasetSort.py +0 -17
  81. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  82. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  83. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  84. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  85. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  86. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  87. phoenix/server/api/mutations/__init__.py +0 -13
  88. phoenix/server/api/mutations/auth.py +0 -11
  89. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  90. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  91. phoenix/server/api/mutations/project_mutations.py +0 -47
  92. phoenix/server/api/openapi/__init__.py +0 -0
  93. phoenix/server/api/openapi/main.py +0 -6
  94. phoenix/server/api/openapi/schema.py +0 -16
  95. phoenix/server/api/queries.py +0 -503
  96. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  97. phoenix/server/api/routers/v1/datasets.py +0 -965
  98. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
  99. phoenix/server/api/routers/v1/experiment_runs.py +0 -108
  100. phoenix/server/api/routers/v1/experiments.py +0 -174
  101. phoenix/server/api/types/AnnotatorKind.py +0 -10
  102. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  103. phoenix/server/api/types/DatasetExample.py +0 -85
  104. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  105. phoenix/server/api/types/DatasetVersion.py +0 -14
  106. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  107. phoenix/server/api/types/Experiment.py +0 -140
  108. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  109. phoenix/server/api/types/ExperimentComparison.py +0 -19
  110. phoenix/server/api/types/ExperimentRun.py +0 -91
  111. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  112. phoenix/server/api/types/Inferences.py +0 -80
  113. phoenix/server/api/types/InferencesRole.py +0 -23
  114. phoenix/utilities/json.py +0 -61
  115. phoenix/utilities/re.py +0 -50
  116. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  117. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  118. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -1,503 +0,0 @@
1
- from collections import defaultdict
2
- from typing import DefaultDict, Dict, List, Optional, Set, Union
3
-
4
- import numpy as np
5
- import numpy.typing as npt
6
- import strawberry
7
- from sqlalchemy import and_, distinct, func, select
8
- from sqlalchemy.orm import joinedload
9
- from strawberry import ID, UNSET
10
- from strawberry.relay import Connection, GlobalID, Node
11
- from strawberry.types import Info
12
- from typing_extensions import Annotated, TypeAlias
13
-
14
- from phoenix.db import models
15
- from phoenix.db.models import (
16
- DatasetExample as OrmExample,
17
- )
18
- from phoenix.db.models import (
19
- DatasetExampleRevision as OrmRevision,
20
- )
21
- from phoenix.db.models import (
22
- DatasetVersion as OrmVersion,
23
- )
24
- from phoenix.db.models import (
25
- Experiment as OrmExperiment,
26
- )
27
- from phoenix.db.models import (
28
- ExperimentRun as OrmRun,
29
- )
30
- from phoenix.db.models import (
31
- Trace as OrmTrace,
32
- )
33
- from phoenix.pointcloud.clustering import Hdbscan
34
- from phoenix.server.api.context import Context
35
- from phoenix.server.api.helpers import ensure_list
36
- from phoenix.server.api.input_types.ClusterInput import ClusterInput
37
- from phoenix.server.api.input_types.Coordinates import (
38
- InputCoordinate2D,
39
- InputCoordinate3D,
40
- )
41
- from phoenix.server.api.input_types.DatasetSort import DatasetSort
42
- from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
43
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
44
- from phoenix.server.api.types.DatasetExample import DatasetExample
45
- from phoenix.server.api.types.Dimension import to_gql_dimension
46
- from phoenix.server.api.types.EmbeddingDimension import (
47
- DEFAULT_CLUSTER_SELECTION_EPSILON,
48
- DEFAULT_MIN_CLUSTER_SIZE,
49
- DEFAULT_MIN_SAMPLES,
50
- to_gql_embedding_dimension,
51
- )
52
- from phoenix.server.api.types.Event import create_event_id, unpack_event_id
53
- from phoenix.server.api.types.Experiment import Experiment
54
- from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
55
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
56
- from phoenix.server.api.types.Functionality import Functionality
57
- from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
58
- from phoenix.server.api.types.Model import Model
59
- from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
60
- from phoenix.server.api.types.pagination import (
61
- ConnectionArgs,
62
- CursorString,
63
- connection_from_list,
64
- )
65
- from phoenix.server.api.types.Project import Project
66
- from phoenix.server.api.types.SortDir import SortDir
67
- from phoenix.server.api.types.Span import Span, to_gql_span
68
- from phoenix.server.api.types.Trace import Trace
69
-
70
-
71
- @strawberry.type
72
- class Query:
73
- @strawberry.field
74
- async def projects(
75
- self,
76
- info: Info[Context, None],
77
- first: Optional[int] = 50,
78
- last: Optional[int] = UNSET,
79
- after: Optional[CursorString] = UNSET,
80
- before: Optional[CursorString] = UNSET,
81
- ) -> Connection[Project]:
82
- args = ConnectionArgs(
83
- first=first,
84
- after=after if isinstance(after, CursorString) else None,
85
- last=last,
86
- before=before if isinstance(before, CursorString) else None,
87
- )
88
- stmt = (
89
- select(models.Project)
90
- .outerjoin(
91
- models.Experiment,
92
- models.Project.name == models.Experiment.project_name,
93
- )
94
- .where(models.Experiment.project_name.is_(None))
95
- )
96
- async with info.context.db() as session:
97
- projects = await session.stream_scalars(stmt)
98
- data = [
99
- Project(
100
- id_attr=project.id,
101
- name=project.name,
102
- gradient_start_color=project.gradient_start_color,
103
- gradient_end_color=project.gradient_end_color,
104
- )
105
- async for project in projects
106
- ]
107
- return connection_from_list(data=data, args=args)
108
-
109
- @strawberry.field
110
- async def datasets(
111
- self,
112
- info: Info[Context, None],
113
- first: Optional[int] = 50,
114
- last: Optional[int] = UNSET,
115
- after: Optional[CursorString] = UNSET,
116
- before: Optional[CursorString] = UNSET,
117
- sort: Optional[DatasetSort] = UNSET,
118
- ) -> Connection[Dataset]:
119
- args = ConnectionArgs(
120
- first=first,
121
- after=after if isinstance(after, CursorString) else None,
122
- last=last,
123
- before=before if isinstance(before, CursorString) else None,
124
- )
125
- stmt = select(models.Dataset)
126
- if sort:
127
- sort_col = getattr(models.Dataset, sort.col.value)
128
- stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
129
- async with info.context.db() as session:
130
- datasets = await session.scalars(stmt)
131
- return connection_from_list(
132
- data=[to_gql_dataset(dataset) for dataset in datasets], args=args
133
- )
134
-
135
- @strawberry.field
136
- async def compare_experiments(
137
- self,
138
- info: Info[Context, None],
139
- experiment_ids: List[GlobalID],
140
- ) -> List[ExperimentComparison]:
141
- experiment_ids_ = [
142
- from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
143
- for experiment_id in experiment_ids
144
- ]
145
- if len(set(experiment_ids_)) != len(experiment_ids_):
146
- raise ValueError("Experiment IDs must be unique.")
147
-
148
- async with info.context.db() as session:
149
- validation_result = (
150
- await session.execute(
151
- select(
152
- func.count(distinct(OrmVersion.dataset_id)),
153
- func.max(OrmVersion.dataset_id),
154
- func.max(OrmVersion.id),
155
- func.count(OrmExperiment.id),
156
- )
157
- .select_from(OrmVersion)
158
- .join(
159
- OrmExperiment,
160
- OrmExperiment.dataset_version_id == OrmVersion.id,
161
- )
162
- .where(
163
- OrmExperiment.id.in_(experiment_ids_),
164
- )
165
- )
166
- ).first()
167
- if validation_result is None:
168
- raise ValueError("No experiments could be found for input IDs.")
169
-
170
- num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
171
- if num_datasets != 1:
172
- raise ValueError("Experiments must belong to the same dataset.")
173
- if num_resolved_experiment_ids != len(experiment_ids_):
174
- raise ValueError("Unable to resolve one or more experiment IDs.")
175
-
176
- revision_ids = (
177
- select(func.max(OrmRevision.id))
178
- .join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
179
- .where(
180
- and_(
181
- OrmRevision.dataset_version_id <= version_id,
182
- OrmExample.dataset_id == dataset_id,
183
- )
184
- )
185
- .group_by(OrmRevision.dataset_example_id)
186
- .scalar_subquery()
187
- )
188
- examples = (
189
- await session.scalars(
190
- select(OrmExample)
191
- .join(OrmRevision, OrmExample.id == OrmRevision.dataset_example_id)
192
- .where(
193
- and_(
194
- OrmRevision.id.in_(revision_ids),
195
- OrmRevision.revision_kind != "DELETE",
196
- )
197
- )
198
- .order_by(OrmRevision.dataset_example_id.desc())
199
- )
200
- ).all()
201
-
202
- ExampleID: TypeAlias = int
203
- ExperimentID: TypeAlias = int
204
- runs: DefaultDict[ExampleID, DefaultDict[ExperimentID, List[OrmRun]]] = defaultdict(
205
- lambda: defaultdict(list)
206
- )
207
- async for run in await session.stream_scalars(
208
- select(OrmRun)
209
- .where(
210
- and_(
211
- OrmRun.dataset_example_id.in_(example.id for example in examples),
212
- OrmRun.experiment_id.in_(experiment_ids_),
213
- )
214
- )
215
- .options(joinedload(OrmRun.trace).load_only(OrmTrace.trace_id))
216
- ):
217
- runs[run.dataset_example_id][run.experiment_id].append(run)
218
-
219
- experiment_comparisons = []
220
- for example in examples:
221
- run_comparison_items = []
222
- for experiment_id in experiment_ids_:
223
- run_comparison_items.append(
224
- RunComparisonItem(
225
- experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
226
- runs=[
227
- to_gql_experiment_run(run)
228
- for run in sorted(
229
- runs[example.id][experiment_id], key=lambda run: run.id
230
- )
231
- ],
232
- )
233
- )
234
- experiment_comparisons.append(
235
- ExperimentComparison(
236
- example=DatasetExample(
237
- id_attr=example.id,
238
- created_at=example.created_at,
239
- version_id=version_id,
240
- ),
241
- run_comparison_items=run_comparison_items,
242
- )
243
- )
244
- return experiment_comparisons
245
-
246
- @strawberry.field
247
- async def functionality(self, info: Info[Context, None]) -> "Functionality":
248
- has_model_inferences = not info.context.model.is_empty
249
- async with info.context.db() as session:
250
- has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
251
- return Functionality(
252
- model_inferences=has_model_inferences,
253
- tracing=has_traces,
254
- )
255
-
256
- @strawberry.field
257
- def model(self) -> Model:
258
- return Model()
259
-
260
- @strawberry.field
261
- async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
262
- type_name, node_id = from_global_id(id)
263
- if type_name == "Dimension":
264
- dimension = info.context.model.scalar_dimensions[node_id]
265
- return to_gql_dimension(node_id, dimension)
266
- elif type_name == "EmbeddingDimension":
267
- embedding_dimension = info.context.model.embedding_dimensions[node_id]
268
- return to_gql_embedding_dimension(node_id, embedding_dimension)
269
- elif type_name == "Project":
270
- project_stmt = select(
271
- models.Project.id,
272
- models.Project.name,
273
- models.Project.gradient_start_color,
274
- models.Project.gradient_end_color,
275
- ).where(models.Project.id == node_id)
276
- async with info.context.db() as session:
277
- project = (await session.execute(project_stmt)).first()
278
- if project is None:
279
- raise ValueError(f"Unknown project: {id}")
280
- return Project(
281
- id_attr=project.id,
282
- name=project.name,
283
- gradient_start_color=project.gradient_start_color,
284
- gradient_end_color=project.gradient_end_color,
285
- )
286
- elif type_name == "Trace":
287
- trace_stmt = select(
288
- models.Trace.id,
289
- models.Trace.project_rowid,
290
- ).where(models.Trace.id == node_id)
291
- async with info.context.db() as session:
292
- trace = (await session.execute(trace_stmt)).first()
293
- if trace is None:
294
- raise ValueError(f"Unknown trace: {id}")
295
- return Trace(
296
- id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
297
- )
298
- elif type_name == Span.__name__:
299
- span_stmt = (
300
- select(models.Span)
301
- .options(
302
- joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
303
- )
304
- .where(models.Span.id == node_id)
305
- )
306
- async with info.context.db() as session:
307
- span = await session.scalar(span_stmt)
308
- if span is None:
309
- raise ValueError(f"Unknown span: {id}")
310
- return to_gql_span(span)
311
- elif type_name == Dataset.__name__:
312
- dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
313
- async with info.context.db() as session:
314
- if (dataset := await session.scalar(dataset_stmt)) is None:
315
- raise ValueError(f"Unknown dataset: {id}")
316
- return to_gql_dataset(dataset)
317
- elif type_name == DatasetExample.__name__:
318
- example_id = node_id
319
- latest_revision_id = (
320
- select(func.max(models.DatasetExampleRevision.id))
321
- .where(models.DatasetExampleRevision.dataset_example_id == example_id)
322
- .scalar_subquery()
323
- )
324
- async with info.context.db() as session:
325
- example = await session.scalar(
326
- select(models.DatasetExample)
327
- .join(
328
- models.DatasetExampleRevision,
329
- onclause=models.DatasetExampleRevision.dataset_example_id
330
- == models.DatasetExample.id,
331
- )
332
- .where(
333
- and_(
334
- models.DatasetExample.id == example_id,
335
- models.DatasetExampleRevision.id == latest_revision_id,
336
- models.DatasetExampleRevision.revision_kind != "DELETE",
337
- )
338
- )
339
- )
340
- if not example:
341
- raise ValueError(f"Unknown dataset example: {id}")
342
- return DatasetExample(
343
- id_attr=example.id,
344
- created_at=example.created_at,
345
- )
346
- elif type_name == Experiment.__name__:
347
- async with info.context.db() as session:
348
- experiment = await session.scalar(
349
- select(models.Experiment).where(models.Experiment.id == node_id)
350
- )
351
- if not experiment:
352
- raise ValueError(f"Unknown experiment: {id}")
353
- return Experiment(
354
- id_attr=experiment.id,
355
- name=experiment.name,
356
- project_name=experiment.project_name,
357
- description=experiment.description,
358
- created_at=experiment.created_at,
359
- updated_at=experiment.updated_at,
360
- metadata=experiment.metadata_,
361
- )
362
- elif type_name == ExperimentRun.__name__:
363
- async with info.context.db() as session:
364
- if not (
365
- run := await session.scalar(
366
- select(models.ExperimentRun)
367
- .where(models.ExperimentRun.id == node_id)
368
- .options(
369
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
370
- )
371
- )
372
- ):
373
- raise ValueError(f"Unknown experiment run: {id}")
374
- return to_gql_experiment_run(run)
375
- raise Exception(f"Unknown node type: {type_name}")
376
-
377
- @strawberry.field
378
- def clusters(
379
- self,
380
- clusters: List[ClusterInput],
381
- ) -> List[Cluster]:
382
- clustered_events: Dict[str, Set[ID]] = defaultdict(set)
383
- for i, cluster in enumerate(clusters):
384
- clustered_events[cluster.id or str(i)].update(cluster.event_ids)
385
- return to_gql_clusters(
386
- clustered_events=clustered_events,
387
- )
388
-
389
- @strawberry.field
390
- def hdbscan_clustering(
391
- self,
392
- info: Info[Context, None],
393
- event_ids: Annotated[
394
- List[ID],
395
- strawberry.argument(
396
- description="Event ID of the coordinates",
397
- ),
398
- ],
399
- coordinates_2d: Annotated[
400
- Optional[List[InputCoordinate2D]],
401
- strawberry.argument(
402
- description="Point coordinates. Must be either 2D or 3D.",
403
- ),
404
- ] = UNSET,
405
- coordinates_3d: Annotated[
406
- Optional[List[InputCoordinate3D]],
407
- strawberry.argument(
408
- description="Point coordinates. Must be either 2D or 3D.",
409
- ),
410
- ] = UNSET,
411
- min_cluster_size: Annotated[
412
- int,
413
- strawberry.argument(
414
- description="HDBSCAN minimum cluster size",
415
- ),
416
- ] = DEFAULT_MIN_CLUSTER_SIZE,
417
- cluster_min_samples: Annotated[
418
- int,
419
- strawberry.argument(
420
- description="HDBSCAN minimum samples",
421
- ),
422
- ] = DEFAULT_MIN_SAMPLES,
423
- cluster_selection_epsilon: Annotated[
424
- float,
425
- strawberry.argument(
426
- description="HDBSCAN cluster selection epsilon",
427
- ),
428
- ] = DEFAULT_CLUSTER_SELECTION_EPSILON,
429
- ) -> List[Cluster]:
430
- coordinates_3d = ensure_list(coordinates_3d)
431
- coordinates_2d = ensure_list(coordinates_2d)
432
-
433
- if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
434
- raise ValueError("must specify only one of 2D or 3D coordinates")
435
-
436
- if len(coordinates_3d) > 0:
437
- coordinates = list(
438
- map(
439
- lambda coord: np.array(
440
- [coord.x, coord.y, coord.z],
441
- ),
442
- coordinates_3d,
443
- )
444
- )
445
- else:
446
- coordinates = list(
447
- map(
448
- lambda coord: np.array(
449
- [coord.x, coord.y],
450
- ),
451
- coordinates_2d,
452
- )
453
- )
454
-
455
- if len(event_ids) != len(coordinates):
456
- raise ValueError(
457
- f"length mismatch between "
458
- f"event_ids ({len(event_ids)}) "
459
- f"and coordinates ({len(coordinates)})"
460
- )
461
-
462
- if len(event_ids) == 0:
463
- return []
464
-
465
- grouped_event_ids: Dict[
466
- Union[InferencesRole, AncillaryInferencesRole],
467
- List[ID],
468
- ] = defaultdict(list)
469
- grouped_coordinates: Dict[
470
- Union[InferencesRole, AncillaryInferencesRole],
471
- List[npt.NDArray[np.float64]],
472
- ] = defaultdict(list)
473
-
474
- for event_id, coordinate in zip(event_ids, coordinates):
475
- row_id, inferences_role = unpack_event_id(event_id)
476
- grouped_coordinates[inferences_role].append(coordinate)
477
- grouped_event_ids[inferences_role].append(create_event_id(row_id, inferences_role))
478
-
479
- stacked_event_ids = (
480
- grouped_event_ids[InferencesRole.primary]
481
- + grouped_event_ids[InferencesRole.reference]
482
- + grouped_event_ids[AncillaryInferencesRole.corpus]
483
- )
484
- stacked_coordinates = np.stack(
485
- grouped_coordinates[InferencesRole.primary]
486
- + grouped_coordinates[InferencesRole.reference]
487
- + grouped_coordinates[AncillaryInferencesRole.corpus]
488
- )
489
-
490
- clusters = Hdbscan(
491
- min_cluster_size=min_cluster_size,
492
- min_samples=cluster_min_samples,
493
- cluster_selection_epsilon=cluster_selection_epsilon,
494
- ).find_clusters(stacked_coordinates)
495
-
496
- clustered_events = {
497
- str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
498
- for i, cluster in enumerate(clusters)
499
- }
500
-
501
- return to_gql_clusters(
502
- clustered_events=clustered_events,
503
- )
@@ -1,178 +0,0 @@
1
- from sqlalchemy import and_, func, select
2
- from starlette.requests import Request
3
- from starlette.responses import JSONResponse, Response
4
- from starlette.status import HTTP_404_NOT_FOUND
5
- from strawberry.relay import GlobalID
6
-
7
- from phoenix.db.models import Dataset, DatasetExample, DatasetExampleRevision, DatasetVersion
8
-
9
-
10
- async def list_dataset_examples(request: Request) -> Response:
11
- """
12
- summary: Get dataset examples by dataset ID
13
- operationId: getDatasetExamples
14
- tags:
15
- - datasets
16
- parameters:
17
- - in: path
18
- name: id
19
- required: true
20
- schema:
21
- type: string
22
- description: Dataset ID
23
- - in: query
24
- name: version-id
25
- schema:
26
- type: string
27
- description: Dataset version ID. If omitted, returns the latest version.
28
- responses:
29
- 200:
30
- description: Success
31
- content:
32
- application/json:
33
- schema:
34
- type: object
35
- properties:
36
- data:
37
- type: object
38
- properties:
39
- dataset_id:
40
- type: string
41
- description: ID of the dataset
42
- version_id:
43
- type: string
44
- description: ID of the version
45
- examples:
46
- type: array
47
- items:
48
- type: object
49
- properties:
50
- id:
51
- type: string
52
- description: ID of the dataset example
53
- input:
54
- type: object
55
- description: Input data of the example
56
- output:
57
- type: object
58
- description: Output data of the example
59
- metadata:
60
- type: object
61
- description: Metadata of the example
62
- updated_at:
63
- type: string
64
- format: date-time
65
- description: ISO formatted timestamp of when the example was updated
66
- required:
67
- - id
68
- - input
69
- - output
70
- - metadata
71
- - updated_at
72
- required:
73
- - dataset_id
74
- - version_id
75
- - examples
76
- 403:
77
- description: Forbidden
78
- 404:
79
- description: Dataset does not exist.
80
- """
81
- dataset_id = GlobalID.from_id(request.path_params["id"])
82
- raw_version_id = request.query_params.get("version-id")
83
- version_id = GlobalID.from_id(raw_version_id) if raw_version_id else None
84
-
85
- if (dataset_type := dataset_id.type_name) != "Dataset":
86
- return Response(
87
- content=f"ID {dataset_id} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
88
- )
89
-
90
- if version_id and (version_type := version_id.type_name) != "DatasetVersion":
91
- return Response(
92
- content=f"ID {version_id} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
93
- )
94
-
95
- async with request.app.state.db() as session:
96
- if (
97
- resolved_dataset_id := await session.scalar(
98
- select(Dataset.id).where(Dataset.id == int(dataset_id.node_id))
99
- )
100
- ) is None:
101
- return Response(
102
- content=f"No dataset with id {dataset_id} can be found.",
103
- status_code=HTTP_404_NOT_FOUND,
104
- )
105
-
106
- # Subquery to find the maximum created_at for each dataset_example_id
107
- # timestamp tiebreaks are resolved by the largest id
108
- partial_subquery = select(
109
- func.max(DatasetExampleRevision.id).label("max_id"),
110
- ).group_by(DatasetExampleRevision.dataset_example_id)
111
-
112
- if version_id:
113
- if (
114
- resolved_version_id := await session.scalar(
115
- select(DatasetVersion.id).where(
116
- and_(
117
- DatasetVersion.dataset_id == resolved_dataset_id,
118
- DatasetVersion.id == int(version_id.node_id),
119
- )
120
- )
121
- )
122
- ) is None:
123
- return Response(
124
- content=f"No dataset version with id {version_id} can be found.",
125
- status_code=HTTP_404_NOT_FOUND,
126
- )
127
- # if a version_id is provided, filter the subquery to only include revisions from that
128
- partial_subquery = partial_subquery.filter(
129
- DatasetExampleRevision.dataset_version_id <= resolved_version_id
130
- )
131
- else:
132
- if (
133
- resolved_version_id := await session.scalar(
134
- select(func.max(DatasetVersion.id)).where(
135
- DatasetVersion.dataset_id == resolved_dataset_id
136
- )
137
- )
138
- ) is None:
139
- return Response(
140
- content="Dataset has no versions.",
141
- status_code=HTTP_404_NOT_FOUND,
142
- )
143
-
144
- subquery = partial_subquery.subquery()
145
- # Query for the most recent example revisions that are not deleted
146
- query = (
147
- select(DatasetExample, DatasetExampleRevision)
148
- .join(
149
- DatasetExampleRevision,
150
- DatasetExample.id == DatasetExampleRevision.dataset_example_id,
151
- )
152
- .join(
153
- subquery,
154
- (subquery.c.max_id == DatasetExampleRevision.id),
155
- )
156
- .filter(DatasetExample.dataset_id == resolved_dataset_id)
157
- .filter(DatasetExampleRevision.revision_kind != "DELETE")
158
- .order_by(DatasetExample.id.asc())
159
- )
160
- examples = [
161
- {
162
- "id": str(GlobalID("DatasetExample", str(example.id))),
163
- "input": revision.input,
164
- "output": revision.output,
165
- "metadata": revision.metadata_,
166
- "updated_at": revision.created_at.isoformat(),
167
- }
168
- async for example, revision in await session.stream(query)
169
- ]
170
- return JSONResponse(
171
- {
172
- "data": {
173
- "dataset_id": str(GlobalID("Dataset", str(resolved_dataset_id))),
174
- "version_id": str(GlobalID("DatasetVersion", str(resolved_version_id))),
175
- "examples": examples,
176
- }
177
- }
178
- )