arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +27 -0
- phoenix/config.py +7 -21
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +62 -64
- phoenix/core/model_schema_adapter.py +25 -27
- phoenix/db/bulk_inserter.py +14 -54
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +2 -13
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
- phoenix/db/models.py +4 -236
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +0 -18
- phoenix/server/api/dataloaders/__init__.py +0 -18
- phoenix/server/api/dataloaders/span_descendants.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -77
- phoenix/server/api/routers/v1/evaluations.py +2 -4
- phoenix/server/api/routers/v1/spans.py +1 -3
- phoenix/server/api/routers/v1/traces.py +4 -1
- phoenix/server/api/schema.py +303 -2
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/Dataset.py +63 -282
- phoenix/server/api/types/DatasetRole.py +23 -0
- phoenix/server/api/types/Dimension.py +29 -30
- phoenix/server/api/types/EmbeddingDimension.py +34 -40
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
- phoenix/server/api/types/Model.py +42 -43
- phoenix/server/api/types/Project.py +12 -26
- phoenix/server/api/types/Span.py +2 -79
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +4 -15
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +111 -5
- phoenix/server/api/types/pagination.py +52 -10
- phoenix/server/app.py +49 -101
- phoenix/server/main.py +27 -49
- phoenix/server/openapi/docs.py +0 -3
- phoenix/server/static/index.js +2595 -3523
- phoenix/server/templates/index.html +0 -1
- phoenix/services.py +15 -15
- phoenix/session/client.py +21 -438
- phoenix/session/session.py +37 -47
- phoenix/trace/exporter.py +9 -14
- phoenix/trace/fixtures.py +7 -133
- phoenix/trace/schemas.py +2 -1
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/version.py +1 -1
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators/__init__.py +0 -18
- phoenix/datasets/evaluators/code_evaluators.py +0 -99
- phoenix/datasets/evaluators/llm_evaluators.py +0 -244
- phoenix/datasets/evaluators/utils.py +0 -292
- phoenix/datasets/experiments.py +0 -550
- phoenix/datasets/tracing.py +0 -85
- phoenix/datasets/types.py +0 -178
- phoenix/db/insertion/dataset.py +0 -237
- phoenix/db/migrations/types.py +0 -29
- phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
- phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
- phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
- phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
- phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
- phoenix/server/api/dataloaders/project_by_name.py +0 -31
- phoenix/server/api/dataloaders/span_projects.py +0 -33
- phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
- phoenix/server/api/helpers/dataset_helpers.py +0 -179
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
- phoenix/server/api/input_types/ClearProjectInput.py +0 -15
- phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
- phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
- phoenix/server/api/input_types/DatasetSort.py +0 -17
- phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
- phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
- phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
- phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
- phoenix/server/api/mutations/__init__.py +0 -13
- phoenix/server/api/mutations/auth.py +0 -11
- phoenix/server/api/mutations/dataset_mutations.py +0 -520
- phoenix/server/api/mutations/experiment_mutations.py +0 -65
- phoenix/server/api/mutations/project_mutations.py +0 -47
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +0 -6
- phoenix/server/api/openapi/schema.py +0 -16
- phoenix/server/api/queries.py +0 -503
- phoenix/server/api/routers/v1/dataset_examples.py +0 -178
- phoenix/server/api/routers/v1/datasets.py +0 -965
- phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
- phoenix/server/api/routers/v1/experiment_runs.py +0 -108
- phoenix/server/api/routers/v1/experiments.py +0 -174
- phoenix/server/api/types/AnnotatorKind.py +0 -10
- phoenix/server/api/types/CreateDatasetPayload.py +0 -8
- phoenix/server/api/types/DatasetExample.py +0 -85
- phoenix/server/api/types/DatasetExampleRevision.py +0 -34
- phoenix/server/api/types/DatasetVersion.py +0 -14
- phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
- phoenix/server/api/types/Experiment.py +0 -140
- phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
- phoenix/server/api/types/ExperimentComparison.py +0 -19
- phoenix/server/api/types/ExperimentRun.py +0 -91
- phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
- phoenix/server/api/types/Inferences.py +0 -80
- phoenix/server/api/types/InferencesRole.py +0 -23
- phoenix/utilities/json.py +0 -61
- phoenix/utilities/re.py +0 -50
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
phoenix/server/api/queries.py
DELETED
|
@@ -1,503 +0,0 @@
|
|
|
1
|
-
from collections import defaultdict
|
|
2
|
-
from typing import DefaultDict, Dict, List, Optional, Set, Union
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
|
-
import strawberry
|
|
7
|
-
from sqlalchemy import and_, distinct, func, select
|
|
8
|
-
from sqlalchemy.orm import joinedload
|
|
9
|
-
from strawberry import ID, UNSET
|
|
10
|
-
from strawberry.relay import Connection, GlobalID, Node
|
|
11
|
-
from strawberry.types import Info
|
|
12
|
-
from typing_extensions import Annotated, TypeAlias
|
|
13
|
-
|
|
14
|
-
from phoenix.db import models
|
|
15
|
-
from phoenix.db.models import (
|
|
16
|
-
DatasetExample as OrmExample,
|
|
17
|
-
)
|
|
18
|
-
from phoenix.db.models import (
|
|
19
|
-
DatasetExampleRevision as OrmRevision,
|
|
20
|
-
)
|
|
21
|
-
from phoenix.db.models import (
|
|
22
|
-
DatasetVersion as OrmVersion,
|
|
23
|
-
)
|
|
24
|
-
from phoenix.db.models import (
|
|
25
|
-
Experiment as OrmExperiment,
|
|
26
|
-
)
|
|
27
|
-
from phoenix.db.models import (
|
|
28
|
-
ExperimentRun as OrmRun,
|
|
29
|
-
)
|
|
30
|
-
from phoenix.db.models import (
|
|
31
|
-
Trace as OrmTrace,
|
|
32
|
-
)
|
|
33
|
-
from phoenix.pointcloud.clustering import Hdbscan
|
|
34
|
-
from phoenix.server.api.context import Context
|
|
35
|
-
from phoenix.server.api.helpers import ensure_list
|
|
36
|
-
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
37
|
-
from phoenix.server.api.input_types.Coordinates import (
|
|
38
|
-
InputCoordinate2D,
|
|
39
|
-
InputCoordinate3D,
|
|
40
|
-
)
|
|
41
|
-
from phoenix.server.api.input_types.DatasetSort import DatasetSort
|
|
42
|
-
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
43
|
-
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
44
|
-
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
45
|
-
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
46
|
-
from phoenix.server.api.types.EmbeddingDimension import (
|
|
47
|
-
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
48
|
-
DEFAULT_MIN_CLUSTER_SIZE,
|
|
49
|
-
DEFAULT_MIN_SAMPLES,
|
|
50
|
-
to_gql_embedding_dimension,
|
|
51
|
-
)
|
|
52
|
-
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
53
|
-
from phoenix.server.api.types.Experiment import Experiment
|
|
54
|
-
from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
|
|
55
|
-
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
56
|
-
from phoenix.server.api.types.Functionality import Functionality
|
|
57
|
-
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
58
|
-
from phoenix.server.api.types.Model import Model
|
|
59
|
-
from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
|
|
60
|
-
from phoenix.server.api.types.pagination import (
|
|
61
|
-
ConnectionArgs,
|
|
62
|
-
CursorString,
|
|
63
|
-
connection_from_list,
|
|
64
|
-
)
|
|
65
|
-
from phoenix.server.api.types.Project import Project
|
|
66
|
-
from phoenix.server.api.types.SortDir import SortDir
|
|
67
|
-
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
68
|
-
from phoenix.server.api.types.Trace import Trace
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
@strawberry.type
|
|
72
|
-
class Query:
|
|
73
|
-
@strawberry.field
|
|
74
|
-
async def projects(
|
|
75
|
-
self,
|
|
76
|
-
info: Info[Context, None],
|
|
77
|
-
first: Optional[int] = 50,
|
|
78
|
-
last: Optional[int] = UNSET,
|
|
79
|
-
after: Optional[CursorString] = UNSET,
|
|
80
|
-
before: Optional[CursorString] = UNSET,
|
|
81
|
-
) -> Connection[Project]:
|
|
82
|
-
args = ConnectionArgs(
|
|
83
|
-
first=first,
|
|
84
|
-
after=after if isinstance(after, CursorString) else None,
|
|
85
|
-
last=last,
|
|
86
|
-
before=before if isinstance(before, CursorString) else None,
|
|
87
|
-
)
|
|
88
|
-
stmt = (
|
|
89
|
-
select(models.Project)
|
|
90
|
-
.outerjoin(
|
|
91
|
-
models.Experiment,
|
|
92
|
-
models.Project.name == models.Experiment.project_name,
|
|
93
|
-
)
|
|
94
|
-
.where(models.Experiment.project_name.is_(None))
|
|
95
|
-
)
|
|
96
|
-
async with info.context.db() as session:
|
|
97
|
-
projects = await session.stream_scalars(stmt)
|
|
98
|
-
data = [
|
|
99
|
-
Project(
|
|
100
|
-
id_attr=project.id,
|
|
101
|
-
name=project.name,
|
|
102
|
-
gradient_start_color=project.gradient_start_color,
|
|
103
|
-
gradient_end_color=project.gradient_end_color,
|
|
104
|
-
)
|
|
105
|
-
async for project in projects
|
|
106
|
-
]
|
|
107
|
-
return connection_from_list(data=data, args=args)
|
|
108
|
-
|
|
109
|
-
@strawberry.field
|
|
110
|
-
async def datasets(
|
|
111
|
-
self,
|
|
112
|
-
info: Info[Context, None],
|
|
113
|
-
first: Optional[int] = 50,
|
|
114
|
-
last: Optional[int] = UNSET,
|
|
115
|
-
after: Optional[CursorString] = UNSET,
|
|
116
|
-
before: Optional[CursorString] = UNSET,
|
|
117
|
-
sort: Optional[DatasetSort] = UNSET,
|
|
118
|
-
) -> Connection[Dataset]:
|
|
119
|
-
args = ConnectionArgs(
|
|
120
|
-
first=first,
|
|
121
|
-
after=after if isinstance(after, CursorString) else None,
|
|
122
|
-
last=last,
|
|
123
|
-
before=before if isinstance(before, CursorString) else None,
|
|
124
|
-
)
|
|
125
|
-
stmt = select(models.Dataset)
|
|
126
|
-
if sort:
|
|
127
|
-
sort_col = getattr(models.Dataset, sort.col.value)
|
|
128
|
-
stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
|
|
129
|
-
async with info.context.db() as session:
|
|
130
|
-
datasets = await session.scalars(stmt)
|
|
131
|
-
return connection_from_list(
|
|
132
|
-
data=[to_gql_dataset(dataset) for dataset in datasets], args=args
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
@strawberry.field
|
|
136
|
-
async def compare_experiments(
|
|
137
|
-
self,
|
|
138
|
-
info: Info[Context, None],
|
|
139
|
-
experiment_ids: List[GlobalID],
|
|
140
|
-
) -> List[ExperimentComparison]:
|
|
141
|
-
experiment_ids_ = [
|
|
142
|
-
from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
|
|
143
|
-
for experiment_id in experiment_ids
|
|
144
|
-
]
|
|
145
|
-
if len(set(experiment_ids_)) != len(experiment_ids_):
|
|
146
|
-
raise ValueError("Experiment IDs must be unique.")
|
|
147
|
-
|
|
148
|
-
async with info.context.db() as session:
|
|
149
|
-
validation_result = (
|
|
150
|
-
await session.execute(
|
|
151
|
-
select(
|
|
152
|
-
func.count(distinct(OrmVersion.dataset_id)),
|
|
153
|
-
func.max(OrmVersion.dataset_id),
|
|
154
|
-
func.max(OrmVersion.id),
|
|
155
|
-
func.count(OrmExperiment.id),
|
|
156
|
-
)
|
|
157
|
-
.select_from(OrmVersion)
|
|
158
|
-
.join(
|
|
159
|
-
OrmExperiment,
|
|
160
|
-
OrmExperiment.dataset_version_id == OrmVersion.id,
|
|
161
|
-
)
|
|
162
|
-
.where(
|
|
163
|
-
OrmExperiment.id.in_(experiment_ids_),
|
|
164
|
-
)
|
|
165
|
-
)
|
|
166
|
-
).first()
|
|
167
|
-
if validation_result is None:
|
|
168
|
-
raise ValueError("No experiments could be found for input IDs.")
|
|
169
|
-
|
|
170
|
-
num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
|
|
171
|
-
if num_datasets != 1:
|
|
172
|
-
raise ValueError("Experiments must belong to the same dataset.")
|
|
173
|
-
if num_resolved_experiment_ids != len(experiment_ids_):
|
|
174
|
-
raise ValueError("Unable to resolve one or more experiment IDs.")
|
|
175
|
-
|
|
176
|
-
revision_ids = (
|
|
177
|
-
select(func.max(OrmRevision.id))
|
|
178
|
-
.join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
|
|
179
|
-
.where(
|
|
180
|
-
and_(
|
|
181
|
-
OrmRevision.dataset_version_id <= version_id,
|
|
182
|
-
OrmExample.dataset_id == dataset_id,
|
|
183
|
-
)
|
|
184
|
-
)
|
|
185
|
-
.group_by(OrmRevision.dataset_example_id)
|
|
186
|
-
.scalar_subquery()
|
|
187
|
-
)
|
|
188
|
-
examples = (
|
|
189
|
-
await session.scalars(
|
|
190
|
-
select(OrmExample)
|
|
191
|
-
.join(OrmRevision, OrmExample.id == OrmRevision.dataset_example_id)
|
|
192
|
-
.where(
|
|
193
|
-
and_(
|
|
194
|
-
OrmRevision.id.in_(revision_ids),
|
|
195
|
-
OrmRevision.revision_kind != "DELETE",
|
|
196
|
-
)
|
|
197
|
-
)
|
|
198
|
-
.order_by(OrmRevision.dataset_example_id.desc())
|
|
199
|
-
)
|
|
200
|
-
).all()
|
|
201
|
-
|
|
202
|
-
ExampleID: TypeAlias = int
|
|
203
|
-
ExperimentID: TypeAlias = int
|
|
204
|
-
runs: DefaultDict[ExampleID, DefaultDict[ExperimentID, List[OrmRun]]] = defaultdict(
|
|
205
|
-
lambda: defaultdict(list)
|
|
206
|
-
)
|
|
207
|
-
async for run in await session.stream_scalars(
|
|
208
|
-
select(OrmRun)
|
|
209
|
-
.where(
|
|
210
|
-
and_(
|
|
211
|
-
OrmRun.dataset_example_id.in_(example.id for example in examples),
|
|
212
|
-
OrmRun.experiment_id.in_(experiment_ids_),
|
|
213
|
-
)
|
|
214
|
-
)
|
|
215
|
-
.options(joinedload(OrmRun.trace).load_only(OrmTrace.trace_id))
|
|
216
|
-
):
|
|
217
|
-
runs[run.dataset_example_id][run.experiment_id].append(run)
|
|
218
|
-
|
|
219
|
-
experiment_comparisons = []
|
|
220
|
-
for example in examples:
|
|
221
|
-
run_comparison_items = []
|
|
222
|
-
for experiment_id in experiment_ids_:
|
|
223
|
-
run_comparison_items.append(
|
|
224
|
-
RunComparisonItem(
|
|
225
|
-
experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
|
|
226
|
-
runs=[
|
|
227
|
-
to_gql_experiment_run(run)
|
|
228
|
-
for run in sorted(
|
|
229
|
-
runs[example.id][experiment_id], key=lambda run: run.id
|
|
230
|
-
)
|
|
231
|
-
],
|
|
232
|
-
)
|
|
233
|
-
)
|
|
234
|
-
experiment_comparisons.append(
|
|
235
|
-
ExperimentComparison(
|
|
236
|
-
example=DatasetExample(
|
|
237
|
-
id_attr=example.id,
|
|
238
|
-
created_at=example.created_at,
|
|
239
|
-
version_id=version_id,
|
|
240
|
-
),
|
|
241
|
-
run_comparison_items=run_comparison_items,
|
|
242
|
-
)
|
|
243
|
-
)
|
|
244
|
-
return experiment_comparisons
|
|
245
|
-
|
|
246
|
-
@strawberry.field
|
|
247
|
-
async def functionality(self, info: Info[Context, None]) -> "Functionality":
|
|
248
|
-
has_model_inferences = not info.context.model.is_empty
|
|
249
|
-
async with info.context.db() as session:
|
|
250
|
-
has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
|
|
251
|
-
return Functionality(
|
|
252
|
-
model_inferences=has_model_inferences,
|
|
253
|
-
tracing=has_traces,
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
@strawberry.field
|
|
257
|
-
def model(self) -> Model:
|
|
258
|
-
return Model()
|
|
259
|
-
|
|
260
|
-
@strawberry.field
|
|
261
|
-
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
|
|
262
|
-
type_name, node_id = from_global_id(id)
|
|
263
|
-
if type_name == "Dimension":
|
|
264
|
-
dimension = info.context.model.scalar_dimensions[node_id]
|
|
265
|
-
return to_gql_dimension(node_id, dimension)
|
|
266
|
-
elif type_name == "EmbeddingDimension":
|
|
267
|
-
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
268
|
-
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
269
|
-
elif type_name == "Project":
|
|
270
|
-
project_stmt = select(
|
|
271
|
-
models.Project.id,
|
|
272
|
-
models.Project.name,
|
|
273
|
-
models.Project.gradient_start_color,
|
|
274
|
-
models.Project.gradient_end_color,
|
|
275
|
-
).where(models.Project.id == node_id)
|
|
276
|
-
async with info.context.db() as session:
|
|
277
|
-
project = (await session.execute(project_stmt)).first()
|
|
278
|
-
if project is None:
|
|
279
|
-
raise ValueError(f"Unknown project: {id}")
|
|
280
|
-
return Project(
|
|
281
|
-
id_attr=project.id,
|
|
282
|
-
name=project.name,
|
|
283
|
-
gradient_start_color=project.gradient_start_color,
|
|
284
|
-
gradient_end_color=project.gradient_end_color,
|
|
285
|
-
)
|
|
286
|
-
elif type_name == "Trace":
|
|
287
|
-
trace_stmt = select(
|
|
288
|
-
models.Trace.id,
|
|
289
|
-
models.Trace.project_rowid,
|
|
290
|
-
).where(models.Trace.id == node_id)
|
|
291
|
-
async with info.context.db() as session:
|
|
292
|
-
trace = (await session.execute(trace_stmt)).first()
|
|
293
|
-
if trace is None:
|
|
294
|
-
raise ValueError(f"Unknown trace: {id}")
|
|
295
|
-
return Trace(
|
|
296
|
-
id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
|
|
297
|
-
)
|
|
298
|
-
elif type_name == Span.__name__:
|
|
299
|
-
span_stmt = (
|
|
300
|
-
select(models.Span)
|
|
301
|
-
.options(
|
|
302
|
-
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
|
|
303
|
-
)
|
|
304
|
-
.where(models.Span.id == node_id)
|
|
305
|
-
)
|
|
306
|
-
async with info.context.db() as session:
|
|
307
|
-
span = await session.scalar(span_stmt)
|
|
308
|
-
if span is None:
|
|
309
|
-
raise ValueError(f"Unknown span: {id}")
|
|
310
|
-
return to_gql_span(span)
|
|
311
|
-
elif type_name == Dataset.__name__:
|
|
312
|
-
dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
|
|
313
|
-
async with info.context.db() as session:
|
|
314
|
-
if (dataset := await session.scalar(dataset_stmt)) is None:
|
|
315
|
-
raise ValueError(f"Unknown dataset: {id}")
|
|
316
|
-
return to_gql_dataset(dataset)
|
|
317
|
-
elif type_name == DatasetExample.__name__:
|
|
318
|
-
example_id = node_id
|
|
319
|
-
latest_revision_id = (
|
|
320
|
-
select(func.max(models.DatasetExampleRevision.id))
|
|
321
|
-
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
322
|
-
.scalar_subquery()
|
|
323
|
-
)
|
|
324
|
-
async with info.context.db() as session:
|
|
325
|
-
example = await session.scalar(
|
|
326
|
-
select(models.DatasetExample)
|
|
327
|
-
.join(
|
|
328
|
-
models.DatasetExampleRevision,
|
|
329
|
-
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
330
|
-
== models.DatasetExample.id,
|
|
331
|
-
)
|
|
332
|
-
.where(
|
|
333
|
-
and_(
|
|
334
|
-
models.DatasetExample.id == example_id,
|
|
335
|
-
models.DatasetExampleRevision.id == latest_revision_id,
|
|
336
|
-
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
337
|
-
)
|
|
338
|
-
)
|
|
339
|
-
)
|
|
340
|
-
if not example:
|
|
341
|
-
raise ValueError(f"Unknown dataset example: {id}")
|
|
342
|
-
return DatasetExample(
|
|
343
|
-
id_attr=example.id,
|
|
344
|
-
created_at=example.created_at,
|
|
345
|
-
)
|
|
346
|
-
elif type_name == Experiment.__name__:
|
|
347
|
-
async with info.context.db() as session:
|
|
348
|
-
experiment = await session.scalar(
|
|
349
|
-
select(models.Experiment).where(models.Experiment.id == node_id)
|
|
350
|
-
)
|
|
351
|
-
if not experiment:
|
|
352
|
-
raise ValueError(f"Unknown experiment: {id}")
|
|
353
|
-
return Experiment(
|
|
354
|
-
id_attr=experiment.id,
|
|
355
|
-
name=experiment.name,
|
|
356
|
-
project_name=experiment.project_name,
|
|
357
|
-
description=experiment.description,
|
|
358
|
-
created_at=experiment.created_at,
|
|
359
|
-
updated_at=experiment.updated_at,
|
|
360
|
-
metadata=experiment.metadata_,
|
|
361
|
-
)
|
|
362
|
-
elif type_name == ExperimentRun.__name__:
|
|
363
|
-
async with info.context.db() as session:
|
|
364
|
-
if not (
|
|
365
|
-
run := await session.scalar(
|
|
366
|
-
select(models.ExperimentRun)
|
|
367
|
-
.where(models.ExperimentRun.id == node_id)
|
|
368
|
-
.options(
|
|
369
|
-
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
370
|
-
)
|
|
371
|
-
)
|
|
372
|
-
):
|
|
373
|
-
raise ValueError(f"Unknown experiment run: {id}")
|
|
374
|
-
return to_gql_experiment_run(run)
|
|
375
|
-
raise Exception(f"Unknown node type: {type_name}")
|
|
376
|
-
|
|
377
|
-
@strawberry.field
|
|
378
|
-
def clusters(
|
|
379
|
-
self,
|
|
380
|
-
clusters: List[ClusterInput],
|
|
381
|
-
) -> List[Cluster]:
|
|
382
|
-
clustered_events: Dict[str, Set[ID]] = defaultdict(set)
|
|
383
|
-
for i, cluster in enumerate(clusters):
|
|
384
|
-
clustered_events[cluster.id or str(i)].update(cluster.event_ids)
|
|
385
|
-
return to_gql_clusters(
|
|
386
|
-
clustered_events=clustered_events,
|
|
387
|
-
)
|
|
388
|
-
|
|
389
|
-
@strawberry.field
|
|
390
|
-
def hdbscan_clustering(
|
|
391
|
-
self,
|
|
392
|
-
info: Info[Context, None],
|
|
393
|
-
event_ids: Annotated[
|
|
394
|
-
List[ID],
|
|
395
|
-
strawberry.argument(
|
|
396
|
-
description="Event ID of the coordinates",
|
|
397
|
-
),
|
|
398
|
-
],
|
|
399
|
-
coordinates_2d: Annotated[
|
|
400
|
-
Optional[List[InputCoordinate2D]],
|
|
401
|
-
strawberry.argument(
|
|
402
|
-
description="Point coordinates. Must be either 2D or 3D.",
|
|
403
|
-
),
|
|
404
|
-
] = UNSET,
|
|
405
|
-
coordinates_3d: Annotated[
|
|
406
|
-
Optional[List[InputCoordinate3D]],
|
|
407
|
-
strawberry.argument(
|
|
408
|
-
description="Point coordinates. Must be either 2D or 3D.",
|
|
409
|
-
),
|
|
410
|
-
] = UNSET,
|
|
411
|
-
min_cluster_size: Annotated[
|
|
412
|
-
int,
|
|
413
|
-
strawberry.argument(
|
|
414
|
-
description="HDBSCAN minimum cluster size",
|
|
415
|
-
),
|
|
416
|
-
] = DEFAULT_MIN_CLUSTER_SIZE,
|
|
417
|
-
cluster_min_samples: Annotated[
|
|
418
|
-
int,
|
|
419
|
-
strawberry.argument(
|
|
420
|
-
description="HDBSCAN minimum samples",
|
|
421
|
-
),
|
|
422
|
-
] = DEFAULT_MIN_SAMPLES,
|
|
423
|
-
cluster_selection_epsilon: Annotated[
|
|
424
|
-
float,
|
|
425
|
-
strawberry.argument(
|
|
426
|
-
description="HDBSCAN cluster selection epsilon",
|
|
427
|
-
),
|
|
428
|
-
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
429
|
-
) -> List[Cluster]:
|
|
430
|
-
coordinates_3d = ensure_list(coordinates_3d)
|
|
431
|
-
coordinates_2d = ensure_list(coordinates_2d)
|
|
432
|
-
|
|
433
|
-
if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
|
|
434
|
-
raise ValueError("must specify only one of 2D or 3D coordinates")
|
|
435
|
-
|
|
436
|
-
if len(coordinates_3d) > 0:
|
|
437
|
-
coordinates = list(
|
|
438
|
-
map(
|
|
439
|
-
lambda coord: np.array(
|
|
440
|
-
[coord.x, coord.y, coord.z],
|
|
441
|
-
),
|
|
442
|
-
coordinates_3d,
|
|
443
|
-
)
|
|
444
|
-
)
|
|
445
|
-
else:
|
|
446
|
-
coordinates = list(
|
|
447
|
-
map(
|
|
448
|
-
lambda coord: np.array(
|
|
449
|
-
[coord.x, coord.y],
|
|
450
|
-
),
|
|
451
|
-
coordinates_2d,
|
|
452
|
-
)
|
|
453
|
-
)
|
|
454
|
-
|
|
455
|
-
if len(event_ids) != len(coordinates):
|
|
456
|
-
raise ValueError(
|
|
457
|
-
f"length mismatch between "
|
|
458
|
-
f"event_ids ({len(event_ids)}) "
|
|
459
|
-
f"and coordinates ({len(coordinates)})"
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
if len(event_ids) == 0:
|
|
463
|
-
return []
|
|
464
|
-
|
|
465
|
-
grouped_event_ids: Dict[
|
|
466
|
-
Union[InferencesRole, AncillaryInferencesRole],
|
|
467
|
-
List[ID],
|
|
468
|
-
] = defaultdict(list)
|
|
469
|
-
grouped_coordinates: Dict[
|
|
470
|
-
Union[InferencesRole, AncillaryInferencesRole],
|
|
471
|
-
List[npt.NDArray[np.float64]],
|
|
472
|
-
] = defaultdict(list)
|
|
473
|
-
|
|
474
|
-
for event_id, coordinate in zip(event_ids, coordinates):
|
|
475
|
-
row_id, inferences_role = unpack_event_id(event_id)
|
|
476
|
-
grouped_coordinates[inferences_role].append(coordinate)
|
|
477
|
-
grouped_event_ids[inferences_role].append(create_event_id(row_id, inferences_role))
|
|
478
|
-
|
|
479
|
-
stacked_event_ids = (
|
|
480
|
-
grouped_event_ids[InferencesRole.primary]
|
|
481
|
-
+ grouped_event_ids[InferencesRole.reference]
|
|
482
|
-
+ grouped_event_ids[AncillaryInferencesRole.corpus]
|
|
483
|
-
)
|
|
484
|
-
stacked_coordinates = np.stack(
|
|
485
|
-
grouped_coordinates[InferencesRole.primary]
|
|
486
|
-
+ grouped_coordinates[InferencesRole.reference]
|
|
487
|
-
+ grouped_coordinates[AncillaryInferencesRole.corpus]
|
|
488
|
-
)
|
|
489
|
-
|
|
490
|
-
clusters = Hdbscan(
|
|
491
|
-
min_cluster_size=min_cluster_size,
|
|
492
|
-
min_samples=cluster_min_samples,
|
|
493
|
-
cluster_selection_epsilon=cluster_selection_epsilon,
|
|
494
|
-
).find_clusters(stacked_coordinates)
|
|
495
|
-
|
|
496
|
-
clustered_events = {
|
|
497
|
-
str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
|
|
498
|
-
for i, cluster in enumerate(clusters)
|
|
499
|
-
}
|
|
500
|
-
|
|
501
|
-
return to_gql_clusters(
|
|
502
|
-
clustered_events=clustered_events,
|
|
503
|
-
)
|
|
@@ -1,178 +0,0 @@
|
|
|
1
|
-
from sqlalchemy import and_, func, select
|
|
2
|
-
from starlette.requests import Request
|
|
3
|
-
from starlette.responses import JSONResponse, Response
|
|
4
|
-
from starlette.status import HTTP_404_NOT_FOUND
|
|
5
|
-
from strawberry.relay import GlobalID
|
|
6
|
-
|
|
7
|
-
from phoenix.db.models import Dataset, DatasetExample, DatasetExampleRevision, DatasetVersion
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
async def list_dataset_examples(request: Request) -> Response:
|
|
11
|
-
"""
|
|
12
|
-
summary: Get dataset examples by dataset ID
|
|
13
|
-
operationId: getDatasetExamples
|
|
14
|
-
tags:
|
|
15
|
-
- datasets
|
|
16
|
-
parameters:
|
|
17
|
-
- in: path
|
|
18
|
-
name: id
|
|
19
|
-
required: true
|
|
20
|
-
schema:
|
|
21
|
-
type: string
|
|
22
|
-
description: Dataset ID
|
|
23
|
-
- in: query
|
|
24
|
-
name: version-id
|
|
25
|
-
schema:
|
|
26
|
-
type: string
|
|
27
|
-
description: Dataset version ID. If omitted, returns the latest version.
|
|
28
|
-
responses:
|
|
29
|
-
200:
|
|
30
|
-
description: Success
|
|
31
|
-
content:
|
|
32
|
-
application/json:
|
|
33
|
-
schema:
|
|
34
|
-
type: object
|
|
35
|
-
properties:
|
|
36
|
-
data:
|
|
37
|
-
type: object
|
|
38
|
-
properties:
|
|
39
|
-
dataset_id:
|
|
40
|
-
type: string
|
|
41
|
-
description: ID of the dataset
|
|
42
|
-
version_id:
|
|
43
|
-
type: string
|
|
44
|
-
description: ID of the version
|
|
45
|
-
examples:
|
|
46
|
-
type: array
|
|
47
|
-
items:
|
|
48
|
-
type: object
|
|
49
|
-
properties:
|
|
50
|
-
id:
|
|
51
|
-
type: string
|
|
52
|
-
description: ID of the dataset example
|
|
53
|
-
input:
|
|
54
|
-
type: object
|
|
55
|
-
description: Input data of the example
|
|
56
|
-
output:
|
|
57
|
-
type: object
|
|
58
|
-
description: Output data of the example
|
|
59
|
-
metadata:
|
|
60
|
-
type: object
|
|
61
|
-
description: Metadata of the example
|
|
62
|
-
updated_at:
|
|
63
|
-
type: string
|
|
64
|
-
format: date-time
|
|
65
|
-
description: ISO formatted timestamp of when the example was updated
|
|
66
|
-
required:
|
|
67
|
-
- id
|
|
68
|
-
- input
|
|
69
|
-
- output
|
|
70
|
-
- metadata
|
|
71
|
-
- updated_at
|
|
72
|
-
required:
|
|
73
|
-
- dataset_id
|
|
74
|
-
- version_id
|
|
75
|
-
- examples
|
|
76
|
-
403:
|
|
77
|
-
description: Forbidden
|
|
78
|
-
404:
|
|
79
|
-
description: Dataset does not exist.
|
|
80
|
-
"""
|
|
81
|
-
dataset_id = GlobalID.from_id(request.path_params["id"])
|
|
82
|
-
raw_version_id = request.query_params.get("version-id")
|
|
83
|
-
version_id = GlobalID.from_id(raw_version_id) if raw_version_id else None
|
|
84
|
-
|
|
85
|
-
if (dataset_type := dataset_id.type_name) != "Dataset":
|
|
86
|
-
return Response(
|
|
87
|
-
content=f"ID {dataset_id} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
if version_id and (version_type := version_id.type_name) != "DatasetVersion":
|
|
91
|
-
return Response(
|
|
92
|
-
content=f"ID {version_id} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
async with request.app.state.db() as session:
|
|
96
|
-
if (
|
|
97
|
-
resolved_dataset_id := await session.scalar(
|
|
98
|
-
select(Dataset.id).where(Dataset.id == int(dataset_id.node_id))
|
|
99
|
-
)
|
|
100
|
-
) is None:
|
|
101
|
-
return Response(
|
|
102
|
-
content=f"No dataset with id {dataset_id} can be found.",
|
|
103
|
-
status_code=HTTP_404_NOT_FOUND,
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
# Subquery to find the maximum created_at for each dataset_example_id
|
|
107
|
-
# timestamp tiebreaks are resolved by the largest id
|
|
108
|
-
partial_subquery = select(
|
|
109
|
-
func.max(DatasetExampleRevision.id).label("max_id"),
|
|
110
|
-
).group_by(DatasetExampleRevision.dataset_example_id)
|
|
111
|
-
|
|
112
|
-
if version_id:
|
|
113
|
-
if (
|
|
114
|
-
resolved_version_id := await session.scalar(
|
|
115
|
-
select(DatasetVersion.id).where(
|
|
116
|
-
and_(
|
|
117
|
-
DatasetVersion.dataset_id == resolved_dataset_id,
|
|
118
|
-
DatasetVersion.id == int(version_id.node_id),
|
|
119
|
-
)
|
|
120
|
-
)
|
|
121
|
-
)
|
|
122
|
-
) is None:
|
|
123
|
-
return Response(
|
|
124
|
-
content=f"No dataset version with id {version_id} can be found.",
|
|
125
|
-
status_code=HTTP_404_NOT_FOUND,
|
|
126
|
-
)
|
|
127
|
-
# if a version_id is provided, filter the subquery to only include revisions from that
|
|
128
|
-
partial_subquery = partial_subquery.filter(
|
|
129
|
-
DatasetExampleRevision.dataset_version_id <= resolved_version_id
|
|
130
|
-
)
|
|
131
|
-
else:
|
|
132
|
-
if (
|
|
133
|
-
resolved_version_id := await session.scalar(
|
|
134
|
-
select(func.max(DatasetVersion.id)).where(
|
|
135
|
-
DatasetVersion.dataset_id == resolved_dataset_id
|
|
136
|
-
)
|
|
137
|
-
)
|
|
138
|
-
) is None:
|
|
139
|
-
return Response(
|
|
140
|
-
content="Dataset has no versions.",
|
|
141
|
-
status_code=HTTP_404_NOT_FOUND,
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
subquery = partial_subquery.subquery()
|
|
145
|
-
# Query for the most recent example revisions that are not deleted
|
|
146
|
-
query = (
|
|
147
|
-
select(DatasetExample, DatasetExampleRevision)
|
|
148
|
-
.join(
|
|
149
|
-
DatasetExampleRevision,
|
|
150
|
-
DatasetExample.id == DatasetExampleRevision.dataset_example_id,
|
|
151
|
-
)
|
|
152
|
-
.join(
|
|
153
|
-
subquery,
|
|
154
|
-
(subquery.c.max_id == DatasetExampleRevision.id),
|
|
155
|
-
)
|
|
156
|
-
.filter(DatasetExample.dataset_id == resolved_dataset_id)
|
|
157
|
-
.filter(DatasetExampleRevision.revision_kind != "DELETE")
|
|
158
|
-
.order_by(DatasetExample.id.asc())
|
|
159
|
-
)
|
|
160
|
-
examples = [
|
|
161
|
-
{
|
|
162
|
-
"id": str(GlobalID("DatasetExample", str(example.id))),
|
|
163
|
-
"input": revision.input,
|
|
164
|
-
"output": revision.output,
|
|
165
|
-
"metadata": revision.metadata_,
|
|
166
|
-
"updated_at": revision.created_at.isoformat(),
|
|
167
|
-
}
|
|
168
|
-
async for example, revision in await session.stream(query)
|
|
169
|
-
]
|
|
170
|
-
return JSONResponse(
|
|
171
|
-
{
|
|
172
|
-
"data": {
|
|
173
|
-
"dataset_id": str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
174
|
-
"version_id": str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
175
|
-
"examples": examples,
|
|
176
|
-
}
|
|
177
|
-
}
|
|
178
|
-
)
|