arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/METADATA +4 -4
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/RECORD +111 -55
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +0 -27
- phoenix/config.py +21 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +64 -62
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators.py +275 -0
- phoenix/datasets/experiments.py +469 -0
- phoenix/datasets/tracing.py +66 -0
- phoenix/datasets/types.py +212 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +234 -0
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +13 -2
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +230 -3
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +16 -0
- phoenix/server/api/dataloaders/__init__.py +16 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +178 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +42 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +6 -0
- phoenix/server/api/openapi/schema.py +15 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +861 -0
- phoenix/server/api/routers/v1/evaluations.py +4 -2
- phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
- phoenix/server/api/routers/v1/experiment_runs.py +108 -0
- phoenix/server/api/routers/v1/experiments.py +174 -0
- phoenix/server/api/routers/v1/spans.py +3 -1
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +135 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Span.py +78 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +99 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2246 -1368
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +316 -21
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
phoenix/server/api/schema.py
CHANGED
|
@@ -1,308 +1,7 @@
|
|
|
1
|
-
from collections import defaultdict
|
|
2
|
-
from typing import Dict, List, Optional, Set, Union
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import numpy.typing as npt
|
|
6
1
|
import strawberry
|
|
7
|
-
from sqlalchemy import delete, select
|
|
8
|
-
from sqlalchemy.orm import contains_eager, load_only
|
|
9
|
-
from strawberry import ID, UNSET
|
|
10
|
-
from strawberry.types import Info
|
|
11
|
-
from typing_extensions import Annotated
|
|
12
|
-
|
|
13
|
-
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
14
|
-
from phoenix.db import models
|
|
15
|
-
from phoenix.db.insertion.span import ClearProjectSpansEvent
|
|
16
|
-
from phoenix.pointcloud.clustering import Hdbscan
|
|
17
|
-
from phoenix.server.api.context import Context
|
|
18
|
-
from phoenix.server.api.helpers import ensure_list
|
|
19
|
-
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
20
|
-
from phoenix.server.api.input_types.Coordinates import (
|
|
21
|
-
InputCoordinate2D,
|
|
22
|
-
InputCoordinate3D,
|
|
23
|
-
)
|
|
24
|
-
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
25
|
-
from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
|
|
26
|
-
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
27
|
-
from phoenix.server.api.types.EmbeddingDimension import (
|
|
28
|
-
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
29
|
-
DEFAULT_MIN_CLUSTER_SIZE,
|
|
30
|
-
DEFAULT_MIN_SAMPLES,
|
|
31
|
-
to_gql_embedding_dimension,
|
|
32
|
-
)
|
|
33
|
-
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
34
|
-
from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
|
|
35
|
-
from phoenix.server.api.types.Functionality import Functionality
|
|
36
|
-
from phoenix.server.api.types.Model import Model
|
|
37
|
-
from phoenix.server.api.types.node import (
|
|
38
|
-
GlobalID,
|
|
39
|
-
Node,
|
|
40
|
-
from_global_id,
|
|
41
|
-
from_global_id_with_expected_type,
|
|
42
|
-
)
|
|
43
|
-
from phoenix.server.api.types.pagination import (
|
|
44
|
-
Connection,
|
|
45
|
-
ConnectionArgs,
|
|
46
|
-
CursorString,
|
|
47
|
-
connection_from_list,
|
|
48
|
-
)
|
|
49
|
-
from phoenix.server.api.types.Project import Project
|
|
50
|
-
from phoenix.server.api.types.Span import to_gql_span
|
|
51
|
-
from phoenix.server.api.types.Trace import Trace
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
@strawberry.type
|
|
55
|
-
class Query:
|
|
56
|
-
@strawberry.field
|
|
57
|
-
async def projects(
|
|
58
|
-
self,
|
|
59
|
-
info: Info[Context, None],
|
|
60
|
-
first: Optional[int] = 50,
|
|
61
|
-
last: Optional[int] = UNSET,
|
|
62
|
-
after: Optional[CursorString] = UNSET,
|
|
63
|
-
before: Optional[CursorString] = UNSET,
|
|
64
|
-
) -> Connection[Project]:
|
|
65
|
-
args = ConnectionArgs(
|
|
66
|
-
first=first,
|
|
67
|
-
after=after if isinstance(after, CursorString) else None,
|
|
68
|
-
last=last,
|
|
69
|
-
before=before if isinstance(before, CursorString) else None,
|
|
70
|
-
)
|
|
71
|
-
async with info.context.db() as session:
|
|
72
|
-
projects = await session.scalars(select(models.Project))
|
|
73
|
-
data = [
|
|
74
|
-
Project(
|
|
75
|
-
id_attr=project.id,
|
|
76
|
-
name=project.name,
|
|
77
|
-
gradient_start_color=project.gradient_start_color,
|
|
78
|
-
gradient_end_color=project.gradient_end_color,
|
|
79
|
-
)
|
|
80
|
-
for project in projects
|
|
81
|
-
]
|
|
82
|
-
return connection_from_list(data=data, args=args)
|
|
83
|
-
|
|
84
|
-
@strawberry.field
|
|
85
|
-
async def functionality(self, info: Info[Context, None]) -> "Functionality":
|
|
86
|
-
has_model_inferences = not info.context.model.is_empty
|
|
87
|
-
async with info.context.db() as session:
|
|
88
|
-
has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
|
|
89
|
-
return Functionality(
|
|
90
|
-
model_inferences=has_model_inferences,
|
|
91
|
-
tracing=has_traces,
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
@strawberry.field
|
|
95
|
-
def model(self) -> Model:
|
|
96
|
-
return Model()
|
|
97
|
-
|
|
98
|
-
@strawberry.field
|
|
99
|
-
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
|
|
100
|
-
type_name, node_id = from_global_id(str(id))
|
|
101
|
-
if type_name == "Dimension":
|
|
102
|
-
dimension = info.context.model.scalar_dimensions[node_id]
|
|
103
|
-
return to_gql_dimension(node_id, dimension)
|
|
104
|
-
elif type_name == "EmbeddingDimension":
|
|
105
|
-
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
106
|
-
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
107
|
-
elif type_name == "Project":
|
|
108
|
-
project_stmt = select(
|
|
109
|
-
models.Project.id,
|
|
110
|
-
models.Project.name,
|
|
111
|
-
models.Project.gradient_start_color,
|
|
112
|
-
models.Project.gradient_end_color,
|
|
113
|
-
).where(models.Project.id == node_id)
|
|
114
|
-
async with info.context.db() as session:
|
|
115
|
-
project = (await session.execute(project_stmt)).first()
|
|
116
|
-
if project is None:
|
|
117
|
-
raise ValueError(f"Unknown project: {id}")
|
|
118
|
-
return Project(
|
|
119
|
-
id_attr=project.id,
|
|
120
|
-
name=project.name,
|
|
121
|
-
gradient_start_color=project.gradient_start_color,
|
|
122
|
-
gradient_end_color=project.gradient_end_color,
|
|
123
|
-
)
|
|
124
|
-
elif type_name == "Trace":
|
|
125
|
-
trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
|
|
126
|
-
async with info.context.db() as session:
|
|
127
|
-
id_attr = await session.scalar(trace_stmt)
|
|
128
|
-
if id_attr is None:
|
|
129
|
-
raise ValueError(f"Unknown trace: {id}")
|
|
130
|
-
return Trace(id_attr=id_attr)
|
|
131
|
-
elif type_name == "Span":
|
|
132
|
-
span_stmt = (
|
|
133
|
-
select(models.Span)
|
|
134
|
-
.join(models.Trace)
|
|
135
|
-
.options(contains_eager(models.Span.trace))
|
|
136
|
-
.where(models.Span.id == node_id)
|
|
137
|
-
)
|
|
138
|
-
async with info.context.db() as session:
|
|
139
|
-
span = await session.scalar(span_stmt)
|
|
140
|
-
if span is None:
|
|
141
|
-
raise ValueError(f"Unknown span: {id}")
|
|
142
|
-
return to_gql_span(span)
|
|
143
|
-
raise Exception(f"Unknown node type: {type_name}")
|
|
144
|
-
|
|
145
|
-
@strawberry.field
|
|
146
|
-
def clusters(
|
|
147
|
-
self,
|
|
148
|
-
clusters: List[ClusterInput],
|
|
149
|
-
) -> List[Cluster]:
|
|
150
|
-
clustered_events: Dict[str, Set[ID]] = defaultdict(set)
|
|
151
|
-
for i, cluster in enumerate(clusters):
|
|
152
|
-
clustered_events[cluster.id or str(i)].update(cluster.event_ids)
|
|
153
|
-
return to_gql_clusters(
|
|
154
|
-
clustered_events=clustered_events,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
@strawberry.field
|
|
158
|
-
def hdbscan_clustering(
|
|
159
|
-
self,
|
|
160
|
-
info: Info[Context, None],
|
|
161
|
-
event_ids: Annotated[
|
|
162
|
-
List[ID],
|
|
163
|
-
strawberry.argument(
|
|
164
|
-
description="Event ID of the coordinates",
|
|
165
|
-
),
|
|
166
|
-
],
|
|
167
|
-
coordinates_2d: Annotated[
|
|
168
|
-
Optional[List[InputCoordinate2D]],
|
|
169
|
-
strawberry.argument(
|
|
170
|
-
description="Point coordinates. Must be either 2D or 3D.",
|
|
171
|
-
),
|
|
172
|
-
] = UNSET,
|
|
173
|
-
coordinates_3d: Annotated[
|
|
174
|
-
Optional[List[InputCoordinate3D]],
|
|
175
|
-
strawberry.argument(
|
|
176
|
-
description="Point coordinates. Must be either 2D or 3D.",
|
|
177
|
-
),
|
|
178
|
-
] = UNSET,
|
|
179
|
-
min_cluster_size: Annotated[
|
|
180
|
-
int,
|
|
181
|
-
strawberry.argument(
|
|
182
|
-
description="HDBSCAN minimum cluster size",
|
|
183
|
-
),
|
|
184
|
-
] = DEFAULT_MIN_CLUSTER_SIZE,
|
|
185
|
-
cluster_min_samples: Annotated[
|
|
186
|
-
int,
|
|
187
|
-
strawberry.argument(
|
|
188
|
-
description="HDBSCAN minimum samples",
|
|
189
|
-
),
|
|
190
|
-
] = DEFAULT_MIN_SAMPLES,
|
|
191
|
-
cluster_selection_epsilon: Annotated[
|
|
192
|
-
float,
|
|
193
|
-
strawberry.argument(
|
|
194
|
-
description="HDBSCAN cluster selection epsilon",
|
|
195
|
-
),
|
|
196
|
-
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
197
|
-
) -> List[Cluster]:
|
|
198
|
-
coordinates_3d = ensure_list(coordinates_3d)
|
|
199
|
-
coordinates_2d = ensure_list(coordinates_2d)
|
|
200
|
-
|
|
201
|
-
if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
|
|
202
|
-
raise ValueError("must specify only one of 2D or 3D coordinates")
|
|
203
|
-
|
|
204
|
-
if len(coordinates_3d) > 0:
|
|
205
|
-
coordinates = list(
|
|
206
|
-
map(
|
|
207
|
-
lambda coord: np.array(
|
|
208
|
-
[coord.x, coord.y, coord.z],
|
|
209
|
-
),
|
|
210
|
-
coordinates_3d,
|
|
211
|
-
)
|
|
212
|
-
)
|
|
213
|
-
else:
|
|
214
|
-
coordinates = list(
|
|
215
|
-
map(
|
|
216
|
-
lambda coord: np.array(
|
|
217
|
-
[coord.x, coord.y],
|
|
218
|
-
),
|
|
219
|
-
coordinates_2d,
|
|
220
|
-
)
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
if len(event_ids) != len(coordinates):
|
|
224
|
-
raise ValueError(
|
|
225
|
-
f"length mismatch between "
|
|
226
|
-
f"event_ids ({len(event_ids)}) "
|
|
227
|
-
f"and coordinates ({len(coordinates)})"
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
if len(event_ids) == 0:
|
|
231
|
-
return []
|
|
232
|
-
|
|
233
|
-
grouped_event_ids: Dict[
|
|
234
|
-
Union[DatasetRole, AncillaryDatasetRole],
|
|
235
|
-
List[ID],
|
|
236
|
-
] = defaultdict(list)
|
|
237
|
-
grouped_coordinates: Dict[
|
|
238
|
-
Union[DatasetRole, AncillaryDatasetRole],
|
|
239
|
-
List[npt.NDArray[np.float64]],
|
|
240
|
-
] = defaultdict(list)
|
|
241
|
-
|
|
242
|
-
for event_id, coordinate in zip(event_ids, coordinates):
|
|
243
|
-
row_id, dataset_role = unpack_event_id(event_id)
|
|
244
|
-
grouped_coordinates[dataset_role].append(coordinate)
|
|
245
|
-
grouped_event_ids[dataset_role].append(create_event_id(row_id, dataset_role))
|
|
246
|
-
|
|
247
|
-
stacked_event_ids = (
|
|
248
|
-
grouped_event_ids[DatasetRole.primary]
|
|
249
|
-
+ grouped_event_ids[DatasetRole.reference]
|
|
250
|
-
+ grouped_event_ids[AncillaryDatasetRole.corpus]
|
|
251
|
-
)
|
|
252
|
-
stacked_coordinates = np.stack(
|
|
253
|
-
grouped_coordinates[DatasetRole.primary]
|
|
254
|
-
+ grouped_coordinates[DatasetRole.reference]
|
|
255
|
-
+ grouped_coordinates[AncillaryDatasetRole.corpus]
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
clusters = Hdbscan(
|
|
259
|
-
min_cluster_size=min_cluster_size,
|
|
260
|
-
min_samples=cluster_min_samples,
|
|
261
|
-
cluster_selection_epsilon=cluster_selection_epsilon,
|
|
262
|
-
).find_clusters(stacked_coordinates)
|
|
263
|
-
|
|
264
|
-
clustered_events = {
|
|
265
|
-
str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
|
|
266
|
-
for i, cluster in enumerate(clusters)
|
|
267
|
-
}
|
|
268
|
-
|
|
269
|
-
return to_gql_clusters(
|
|
270
|
-
clustered_events=clustered_events,
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
@strawberry.type
|
|
275
|
-
class Mutation(ExportEventsMutation):
|
|
276
|
-
@strawberry.mutation
|
|
277
|
-
async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
|
|
278
|
-
if info.context.read_only:
|
|
279
|
-
return Query()
|
|
280
|
-
node_id = from_global_id_with_expected_type(str(id), "Project")
|
|
281
|
-
async with info.context.db() as session:
|
|
282
|
-
project = await session.scalar(
|
|
283
|
-
select(models.Project)
|
|
284
|
-
.where(models.Project.id == node_id)
|
|
285
|
-
.options(load_only(models.Project.name))
|
|
286
|
-
)
|
|
287
|
-
if project is None:
|
|
288
|
-
raise ValueError(f"Unknown project: {id}")
|
|
289
|
-
if project.name == DEFAULT_PROJECT_NAME:
|
|
290
|
-
raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
|
|
291
|
-
await session.delete(project)
|
|
292
|
-
return Query()
|
|
293
|
-
|
|
294
|
-
@strawberry.mutation
|
|
295
|
-
async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
|
|
296
|
-
if info.context.read_only:
|
|
297
|
-
return Query()
|
|
298
|
-
project_id = from_global_id_with_expected_type(str(id), "Project")
|
|
299
|
-
delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
|
|
300
|
-
async with info.context.db() as session:
|
|
301
|
-
await session.execute(delete_statement)
|
|
302
|
-
if cache := info.context.cache_for_dataloaders:
|
|
303
|
-
cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
|
|
304
|
-
return Query()
|
|
305
2
|
|
|
3
|
+
from phoenix.server.api.mutations import Mutation
|
|
4
|
+
from phoenix.server.api.queries import Query
|
|
306
5
|
|
|
307
6
|
# This is the schema for generating `schema.graphql`.
|
|
308
7
|
# See https://strawberry.rocks/docs/guides/schema-export
|
|
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
|
|
|
9
9
|
from phoenix.server.api.context import Context
|
|
10
10
|
from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
|
|
11
11
|
from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
|
|
12
|
-
from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
|
|
13
12
|
from phoenix.server.api.types.DatasetValues import DatasetValues
|
|
14
13
|
from phoenix.server.api.types.Event import unpack_event_id
|
|
14
|
+
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@strawberry.type
|
|
@@ -36,8 +36,8 @@ class Cluster:
|
|
|
36
36
|
"""
|
|
37
37
|
Calculates the drift score of the cluster. The score will be a value
|
|
38
38
|
representing the balance of points between the primary and the reference
|
|
39
|
-
|
|
40
|
-
reference), with 0 being an even balance between the two
|
|
39
|
+
inferences, and will be on a scale between 1 (all primary) and -1 (all
|
|
40
|
+
reference), with 0 being an even balance between the two inference sets.
|
|
41
41
|
|
|
42
42
|
Returns
|
|
43
43
|
-------
|
|
@@ -47,8 +47,8 @@ class Cluster:
|
|
|
47
47
|
if model[REFERENCE].empty:
|
|
48
48
|
return None
|
|
49
49
|
count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
|
|
50
|
-
primary_count = count_by_role[
|
|
51
|
-
reference_count = count_by_role[
|
|
50
|
+
primary_count = count_by_role[InferencesRole.primary]
|
|
51
|
+
reference_count = count_by_role[InferencesRole.reference]
|
|
52
52
|
return (
|
|
53
53
|
None
|
|
54
54
|
if not (denominator := (primary_count + reference_count))
|
|
@@ -76,8 +76,8 @@ class Cluster:
|
|
|
76
76
|
if corpus is None or corpus[PRIMARY].empty:
|
|
77
77
|
return None
|
|
78
78
|
count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
|
|
79
|
-
primary_count = count_by_role[
|
|
80
|
-
corpus_count = count_by_role[
|
|
79
|
+
primary_count = count_by_role[InferencesRole.primary]
|
|
80
|
+
corpus_count = count_by_role[AncillaryInferencesRole.corpus]
|
|
81
81
|
return (
|
|
82
82
|
None
|
|
83
83
|
if not (denominator := (primary_count + corpus_count))
|
|
@@ -94,19 +94,19 @@ class Cluster:
|
|
|
94
94
|
metric: DataQualityMetricInput,
|
|
95
95
|
) -> DatasetValues:
|
|
96
96
|
model = info.context.model
|
|
97
|
-
row_ids: Dict[
|
|
98
|
-
for row_id,
|
|
99
|
-
if not isinstance(
|
|
97
|
+
row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
|
|
98
|
+
for row_id, inferences_role in map(unpack_event_id, self.event_ids):
|
|
99
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
100
100
|
continue
|
|
101
|
-
row_ids[
|
|
101
|
+
row_ids[inferences_role].append(row_id)
|
|
102
102
|
return DatasetValues(
|
|
103
103
|
primary_value=metric.metric_instance(
|
|
104
104
|
model[PRIMARY],
|
|
105
|
-
subset_rows=row_ids[
|
|
105
|
+
subset_rows=row_ids[InferencesRole.primary],
|
|
106
106
|
),
|
|
107
107
|
reference_value=metric.metric_instance(
|
|
108
108
|
model[REFERENCE],
|
|
109
|
-
subset_rows=row_ids[
|
|
109
|
+
subset_rows=row_ids[InferencesRole.reference],
|
|
110
110
|
),
|
|
111
111
|
)
|
|
112
112
|
|
|
@@ -120,20 +120,20 @@ class Cluster:
|
|
|
120
120
|
metric: PerformanceMetricInput,
|
|
121
121
|
) -> DatasetValues:
|
|
122
122
|
model = info.context.model
|
|
123
|
-
row_ids: Dict[
|
|
124
|
-
for row_id,
|
|
125
|
-
if not isinstance(
|
|
123
|
+
row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
|
|
124
|
+
for row_id, inferences_role in map(unpack_event_id, self.event_ids):
|
|
125
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
126
126
|
continue
|
|
127
|
-
row_ids[
|
|
127
|
+
row_ids[inferences_role].append(row_id)
|
|
128
128
|
metric_instance = metric.metric_instance(model)
|
|
129
129
|
return DatasetValues(
|
|
130
130
|
primary_value=metric_instance(
|
|
131
131
|
model[PRIMARY],
|
|
132
|
-
subset_rows=row_ids[
|
|
132
|
+
subset_rows=row_ids[InferencesRole.primary],
|
|
133
133
|
),
|
|
134
134
|
reference_value=metric_instance(
|
|
135
135
|
model[REFERENCE],
|
|
136
|
-
subset_rows=row_ids[
|
|
136
|
+
subset_rows=row_ids[InferencesRole.reference],
|
|
137
137
|
),
|
|
138
138
|
)
|
|
139
139
|
|