arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/document_annotation.py +1 -1
  9. phoenix/db/insertion/helpers.py +2 -2
  10. phoenix/db/insertion/session_annotation.py +176 -0
  11. phoenix/db/insertion/span_annotation.py +1 -1
  12. phoenix/db/insertion/trace_annotation.py +1 -1
  13. phoenix/db/insertion/types.py +29 -3
  14. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  15. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  16. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  17. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  18. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  19. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  20. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  21. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  22. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  23. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  24. phoenix/db/models.py +306 -46
  25. phoenix/server/api/context.py +15 -2
  26. phoenix/server/api/dataloaders/__init__.py +8 -2
  27. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  28. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  29. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  30. phoenix/server/api/dataloaders/table_fields.py +2 -2
  31. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  32. phoenix/server/api/helpers/playground_clients.py +66 -35
  33. phoenix/server/api/helpers/playground_users.py +26 -0
  34. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  35. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  36. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  37. phoenix/server/api/mutations/__init__.py +8 -0
  38. phoenix/server/api/mutations/chat_mutations.py +8 -3
  39. phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
  40. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  41. phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
  42. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  43. phoenix/server/api/queries.py +53 -0
  44. phoenix/server/api/routers/auth.py +5 -5
  45. phoenix/server/api/routers/oauth2.py +5 -23
  46. phoenix/server/api/routers/v1/__init__.py +2 -0
  47. phoenix/server/api/routers/v1/annotations.py +320 -0
  48. phoenix/server/api/routers/v1/datasets.py +5 -0
  49. phoenix/server/api/routers/v1/experiments.py +10 -3
  50. phoenix/server/api/routers/v1/sessions.py +111 -0
  51. phoenix/server/api/routers/v1/traces.py +1 -2
  52. phoenix/server/api/routers/v1/users.py +7 -0
  53. phoenix/server/api/subscriptions.py +5 -2
  54. phoenix/server/api/types/Dataset.py +8 -0
  55. phoenix/server/api/types/DatasetExample.py +18 -0
  56. phoenix/server/api/types/DatasetLabel.py +23 -0
  57. phoenix/server/api/types/DatasetSplit.py +32 -0
  58. phoenix/server/api/types/Experiment.py +0 -4
  59. phoenix/server/api/types/Project.py +16 -0
  60. phoenix/server/api/types/ProjectSession.py +88 -3
  61. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  62. phoenix/server/api/types/Prompt.py +18 -1
  63. phoenix/server/api/types/Span.py +5 -5
  64. phoenix/server/api/types/Trace.py +61 -0
  65. phoenix/server/app.py +13 -14
  66. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  67. phoenix/server/dml_event.py +13 -0
  68. phoenix/server/static/.vite/manifest.json +39 -39
  69. phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
  70. phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
  71. phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
  72. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
  73. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
  74. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
  75. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
  76. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
  77. phoenix/server/utils.py +74 -0
  78. phoenix/session/session.py +25 -5
  79. phoenix/version.py +1 -1
  80. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  81. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,423 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy import delete, func, insert, select
5
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
6
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
7
+ from strawberry import UNSET
8
+ from strawberry.relay import GlobalID
9
+ from strawberry.scalars import JSON
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
14
+ from phoenix.server.api.context import Context
15
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
+ from phoenix.server.api.helpers.playground_users import get_user
17
+ from phoenix.server.api.queries import Query
18
+ from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
19
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
20
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
21
+
22
+
23
+ @strawberry.input
24
+ class CreateDatasetSplitInput:
25
+ name: str
26
+ description: Optional[str] = UNSET
27
+ color: str
28
+ metadata: Optional[JSON] = UNSET
29
+
30
+
31
+ @strawberry.input
32
+ class PatchDatasetSplitInput:
33
+ dataset_split_id: GlobalID
34
+ name: Optional[str] = UNSET
35
+ description: Optional[str] = UNSET
36
+ color: Optional[str] = UNSET
37
+ metadata: Optional[JSON] = UNSET
38
+
39
+
40
+ @strawberry.input
41
+ class DeleteDatasetSplitInput:
42
+ dataset_split_ids: list[GlobalID]
43
+
44
+
45
+ @strawberry.input
46
+ class AddDatasetExamplesToDatasetSplitsInput:
47
+ dataset_split_ids: list[GlobalID]
48
+ example_ids: list[GlobalID]
49
+
50
+
51
+ @strawberry.input
52
+ class RemoveDatasetExamplesFromDatasetSplitsInput:
53
+ dataset_split_ids: list[GlobalID]
54
+ example_ids: list[GlobalID]
55
+
56
+
57
+ @strawberry.input
58
+ class CreateDatasetSplitWithExamplesInput:
59
+ name: str
60
+ description: Optional[str] = UNSET
61
+ color: str
62
+ metadata: Optional[JSON] = UNSET
63
+ example_ids: list[GlobalID]
64
+
65
+
66
+ @strawberry.type
67
+ class DatasetSplitMutationPayload:
68
+ dataset_split: DatasetSplit
69
+ query: "Query"
70
+
71
+
72
+ @strawberry.type
73
+ class DatasetSplitMutationPayloadWithExamples:
74
+ dataset_split: DatasetSplit
75
+ query: "Query"
76
+ examples: list[DatasetExample]
77
+
78
+
79
+ @strawberry.type
80
+ class DeleteDatasetSplitsMutationPayload:
81
+ dataset_splits: list[DatasetSplit]
82
+ query: "Query"
83
+
84
+
85
+ @strawberry.type
86
+ class AddDatasetExamplesToDatasetSplitsMutationPayload:
87
+ query: "Query"
88
+ examples: list[DatasetExample]
89
+
90
+
91
+ @strawberry.type
92
+ class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
93
+ query: "Query"
94
+ examples: list[DatasetExample]
95
+
96
+
97
+ @strawberry.type
98
+ class DatasetSplitMutationMixin:
99
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
100
+ async def create_dataset_split(
101
+ self, info: Info[Context, None], input: CreateDatasetSplitInput
102
+ ) -> DatasetSplitMutationPayload:
103
+ user_id = get_user(info)
104
+ validated_name = _validated_name(input.name)
105
+ async with info.context.db() as session:
106
+ dataset_split_orm = models.DatasetSplit(
107
+ name=validated_name,
108
+ description=input.description,
109
+ color=input.color,
110
+ metadata_=input.metadata or {},
111
+ user_id=user_id,
112
+ )
113
+ session.add(dataset_split_orm)
114
+ try:
115
+ await session.commit()
116
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
117
+ raise Conflict(f"A dataset split named '{input.name}' already exists.")
118
+ return DatasetSplitMutationPayload(
119
+ dataset_split=to_gql_dataset_split(dataset_split_orm), query=Query()
120
+ )
121
+
122
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
123
+ async def patch_dataset_split(
124
+ self, info: Info[Context, None], input: PatchDatasetSplitInput
125
+ ) -> DatasetSplitMutationPayload:
126
+ validated_name = _validated_name(input.name) if input.name else None
127
+ async with info.context.db() as session:
128
+ dataset_split_id = from_global_id_with_expected_type(
129
+ input.dataset_split_id, DatasetSplit.__name__
130
+ )
131
+ dataset_split_orm = await session.get(models.DatasetSplit, dataset_split_id)
132
+ if not dataset_split_orm:
133
+ raise NotFound(f"Dataset split with ID {input.dataset_split_id} not found")
134
+
135
+ if validated_name:
136
+ dataset_split_orm.name = validated_name
137
+ if input.description:
138
+ dataset_split_orm.description = input.description
139
+ if input.color:
140
+ dataset_split_orm.color = input.color
141
+ if isinstance(input.metadata, dict):
142
+ dataset_split_orm.metadata_ = input.metadata
143
+
144
+ gql_dataset_split = to_gql_dataset_split(dataset_split_orm)
145
+ try:
146
+ await session.commit()
147
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
148
+ raise Conflict("A dataset split with this name already exists")
149
+
150
+ return DatasetSplitMutationPayload(
151
+ dataset_split=gql_dataset_split,
152
+ query=Query(),
153
+ )
154
+
155
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
156
+ async def delete_dataset_splits(
157
+ self, info: Info[Context, None], input: DeleteDatasetSplitInput
158
+ ) -> DeleteDatasetSplitsMutationPayload:
159
+ unique_dataset_split_rowids: dict[int, None] = {} # use a dict to preserve ordering
160
+ for dataset_split_gid in input.dataset_split_ids:
161
+ try:
162
+ dataset_split_rowid = from_global_id_with_expected_type(
163
+ dataset_split_gid, DatasetSplit.__name__
164
+ )
165
+ except ValueError:
166
+ raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
167
+ unique_dataset_split_rowids[dataset_split_rowid] = None
168
+ dataset_split_rowids = list(unique_dataset_split_rowids.keys())
169
+
170
+ async with info.context.db() as session:
171
+ deleted_splits_by_id = {
172
+ split.id: split
173
+ for split in (
174
+ await session.scalars(
175
+ delete(models.DatasetSplit)
176
+ .where(models.DatasetSplit.id.in_(dataset_split_rowids))
177
+ .returning(models.DatasetSplit)
178
+ )
179
+ ).all()
180
+ }
181
+ if len(deleted_splits_by_id) < len(dataset_split_rowids):
182
+ await session.rollback()
183
+ raise NotFound("One or more dataset splits not found")
184
+ await session.commit()
185
+
186
+ return DeleteDatasetSplitsMutationPayload(
187
+ dataset_splits=[
188
+ to_gql_dataset_split(deleted_splits_by_id[dataset_split_rowid])
189
+ for dataset_split_rowid in dataset_split_rowids
190
+ ],
191
+ query=Query(),
192
+ )
193
+
194
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
195
+ async def add_dataset_examples_to_dataset_splits(
196
+ self, info: Info[Context, None], input: AddDatasetExamplesToDatasetSplitsInput
197
+ ) -> AddDatasetExamplesToDatasetSplitsMutationPayload:
198
+ if not input.example_ids:
199
+ raise BadRequest("No examples provided.")
200
+ if not input.dataset_split_ids:
201
+ raise BadRequest("No dataset splits provided.")
202
+
203
+ unique_dataset_split_rowids: set[int] = set()
204
+ for dataset_split_gid in input.dataset_split_ids:
205
+ try:
206
+ dataset_split_rowid = from_global_id_with_expected_type(
207
+ dataset_split_gid, DatasetSplit.__name__
208
+ )
209
+ except ValueError:
210
+ raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
211
+ unique_dataset_split_rowids.add(dataset_split_rowid)
212
+ dataset_split_rowids = list(unique_dataset_split_rowids)
213
+
214
+ unique_example_rowids: set[int] = set()
215
+ for example_gid in input.example_ids:
216
+ try:
217
+ example_rowid = from_global_id_with_expected_type(
218
+ example_gid, models.DatasetExample.__name__
219
+ )
220
+ except ValueError:
221
+ raise BadRequest(f"Invalid example ID: {example_gid}")
222
+ unique_example_rowids.add(example_rowid)
223
+ example_rowids = list(unique_example_rowids)
224
+
225
+ async with info.context.db() as session:
226
+ existing_dataset_split_ids = (
227
+ await session.scalars(
228
+ select(models.DatasetSplit.id).where(
229
+ models.DatasetSplit.id.in_(dataset_split_rowids)
230
+ )
231
+ )
232
+ ).all()
233
+ if len(existing_dataset_split_ids) != len(dataset_split_rowids):
234
+ raise NotFound("One or more dataset splits not found")
235
+
236
+ # Find existing (dataset_split_id, dataset_example_id) keys to avoid duplicates
237
+ # Users can submit multiple examples at once which can have
238
+ # indeterminate participation in multiple splits
239
+ existing_dataset_example_split_keys = await session.execute(
240
+ select(
241
+ models.DatasetSplitDatasetExample.dataset_split_id,
242
+ models.DatasetSplitDatasetExample.dataset_example_id,
243
+ ).where(
244
+ models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
245
+ & models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
246
+ )
247
+ )
248
+ unique_dataset_example_split_keys = set(existing_dataset_example_split_keys.all())
249
+
250
+ # Compute all desired pairs and insert only missing
251
+ values = []
252
+ for dataset_split_rowid in dataset_split_rowids:
253
+ for example_rowid in example_rowids:
254
+ # if the keys already exists, skip
255
+ if (dataset_split_rowid, example_rowid) in unique_dataset_example_split_keys:
256
+ continue
257
+ dataset_split_id_key = models.DatasetSplitDatasetExample.dataset_split_id.key
258
+ dataset_example_id_key = (
259
+ models.DatasetSplitDatasetExample.dataset_example_id.key
260
+ )
261
+ values.append(
262
+ {
263
+ dataset_split_id_key: dataset_split_rowid,
264
+ dataset_example_id_key: example_rowid,
265
+ }
266
+ )
267
+
268
+ if values:
269
+ try:
270
+ await session.execute(insert(models.DatasetSplitDatasetExample), values)
271
+ await session.flush()
272
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
273
+ raise Conflict("Failed to add examples to dataset splits.") from e
274
+
275
+ examples = (
276
+ await session.scalars(
277
+ select(models.DatasetExample).where(
278
+ models.DatasetExample.id.in_(example_rowids)
279
+ )
280
+ )
281
+ ).all()
282
+ return AddDatasetExamplesToDatasetSplitsMutationPayload(
283
+ query=Query(),
284
+ examples=[to_gql_dataset_example(example) for example in examples],
285
+ )
286
+
287
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
288
+ async def remove_dataset_examples_from_dataset_splits(
289
+ self, info: Info[Context, None], input: RemoveDatasetExamplesFromDatasetSplitsInput
290
+ ) -> RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
291
+ if not input.dataset_split_ids:
292
+ raise BadRequest("No dataset splits provided.")
293
+ if not input.example_ids:
294
+ raise BadRequest("No examples provided.")
295
+
296
+ unique_dataset_split_rowids: set[int] = set()
297
+ for dataset_split_gid in input.dataset_split_ids:
298
+ try:
299
+ dataset_split_rowid = from_global_id_with_expected_type(
300
+ dataset_split_gid, DatasetSplit.__name__
301
+ )
302
+ except ValueError:
303
+ raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
304
+ unique_dataset_split_rowids.add(dataset_split_rowid)
305
+ dataset_split_rowids = list(unique_dataset_split_rowids)
306
+
307
+ unique_example_rowids: set[int] = set()
308
+ for example_gid in input.example_ids:
309
+ try:
310
+ example_rowid = from_global_id_with_expected_type(
311
+ example_gid, models.DatasetExample.__name__
312
+ )
313
+ except ValueError:
314
+ raise BadRequest(f"Invalid example ID: {example_gid}")
315
+ unique_example_rowids.add(example_rowid)
316
+ example_rowids = list(unique_example_rowids)
317
+
318
+ stmt = delete(models.DatasetSplitDatasetExample).where(
319
+ models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
320
+ & models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
321
+ )
322
+ async with info.context.db() as session:
323
+ existing_dataset_split_ids = (
324
+ await session.scalars(
325
+ select(models.DatasetSplit.id).where(
326
+ models.DatasetSplit.id.in_(dataset_split_rowids)
327
+ )
328
+ )
329
+ ).all()
330
+ if len(existing_dataset_split_ids) != len(dataset_split_rowids):
331
+ raise NotFound("One or more dataset splits not found")
332
+
333
+ await session.execute(stmt)
334
+
335
+ examples = (
336
+ await session.scalars(
337
+ select(models.DatasetExample).where(
338
+ models.DatasetExample.id.in_(example_rowids)
339
+ )
340
+ )
341
+ ).all()
342
+
343
+ return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
344
+ query=Query(),
345
+ examples=[to_gql_dataset_example(example) for example in examples],
346
+ )
347
+
348
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
349
+ async def create_dataset_split_with_examples(
350
+ self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
351
+ ) -> DatasetSplitMutationPayloadWithExamples:
352
+ user_id = get_user(info)
353
+ validated_name = _validated_name(input.name)
354
+ unique_example_rowids: set[int] = set()
355
+ for example_gid in input.example_ids:
356
+ try:
357
+ example_rowid = from_global_id_with_expected_type(
358
+ example_gid, models.DatasetExample.__name__
359
+ )
360
+ unique_example_rowids.add(example_rowid)
361
+ except ValueError:
362
+ raise BadRequest(f"Invalid example ID: {example_gid}")
363
+ example_rowids = list(unique_example_rowids)
364
+ async with info.context.db() as session:
365
+ if example_rowids:
366
+ found_count = await session.scalar(
367
+ select(func.count(models.DatasetExample.id)).where(
368
+ models.DatasetExample.id.in_(example_rowids)
369
+ )
370
+ )
371
+ if found_count is None or found_count < len(example_rowids):
372
+ raise NotFound("One or more dataset examples were not found.")
373
+
374
+ dataset_split_orm = models.DatasetSplit(
375
+ name=validated_name,
376
+ description=input.description or None,
377
+ color=input.color,
378
+ metadata_=input.metadata or {},
379
+ user_id=user_id,
380
+ )
381
+ session.add(dataset_split_orm)
382
+ try:
383
+ await session.flush()
384
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
385
+ raise Conflict(f"A dataset split named '{validated_name}' already exists.")
386
+
387
+ if example_rowids:
388
+ values = [
389
+ {
390
+ models.DatasetSplitDatasetExample.dataset_split_id.key: dataset_split_orm.id, # noqa: E501
391
+ models.DatasetSplitDatasetExample.dataset_example_id.key: example_id,
392
+ }
393
+ for example_id in example_rowids
394
+ ]
395
+ try:
396
+ await session.execute(insert(models.DatasetSplitDatasetExample), values)
397
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
398
+ # Roll back the transaction on association failure
399
+ await session.rollback()
400
+ raise Conflict(
401
+ "Failed to associate examples with the new dataset split."
402
+ ) from e
403
+
404
+ examples = (
405
+ await session.scalars(
406
+ select(models.DatasetExample).where(
407
+ models.DatasetExample.id.in_(example_rowids)
408
+ )
409
+ )
410
+ ).all()
411
+
412
+ return DatasetSplitMutationPayloadWithExamples(
413
+ dataset_split=to_gql_dataset_split(dataset_split_orm),
414
+ query=Query(),
415
+ examples=[to_gql_dataset_example(example) for example in examples],
416
+ )
417
+
418
+
419
+ def _validated_name(name: str) -> str:
420
+ validated_name = name.strip()
421
+ if not validated_name:
422
+ raise BadRequest("Name cannot be empty")
423
+ return validated_name
@@ -0,0 +1,161 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
5
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
6
+ from starlette.requests import Request
7
+ from strawberry import Info
8
+ from strawberry.relay import GlobalID
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized
14
+ from phoenix.server.api.helpers.annotations import get_user_identifier
15
+ from phoenix.server.api.input_types.CreateProjectSessionAnnotationInput import (
16
+ CreateProjectSessionAnnotationInput,
17
+ )
18
+ from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotationInput
19
+ from phoenix.server.api.queries import Query
20
+ from phoenix.server.api.types.AnnotationSource import AnnotationSource
21
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
22
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
23
+ ProjectSessionAnnotation,
24
+ to_gql_project_session_annotation,
25
+ )
26
+ from phoenix.server.bearer_auth import PhoenixUser
27
+ from phoenix.server.dml_event import (
28
+ ProjectSessionAnnotationDeleteEvent,
29
+ ProjectSessionAnnotationInsertEvent,
30
+ )
31
+
32
+
33
+ @strawberry.type
34
+ class ProjectSessionAnnotationMutationPayload:
35
+ project_session_annotation: ProjectSessionAnnotation
36
+ query: Query
37
+
38
+
39
+ @strawberry.type
40
+ class ProjectSessionAnnotationMutationMixin:
41
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
42
+ async def create_project_session_annotations(
43
+ self, info: Info[Context, None], input: CreateProjectSessionAnnotationInput
44
+ ) -> ProjectSessionAnnotationMutationPayload:
45
+ assert isinstance(request := info.context.request, Request)
46
+ user_id: Optional[int] = None
47
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
48
+ user_id = int(user.identity)
49
+
50
+ try:
51
+ project_session_id = from_global_id_with_expected_type(
52
+ input.project_session_id, "ProjectSession"
53
+ )
54
+ except ValueError:
55
+ raise BadRequest(f"Invalid session ID: {input.project_session_id}")
56
+
57
+ identifier = ""
58
+ if isinstance(input.identifier, str):
59
+ identifier = input.identifier # Already trimmed in __post_init__
60
+ elif input.source == AnnotationSource.APP and user_id is not None:
61
+ identifier = get_user_identifier(user_id)
62
+
63
+ try:
64
+ async with info.context.db() as session:
65
+ anno = models.ProjectSessionAnnotation(
66
+ project_session_id=project_session_id,
67
+ name=input.name,
68
+ label=input.label,
69
+ score=input.score,
70
+ explanation=input.explanation,
71
+ annotator_kind=input.annotator_kind.value,
72
+ metadata_=input.metadata,
73
+ identifier=identifier,
74
+ source=input.source.value,
75
+ user_id=user_id,
76
+ )
77
+ session.add(anno)
78
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
79
+ raise Conflict(f"Error creating annotation: {e}")
80
+
81
+ info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
82
+
83
+ return ProjectSessionAnnotationMutationPayload(
84
+ project_session_annotation=to_gql_project_session_annotation(anno),
85
+ query=Query(),
86
+ )
87
+
88
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
89
+ async def update_project_session_annotations(
90
+ self, info: Info[Context, None], input: UpdateAnnotationInput
91
+ ) -> ProjectSessionAnnotationMutationPayload:
92
+ assert isinstance(request := info.context.request, Request)
93
+ user_id: Optional[int] = None
94
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
95
+ user_id = int(user.identity)
96
+
97
+ try:
98
+ id_ = from_global_id_with_expected_type(input.id, "ProjectSessionAnnotation")
99
+ except ValueError:
100
+ raise BadRequest(f"Invalid session annotation ID: {input.id}")
101
+
102
+ async with info.context.db() as session:
103
+ if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
104
+ raise NotFound(f"Could not find session annotation with ID: {input.id}")
105
+ if anno.user_id != user_id:
106
+ raise Unauthorized("Session annotation is not associated with the current user.")
107
+
108
+ # Update the annotation fields
109
+ anno.name = input.name
110
+ anno.label = input.label
111
+ anno.score = input.score
112
+ anno.explanation = input.explanation
113
+ anno.annotator_kind = input.annotator_kind.value
114
+ anno.metadata_ = input.metadata
115
+ anno.source = input.source.value
116
+
117
+ session.add(anno)
118
+ try:
119
+ await session.flush()
120
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
121
+ raise Conflict(f"Error updating annotation: {e}")
122
+
123
+ info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
124
+ return ProjectSessionAnnotationMutationPayload(
125
+ project_session_annotation=to_gql_project_session_annotation(anno),
126
+ query=Query(),
127
+ )
128
+
129
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
130
+ async def delete_project_session_annotation(
131
+ self, info: Info[Context, None], id: GlobalID
132
+ ) -> ProjectSessionAnnotationMutationPayload:
133
+ try:
134
+ id_ = from_global_id_with_expected_type(id, "ProjectSessionAnnotation")
135
+ except ValueError:
136
+ raise BadRequest(f"Invalid session annotation ID: {id}")
137
+
138
+ assert isinstance(request := info.context.request, Request)
139
+ user_id: Optional[int] = None
140
+ user_is_admin = False
141
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
142
+ user_id = int(user.identity)
143
+ user_is_admin = user.is_admin
144
+
145
+ async with info.context.db() as session:
146
+ if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
147
+ raise NotFound(f"Could not find session annotation with ID: {id}")
148
+
149
+ if not user_is_admin and anno.user_id != user_id:
150
+ raise Unauthorized(
151
+ "Session annotation is not associated with the current user and "
152
+ "the current user is not an admin."
153
+ )
154
+
155
+ await session.delete(anno)
156
+
157
+ deleted_gql_annotation = to_gql_project_session_annotation(anno)
158
+ info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
159
+ return ProjectSessionAnnotationMutationPayload(
160
+ project_session_annotation=deleted_gql_annotation, query=Query()
161
+ )
@@ -48,6 +48,8 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
48
48
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
49
49
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
50
50
  from phoenix.server.api.types.DatasetExample import DatasetExample
51
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
52
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
51
53
  from phoenix.server.api.types.Dimension import to_gql_dimension
52
54
  from phoenix.server.api.types.EmbeddingDimension import (
53
55
  DEFAULT_CLUSTER_SELECTION_EPSILON,
@@ -959,6 +961,14 @@ class Query:
959
961
  id_attr=example.id,
960
962
  created_at=example.created_at,
961
963
  )
964
+ elif type_name == DatasetSplit.__name__:
965
+ async with info.context.db() as session:
966
+ dataset_split = await session.scalar(
967
+ select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
968
+ )
969
+ if not dataset_split:
970
+ raise NotFound(f"Unknown dataset split: {id}")
971
+ return to_gql_dataset_split(dataset_split)
962
972
  elif type_name == Experiment.__name__:
963
973
  async with info.context.db() as session:
964
974
  experiment = await session.scalar(
@@ -1140,6 +1150,49 @@ class Query:
1140
1150
  args=args,
1141
1151
  )
1142
1152
 
1153
+ @strawberry.field
1154
+ async def dataset_labels(
1155
+ self,
1156
+ info: Info[Context, None],
1157
+ first: Optional[int] = 50,
1158
+ last: Optional[int] = UNSET,
1159
+ after: Optional[CursorString] = UNSET,
1160
+ before: Optional[CursorString] = UNSET,
1161
+ ) -> Connection[DatasetLabel]:
1162
+ args = ConnectionArgs(
1163
+ first=first,
1164
+ after=after if isinstance(after, CursorString) else None,
1165
+ last=last,
1166
+ before=before if isinstance(before, CursorString) else None,
1167
+ )
1168
+ async with info.context.db() as session:
1169
+ dataset_labels = await session.scalars(select(models.DatasetLabel))
1170
+ data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
1171
+ return connection_from_list(data=data, args=args)
1172
+
1173
+ @strawberry.field
1174
+ async def dataset_splits(
1175
+ self,
1176
+ info: Info[Context, None],
1177
+ first: Optional[int] = 50,
1178
+ last: Optional[int] = UNSET,
1179
+ after: Optional[CursorString] = UNSET,
1180
+ before: Optional[CursorString] = UNSET,
1181
+ ) -> Connection[DatasetSplit]:
1182
+ args = ConnectionArgs(
1183
+ first=first,
1184
+ after=after if isinstance(after, CursorString) else None,
1185
+ last=last,
1186
+ before=before if isinstance(before, CursorString) else None,
1187
+ )
1188
+ async with info.context.db() as session:
1189
+ splits = await session.stream_scalars(select(models.DatasetSplit))
1190
+ data = [to_gql_dataset_split(split) async for split in splits]
1191
+ return connection_from_list(
1192
+ data=data,
1193
+ args=args,
1194
+ )
1195
+
1143
1196
  @strawberry.field
1144
1197
  async def annotation_configs(
1145
1198
  self,