arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/METADATA +4 -4
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/RECORD +111 -55
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +0 -27
- phoenix/config.py +21 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +64 -62
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators.py +275 -0
- phoenix/datasets/experiments.py +469 -0
- phoenix/datasets/tracing.py +66 -0
- phoenix/datasets/types.py +212 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +234 -0
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +13 -2
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +230 -3
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +16 -0
- phoenix/server/api/dataloaders/__init__.py +16 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +178 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +42 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +6 -0
- phoenix/server/api/openapi/schema.py +15 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +861 -0
- phoenix/server/api/routers/v1/evaluations.py +4 -2
- phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
- phoenix/server/api/routers/v1/experiment_runs.py +108 -0
- phoenix/server/api/routers/v1/experiments.py +174 -0
- phoenix/server/api/routers/v1/spans.py +3 -1
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +135 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Span.py +78 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +99 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2246 -1368
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +316 -21
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
|
@@ -0,0 +1,520 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from openinference.semconv.trace import (
|
|
6
|
+
SpanAttributes,
|
|
7
|
+
)
|
|
8
|
+
from sqlalchemy import and_, delete, distinct, func, insert, select, update
|
|
9
|
+
from strawberry import UNSET
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.context import Context
|
|
14
|
+
from phoenix.server.api.helpers.dataset_helpers import (
|
|
15
|
+
get_dataset_example_input,
|
|
16
|
+
get_dataset_example_output,
|
|
17
|
+
)
|
|
18
|
+
from phoenix.server.api.input_types.AddExamplesToDatasetInput import AddExamplesToDatasetInput
|
|
19
|
+
from phoenix.server.api.input_types.AddSpansToDatasetInput import AddSpansToDatasetInput
|
|
20
|
+
from phoenix.server.api.input_types.CreateDatasetInput import CreateDatasetInput
|
|
21
|
+
from phoenix.server.api.input_types.DeleteDatasetExamplesInput import DeleteDatasetExamplesInput
|
|
22
|
+
from phoenix.server.api.input_types.DeleteDatasetInput import DeleteDatasetInput
|
|
23
|
+
from phoenix.server.api.input_types.PatchDatasetExamplesInput import (
|
|
24
|
+
DatasetExamplePatch,
|
|
25
|
+
PatchDatasetExamplesInput,
|
|
26
|
+
)
|
|
27
|
+
from phoenix.server.api.input_types.PatchDatasetInput import PatchDatasetInput
|
|
28
|
+
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
29
|
+
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
30
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
31
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
32
|
+
from phoenix.server.api.types.Span import Span
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@strawberry.type
|
|
36
|
+
class DatasetMutationPayload:
|
|
37
|
+
dataset: Dataset
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@strawberry.type
|
|
41
|
+
class DatasetMutationMixin:
|
|
42
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
43
|
+
async def create_dataset(
|
|
44
|
+
self,
|
|
45
|
+
info: Info[Context, None],
|
|
46
|
+
input: CreateDatasetInput,
|
|
47
|
+
) -> DatasetMutationPayload:
|
|
48
|
+
name = input.name
|
|
49
|
+
description = input.description if input.description is not UNSET else None
|
|
50
|
+
metadata = input.metadata
|
|
51
|
+
async with info.context.db() as session:
|
|
52
|
+
dataset = await session.scalar(
|
|
53
|
+
insert(models.Dataset)
|
|
54
|
+
.values(
|
|
55
|
+
name=name,
|
|
56
|
+
description=description,
|
|
57
|
+
metadata_=metadata,
|
|
58
|
+
)
|
|
59
|
+
.returning(models.Dataset)
|
|
60
|
+
)
|
|
61
|
+
assert dataset is not None
|
|
62
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
63
|
+
|
|
64
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
65
|
+
async def patch_dataset(
|
|
66
|
+
self,
|
|
67
|
+
info: Info[Context, None],
|
|
68
|
+
input: PatchDatasetInput,
|
|
69
|
+
) -> DatasetMutationPayload:
|
|
70
|
+
dataset_id = from_global_id_with_expected_type(
|
|
71
|
+
global_id=input.dataset_id, expected_type_name=Dataset.__name__
|
|
72
|
+
)
|
|
73
|
+
patch = {
|
|
74
|
+
column.key: patch_value
|
|
75
|
+
for column, patch_value, column_is_nullable in (
|
|
76
|
+
(models.Dataset.name, input.name, False),
|
|
77
|
+
(models.Dataset.description, input.description, True),
|
|
78
|
+
(models.Dataset.metadata_, input.metadata, False),
|
|
79
|
+
)
|
|
80
|
+
if patch_value is not UNSET and (patch_value is not None or column_is_nullable)
|
|
81
|
+
}
|
|
82
|
+
async with info.context.db() as session:
|
|
83
|
+
dataset = await session.scalar(
|
|
84
|
+
update(models.Dataset)
|
|
85
|
+
.where(models.Dataset.id == dataset_id)
|
|
86
|
+
.returning(models.Dataset)
|
|
87
|
+
.values(**patch)
|
|
88
|
+
)
|
|
89
|
+
assert dataset is not None
|
|
90
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
91
|
+
|
|
92
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
93
|
+
async def add_spans_to_dataset(
|
|
94
|
+
self,
|
|
95
|
+
info: Info[Context, None],
|
|
96
|
+
input: AddSpansToDatasetInput,
|
|
97
|
+
) -> DatasetMutationPayload:
|
|
98
|
+
dataset_id = input.dataset_id
|
|
99
|
+
span_ids = input.span_ids
|
|
100
|
+
dataset_version_description = (
|
|
101
|
+
input.dataset_version_description
|
|
102
|
+
if isinstance(input.dataset_version_description, str)
|
|
103
|
+
else None
|
|
104
|
+
)
|
|
105
|
+
dataset_version_metadata = input.dataset_version_metadata
|
|
106
|
+
dataset_rowid = from_global_id_with_expected_type(
|
|
107
|
+
global_id=dataset_id, expected_type_name=Dataset.__name__
|
|
108
|
+
)
|
|
109
|
+
span_rowids = {
|
|
110
|
+
from_global_id_with_expected_type(global_id=span_id, expected_type_name=Span.__name__)
|
|
111
|
+
for span_id in set(span_ids)
|
|
112
|
+
}
|
|
113
|
+
async with info.context.db() as session:
|
|
114
|
+
if (
|
|
115
|
+
dataset := await session.scalar(
|
|
116
|
+
select(models.Dataset).where(models.Dataset.id == dataset_rowid)
|
|
117
|
+
)
|
|
118
|
+
) is None:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Unknown dataset: {dataset_id}"
|
|
121
|
+
) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
|
|
122
|
+
dataset_version_rowid = await session.scalar(
|
|
123
|
+
insert(models.DatasetVersion)
|
|
124
|
+
.values(
|
|
125
|
+
dataset_id=dataset_rowid,
|
|
126
|
+
description=dataset_version_description,
|
|
127
|
+
metadata_=dataset_version_metadata,
|
|
128
|
+
)
|
|
129
|
+
.returning(models.DatasetVersion.id)
|
|
130
|
+
)
|
|
131
|
+
spans = (
|
|
132
|
+
await session.execute(
|
|
133
|
+
select(
|
|
134
|
+
models.Span.id,
|
|
135
|
+
models.Span.span_kind,
|
|
136
|
+
models.Span.attributes,
|
|
137
|
+
_span_attribute(INPUT_MIME_TYPE),
|
|
138
|
+
_span_attribute(INPUT_VALUE),
|
|
139
|
+
_span_attribute(OUTPUT_MIME_TYPE),
|
|
140
|
+
_span_attribute(OUTPUT_VALUE),
|
|
141
|
+
_span_attribute(LLM_PROMPT_TEMPLATE_VARIABLES),
|
|
142
|
+
_span_attribute(LLM_INPUT_MESSAGES),
|
|
143
|
+
_span_attribute(LLM_OUTPUT_MESSAGES),
|
|
144
|
+
_span_attribute(RETRIEVAL_DOCUMENTS),
|
|
145
|
+
)
|
|
146
|
+
.select_from(models.Span)
|
|
147
|
+
.where(models.Span.id.in_(span_rowids))
|
|
148
|
+
)
|
|
149
|
+
).all()
|
|
150
|
+
if missing_span_rowids := span_rowids - {span.id for span in spans}:
|
|
151
|
+
raise ValueError(
|
|
152
|
+
f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
|
|
153
|
+
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221
|
|
154
|
+
DatasetExample = models.DatasetExample
|
|
155
|
+
dataset_example_rowids = (
|
|
156
|
+
await session.scalars(
|
|
157
|
+
insert(DatasetExample).returning(DatasetExample.id),
|
|
158
|
+
[
|
|
159
|
+
{
|
|
160
|
+
DatasetExample.dataset_id.key: dataset_rowid,
|
|
161
|
+
DatasetExample.span_rowid.key: span.id,
|
|
162
|
+
}
|
|
163
|
+
for span in spans
|
|
164
|
+
],
|
|
165
|
+
)
|
|
166
|
+
).all()
|
|
167
|
+
assert len(dataset_example_rowids) == len(spans)
|
|
168
|
+
assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
|
|
169
|
+
DatasetExampleRevision = models.DatasetExampleRevision
|
|
170
|
+
await session.execute(
|
|
171
|
+
insert(DatasetExampleRevision),
|
|
172
|
+
[
|
|
173
|
+
{
|
|
174
|
+
DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
|
|
175
|
+
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
|
|
176
|
+
DatasetExampleRevision.input.key: get_dataset_example_input(span),
|
|
177
|
+
DatasetExampleRevision.output.key: get_dataset_example_output(span),
|
|
178
|
+
DatasetExampleRevision.metadata_.key: span.attributes,
|
|
179
|
+
DatasetExampleRevision.revision_kind.key: "CREATE",
|
|
180
|
+
}
|
|
181
|
+
for dataset_example_rowid, span in zip(dataset_example_rowids, spans)
|
|
182
|
+
],
|
|
183
|
+
)
|
|
184
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
185
|
+
|
|
186
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
187
|
+
async def add_examples_to_dataset(
|
|
188
|
+
self, info: Info[Context, None], input: AddExamplesToDatasetInput
|
|
189
|
+
) -> DatasetMutationPayload:
|
|
190
|
+
dataset_id = input.dataset_id
|
|
191
|
+
# Extract the span rowids from the input examples if they exist
|
|
192
|
+
span_ids = span_ids = [example.span_id for example in input.examples if example.span_id]
|
|
193
|
+
span_rowids = {
|
|
194
|
+
from_global_id_with_expected_type(global_id=span_id, expected_type_name=Span.__name__)
|
|
195
|
+
for span_id in set(span_ids)
|
|
196
|
+
}
|
|
197
|
+
dataset_version_description = (
|
|
198
|
+
input.dataset_version_description if input.dataset_version_description else None
|
|
199
|
+
)
|
|
200
|
+
dataset_version_metadata = input.dataset_version_metadata
|
|
201
|
+
dataset_rowid = from_global_id_with_expected_type(
|
|
202
|
+
global_id=dataset_id, expected_type_name=Dataset.__name__
|
|
203
|
+
)
|
|
204
|
+
async with info.context.db() as session:
|
|
205
|
+
if (
|
|
206
|
+
dataset := await session.scalar(
|
|
207
|
+
select(models.Dataset).where(models.Dataset.id == dataset_rowid)
|
|
208
|
+
)
|
|
209
|
+
) is None:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
f"Unknown dataset: {dataset_id}"
|
|
212
|
+
) # todo: implement error types https://github.com/Arize-ai/phoenix/issues/3221
|
|
213
|
+
dataset_version_rowid = await session.scalar(
|
|
214
|
+
insert(models.DatasetVersion)
|
|
215
|
+
.values(
|
|
216
|
+
dataset_id=dataset_rowid,
|
|
217
|
+
description=dataset_version_description,
|
|
218
|
+
metadata_=dataset_version_metadata,
|
|
219
|
+
)
|
|
220
|
+
.returning(models.DatasetVersion.id)
|
|
221
|
+
)
|
|
222
|
+
spans = (
|
|
223
|
+
await session.execute(
|
|
224
|
+
select(models.Span.id)
|
|
225
|
+
.select_from(models.Span)
|
|
226
|
+
.where(models.Span.id.in_(span_rowids))
|
|
227
|
+
)
|
|
228
|
+
).all()
|
|
229
|
+
# Just validate that the number of spans matches the number of span_ids
|
|
230
|
+
# to ensure that the span_ids are valid
|
|
231
|
+
assert len(spans) == len(span_rowids)
|
|
232
|
+
DatasetExample = models.DatasetExample
|
|
233
|
+
dataset_example_rowids = (
|
|
234
|
+
await session.scalars(
|
|
235
|
+
insert(DatasetExample).returning(DatasetExample.id),
|
|
236
|
+
[
|
|
237
|
+
{
|
|
238
|
+
DatasetExample.dataset_id.key: dataset_rowid,
|
|
239
|
+
DatasetExample.span_rowid.key: from_global_id_with_expected_type(
|
|
240
|
+
global_id=example.span_id,
|
|
241
|
+
expected_type_name=Span.__name__,
|
|
242
|
+
)
|
|
243
|
+
if example.span_id
|
|
244
|
+
else None,
|
|
245
|
+
}
|
|
246
|
+
for example in input.examples
|
|
247
|
+
],
|
|
248
|
+
)
|
|
249
|
+
).all()
|
|
250
|
+
assert len(dataset_example_rowids) == len(input.examples)
|
|
251
|
+
assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
|
|
252
|
+
DatasetExampleRevision = models.DatasetExampleRevision
|
|
253
|
+
await session.execute(
|
|
254
|
+
insert(DatasetExampleRevision),
|
|
255
|
+
[
|
|
256
|
+
{
|
|
257
|
+
DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
|
|
258
|
+
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
|
|
259
|
+
DatasetExampleRevision.input.key: example.input,
|
|
260
|
+
DatasetExampleRevision.output.key: example.output,
|
|
261
|
+
DatasetExampleRevision.metadata_.key: example.metadata,
|
|
262
|
+
DatasetExampleRevision.revision_kind.key: "CREATE",
|
|
263
|
+
}
|
|
264
|
+
for dataset_example_rowid, example in zip(
|
|
265
|
+
dataset_example_rowids, input.examples
|
|
266
|
+
)
|
|
267
|
+
],
|
|
268
|
+
)
|
|
269
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
270
|
+
|
|
271
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
272
|
+
async def delete_dataset(
|
|
273
|
+
self,
|
|
274
|
+
info: Info[Context, None],
|
|
275
|
+
input: DeleteDatasetInput,
|
|
276
|
+
) -> DatasetMutationPayload:
|
|
277
|
+
dataset_id = input.dataset_id
|
|
278
|
+
dataset_rowid = from_global_id_with_expected_type(
|
|
279
|
+
global_id=dataset_id, expected_type_name=Dataset.__name__
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
async with info.context.db() as session:
|
|
283
|
+
delete_result = await session.execute(
|
|
284
|
+
delete(models.Dataset)
|
|
285
|
+
.where(models.Dataset.id == dataset_rowid)
|
|
286
|
+
.returning(models.Dataset)
|
|
287
|
+
)
|
|
288
|
+
if not (datasets := delete_result.first()):
|
|
289
|
+
raise ValueError(f"Unknown dataset: {dataset_id}")
|
|
290
|
+
|
|
291
|
+
dataset = datasets[0]
|
|
292
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
293
|
+
|
|
294
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
295
|
+
async def patch_dataset_examples(
|
|
296
|
+
self,
|
|
297
|
+
info: Info[Context, None],
|
|
298
|
+
input: PatchDatasetExamplesInput,
|
|
299
|
+
) -> DatasetMutationPayload:
|
|
300
|
+
if not (patches := input.patches):
|
|
301
|
+
raise ValueError("Must provide examples to patch.")
|
|
302
|
+
by_numeric_id = [
|
|
303
|
+
(
|
|
304
|
+
from_global_id_with_expected_type(patch.example_id, DatasetExample.__name__),
|
|
305
|
+
index,
|
|
306
|
+
patch,
|
|
307
|
+
)
|
|
308
|
+
for index, patch in enumerate(patches)
|
|
309
|
+
]
|
|
310
|
+
example_ids, _, patches = map(list, zip(*sorted(by_numeric_id)))
|
|
311
|
+
if len(set(example_ids)) < len(example_ids):
|
|
312
|
+
raise ValueError("Cannot patch the same example more than once per mutation.")
|
|
313
|
+
if any(patch.is_empty() for patch in patches):
|
|
314
|
+
raise ValueError("Received one or more empty patches that contain no fields to update.")
|
|
315
|
+
version_description = input.version_description or None
|
|
316
|
+
version_metadata = input.version_metadata
|
|
317
|
+
async with info.context.db() as session:
|
|
318
|
+
datasets = (
|
|
319
|
+
await session.scalars(
|
|
320
|
+
select(models.Dataset)
|
|
321
|
+
.where(
|
|
322
|
+
models.Dataset.id.in_(
|
|
323
|
+
select(distinct(models.DatasetExample.dataset_id))
|
|
324
|
+
.where(models.DatasetExample.id.in_(example_ids))
|
|
325
|
+
.scalar_subquery()
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
.limit(2)
|
|
329
|
+
)
|
|
330
|
+
).all()
|
|
331
|
+
if not datasets:
|
|
332
|
+
raise ValueError("No examples found.")
|
|
333
|
+
if len(set(ds.id for ds in datasets)) > 1:
|
|
334
|
+
raise ValueError("Examples must come from the same dataset.")
|
|
335
|
+
dataset = datasets[0]
|
|
336
|
+
|
|
337
|
+
revision_ids = (
|
|
338
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
339
|
+
.where(models.DatasetExampleRevision.dataset_example_id.in_(example_ids))
|
|
340
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
341
|
+
.scalar_subquery()
|
|
342
|
+
)
|
|
343
|
+
revisions = (
|
|
344
|
+
await session.scalars(
|
|
345
|
+
select(models.DatasetExampleRevision)
|
|
346
|
+
.where(
|
|
347
|
+
and_(
|
|
348
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
349
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
.order_by(
|
|
353
|
+
models.DatasetExampleRevision.dataset_example_id
|
|
354
|
+
) # ensure the order of the revisions matches the order of the input patches
|
|
355
|
+
)
|
|
356
|
+
).all()
|
|
357
|
+
if (num_missing_examples := len(example_ids) - len(revisions)) > 0:
|
|
358
|
+
raise ValueError(f"{num_missing_examples} example(s) could not be found.")
|
|
359
|
+
|
|
360
|
+
version_id = await session.scalar(
|
|
361
|
+
insert(models.DatasetVersion)
|
|
362
|
+
.returning(models.DatasetVersion.id)
|
|
363
|
+
.values(
|
|
364
|
+
dataset_id=dataset.id,
|
|
365
|
+
description=version_description,
|
|
366
|
+
metadata_=version_metadata,
|
|
367
|
+
)
|
|
368
|
+
)
|
|
369
|
+
assert version_id is not None
|
|
370
|
+
|
|
371
|
+
await session.execute(
|
|
372
|
+
insert(models.DatasetExampleRevision),
|
|
373
|
+
[
|
|
374
|
+
_to_orm_revision(
|
|
375
|
+
existing_revision=revision,
|
|
376
|
+
patch=patch,
|
|
377
|
+
example_id=example_id,
|
|
378
|
+
version_id=version_id,
|
|
379
|
+
)
|
|
380
|
+
for revision, patch, example_id in zip(revisions, patches, example_ids)
|
|
381
|
+
],
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
385
|
+
|
|
386
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
387
|
+
async def delete_dataset_examples(
|
|
388
|
+
self, info: Info[Context, None], input: DeleteDatasetExamplesInput
|
|
389
|
+
) -> DatasetMutationPayload:
|
|
390
|
+
timestamp = datetime.now()
|
|
391
|
+
example_db_ids = [
|
|
392
|
+
from_global_id_with_expected_type(global_id, models.DatasetExample.__name__)
|
|
393
|
+
for global_id in input.example_ids
|
|
394
|
+
]
|
|
395
|
+
# Guard against empty input
|
|
396
|
+
if not example_db_ids:
|
|
397
|
+
raise ValueError("Must provide examples to delete")
|
|
398
|
+
dataset_version_description = (
|
|
399
|
+
input.dataset_version_description
|
|
400
|
+
if isinstance(input.dataset_version_description, str)
|
|
401
|
+
else None
|
|
402
|
+
)
|
|
403
|
+
dataset_version_metadata = input.dataset_version_metadata
|
|
404
|
+
async with info.context.db() as session:
|
|
405
|
+
# Check if the examples are from a single dataset
|
|
406
|
+
datasets = (
|
|
407
|
+
await session.scalars(
|
|
408
|
+
select(models.Dataset)
|
|
409
|
+
.join(
|
|
410
|
+
models.DatasetExample, models.Dataset.id == models.DatasetExample.dataset_id
|
|
411
|
+
)
|
|
412
|
+
.where(models.DatasetExample.id.in_(example_db_ids))
|
|
413
|
+
.distinct()
|
|
414
|
+
.limit(2) # limit to 2 to check if there are more than 1 dataset
|
|
415
|
+
)
|
|
416
|
+
).all()
|
|
417
|
+
if len(datasets) > 1:
|
|
418
|
+
raise ValueError("Examples must be from the same dataset")
|
|
419
|
+
elif not datasets:
|
|
420
|
+
raise ValueError("Examples not found")
|
|
421
|
+
|
|
422
|
+
dataset = datasets[0]
|
|
423
|
+
|
|
424
|
+
dataset_version_rowid = await session.scalar(
|
|
425
|
+
insert(models.DatasetVersion)
|
|
426
|
+
.values(
|
|
427
|
+
dataset_id=dataset.id,
|
|
428
|
+
description=dataset_version_description,
|
|
429
|
+
metadata_=dataset_version_metadata,
|
|
430
|
+
created_at=timestamp,
|
|
431
|
+
)
|
|
432
|
+
.returning(models.DatasetVersion.id)
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# If the examples already have a delete revision, skip the deletion
|
|
436
|
+
existing_delete_revisions = (
|
|
437
|
+
await session.scalars(
|
|
438
|
+
select(models.DatasetExampleRevision).where(
|
|
439
|
+
models.DatasetExampleRevision.dataset_example_id.in_(example_db_ids),
|
|
440
|
+
models.DatasetExampleRevision.revision_kind == "DELETE",
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
).all()
|
|
444
|
+
|
|
445
|
+
if existing_delete_revisions:
|
|
446
|
+
raise ValueError(
|
|
447
|
+
"Provided examples contain already deleted examples. Delete aborted."
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
DatasetExampleRevision = models.DatasetExampleRevision
|
|
451
|
+
await session.execute(
|
|
452
|
+
insert(DatasetExampleRevision),
|
|
453
|
+
[
|
|
454
|
+
{
|
|
455
|
+
DatasetExampleRevision.dataset_example_id.key: dataset_example_rowid,
|
|
456
|
+
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
|
|
457
|
+
DatasetExampleRevision.input.key: {},
|
|
458
|
+
DatasetExampleRevision.output.key: {},
|
|
459
|
+
DatasetExampleRevision.metadata_.key: {},
|
|
460
|
+
DatasetExampleRevision.revision_kind.key: "DELETE",
|
|
461
|
+
DatasetExampleRevision.created_at.key: timestamp,
|
|
462
|
+
}
|
|
463
|
+
for dataset_example_rowid in example_db_ids
|
|
464
|
+
],
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
def _span_attribute(semconv: str) -> Any:
|
|
471
|
+
"""
|
|
472
|
+
Extracts an attribute from the ORM span attributes column and labels the
|
|
473
|
+
result.
|
|
474
|
+
|
|
475
|
+
E.g., "input.value" -> Span.attributes["input"]["value"].label("input_value")
|
|
476
|
+
"""
|
|
477
|
+
attribute_value: Any = models.Span.attributes
|
|
478
|
+
for key in semconv.split("."):
|
|
479
|
+
attribute_value = attribute_value[key]
|
|
480
|
+
return attribute_value.label(semconv.replace(".", "_"))
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _to_orm_revision(
|
|
484
|
+
*,
|
|
485
|
+
existing_revision: models.DatasetExampleRevision,
|
|
486
|
+
patch: DatasetExamplePatch,
|
|
487
|
+
example_id: int,
|
|
488
|
+
version_id: int,
|
|
489
|
+
) -> Dict[str, Any]:
|
|
490
|
+
"""
|
|
491
|
+
Creates a new revision from an existing revision and a patch. The output is a
|
|
492
|
+
dictionary suitable for insertion into the database using the sqlalchemy
|
|
493
|
+
bulk insertion API.
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
db_rev = models.DatasetExampleRevision
|
|
497
|
+
input = patch.input if isinstance(patch.input, dict) else existing_revision.input
|
|
498
|
+
output = patch.output if isinstance(patch.output, dict) else existing_revision.output
|
|
499
|
+
metadata = patch.metadata if isinstance(patch.metadata, dict) else existing_revision.metadata_
|
|
500
|
+
return {
|
|
501
|
+
str(db_column.key): patch_value
|
|
502
|
+
for db_column, patch_value in (
|
|
503
|
+
(db_rev.dataset_example_id, example_id),
|
|
504
|
+
(db_rev.dataset_version_id, version_id),
|
|
505
|
+
(db_rev.input, input),
|
|
506
|
+
(db_rev.output, output),
|
|
507
|
+
(db_rev.metadata_, metadata),
|
|
508
|
+
(db_rev.revision_kind, "PATCH"),
|
|
509
|
+
)
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
|
|
514
|
+
INPUT_VALUE = SpanAttributes.INPUT_VALUE
|
|
515
|
+
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
|
|
516
|
+
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
|
|
517
|
+
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
518
|
+
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
|
|
519
|
+
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
|
|
520
|
+
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import delete
|
|
5
|
+
from strawberry.relay import GlobalID
|
|
6
|
+
from strawberry.types import Info
|
|
7
|
+
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.server.api.context import Context
|
|
10
|
+
from phoenix.server.api.input_types.DeleteExperimentsInput import DeleteExperimentsInput
|
|
11
|
+
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
12
|
+
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
13
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@strawberry.type
|
|
17
|
+
class ExperimentMutationPayload:
|
|
18
|
+
experiments: List[Experiment]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@strawberry.type
|
|
22
|
+
class ExperimentMutationMixin:
|
|
23
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
24
|
+
async def delete_experiments(
|
|
25
|
+
self,
|
|
26
|
+
info: Info[Context, None],
|
|
27
|
+
input: DeleteExperimentsInput,
|
|
28
|
+
) -> ExperimentMutationPayload:
|
|
29
|
+
experiment_ids = [
|
|
30
|
+
from_global_id_with_expected_type(experiment_id, Experiment.__name__)
|
|
31
|
+
for experiment_id in input.experiment_ids
|
|
32
|
+
]
|
|
33
|
+
async with info.context.db() as session:
|
|
34
|
+
savepoint = await session.begin_nested()
|
|
35
|
+
experiments = {
|
|
36
|
+
experiment.id: experiment
|
|
37
|
+
async for experiment in (
|
|
38
|
+
await session.stream_scalars(
|
|
39
|
+
delete(models.Experiment)
|
|
40
|
+
.where(models.Experiment.id.in_(experiment_ids))
|
|
41
|
+
.returning(models.Experiment)
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
}
|
|
45
|
+
if unknown_experiment_ids := set(experiment_ids) - set(experiments.keys()):
|
|
46
|
+
await savepoint.rollback()
|
|
47
|
+
raise ValueError(
|
|
48
|
+
"Failed to delete experiment(s), "
|
|
49
|
+
"probably due to invalid input experiment ID(s): "
|
|
50
|
+
+ str(
|
|
51
|
+
[
|
|
52
|
+
str(GlobalID(Experiment.__name__, str(experiment_id)))
|
|
53
|
+
for experiment_id in unknown_experiment_ids
|
|
54
|
+
]
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
if project_names := set(filter(bool, (e.project_name for e in experiments.values()))):
|
|
58
|
+
await session.execute(
|
|
59
|
+
delete(models.Project).where(models.Project.name.in_(project_names))
|
|
60
|
+
)
|
|
61
|
+
return ExperimentMutationPayload(
|
|
62
|
+
experiments=[
|
|
63
|
+
to_gql_experiment(experiments[experiment_id]) for experiment_id in experiment_ids
|
|
64
|
+
]
|
|
65
|
+
)
|
|
@@ -10,14 +10,16 @@ from strawberry.types import Info
|
|
|
10
10
|
import phoenix.core.model_schema as ms
|
|
11
11
|
from phoenix.server.api.context import Context
|
|
12
12
|
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
13
|
-
from phoenix.server.api.
|
|
14
|
-
from phoenix.server.api.types.Event import
|
|
13
|
+
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
14
|
+
from phoenix.server.api.types.Event import parse_event_ids_by_inferences_role, unpack_event_id
|
|
15
15
|
from phoenix.server.api.types.ExportedFile import ExportedFile
|
|
16
|
+
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@strawberry.type
|
|
19
|
-
class
|
|
20
|
+
class ExportEventsMutationMixin:
|
|
20
21
|
@strawberry.mutation(
|
|
22
|
+
permission_classes=[IsAuthenticated],
|
|
21
23
|
description=(
|
|
22
24
|
"Given a list of event ids, export the corresponding data subset in Parquet format."
|
|
23
25
|
" File name is optional, but if specified, should be without file extension. By default"
|
|
@@ -32,11 +34,11 @@ class ExportEventsMutation:
|
|
|
32
34
|
) -> ExportedFile:
|
|
33
35
|
if not isinstance(file_name, str):
|
|
34
36
|
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
35
|
-
row_ids =
|
|
37
|
+
row_ids = parse_event_ids_by_inferences_role(event_ids)
|
|
36
38
|
exclude_corpus_row_ids = {}
|
|
37
|
-
for
|
|
38
|
-
if isinstance(
|
|
39
|
-
exclude_corpus_row_ids[
|
|
39
|
+
for inferences_role in list(row_ids.keys()):
|
|
40
|
+
if isinstance(inferences_role, InferencesRole):
|
|
41
|
+
exclude_corpus_row_ids[inferences_role.value] = row_ids[inferences_role]
|
|
40
42
|
path = info.context.export_path
|
|
41
43
|
with open(path / (file_name + ".parquet"), "wb") as fd:
|
|
42
44
|
loop = asyncio.get_running_loop()
|
|
@@ -49,6 +51,7 @@ class ExportEventsMutation:
|
|
|
49
51
|
return ExportedFile(file_name=file_name)
|
|
50
52
|
|
|
51
53
|
@strawberry.mutation(
|
|
54
|
+
permission_classes=[IsAuthenticated],
|
|
52
55
|
description=(
|
|
53
56
|
"Given a list of clusters, export the corresponding data subset in Parquet format."
|
|
54
57
|
" File name is optional, but if specified, should be without file extension. By default"
|
|
@@ -79,13 +82,13 @@ class ExportEventsMutation:
|
|
|
79
82
|
|
|
80
83
|
def _unpack_clusters(
|
|
81
84
|
clusters: List[ClusterInput],
|
|
82
|
-
) -> Tuple[Dict[ms.
|
|
83
|
-
row_numbers: Dict[ms.
|
|
84
|
-
cluster_ids: Dict[ms.
|
|
85
|
+
) -> Tuple[Dict[ms.InferencesRole, List[int]], Dict[ms.InferencesRole, Dict[int, str]]]:
|
|
86
|
+
row_numbers: Dict[ms.InferencesRole, List[int]] = defaultdict(list)
|
|
87
|
+
cluster_ids: Dict[ms.InferencesRole, Dict[int, str]] = defaultdict(dict)
|
|
85
88
|
for i, cluster in enumerate(clusters):
|
|
86
|
-
for row_number,
|
|
87
|
-
if isinstance(
|
|
89
|
+
for row_number, inferences_role in map(unpack_event_id, cluster.event_ids):
|
|
90
|
+
if isinstance(inferences_role, AncillaryInferencesRole):
|
|
88
91
|
continue
|
|
89
|
-
row_numbers[
|
|
90
|
-
cluster_ids[
|
|
92
|
+
row_numbers[inferences_role.value].append(row_number)
|
|
93
|
+
cluster_ids[inferences_role.value][row_number] = cluster.id or str(i)
|
|
91
94
|
return row_numbers, cluster_ids
|