arize-phoenix 11.37.0__py3-none-any.whl → 12.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +74 -53
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/types.py +30 -0
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +285 -46
- phoenix/server/api/context.py +13 -2
- phoenix/server/api/dataloaders/__init__.py +6 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +65 -35
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +2 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +24 -9
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +32 -0
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +25 -7
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/DatasetExample.py +11 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +6 -2
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-CFzdBkk_.js → components-Dl9SUw1U.js} +371 -327
- phoenix/server/static/assets/{index-DayUA9lQ.js → index-CqQS0dTo.js} +2 -2
- phoenix/server/static/assets/{pages-CvUhOO9h.js → pages-DKSjVA_E.js} +771 -518
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -26,7 +26,10 @@ from typing_extensions import assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
-
from phoenix.db.helpers import
|
|
29
|
+
from phoenix.db.helpers import (
|
|
30
|
+
get_dataset_example_revisions,
|
|
31
|
+
insert_experiment_with_examples_snapshot,
|
|
32
|
+
)
|
|
30
33
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
31
34
|
from phoenix.server.api.context import Context
|
|
32
35
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
@@ -46,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
46
49
|
llm_tools,
|
|
47
50
|
prompt_metadata,
|
|
48
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
49
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
50
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
51
55
|
ChatCompletionInput,
|
|
@@ -112,6 +116,7 @@ class ChatCompletionMutationError:
|
|
|
112
116
|
@strawberry.type
|
|
113
117
|
class ChatCompletionOverDatasetMutationExamplePayload:
|
|
114
118
|
dataset_example_id: GlobalID
|
|
119
|
+
repetition_number: int
|
|
115
120
|
experiment_run_id: GlobalID
|
|
116
121
|
result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
|
|
117
122
|
|
|
@@ -191,6 +196,7 @@ class ChatCompletionMutationMixin:
|
|
|
191
196
|
]
|
|
192
197
|
if not revisions:
|
|
193
198
|
raise NotFound("No examples found for the given dataset and version")
|
|
199
|
+
user_id = get_user(info)
|
|
194
200
|
experiment = models.Experiment(
|
|
195
201
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
196
202
|
dataset_version_id=resolved_version_id,
|
|
@@ -200,14 +206,19 @@ class ChatCompletionMutationMixin:
|
|
|
200
206
|
repetitions=1,
|
|
201
207
|
metadata_=input.experiment_metadata or dict(),
|
|
202
208
|
project_name=project_name,
|
|
209
|
+
user_id=user_id,
|
|
203
210
|
)
|
|
204
|
-
session
|
|
205
|
-
await session.flush()
|
|
211
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
206
212
|
|
|
207
213
|
results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
|
|
208
214
|
batch_size = 3
|
|
209
215
|
start_time = datetime.now(timezone.utc)
|
|
210
|
-
|
|
216
|
+
unbatched_items = [
|
|
217
|
+
(revision, repetition_number)
|
|
218
|
+
for revision in revisions
|
|
219
|
+
for repetition_number in range(1, input.repetitions + 1)
|
|
220
|
+
]
|
|
221
|
+
for batch in _get_batches(unbatched_items, batch_size):
|
|
211
222
|
batch_results = await asyncio.gather(
|
|
212
223
|
*(
|
|
213
224
|
cls._chat_completion(
|
|
@@ -224,10 +235,11 @@ class ChatCompletionMutationMixin:
|
|
|
224
235
|
variables=revision.input,
|
|
225
236
|
),
|
|
226
237
|
prompt_name=input.prompt_name,
|
|
238
|
+
repetitions=repetition_number,
|
|
227
239
|
),
|
|
228
240
|
project_name=project_name,
|
|
229
241
|
)
|
|
230
|
-
for revision in batch
|
|
242
|
+
for revision, repetition_number in batch
|
|
231
243
|
),
|
|
232
244
|
return_exceptions=True,
|
|
233
245
|
)
|
|
@@ -239,13 +251,13 @@ class ChatCompletionMutationMixin:
|
|
|
239
251
|
experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
|
|
240
252
|
)
|
|
241
253
|
experiment_runs = []
|
|
242
|
-
for revision, result in zip(
|
|
254
|
+
for (revision, repetition_number), result in zip(unbatched_items, results):
|
|
243
255
|
if isinstance(result, BaseException):
|
|
244
256
|
experiment_run = models.ExperimentRun(
|
|
245
257
|
experiment_id=experiment.id,
|
|
246
258
|
dataset_example_id=revision.dataset_example_id,
|
|
247
259
|
output={},
|
|
248
|
-
repetition_number=
|
|
260
|
+
repetition_number=repetition_number,
|
|
249
261
|
start_time=start_time,
|
|
250
262
|
end_time=start_time,
|
|
251
263
|
error=str(result),
|
|
@@ -261,7 +273,7 @@ class ChatCompletionMutationMixin:
|
|
|
261
273
|
),
|
|
262
274
|
prompt_token_count=db_span.cumulative_llm_token_count_prompt,
|
|
263
275
|
completion_token_count=db_span.cumulative_llm_token_count_completion,
|
|
264
|
-
repetition_number=
|
|
276
|
+
repetition_number=repetition_number,
|
|
265
277
|
start_time=db_span.start_time,
|
|
266
278
|
end_time=db_span.end_time,
|
|
267
279
|
error=str(result.error_message) if result.error_message else None,
|
|
@@ -272,13 +284,16 @@ class ChatCompletionMutationMixin:
|
|
|
272
284
|
session.add_all(experiment_runs)
|
|
273
285
|
await session.flush()
|
|
274
286
|
|
|
275
|
-
for revision, experiment_run, result in zip(
|
|
287
|
+
for (revision, repetition_number), experiment_run, result in zip(
|
|
288
|
+
unbatched_items, experiment_runs, results
|
|
289
|
+
):
|
|
276
290
|
dataset_example_id = GlobalID(
|
|
277
291
|
models.DatasetExample.__name__, str(revision.dataset_example_id)
|
|
278
292
|
)
|
|
279
293
|
experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
|
|
280
294
|
example_payload = ChatCompletionOverDatasetMutationExamplePayload(
|
|
281
295
|
dataset_example_id=dataset_example_id,
|
|
296
|
+
repetition_number=repetition_number,
|
|
282
297
|
experiment_run_id=experiment_run_id,
|
|
283
298
|
result=result
|
|
284
299
|
if isinstance(result, ChatCompletionMutationPayload)
|
|
@@ -66,6 +66,7 @@ class DatasetMutationMixin:
|
|
|
66
66
|
name=name,
|
|
67
67
|
description=description,
|
|
68
68
|
metadata_=metadata,
|
|
69
|
+
user_id=info.context.user_id,
|
|
69
70
|
)
|
|
70
71
|
.returning(models.Dataset)
|
|
71
72
|
)
|
|
@@ -136,6 +137,7 @@ class DatasetMutationMixin:
|
|
|
136
137
|
dataset_id=dataset_rowid,
|
|
137
138
|
description=dataset_version_description,
|
|
138
139
|
metadata_=dataset_version_metadata or {},
|
|
140
|
+
user_id=info.context.user_id,
|
|
139
141
|
)
|
|
140
142
|
session.add(dataset_version)
|
|
141
143
|
await session.flush()
|
|
@@ -254,6 +256,7 @@ class DatasetMutationMixin:
|
|
|
254
256
|
dataset_id=dataset_rowid,
|
|
255
257
|
description=dataset_version_description,
|
|
256
258
|
metadata_=dataset_version_metadata,
|
|
259
|
+
user_id=info.context.user_id,
|
|
257
260
|
)
|
|
258
261
|
.returning(models.DatasetVersion.id)
|
|
259
262
|
)
|
|
@@ -451,6 +454,7 @@ class DatasetMutationMixin:
|
|
|
451
454
|
dataset_id=dataset.id,
|
|
452
455
|
description=version_description,
|
|
453
456
|
metadata_=version_metadata,
|
|
457
|
+
user_id=info.context.user_id,
|
|
454
458
|
)
|
|
455
459
|
)
|
|
456
460
|
assert version_id is not None
|
|
@@ -514,6 +518,7 @@ class DatasetMutationMixin:
|
|
|
514
518
|
dataset_id=dataset.id,
|
|
515
519
|
description=dataset_version_description,
|
|
516
520
|
metadata_=dataset_version_metadata,
|
|
521
|
+
user_id=info.context.user_id,
|
|
517
522
|
created_at=timestamp,
|
|
518
523
|
)
|
|
519
524
|
.returning(models.DatasetVersion.id)
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import delete, func, insert, select
|
|
5
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
6
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
9
|
+
from strawberry.scalars import JSON
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
14
|
+
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
17
|
+
from phoenix.server.api.queries import Query
|
|
18
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
19
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@strawberry.input
|
|
23
|
+
class CreateDatasetSplitInput:
|
|
24
|
+
name: str
|
|
25
|
+
description: Optional[str] = UNSET
|
|
26
|
+
color: str
|
|
27
|
+
metadata: Optional[JSON] = UNSET
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@strawberry.input
|
|
31
|
+
class PatchDatasetSplitInput:
|
|
32
|
+
dataset_split_id: GlobalID
|
|
33
|
+
name: Optional[str] = UNSET
|
|
34
|
+
description: Optional[str] = UNSET
|
|
35
|
+
color: Optional[str] = UNSET
|
|
36
|
+
metadata: Optional[JSON] = UNSET
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@strawberry.input
|
|
40
|
+
class DeleteDatasetSplitInput:
|
|
41
|
+
dataset_split_ids: list[GlobalID]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@strawberry.input
|
|
45
|
+
class AddDatasetExamplesToDatasetSplitsInput:
|
|
46
|
+
dataset_split_ids: list[GlobalID]
|
|
47
|
+
example_ids: list[GlobalID]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@strawberry.input
|
|
51
|
+
class RemoveDatasetExamplesFromDatasetSplitsInput:
|
|
52
|
+
dataset_split_ids: list[GlobalID]
|
|
53
|
+
example_ids: list[GlobalID]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@strawberry.input
|
|
57
|
+
class CreateDatasetSplitWithExamplesInput:
|
|
58
|
+
name: str
|
|
59
|
+
description: Optional[str] = UNSET
|
|
60
|
+
color: str
|
|
61
|
+
metadata: Optional[JSON] = UNSET
|
|
62
|
+
example_ids: list[GlobalID]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class DatasetSplitMutationPayload:
|
|
67
|
+
dataset_split: DatasetSplit
|
|
68
|
+
query: "Query"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@strawberry.type
|
|
72
|
+
class DeleteDatasetSplitsMutationPayload:
|
|
73
|
+
dataset_splits: list[DatasetSplit]
|
|
74
|
+
query: "Query"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@strawberry.type
|
|
78
|
+
class AddDatasetExamplesToDatasetSplitsMutationPayload:
|
|
79
|
+
query: "Query"
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@strawberry.type
|
|
83
|
+
class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
|
|
84
|
+
query: "Query"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@strawberry.type
|
|
88
|
+
class DatasetSplitMutationMixin:
|
|
89
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
90
|
+
async def create_dataset_split(
|
|
91
|
+
self, info: Info[Context, None], input: CreateDatasetSplitInput
|
|
92
|
+
) -> DatasetSplitMutationPayload:
|
|
93
|
+
user_id = get_user(info)
|
|
94
|
+
validated_name = _validated_name(input.name)
|
|
95
|
+
async with info.context.db() as session:
|
|
96
|
+
dataset_split_orm = models.DatasetSplit(
|
|
97
|
+
name=validated_name,
|
|
98
|
+
description=input.description,
|
|
99
|
+
color=input.color,
|
|
100
|
+
metadata_=input.metadata or {},
|
|
101
|
+
user_id=user_id,
|
|
102
|
+
)
|
|
103
|
+
session.add(dataset_split_orm)
|
|
104
|
+
try:
|
|
105
|
+
await session.commit()
|
|
106
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
107
|
+
raise Conflict(f"A dataset split named '{input.name}' already exists.")
|
|
108
|
+
return DatasetSplitMutationPayload(
|
|
109
|
+
dataset_split=to_gql_dataset_split(dataset_split_orm), query=Query()
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
113
|
+
async def patch_dataset_split(
|
|
114
|
+
self, info: Info[Context, None], input: PatchDatasetSplitInput
|
|
115
|
+
) -> DatasetSplitMutationPayload:
|
|
116
|
+
validated_name = _validated_name(input.name) if input.name else None
|
|
117
|
+
async with info.context.db() as session:
|
|
118
|
+
dataset_split_id = from_global_id_with_expected_type(
|
|
119
|
+
input.dataset_split_id, DatasetSplit.__name__
|
|
120
|
+
)
|
|
121
|
+
dataset_split_orm = await session.get(models.DatasetSplit, dataset_split_id)
|
|
122
|
+
if not dataset_split_orm:
|
|
123
|
+
raise NotFound(f"Dataset split with ID {input.dataset_split_id} not found")
|
|
124
|
+
|
|
125
|
+
if validated_name:
|
|
126
|
+
dataset_split_orm.name = validated_name
|
|
127
|
+
if input.description:
|
|
128
|
+
dataset_split_orm.description = input.description
|
|
129
|
+
if input.color:
|
|
130
|
+
dataset_split_orm.color = input.color
|
|
131
|
+
if isinstance(input.metadata, dict):
|
|
132
|
+
dataset_split_orm.metadata_ = input.metadata
|
|
133
|
+
|
|
134
|
+
gql_dataset_split = to_gql_dataset_split(dataset_split_orm)
|
|
135
|
+
try:
|
|
136
|
+
await session.commit()
|
|
137
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
138
|
+
raise Conflict("A dataset split with this name already exists")
|
|
139
|
+
|
|
140
|
+
return DatasetSplitMutationPayload(
|
|
141
|
+
dataset_split=gql_dataset_split,
|
|
142
|
+
query=Query(),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
146
|
+
async def delete_dataset_splits(
|
|
147
|
+
self, info: Info[Context, None], input: DeleteDatasetSplitInput
|
|
148
|
+
) -> DeleteDatasetSplitsMutationPayload:
|
|
149
|
+
unique_dataset_split_rowids: dict[int, None] = {} # use a dict to preserve ordering
|
|
150
|
+
for dataset_split_gid in input.dataset_split_ids:
|
|
151
|
+
try:
|
|
152
|
+
dataset_split_rowid = from_global_id_with_expected_type(
|
|
153
|
+
dataset_split_gid, DatasetSplit.__name__
|
|
154
|
+
)
|
|
155
|
+
except ValueError:
|
|
156
|
+
raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
|
|
157
|
+
unique_dataset_split_rowids[dataset_split_rowid] = None
|
|
158
|
+
dataset_split_rowids = list(unique_dataset_split_rowids.keys())
|
|
159
|
+
|
|
160
|
+
async with info.context.db() as session:
|
|
161
|
+
deleted_splits_by_id = {
|
|
162
|
+
split.id: split
|
|
163
|
+
for split in (
|
|
164
|
+
await session.scalars(
|
|
165
|
+
delete(models.DatasetSplit)
|
|
166
|
+
.where(models.DatasetSplit.id.in_(dataset_split_rowids))
|
|
167
|
+
.returning(models.DatasetSplit)
|
|
168
|
+
)
|
|
169
|
+
).all()
|
|
170
|
+
}
|
|
171
|
+
if len(deleted_splits_by_id) < len(dataset_split_rowids):
|
|
172
|
+
await session.rollback()
|
|
173
|
+
raise NotFound("One or more dataset splits not found")
|
|
174
|
+
await session.commit()
|
|
175
|
+
|
|
176
|
+
return DeleteDatasetSplitsMutationPayload(
|
|
177
|
+
dataset_splits=[
|
|
178
|
+
to_gql_dataset_split(deleted_splits_by_id[dataset_split_rowid])
|
|
179
|
+
for dataset_split_rowid in dataset_split_rowids
|
|
180
|
+
],
|
|
181
|
+
query=Query(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
185
|
+
async def add_dataset_examples_to_dataset_splits(
|
|
186
|
+
self, info: Info[Context, None], input: AddDatasetExamplesToDatasetSplitsInput
|
|
187
|
+
) -> AddDatasetExamplesToDatasetSplitsMutationPayload:
|
|
188
|
+
if not input.example_ids:
|
|
189
|
+
raise BadRequest("No examples provided.")
|
|
190
|
+
if not input.dataset_split_ids:
|
|
191
|
+
raise BadRequest("No dataset splits provided.")
|
|
192
|
+
|
|
193
|
+
unique_dataset_split_rowids: set[int] = set()
|
|
194
|
+
for dataset_split_gid in input.dataset_split_ids:
|
|
195
|
+
try:
|
|
196
|
+
dataset_split_rowid = from_global_id_with_expected_type(
|
|
197
|
+
dataset_split_gid, DatasetSplit.__name__
|
|
198
|
+
)
|
|
199
|
+
except ValueError:
|
|
200
|
+
raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
|
|
201
|
+
unique_dataset_split_rowids.add(dataset_split_rowid)
|
|
202
|
+
dataset_split_rowids = list(unique_dataset_split_rowids)
|
|
203
|
+
|
|
204
|
+
unique_example_rowids: set[int] = set()
|
|
205
|
+
for example_gid in input.example_ids:
|
|
206
|
+
try:
|
|
207
|
+
example_rowid = from_global_id_with_expected_type(
|
|
208
|
+
example_gid, models.DatasetExample.__name__
|
|
209
|
+
)
|
|
210
|
+
except ValueError:
|
|
211
|
+
raise BadRequest(f"Invalid example ID: {example_gid}")
|
|
212
|
+
unique_example_rowids.add(example_rowid)
|
|
213
|
+
example_rowids = list(unique_example_rowids)
|
|
214
|
+
|
|
215
|
+
async with info.context.db() as session:
|
|
216
|
+
existing_dataset_split_ids = (
|
|
217
|
+
await session.scalars(
|
|
218
|
+
select(models.DatasetSplit.id).where(
|
|
219
|
+
models.DatasetSplit.id.in_(dataset_split_rowids)
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
).all()
|
|
223
|
+
if len(existing_dataset_split_ids) != len(dataset_split_rowids):
|
|
224
|
+
raise NotFound("One or more dataset splits not found")
|
|
225
|
+
|
|
226
|
+
# Find existing (dataset_split_id, dataset_example_id) keys to avoid duplicates
|
|
227
|
+
# Users can submit multiple examples at once which can have
|
|
228
|
+
# indeterminate participation in multiple splits
|
|
229
|
+
existing_dataset_example_split_keys = await session.execute(
|
|
230
|
+
select(
|
|
231
|
+
models.DatasetSplitDatasetExample.dataset_split_id,
|
|
232
|
+
models.DatasetSplitDatasetExample.dataset_example_id,
|
|
233
|
+
).where(
|
|
234
|
+
models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
|
|
235
|
+
& models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
unique_dataset_example_split_keys = set(existing_dataset_example_split_keys.all())
|
|
239
|
+
|
|
240
|
+
# Compute all desired pairs and insert only missing
|
|
241
|
+
values = []
|
|
242
|
+
for dataset_split_rowid in dataset_split_rowids:
|
|
243
|
+
for example_rowid in example_rowids:
|
|
244
|
+
# if the keys already exists, skip
|
|
245
|
+
if (dataset_split_rowid, example_rowid) in unique_dataset_example_split_keys:
|
|
246
|
+
continue
|
|
247
|
+
dataset_split_id_key = models.DatasetSplitDatasetExample.dataset_split_id.key
|
|
248
|
+
dataset_example_id_key = (
|
|
249
|
+
models.DatasetSplitDatasetExample.dataset_example_id.key
|
|
250
|
+
)
|
|
251
|
+
values.append(
|
|
252
|
+
{
|
|
253
|
+
dataset_split_id_key: dataset_split_rowid,
|
|
254
|
+
dataset_example_id_key: example_rowid,
|
|
255
|
+
}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if values:
|
|
259
|
+
try:
|
|
260
|
+
await session.execute(insert(models.DatasetSplitDatasetExample), values)
|
|
261
|
+
await session.flush()
|
|
262
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
263
|
+
raise Conflict("Failed to add examples to dataset splits.") from e
|
|
264
|
+
|
|
265
|
+
return AddDatasetExamplesToDatasetSplitsMutationPayload(
|
|
266
|
+
query=Query(),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
270
|
+
async def remove_dataset_examples_from_dataset_splits(
|
|
271
|
+
self, info: Info[Context, None], input: RemoveDatasetExamplesFromDatasetSplitsInput
|
|
272
|
+
) -> RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
|
|
273
|
+
if not input.dataset_split_ids:
|
|
274
|
+
raise BadRequest("No dataset splits provided.")
|
|
275
|
+
if not input.example_ids:
|
|
276
|
+
raise BadRequest("No examples provided.")
|
|
277
|
+
|
|
278
|
+
unique_dataset_split_rowids: set[int] = set()
|
|
279
|
+
for dataset_split_gid in input.dataset_split_ids:
|
|
280
|
+
try:
|
|
281
|
+
dataset_split_rowid = from_global_id_with_expected_type(
|
|
282
|
+
dataset_split_gid, DatasetSplit.__name__
|
|
283
|
+
)
|
|
284
|
+
except ValueError:
|
|
285
|
+
raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
|
|
286
|
+
unique_dataset_split_rowids.add(dataset_split_rowid)
|
|
287
|
+
dataset_split_rowids = list(unique_dataset_split_rowids)
|
|
288
|
+
|
|
289
|
+
unique_example_rowids: set[int] = set()
|
|
290
|
+
for example_gid in input.example_ids:
|
|
291
|
+
try:
|
|
292
|
+
example_rowid = from_global_id_with_expected_type(
|
|
293
|
+
example_gid, models.DatasetExample.__name__
|
|
294
|
+
)
|
|
295
|
+
except ValueError:
|
|
296
|
+
raise BadRequest(f"Invalid example ID: {example_gid}")
|
|
297
|
+
unique_example_rowids.add(example_rowid)
|
|
298
|
+
example_rowids = list(unique_example_rowids)
|
|
299
|
+
|
|
300
|
+
stmt = delete(models.DatasetSplitDatasetExample).where(
|
|
301
|
+
models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
|
|
302
|
+
& models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
|
|
303
|
+
)
|
|
304
|
+
async with info.context.db() as session:
|
|
305
|
+
existing_dataset_split_ids = (
|
|
306
|
+
await session.scalars(
|
|
307
|
+
select(models.DatasetSplit.id).where(
|
|
308
|
+
models.DatasetSplit.id.in_(dataset_split_rowids)
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
).all()
|
|
312
|
+
if len(existing_dataset_split_ids) != len(dataset_split_rowids):
|
|
313
|
+
raise NotFound("One or more dataset splits not found")
|
|
314
|
+
|
|
315
|
+
await session.execute(stmt)
|
|
316
|
+
|
|
317
|
+
return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
|
|
318
|
+
query=Query(),
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
322
|
+
async def create_dataset_split_with_examples(
|
|
323
|
+
self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
|
|
324
|
+
) -> DatasetSplitMutationPayload:
|
|
325
|
+
user_id = get_user(info)
|
|
326
|
+
validated_name = _validated_name(input.name)
|
|
327
|
+
unique_example_rowids: set[int] = set()
|
|
328
|
+
for example_gid in input.example_ids:
|
|
329
|
+
try:
|
|
330
|
+
example_rowid = from_global_id_with_expected_type(
|
|
331
|
+
example_gid, models.DatasetExample.__name__
|
|
332
|
+
)
|
|
333
|
+
unique_example_rowids.add(example_rowid)
|
|
334
|
+
except ValueError:
|
|
335
|
+
raise BadRequest(f"Invalid example ID: {example_gid}")
|
|
336
|
+
example_rowids = list(unique_example_rowids)
|
|
337
|
+
async with info.context.db() as session:
|
|
338
|
+
if example_rowids:
|
|
339
|
+
found_count = await session.scalar(
|
|
340
|
+
select(func.count(models.DatasetExample.id)).where(
|
|
341
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
342
|
+
)
|
|
343
|
+
)
|
|
344
|
+
if found_count is None or found_count < len(example_rowids):
|
|
345
|
+
raise NotFound("One or more dataset examples were not found.")
|
|
346
|
+
|
|
347
|
+
dataset_split_orm = models.DatasetSplit(
|
|
348
|
+
name=validated_name,
|
|
349
|
+
description=input.description or None,
|
|
350
|
+
color=input.color,
|
|
351
|
+
metadata_=input.metadata or {},
|
|
352
|
+
user_id=user_id,
|
|
353
|
+
)
|
|
354
|
+
session.add(dataset_split_orm)
|
|
355
|
+
try:
|
|
356
|
+
await session.flush()
|
|
357
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
358
|
+
raise Conflict(f"A dataset split named '{validated_name}' already exists.")
|
|
359
|
+
|
|
360
|
+
if example_rowids:
|
|
361
|
+
values = [
|
|
362
|
+
{
|
|
363
|
+
models.DatasetSplitDatasetExample.dataset_split_id.key: dataset_split_orm.id, # noqa: E501
|
|
364
|
+
models.DatasetSplitDatasetExample.dataset_example_id.key: example_id,
|
|
365
|
+
}
|
|
366
|
+
for example_id in example_rowids
|
|
367
|
+
]
|
|
368
|
+
try:
|
|
369
|
+
await session.execute(insert(models.DatasetSplitDatasetExample), values)
|
|
370
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
371
|
+
# Roll back the transaction on association failure
|
|
372
|
+
await session.rollback()
|
|
373
|
+
raise Conflict(
|
|
374
|
+
"Failed to associate examples with the new dataset split."
|
|
375
|
+
) from e
|
|
376
|
+
|
|
377
|
+
return DatasetSplitMutationPayload(
|
|
378
|
+
dataset_split=to_gql_dataset_split(dataset_split_orm),
|
|
379
|
+
query=Query(),
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _validated_name(name: str) -> str:
|
|
384
|
+
validated_name = name.strip()
|
|
385
|
+
if not validated_name:
|
|
386
|
+
raise BadRequest("Name cannot be empty")
|
|
387
|
+
return validated_name
|