arize-phoenix 4.4.2__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +12 -11
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +110 -57
- phoenix/__init__.py +0 -27
- phoenix/config.py +21 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +66 -64
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators.py +275 -0
- phoenix/datasets/experiments.py +469 -0
- phoenix/datasets/tracing.py +66 -0
- phoenix/datasets/types.py +212 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +234 -0
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +13 -2
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +230 -3
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/metrics/binning.py +2 -2
- phoenix/server/api/context.py +16 -0
- phoenix/server/api/dataloaders/__init__.py +16 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +178 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +42 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +861 -0
- phoenix/server/api/routers/v1/evaluations.py +4 -2
- phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
- phoenix/server/api/routers/v1/experiment_runs.py +108 -0
- phoenix/server/api/routers/v1/experiments.py +174 -0
- phoenix/server/api/routers/v1/spans.py +3 -1
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +135 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Segments.py +1 -1
- phoenix/server/api/types/Span.py +78 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +99 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2246 -1368
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +316 -21
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
|
@@ -1,80 +1,299 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import AsyncIterable, List, Optional, Tuple, cast
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from sqlalchemy import and_, func, select
|
|
6
|
+
from sqlalchemy.sql.functions import count
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.scalars import JSON
|
|
10
|
+
from strawberry.types import Info
|
|
7
11
|
|
|
8
|
-
|
|
9
|
-
from phoenix.
|
|
10
|
-
|
|
11
|
-
from
|
|
12
|
-
from .
|
|
13
|
-
from .
|
|
14
|
-
from .
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.context import Context
|
|
14
|
+
from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
|
|
15
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
16
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
17
|
+
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
18
|
+
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
19
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
|
+
from phoenix.server.api.types.pagination import (
|
|
21
|
+
ConnectionArgs,
|
|
22
|
+
CursorString,
|
|
23
|
+
connection_from_list,
|
|
24
|
+
)
|
|
25
|
+
from phoenix.server.api.types.SortDir import SortDir
|
|
15
26
|
|
|
16
27
|
|
|
17
28
|
@strawberry.type
|
|
18
|
-
class Dataset:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# type ignored here to get around the following: https://github.com/strawberry-graphql/strawberry/issues/1929
|
|
27
|
-
@strawberry.field(description="Returns a human friendly name for the dataset.") # type: ignore
|
|
28
|
-
def name(self) -> str:
|
|
29
|
-
return self.dataset.display_name
|
|
29
|
+
class Dataset(Node):
|
|
30
|
+
id_attr: NodeID[int]
|
|
31
|
+
name: str
|
|
32
|
+
description: Optional[str]
|
|
33
|
+
metadata: JSON
|
|
34
|
+
created_at: datetime
|
|
35
|
+
updated_at: datetime
|
|
30
36
|
|
|
31
37
|
@strawberry.field
|
|
32
|
-
def
|
|
38
|
+
async def versions(
|
|
33
39
|
self,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
events = self.dataset[row_ids[self.dataset_role]]
|
|
47
|
-
requested_gql_dimensions = _get_requested_features_and_tags(
|
|
48
|
-
core_dimensions=self.model.scalar_dimensions,
|
|
49
|
-
requested_dimension_names=set(dim.name for dim in dimensions)
|
|
50
|
-
if isinstance(dimensions, list)
|
|
51
|
-
else None,
|
|
40
|
+
info: Info[Context, None],
|
|
41
|
+
first: Optional[int] = 50,
|
|
42
|
+
last: Optional[int] = UNSET,
|
|
43
|
+
after: Optional[CursorString] = UNSET,
|
|
44
|
+
before: Optional[CursorString] = UNSET,
|
|
45
|
+
sort: Optional[DatasetVersionSort] = UNSET,
|
|
46
|
+
) -> Connection[DatasetVersion]:
|
|
47
|
+
args = ConnectionArgs(
|
|
48
|
+
first=first,
|
|
49
|
+
after=after if isinstance(after, CursorString) else None,
|
|
50
|
+
last=last,
|
|
51
|
+
before=before if isinstance(before, CursorString) else None,
|
|
52
52
|
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
53
|
+
async with info.context.db() as session:
|
|
54
|
+
stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
|
|
55
|
+
if sort:
|
|
56
|
+
# For now assume the the column names match 1:1 with the enum values
|
|
57
|
+
sort_col = getattr(models.DatasetVersion, sort.col.value)
|
|
58
|
+
if sort.dir is SortDir.desc:
|
|
59
|
+
stmt = stmt.order_by(sort_col.desc(), models.DatasetVersion.id.desc())
|
|
60
|
+
else:
|
|
61
|
+
stmt = stmt.order_by(sort_col.asc(), models.DatasetVersion.id.asc())
|
|
62
|
+
else:
|
|
63
|
+
stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
|
|
64
|
+
versions = await session.scalars(stmt)
|
|
65
|
+
data = [
|
|
66
|
+
DatasetVersion(
|
|
67
|
+
id_attr=version.id,
|
|
68
|
+
description=version.description,
|
|
69
|
+
metadata=version.metadata_,
|
|
70
|
+
created_at=version.created_at,
|
|
59
71
|
)
|
|
60
|
-
for
|
|
72
|
+
for version in versions
|
|
61
73
|
]
|
|
74
|
+
return connection_from_list(data=data, args=args)
|
|
75
|
+
|
|
76
|
+
@strawberry.field(
|
|
77
|
+
description="Number of examples in a specific version if version is specified, or in the "
|
|
78
|
+
"latest version if version is not specified."
|
|
79
|
+
) # type: ignore
|
|
80
|
+
async def example_count(
|
|
81
|
+
self,
|
|
82
|
+
info: Info[Context, None],
|
|
83
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
84
|
+
) -> int:
|
|
85
|
+
dataset_id = self.id_attr
|
|
86
|
+
version_id = (
|
|
87
|
+
from_global_id_with_expected_type(
|
|
88
|
+
global_id=dataset_version_id,
|
|
89
|
+
expected_type_name=DatasetVersion.__name__,
|
|
90
|
+
)
|
|
91
|
+
if dataset_version_id
|
|
92
|
+
else None
|
|
93
|
+
)
|
|
94
|
+
revision_ids = (
|
|
95
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
96
|
+
.join(models.DatasetExample)
|
|
97
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
98
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
99
|
+
)
|
|
100
|
+
if version_id:
|
|
101
|
+
version_id_subquery = (
|
|
102
|
+
select(models.DatasetVersion.id)
|
|
103
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
104
|
+
.where(models.DatasetVersion.id == version_id)
|
|
105
|
+
.scalar_subquery()
|
|
106
|
+
)
|
|
107
|
+
revision_ids = revision_ids.where(
|
|
108
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
109
|
+
)
|
|
110
|
+
stmt = (
|
|
111
|
+
select(count(models.DatasetExampleRevision.id))
|
|
112
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
113
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
114
|
+
)
|
|
115
|
+
async with info.context.db() as session:
|
|
116
|
+
return (await session.scalar(stmt)) or 0
|
|
117
|
+
|
|
118
|
+
@strawberry.field
|
|
119
|
+
async def examples(
|
|
120
|
+
self,
|
|
121
|
+
info: Info[Context, None],
|
|
122
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
123
|
+
first: Optional[int] = 50,
|
|
124
|
+
last: Optional[int] = UNSET,
|
|
125
|
+
after: Optional[CursorString] = UNSET,
|
|
126
|
+
before: Optional[CursorString] = UNSET,
|
|
127
|
+
) -> Connection[DatasetExample]:
|
|
128
|
+
args = ConnectionArgs(
|
|
129
|
+
first=first,
|
|
130
|
+
after=after if isinstance(after, CursorString) else None,
|
|
131
|
+
last=last,
|
|
132
|
+
before=before if isinstance(before, CursorString) else None,
|
|
133
|
+
)
|
|
134
|
+
dataset_id = self.id_attr
|
|
135
|
+
version_id = (
|
|
136
|
+
from_global_id_with_expected_type(
|
|
137
|
+
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
138
|
+
)
|
|
139
|
+
if dataset_version_id
|
|
140
|
+
else None
|
|
141
|
+
)
|
|
142
|
+
revision_ids = (
|
|
143
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
144
|
+
.join(models.DatasetExample)
|
|
145
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
146
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
147
|
+
)
|
|
148
|
+
if version_id:
|
|
149
|
+
version_id_subquery = (
|
|
150
|
+
select(models.DatasetVersion.id)
|
|
151
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
152
|
+
.where(models.DatasetVersion.id == version_id)
|
|
153
|
+
.scalar_subquery()
|
|
154
|
+
)
|
|
155
|
+
revision_ids = revision_ids.where(
|
|
156
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
157
|
+
)
|
|
158
|
+
query = (
|
|
159
|
+
select(models.DatasetExample)
|
|
160
|
+
.join(
|
|
161
|
+
models.DatasetExampleRevision,
|
|
162
|
+
onclause=models.DatasetExample.id
|
|
163
|
+
== models.DatasetExampleRevision.dataset_example_id,
|
|
164
|
+
)
|
|
165
|
+
.where(
|
|
166
|
+
and_(
|
|
167
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
168
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.desc())
|
|
172
|
+
)
|
|
173
|
+
async with info.context.db() as session:
|
|
174
|
+
dataset_examples = [
|
|
175
|
+
DatasetExample(
|
|
176
|
+
id_attr=example.id,
|
|
177
|
+
version_id=version_id,
|
|
178
|
+
created_at=example.created_at,
|
|
179
|
+
)
|
|
180
|
+
async for example in await session.stream_scalars(query)
|
|
181
|
+
]
|
|
182
|
+
return connection_from_list(data=dataset_examples, args=args)
|
|
183
|
+
|
|
184
|
+
@strawberry.field(
|
|
185
|
+
description="Number of experiments for a specific version if version is specified, "
|
|
186
|
+
"or for all versions if version is not specified."
|
|
187
|
+
) # type: ignore
|
|
188
|
+
async def experiment_count(
|
|
189
|
+
self,
|
|
190
|
+
info: Info[Context, None],
|
|
191
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
192
|
+
) -> int:
|
|
193
|
+
stmt = select(count(models.Experiment.id)).where(
|
|
194
|
+
models.Experiment.dataset_id == self.id_attr
|
|
195
|
+
)
|
|
196
|
+
version_id = (
|
|
197
|
+
from_global_id_with_expected_type(
|
|
198
|
+
global_id=dataset_version_id,
|
|
199
|
+
expected_type_name=DatasetVersion.__name__,
|
|
200
|
+
)
|
|
201
|
+
if dataset_version_id
|
|
202
|
+
else None
|
|
203
|
+
)
|
|
204
|
+
if version_id is not None:
|
|
205
|
+
stmt = stmt.where(models.Experiment.dataset_version_id == version_id)
|
|
206
|
+
async with info.context.db() as session:
|
|
207
|
+
return (await session.scalar(stmt)) or 0
|
|
208
|
+
|
|
209
|
+
@strawberry.field
|
|
210
|
+
async def experiments(
|
|
211
|
+
self,
|
|
212
|
+
info: Info[Context, None],
|
|
213
|
+
first: Optional[int] = 50,
|
|
214
|
+
last: Optional[int] = UNSET,
|
|
215
|
+
after: Optional[CursorString] = UNSET,
|
|
216
|
+
before: Optional[CursorString] = UNSET,
|
|
217
|
+
) -> Connection[Experiment]:
|
|
218
|
+
args = ConnectionArgs(
|
|
219
|
+
first=first,
|
|
220
|
+
after=after if isinstance(after, CursorString) else None,
|
|
221
|
+
last=last,
|
|
222
|
+
before=before if isinstance(before, CursorString) else None,
|
|
223
|
+
)
|
|
224
|
+
dataset_id = self.id_attr
|
|
225
|
+
row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
|
|
226
|
+
query = (
|
|
227
|
+
select(models.Experiment, row_number)
|
|
228
|
+
.where(models.Experiment.dataset_id == dataset_id)
|
|
229
|
+
.order_by(models.Experiment.id.desc())
|
|
230
|
+
)
|
|
231
|
+
async with info.context.db() as session:
|
|
232
|
+
experiments = [
|
|
233
|
+
to_gql_experiment(experiment, sequence_number)
|
|
234
|
+
async for experiment, sequence_number in cast(
|
|
235
|
+
AsyncIterable[Tuple[models.Experiment, int]],
|
|
236
|
+
await session.stream(query),
|
|
237
|
+
)
|
|
238
|
+
]
|
|
239
|
+
return connection_from_list(data=experiments, args=args)
|
|
240
|
+
|
|
241
|
+
@strawberry.field
|
|
242
|
+
async def experiment_annotation_summaries(
|
|
243
|
+
self, info: Info[Context, None]
|
|
244
|
+
) -> List[ExperimentAnnotationSummary]:
|
|
245
|
+
dataset_id = self.id_attr
|
|
246
|
+
query = (
|
|
247
|
+
select(
|
|
248
|
+
models.ExperimentRunAnnotation.name,
|
|
249
|
+
func.min(models.ExperimentRunAnnotation.score),
|
|
250
|
+
func.max(models.ExperimentRunAnnotation.score),
|
|
251
|
+
func.avg(models.ExperimentRunAnnotation.score),
|
|
252
|
+
func.count(),
|
|
253
|
+
func.count(models.ExperimentRunAnnotation.error),
|
|
254
|
+
)
|
|
255
|
+
.join(
|
|
256
|
+
models.ExperimentRun,
|
|
257
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
258
|
+
)
|
|
259
|
+
.join(
|
|
260
|
+
models.Experiment,
|
|
261
|
+
models.ExperimentRun.experiment_id == models.Experiment.id,
|
|
262
|
+
)
|
|
263
|
+
.where(models.Experiment.dataset_id == dataset_id)
|
|
264
|
+
.group_by(models.ExperimentRunAnnotation.name)
|
|
265
|
+
.order_by(models.ExperimentRunAnnotation.name)
|
|
266
|
+
)
|
|
267
|
+
async with info.context.db() as session:
|
|
268
|
+
return [
|
|
269
|
+
ExperimentAnnotationSummary(
|
|
270
|
+
annotation_name=annotation_name,
|
|
271
|
+
min_score=min_score,
|
|
272
|
+
max_score=max_score,
|
|
273
|
+
mean_score=mean_score,
|
|
274
|
+
count=count_,
|
|
275
|
+
error_count=error_count,
|
|
276
|
+
)
|
|
277
|
+
async for (
|
|
278
|
+
annotation_name,
|
|
279
|
+
min_score,
|
|
280
|
+
max_score,
|
|
281
|
+
mean_score,
|
|
282
|
+
count_,
|
|
283
|
+
error_count,
|
|
284
|
+
) in await session.stream(query)
|
|
285
|
+
]
|
|
62
286
|
|
|
63
287
|
|
|
64
|
-
def
|
|
65
|
-
core_dimensions: Iterable[ScalarDimension],
|
|
66
|
-
requested_dimension_names: Optional[Set[str]] = UNSET,
|
|
67
|
-
) -> List[Dimension]:
|
|
288
|
+
def to_gql_dataset(dataset: models.Dataset) -> Dataset:
|
|
68
289
|
"""
|
|
69
|
-
|
|
70
|
-
dimensions are explicitly requested, returns all features and tags.
|
|
290
|
+
Converts an ORM dataset to a GraphQL dataset.
|
|
71
291
|
"""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
return requested_features_and_tags
|
|
292
|
+
return Dataset(
|
|
293
|
+
id_attr=dataset.id,
|
|
294
|
+
name=dataset.name,
|
|
295
|
+
description=dataset.description,
|
|
296
|
+
metadata=dataset.metadata_,
|
|
297
|
+
created_at=dataset.created_at,
|
|
298
|
+
updated_at=dataset.updated_at,
|
|
299
|
+
)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.orm import joinedload
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay.types import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
14
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
15
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
16
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
17
|
+
from phoenix.server.api.types.pagination import (
|
|
18
|
+
ConnectionArgs,
|
|
19
|
+
CursorString,
|
|
20
|
+
connection_from_list,
|
|
21
|
+
)
|
|
22
|
+
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class DatasetExample(Node):
|
|
27
|
+
id_attr: NodeID[int]
|
|
28
|
+
created_at: datetime
|
|
29
|
+
version_id: strawberry.Private[Optional[int]] = None
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def revision(
|
|
33
|
+
self,
|
|
34
|
+
info: Info[Context, None],
|
|
35
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
36
|
+
) -> DatasetExampleRevision:
|
|
37
|
+
example_id = self.id_attr
|
|
38
|
+
version_id: Optional[int] = None
|
|
39
|
+
if dataset_version_id:
|
|
40
|
+
version_id = from_global_id_with_expected_type(
|
|
41
|
+
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
42
|
+
)
|
|
43
|
+
elif self.version_id is not None:
|
|
44
|
+
version_id = self.version_id
|
|
45
|
+
return await info.context.data_loaders.dataset_example_revisions.load(
|
|
46
|
+
(example_id, version_id)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@strawberry.field
|
|
50
|
+
async def span(
|
|
51
|
+
self,
|
|
52
|
+
info: Info[Context, None],
|
|
53
|
+
) -> Optional[Span]:
|
|
54
|
+
return (
|
|
55
|
+
to_gql_span(span)
|
|
56
|
+
if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
|
|
57
|
+
else None
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@strawberry.field
|
|
61
|
+
async def experiment_runs(
|
|
62
|
+
self,
|
|
63
|
+
info: Info[Context, None],
|
|
64
|
+
first: Optional[int] = 50,
|
|
65
|
+
last: Optional[int] = UNSET,
|
|
66
|
+
after: Optional[CursorString] = UNSET,
|
|
67
|
+
before: Optional[CursorString] = UNSET,
|
|
68
|
+
) -> Connection[ExperimentRun]:
|
|
69
|
+
args = ConnectionArgs(
|
|
70
|
+
first=first,
|
|
71
|
+
after=after if isinstance(after, CursorString) else None,
|
|
72
|
+
last=last,
|
|
73
|
+
before=before if isinstance(before, CursorString) else None,
|
|
74
|
+
)
|
|
75
|
+
example_id = self.id_attr
|
|
76
|
+
query = (
|
|
77
|
+
select(models.ExperimentRun)
|
|
78
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
79
|
+
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
80
|
+
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
81
|
+
.order_by(models.Experiment.id.desc())
|
|
82
|
+
)
|
|
83
|
+
async with info.context.db() as session:
|
|
84
|
+
runs = (await session.scalars(query)).all()
|
|
85
|
+
return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@strawberry.enum
|
|
11
|
+
class RevisionKind(Enum):
|
|
12
|
+
CREATE = "CREATE"
|
|
13
|
+
PATCH = "PATCH"
|
|
14
|
+
DELETE = "DELETE"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class DatasetExampleRevision(ExampleRevision):
|
|
19
|
+
"""
|
|
20
|
+
Represents a revision (i.e., update or alteration) of a dataset example.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
revision_kind: RevisionKind
|
|
24
|
+
created_at: datetime
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_orm_revision(cls, revision: models.DatasetExampleRevision) -> "DatasetExampleRevision":
|
|
28
|
+
return cls(
|
|
29
|
+
input=revision.input,
|
|
30
|
+
output=revision.output,
|
|
31
|
+
metadata=revision.metadata_,
|
|
32
|
+
revision_kind=RevisionKind(revision.revision_kind),
|
|
33
|
+
created_at=revision.created_at,
|
|
34
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.type
|
|
10
|
+
class DatasetVersion(Node):
|
|
11
|
+
id_attr: NodeID[int]
|
|
12
|
+
description: Optional[str]
|
|
13
|
+
metadata: JSON
|
|
14
|
+
created_at: datetime
|
|
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import strawberry
|
|
6
6
|
from strawberry import UNSET
|
|
7
|
+
from strawberry.relay import Node, NodeID
|
|
7
8
|
from strawberry.types import Info
|
|
8
9
|
from typing_extensions import Annotated
|
|
9
10
|
|
|
@@ -17,12 +18,11 @@ from ..context import Context
|
|
|
17
18
|
from ..input_types.Granularity import Granularity
|
|
18
19
|
from ..input_types.TimeRange import TimeRange
|
|
19
20
|
from .DataQualityMetric import DataQualityMetric
|
|
20
|
-
from .DatasetRole import DatasetRole
|
|
21
21
|
from .DatasetValues import DatasetValues
|
|
22
22
|
from .DimensionDataType import DimensionDataType
|
|
23
23
|
from .DimensionShape import DimensionShape
|
|
24
24
|
from .DimensionType import DimensionType
|
|
25
|
-
from .
|
|
25
|
+
from .InferencesRole import InferencesRole
|
|
26
26
|
from .ScalarDriftMetricEnum import ScalarDriftMetric
|
|
27
27
|
from .Segments import (
|
|
28
28
|
GqlBinFactory,
|
|
@@ -40,6 +40,7 @@ from .TimeSeries import (
|
|
|
40
40
|
|
|
41
41
|
@strawberry.type
|
|
42
42
|
class Dimension(Node):
|
|
43
|
+
id_attr: NodeID[int]
|
|
43
44
|
name: str = strawberry.field(description="The name of the dimension (a.k.a. the column name)")
|
|
44
45
|
type: DimensionType = strawberry.field(
|
|
45
46
|
description="Whether the dimension represents a feature, tag, prediction, or actual."
|
|
@@ -62,16 +63,16 @@ class Dimension(Node):
|
|
|
62
63
|
"""
|
|
63
64
|
Computes a drift metric between all reference data and the primary data
|
|
64
65
|
belonging to the input time range (inclusive of the time range start and
|
|
65
|
-
exclusive of the time range end). Returns None if no reference
|
|
66
|
-
|
|
66
|
+
exclusive of the time range end). Returns None if no reference inferences
|
|
67
|
+
exist, if no primary data exists in the input time range, or if the
|
|
67
68
|
input time range is invalid.
|
|
68
69
|
"""
|
|
69
70
|
model = info.context.model
|
|
70
71
|
if model[REFERENCE].empty:
|
|
71
72
|
return None
|
|
72
|
-
|
|
73
|
+
inferences = model[PRIMARY]
|
|
73
74
|
time_range, granularity = ensure_timeseries_parameters(
|
|
74
|
-
|
|
75
|
+
inferences,
|
|
75
76
|
time_range,
|
|
76
77
|
)
|
|
77
78
|
data = get_drift_timeseries_data(
|
|
@@ -92,18 +93,18 @@ class Dimension(Node):
|
|
|
92
93
|
info: Info[Context, None],
|
|
93
94
|
metric: DataQualityMetric,
|
|
94
95
|
time_range: Optional[TimeRange] = UNSET,
|
|
95
|
-
|
|
96
|
-
Optional[
|
|
96
|
+
inferences_role: Annotated[
|
|
97
|
+
Optional[InferencesRole],
|
|
97
98
|
strawberry.argument(
|
|
98
|
-
description="The
|
|
99
|
+
description="The inferences (primary or reference) to query",
|
|
99
100
|
),
|
|
100
|
-
] =
|
|
101
|
+
] = InferencesRole.primary,
|
|
101
102
|
) -> Optional[float]:
|
|
102
|
-
if not isinstance(
|
|
103
|
-
|
|
104
|
-
|
|
103
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
104
|
+
inferences_role = InferencesRole.primary
|
|
105
|
+
inferences = info.context.model[inferences_role.value]
|
|
105
106
|
time_range, granularity = ensure_timeseries_parameters(
|
|
106
|
-
|
|
107
|
+
inferences,
|
|
107
108
|
time_range,
|
|
108
109
|
)
|
|
109
110
|
data = get_data_quality_timeseries_data(
|
|
@@ -111,7 +112,7 @@ class Dimension(Node):
|
|
|
111
112
|
metric,
|
|
112
113
|
time_range,
|
|
113
114
|
granularity,
|
|
114
|
-
|
|
115
|
+
inferences_role,
|
|
115
116
|
)
|
|
116
117
|
return data[0].value if len(data) else None
|
|
117
118
|
|
|
@@ -139,18 +140,18 @@ class Dimension(Node):
|
|
|
139
140
|
metric: DataQualityMetric,
|
|
140
141
|
time_range: TimeRange,
|
|
141
142
|
granularity: Granularity,
|
|
142
|
-
|
|
143
|
-
Optional[
|
|
143
|
+
inferences_role: Annotated[
|
|
144
|
+
Optional[InferencesRole],
|
|
144
145
|
strawberry.argument(
|
|
145
|
-
description="The
|
|
146
|
+
description="The inferences (primary or reference) to query",
|
|
146
147
|
),
|
|
147
|
-
] =
|
|
148
|
+
] = InferencesRole.primary,
|
|
148
149
|
) -> DataQualityTimeSeries:
|
|
149
|
-
if not isinstance(
|
|
150
|
-
|
|
151
|
-
|
|
150
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
151
|
+
inferences_role = InferencesRole.primary
|
|
152
|
+
inferences = info.context.model[inferences_role.value]
|
|
152
153
|
time_range, granularity = ensure_timeseries_parameters(
|
|
153
|
-
|
|
154
|
+
inferences,
|
|
154
155
|
time_range,
|
|
155
156
|
granularity,
|
|
156
157
|
)
|
|
@@ -160,7 +161,7 @@ class Dimension(Node):
|
|
|
160
161
|
metric,
|
|
161
162
|
time_range,
|
|
162
163
|
granularity,
|
|
163
|
-
|
|
164
|
+
inferences_role,
|
|
164
165
|
)
|
|
165
166
|
)
|
|
166
167
|
|
|
@@ -182,9 +183,9 @@ class Dimension(Node):
|
|
|
182
183
|
model = info.context.model
|
|
183
184
|
if model[REFERENCE].empty:
|
|
184
185
|
return DriftTimeSeries(data=[])
|
|
185
|
-
|
|
186
|
+
inferences = model[PRIMARY]
|
|
186
187
|
time_range, granularity = ensure_timeseries_parameters(
|
|
187
|
-
|
|
188
|
+
inferences,
|
|
188
189
|
time_range,
|
|
189
190
|
granularity,
|
|
190
191
|
)
|
|
@@ -202,7 +203,7 @@ class Dimension(Node):
|
|
|
202
203
|
)
|
|
203
204
|
|
|
204
205
|
@strawberry.field(
|
|
205
|
-
description="
|
|
206
|
+
description="The segments across both inference sets and returns the counts per segment",
|
|
206
207
|
) # type: ignore
|
|
207
208
|
def segments_comparison(
|
|
208
209
|
self,
|
|
@@ -249,8 +250,8 @@ class Dimension(Node):
|
|
|
249
250
|
if isinstance(binning_method, binning.IntervalBinning) and binning_method.bins is not None:
|
|
250
251
|
all_bins = all_bins.union(binning_method.bins)
|
|
251
252
|
for bin in all_bins:
|
|
252
|
-
values: Dict[ms.
|
|
253
|
-
for role in ms.
|
|
253
|
+
values: Dict[ms.InferencesRole, Any] = defaultdict(lambda: None)
|
|
254
|
+
for role in ms.InferencesRole:
|
|
254
255
|
if model[role].empty:
|
|
255
256
|
continue
|
|
256
257
|
try:
|