arize-phoenix 11.37.0__py3-none-any.whl → 12.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (75) hide show
  1. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +74 -53
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/helpers.py +2 -2
  9. phoenix/db/insertion/session_annotation.py +176 -0
  10. phoenix/db/insertion/types.py +30 -0
  11. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  12. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  13. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  14. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  15. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  16. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  17. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  18. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  19. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  20. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  21. phoenix/db/models.py +285 -46
  22. phoenix/server/api/context.py +13 -2
  23. phoenix/server/api/dataloaders/__init__.py +6 -2
  24. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  25. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  26. phoenix/server/api/dataloaders/table_fields.py +2 -2
  27. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  28. phoenix/server/api/helpers/playground_clients.py +65 -35
  29. phoenix/server/api/helpers/playground_spans.py +2 -1
  30. phoenix/server/api/helpers/playground_users.py +26 -0
  31. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  32. phoenix/server/api/input_types/ChatCompletionInput.py +2 -0
  33. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  34. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  35. phoenix/server/api/mutations/__init__.py +6 -0
  36. phoenix/server/api/mutations/chat_mutations.py +24 -9
  37. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  38. phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
  39. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  40. phoenix/server/api/queries.py +32 -0
  41. phoenix/server/api/routers/v1/__init__.py +2 -0
  42. phoenix/server/api/routers/v1/annotations.py +320 -0
  43. phoenix/server/api/routers/v1/datasets.py +5 -0
  44. phoenix/server/api/routers/v1/experiments.py +10 -3
  45. phoenix/server/api/routers/v1/sessions.py +111 -0
  46. phoenix/server/api/routers/v1/traces.py +1 -2
  47. phoenix/server/api/routers/v1/users.py +7 -0
  48. phoenix/server/api/subscriptions.py +25 -7
  49. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  50. phoenix/server/api/types/DatasetExample.py +11 -0
  51. phoenix/server/api/types/DatasetSplit.py +32 -0
  52. phoenix/server/api/types/Experiment.py +0 -4
  53. phoenix/server/api/types/Project.py +16 -0
  54. phoenix/server/api/types/ProjectSession.py +88 -3
  55. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  56. phoenix/server/api/types/Span.py +5 -5
  57. phoenix/server/api/types/Trace.py +61 -0
  58. phoenix/server/app.py +6 -2
  59. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  60. phoenix/server/dml_event.py +13 -0
  61. phoenix/server/static/.vite/manifest.json +39 -39
  62. phoenix/server/static/assets/{components-CFzdBkk_.js → components-Dl9SUw1U.js} +371 -327
  63. phoenix/server/static/assets/{index-DayUA9lQ.js → index-CqQS0dTo.js} +2 -2
  64. phoenix/server/static/assets/{pages-CvUhOO9h.js → pages-DKSjVA_E.js} +771 -518
  65. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
  66. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
  67. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
  68. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
  69. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
  70. phoenix/version.py +1 -1
  71. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  72. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
  73. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
  74. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  75. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -26,7 +26,10 @@ from typing_extensions import assert_never
26
26
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
- from phoenix.db.helpers import get_dataset_example_revisions
29
+ from phoenix.db.helpers import (
30
+ get_dataset_example_revisions,
31
+ insert_experiment_with_examples_snapshot,
32
+ )
30
33
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
31
34
  from phoenix.server.api.context import Context
32
35
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
@@ -46,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
46
49
  llm_tools,
47
50
  prompt_metadata,
48
51
  )
52
+ from phoenix.server.api.helpers.playground_users import get_user
49
53
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
50
54
  from phoenix.server.api.input_types.ChatCompletionInput import (
51
55
  ChatCompletionInput,
@@ -112,6 +116,7 @@ class ChatCompletionMutationError:
112
116
  @strawberry.type
113
117
  class ChatCompletionOverDatasetMutationExamplePayload:
114
118
  dataset_example_id: GlobalID
119
+ repetition_number: int
115
120
  experiment_run_id: GlobalID
116
121
  result: Union[ChatCompletionMutationPayload, ChatCompletionMutationError]
117
122
 
@@ -191,6 +196,7 @@ class ChatCompletionMutationMixin:
191
196
  ]
192
197
  if not revisions:
193
198
  raise NotFound("No examples found for the given dataset and version")
199
+ user_id = get_user(info)
194
200
  experiment = models.Experiment(
195
201
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
196
202
  dataset_version_id=resolved_version_id,
@@ -200,14 +206,19 @@ class ChatCompletionMutationMixin:
200
206
  repetitions=1,
201
207
  metadata_=input.experiment_metadata or dict(),
202
208
  project_name=project_name,
209
+ user_id=user_id,
203
210
  )
204
- session.add(experiment)
205
- await session.flush()
211
+ await insert_experiment_with_examples_snapshot(session, experiment)
206
212
 
207
213
  results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
208
214
  batch_size = 3
209
215
  start_time = datetime.now(timezone.utc)
210
- for batch in _get_batches(revisions, batch_size):
216
+ unbatched_items = [
217
+ (revision, repetition_number)
218
+ for revision in revisions
219
+ for repetition_number in range(1, input.repetitions + 1)
220
+ ]
221
+ for batch in _get_batches(unbatched_items, batch_size):
211
222
  batch_results = await asyncio.gather(
212
223
  *(
213
224
  cls._chat_completion(
@@ -224,10 +235,11 @@ class ChatCompletionMutationMixin:
224
235
  variables=revision.input,
225
236
  ),
226
237
  prompt_name=input.prompt_name,
238
+ repetitions=repetition_number,
227
239
  ),
228
240
  project_name=project_name,
229
241
  )
230
- for revision in batch
242
+ for revision, repetition_number in batch
231
243
  ),
232
244
  return_exceptions=True,
233
245
  )
@@ -239,13 +251,13 @@ class ChatCompletionMutationMixin:
239
251
  experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
240
252
  )
241
253
  experiment_runs = []
242
- for revision, result in zip(revisions, results):
254
+ for (revision, repetition_number), result in zip(unbatched_items, results):
243
255
  if isinstance(result, BaseException):
244
256
  experiment_run = models.ExperimentRun(
245
257
  experiment_id=experiment.id,
246
258
  dataset_example_id=revision.dataset_example_id,
247
259
  output={},
248
- repetition_number=1,
260
+ repetition_number=repetition_number,
249
261
  start_time=start_time,
250
262
  end_time=start_time,
251
263
  error=str(result),
@@ -261,7 +273,7 @@ class ChatCompletionMutationMixin:
261
273
  ),
262
274
  prompt_token_count=db_span.cumulative_llm_token_count_prompt,
263
275
  completion_token_count=db_span.cumulative_llm_token_count_completion,
264
- repetition_number=1,
276
+ repetition_number=repetition_number,
265
277
  start_time=db_span.start_time,
266
278
  end_time=db_span.end_time,
267
279
  error=str(result.error_message) if result.error_message else None,
@@ -272,13 +284,16 @@ class ChatCompletionMutationMixin:
272
284
  session.add_all(experiment_runs)
273
285
  await session.flush()
274
286
 
275
- for revision, experiment_run, result in zip(revisions, experiment_runs, results):
287
+ for (revision, repetition_number), experiment_run, result in zip(
288
+ unbatched_items, experiment_runs, results
289
+ ):
276
290
  dataset_example_id = GlobalID(
277
291
  models.DatasetExample.__name__, str(revision.dataset_example_id)
278
292
  )
279
293
  experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
280
294
  example_payload = ChatCompletionOverDatasetMutationExamplePayload(
281
295
  dataset_example_id=dataset_example_id,
296
+ repetition_number=repetition_number,
282
297
  experiment_run_id=experiment_run_id,
283
298
  result=result
284
299
  if isinstance(result, ChatCompletionMutationPayload)
@@ -66,6 +66,7 @@ class DatasetMutationMixin:
66
66
  name=name,
67
67
  description=description,
68
68
  metadata_=metadata,
69
+ user_id=info.context.user_id,
69
70
  )
70
71
  .returning(models.Dataset)
71
72
  )
@@ -136,6 +137,7 @@ class DatasetMutationMixin:
136
137
  dataset_id=dataset_rowid,
137
138
  description=dataset_version_description,
138
139
  metadata_=dataset_version_metadata or {},
140
+ user_id=info.context.user_id,
139
141
  )
140
142
  session.add(dataset_version)
141
143
  await session.flush()
@@ -254,6 +256,7 @@ class DatasetMutationMixin:
254
256
  dataset_id=dataset_rowid,
255
257
  description=dataset_version_description,
256
258
  metadata_=dataset_version_metadata,
259
+ user_id=info.context.user_id,
257
260
  )
258
261
  .returning(models.DatasetVersion.id)
259
262
  )
@@ -451,6 +454,7 @@ class DatasetMutationMixin:
451
454
  dataset_id=dataset.id,
452
455
  description=version_description,
453
456
  metadata_=version_metadata,
457
+ user_id=info.context.user_id,
454
458
  )
455
459
  )
456
460
  assert version_id is not None
@@ -514,6 +518,7 @@ class DatasetMutationMixin:
514
518
  dataset_id=dataset.id,
515
519
  description=dataset_version_description,
516
520
  metadata_=dataset_version_metadata,
521
+ user_id=info.context.user_id,
517
522
  created_at=timestamp,
518
523
  )
519
524
  .returning(models.DatasetVersion.id)
@@ -0,0 +1,387 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy import delete, func, insert, select
5
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
6
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
7
+ from strawberry import UNSET
8
+ from strawberry.relay import GlobalID
9
+ from strawberry.scalars import JSON
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
14
+ from phoenix.server.api.context import Context
15
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
+ from phoenix.server.api.helpers.playground_users import get_user
17
+ from phoenix.server.api.queries import Query
18
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
19
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
20
+
21
+
22
+ @strawberry.input
23
+ class CreateDatasetSplitInput:
24
+ name: str
25
+ description: Optional[str] = UNSET
26
+ color: str
27
+ metadata: Optional[JSON] = UNSET
28
+
29
+
30
+ @strawberry.input
31
+ class PatchDatasetSplitInput:
32
+ dataset_split_id: GlobalID
33
+ name: Optional[str] = UNSET
34
+ description: Optional[str] = UNSET
35
+ color: Optional[str] = UNSET
36
+ metadata: Optional[JSON] = UNSET
37
+
38
+
39
+ @strawberry.input
40
+ class DeleteDatasetSplitInput:
41
+ dataset_split_ids: list[GlobalID]
42
+
43
+
44
+ @strawberry.input
45
+ class AddDatasetExamplesToDatasetSplitsInput:
46
+ dataset_split_ids: list[GlobalID]
47
+ example_ids: list[GlobalID]
48
+
49
+
50
+ @strawberry.input
51
+ class RemoveDatasetExamplesFromDatasetSplitsInput:
52
+ dataset_split_ids: list[GlobalID]
53
+ example_ids: list[GlobalID]
54
+
55
+
56
+ @strawberry.input
57
+ class CreateDatasetSplitWithExamplesInput:
58
+ name: str
59
+ description: Optional[str] = UNSET
60
+ color: str
61
+ metadata: Optional[JSON] = UNSET
62
+ example_ids: list[GlobalID]
63
+
64
+
65
+ @strawberry.type
66
+ class DatasetSplitMutationPayload:
67
+ dataset_split: DatasetSplit
68
+ query: "Query"
69
+
70
+
71
+ @strawberry.type
72
+ class DeleteDatasetSplitsMutationPayload:
73
+ dataset_splits: list[DatasetSplit]
74
+ query: "Query"
75
+
76
+
77
+ @strawberry.type
78
+ class AddDatasetExamplesToDatasetSplitsMutationPayload:
79
+ query: "Query"
80
+
81
+
82
+ @strawberry.type
83
+ class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
84
+ query: "Query"
85
+
86
+
87
+ @strawberry.type
88
+ class DatasetSplitMutationMixin:
89
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
90
+ async def create_dataset_split(
91
+ self, info: Info[Context, None], input: CreateDatasetSplitInput
92
+ ) -> DatasetSplitMutationPayload:
93
+ user_id = get_user(info)
94
+ validated_name = _validated_name(input.name)
95
+ async with info.context.db() as session:
96
+ dataset_split_orm = models.DatasetSplit(
97
+ name=validated_name,
98
+ description=input.description,
99
+ color=input.color,
100
+ metadata_=input.metadata or {},
101
+ user_id=user_id,
102
+ )
103
+ session.add(dataset_split_orm)
104
+ try:
105
+ await session.commit()
106
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
107
+ raise Conflict(f"A dataset split named '{input.name}' already exists.")
108
+ return DatasetSplitMutationPayload(
109
+ dataset_split=to_gql_dataset_split(dataset_split_orm), query=Query()
110
+ )
111
+
112
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
113
+ async def patch_dataset_split(
114
+ self, info: Info[Context, None], input: PatchDatasetSplitInput
115
+ ) -> DatasetSplitMutationPayload:
116
+ validated_name = _validated_name(input.name) if input.name else None
117
+ async with info.context.db() as session:
118
+ dataset_split_id = from_global_id_with_expected_type(
119
+ input.dataset_split_id, DatasetSplit.__name__
120
+ )
121
+ dataset_split_orm = await session.get(models.DatasetSplit, dataset_split_id)
122
+ if not dataset_split_orm:
123
+ raise NotFound(f"Dataset split with ID {input.dataset_split_id} not found")
124
+
125
+ if validated_name:
126
+ dataset_split_orm.name = validated_name
127
+ if input.description:
128
+ dataset_split_orm.description = input.description
129
+ if input.color:
130
+ dataset_split_orm.color = input.color
131
+ if isinstance(input.metadata, dict):
132
+ dataset_split_orm.metadata_ = input.metadata
133
+
134
+ gql_dataset_split = to_gql_dataset_split(dataset_split_orm)
135
+ try:
136
+ await session.commit()
137
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
138
+ raise Conflict("A dataset split with this name already exists")
139
+
140
+ return DatasetSplitMutationPayload(
141
+ dataset_split=gql_dataset_split,
142
+ query=Query(),
143
+ )
144
+
145
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
146
+ async def delete_dataset_splits(
147
+ self, info: Info[Context, None], input: DeleteDatasetSplitInput
148
+ ) -> DeleteDatasetSplitsMutationPayload:
149
+ unique_dataset_split_rowids: dict[int, None] = {} # use a dict to preserve ordering
150
+ for dataset_split_gid in input.dataset_split_ids:
151
+ try:
152
+ dataset_split_rowid = from_global_id_with_expected_type(
153
+ dataset_split_gid, DatasetSplit.__name__
154
+ )
155
+ except ValueError:
156
+ raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
157
+ unique_dataset_split_rowids[dataset_split_rowid] = None
158
+ dataset_split_rowids = list(unique_dataset_split_rowids.keys())
159
+
160
+ async with info.context.db() as session:
161
+ deleted_splits_by_id = {
162
+ split.id: split
163
+ for split in (
164
+ await session.scalars(
165
+ delete(models.DatasetSplit)
166
+ .where(models.DatasetSplit.id.in_(dataset_split_rowids))
167
+ .returning(models.DatasetSplit)
168
+ )
169
+ ).all()
170
+ }
171
+ if len(deleted_splits_by_id) < len(dataset_split_rowids):
172
+ await session.rollback()
173
+ raise NotFound("One or more dataset splits not found")
174
+ await session.commit()
175
+
176
+ return DeleteDatasetSplitsMutationPayload(
177
+ dataset_splits=[
178
+ to_gql_dataset_split(deleted_splits_by_id[dataset_split_rowid])
179
+ for dataset_split_rowid in dataset_split_rowids
180
+ ],
181
+ query=Query(),
182
+ )
183
+
184
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
185
+ async def add_dataset_examples_to_dataset_splits(
186
+ self, info: Info[Context, None], input: AddDatasetExamplesToDatasetSplitsInput
187
+ ) -> AddDatasetExamplesToDatasetSplitsMutationPayload:
188
+ if not input.example_ids:
189
+ raise BadRequest("No examples provided.")
190
+ if not input.dataset_split_ids:
191
+ raise BadRequest("No dataset splits provided.")
192
+
193
+ unique_dataset_split_rowids: set[int] = set()
194
+ for dataset_split_gid in input.dataset_split_ids:
195
+ try:
196
+ dataset_split_rowid = from_global_id_with_expected_type(
197
+ dataset_split_gid, DatasetSplit.__name__
198
+ )
199
+ except ValueError:
200
+ raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
201
+ unique_dataset_split_rowids.add(dataset_split_rowid)
202
+ dataset_split_rowids = list(unique_dataset_split_rowids)
203
+
204
+ unique_example_rowids: set[int] = set()
205
+ for example_gid in input.example_ids:
206
+ try:
207
+ example_rowid = from_global_id_with_expected_type(
208
+ example_gid, models.DatasetExample.__name__
209
+ )
210
+ except ValueError:
211
+ raise BadRequest(f"Invalid example ID: {example_gid}")
212
+ unique_example_rowids.add(example_rowid)
213
+ example_rowids = list(unique_example_rowids)
214
+
215
+ async with info.context.db() as session:
216
+ existing_dataset_split_ids = (
217
+ await session.scalars(
218
+ select(models.DatasetSplit.id).where(
219
+ models.DatasetSplit.id.in_(dataset_split_rowids)
220
+ )
221
+ )
222
+ ).all()
223
+ if len(existing_dataset_split_ids) != len(dataset_split_rowids):
224
+ raise NotFound("One or more dataset splits not found")
225
+
226
+ # Find existing (dataset_split_id, dataset_example_id) keys to avoid duplicates
227
+ # Users can submit multiple examples at once which can have
228
+ # indeterminate participation in multiple splits
229
+ existing_dataset_example_split_keys = await session.execute(
230
+ select(
231
+ models.DatasetSplitDatasetExample.dataset_split_id,
232
+ models.DatasetSplitDatasetExample.dataset_example_id,
233
+ ).where(
234
+ models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
235
+ & models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
236
+ )
237
+ )
238
+ unique_dataset_example_split_keys = set(existing_dataset_example_split_keys.all())
239
+
240
+ # Compute all desired pairs and insert only missing
241
+ values = []
242
+ for dataset_split_rowid in dataset_split_rowids:
243
+ for example_rowid in example_rowids:
244
+ # if the keys already exists, skip
245
+ if (dataset_split_rowid, example_rowid) in unique_dataset_example_split_keys:
246
+ continue
247
+ dataset_split_id_key = models.DatasetSplitDatasetExample.dataset_split_id.key
248
+ dataset_example_id_key = (
249
+ models.DatasetSplitDatasetExample.dataset_example_id.key
250
+ )
251
+ values.append(
252
+ {
253
+ dataset_split_id_key: dataset_split_rowid,
254
+ dataset_example_id_key: example_rowid,
255
+ }
256
+ )
257
+
258
+ if values:
259
+ try:
260
+ await session.execute(insert(models.DatasetSplitDatasetExample), values)
261
+ await session.flush()
262
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
263
+ raise Conflict("Failed to add examples to dataset splits.") from e
264
+
265
+ return AddDatasetExamplesToDatasetSplitsMutationPayload(
266
+ query=Query(),
267
+ )
268
+
269
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
270
+ async def remove_dataset_examples_from_dataset_splits(
271
+ self, info: Info[Context, None], input: RemoveDatasetExamplesFromDatasetSplitsInput
272
+ ) -> RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
273
+ if not input.dataset_split_ids:
274
+ raise BadRequest("No dataset splits provided.")
275
+ if not input.example_ids:
276
+ raise BadRequest("No examples provided.")
277
+
278
+ unique_dataset_split_rowids: set[int] = set()
279
+ for dataset_split_gid in input.dataset_split_ids:
280
+ try:
281
+ dataset_split_rowid = from_global_id_with_expected_type(
282
+ dataset_split_gid, DatasetSplit.__name__
283
+ )
284
+ except ValueError:
285
+ raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
286
+ unique_dataset_split_rowids.add(dataset_split_rowid)
287
+ dataset_split_rowids = list(unique_dataset_split_rowids)
288
+
289
+ unique_example_rowids: set[int] = set()
290
+ for example_gid in input.example_ids:
291
+ try:
292
+ example_rowid = from_global_id_with_expected_type(
293
+ example_gid, models.DatasetExample.__name__
294
+ )
295
+ except ValueError:
296
+ raise BadRequest(f"Invalid example ID: {example_gid}")
297
+ unique_example_rowids.add(example_rowid)
298
+ example_rowids = list(unique_example_rowids)
299
+
300
+ stmt = delete(models.DatasetSplitDatasetExample).where(
301
+ models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
302
+ & models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
303
+ )
304
+ async with info.context.db() as session:
305
+ existing_dataset_split_ids = (
306
+ await session.scalars(
307
+ select(models.DatasetSplit.id).where(
308
+ models.DatasetSplit.id.in_(dataset_split_rowids)
309
+ )
310
+ )
311
+ ).all()
312
+ if len(existing_dataset_split_ids) != len(dataset_split_rowids):
313
+ raise NotFound("One or more dataset splits not found")
314
+
315
+ await session.execute(stmt)
316
+
317
+ return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
318
+ query=Query(),
319
+ )
320
+
321
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
322
+ async def create_dataset_split_with_examples(
323
+ self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
324
+ ) -> DatasetSplitMutationPayload:
325
+ user_id = get_user(info)
326
+ validated_name = _validated_name(input.name)
327
+ unique_example_rowids: set[int] = set()
328
+ for example_gid in input.example_ids:
329
+ try:
330
+ example_rowid = from_global_id_with_expected_type(
331
+ example_gid, models.DatasetExample.__name__
332
+ )
333
+ unique_example_rowids.add(example_rowid)
334
+ except ValueError:
335
+ raise BadRequest(f"Invalid example ID: {example_gid}")
336
+ example_rowids = list(unique_example_rowids)
337
+ async with info.context.db() as session:
338
+ if example_rowids:
339
+ found_count = await session.scalar(
340
+ select(func.count(models.DatasetExample.id)).where(
341
+ models.DatasetExample.id.in_(example_rowids)
342
+ )
343
+ )
344
+ if found_count is None or found_count < len(example_rowids):
345
+ raise NotFound("One or more dataset examples were not found.")
346
+
347
+ dataset_split_orm = models.DatasetSplit(
348
+ name=validated_name,
349
+ description=input.description or None,
350
+ color=input.color,
351
+ metadata_=input.metadata or {},
352
+ user_id=user_id,
353
+ )
354
+ session.add(dataset_split_orm)
355
+ try:
356
+ await session.flush()
357
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
358
+ raise Conflict(f"A dataset split named '{validated_name}' already exists.")
359
+
360
+ if example_rowids:
361
+ values = [
362
+ {
363
+ models.DatasetSplitDatasetExample.dataset_split_id.key: dataset_split_orm.id, # noqa: E501
364
+ models.DatasetSplitDatasetExample.dataset_example_id.key: example_id,
365
+ }
366
+ for example_id in example_rowids
367
+ ]
368
+ try:
369
+ await session.execute(insert(models.DatasetSplitDatasetExample), values)
370
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
371
+ # Roll back the transaction on association failure
372
+ await session.rollback()
373
+ raise Conflict(
374
+ "Failed to associate examples with the new dataset split."
375
+ ) from e
376
+
377
+ return DatasetSplitMutationPayload(
378
+ dataset_split=to_gql_dataset_split(dataset_split_orm),
379
+ query=Query(),
380
+ )
381
+
382
+
383
+ def _validated_name(name: str) -> str:
384
+ validated_name = name.strip()
385
+ if not validated_name:
386
+ raise BadRequest("Name cannot be empty")
387
+ return validated_name