arize-phoenix 12.7.1__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (76) hide show
  1. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +76 -73
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +109 -157
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +80 -213
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/cost_tracking/model_cost_manifest.json +2 -2
  63. phoenix/server/static/.vite/manifest.json +43 -43
  64. phoenix/server/static/assets/{components-BLK5vehh.js → components-v927s3NF.js} +471 -484
  65. phoenix/server/static/assets/{index-BP0Shd90.js → index-DrD9eSrN.js} +20 -16
  66. phoenix/server/static/assets/{pages-DIVgyYyy.js → pages-GVybXa_W.js} +754 -753
  67. phoenix/server/static/assets/{vendor-3BvTzoBp.js → vendor-D-csRHGZ.js} +1 -1
  68. phoenix/server/static/assets/{vendor-arizeai-C6_oC0y8.js → vendor-arizeai-BJLCG_Gc.js} +1 -1
  69. phoenix/server/static/assets/{vendor-codemirror-DPnZGAZA.js → vendor-codemirror-Cr963DyP.js} +3 -3
  70. phoenix/server/static/assets/{vendor-recharts-CjgSbsB0.js → vendor-recharts-DgmPLgIp.js} +1 -1
  71. phoenix/server/static/assets/{vendor-shiki-CJyhDG0E.js → vendor-shiki-wYOt1s7u.js} +1 -1
  72. phoenix/version.py +1 -1
  73. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  74. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  75. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  76. {arize_phoenix-12.7.1.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,6 +4,8 @@ import sqlalchemy
4
4
  import strawberry
5
5
  from sqlalchemy import delete, select
6
6
  from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
7
+ from sqlalchemy.orm import joinedload
8
+ from sqlalchemy.sql import tuple_
7
9
  from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
8
10
  from strawberry import UNSET
9
11
  from strawberry.relay.types import GlobalID
@@ -15,7 +17,7 @@ from phoenix.server.api.context import Context
15
17
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
18
  from phoenix.server.api.queries import Query
17
19
  from phoenix.server.api.types.Dataset import Dataset
18
- from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
20
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
19
21
  from phoenix.server.api.types.node import from_global_id_with_expected_type
20
22
 
21
23
 
@@ -24,11 +26,13 @@ class CreateDatasetLabelInput:
24
26
  name: str
25
27
  description: Optional[str] = UNSET
26
28
  color: str
29
+ dataset_ids: Optional[list[GlobalID]] = UNSET
27
30
 
28
31
 
29
32
  @strawberry.type
30
33
  class CreateDatasetLabelMutationPayload:
31
34
  dataset_label: DatasetLabel
35
+ datasets: list[Dataset]
32
36
 
33
37
 
34
38
  @strawberry.input
@@ -41,39 +45,16 @@ class DeleteDatasetLabelsMutationPayload:
41
45
  dataset_labels: list[DatasetLabel]
42
46
 
43
47
 
44
- @strawberry.input
45
- class UpdateDatasetLabelInput:
46
- dataset_label_id: GlobalID
47
- name: str
48
- description: Optional[str] = None
49
- color: str
50
-
51
-
52
- @strawberry.type
53
- class UpdateDatasetLabelMutationPayload:
54
- dataset_label: DatasetLabel
55
-
56
-
57
48
  @strawberry.input
58
49
  class SetDatasetLabelsInput:
50
+ dataset_id: GlobalID
59
51
  dataset_label_ids: list[GlobalID]
60
- dataset_ids: list[GlobalID]
61
52
 
62
53
 
63
54
  @strawberry.type
64
55
  class SetDatasetLabelsMutationPayload:
65
- query: "Query"
66
-
67
-
68
- @strawberry.input
69
- class UnsetDatasetLabelsInput:
70
- dataset_label_ids: list[GlobalID]
71
- dataset_ids: list[GlobalID]
72
-
73
-
74
- @strawberry.type
75
- class UnsetDatasetLabelsMutationPayload:
76
- query: "Query"
56
+ query: Query
57
+ dataset: Dataset
77
58
 
78
59
 
79
60
  @strawberry.type
@@ -87,50 +68,56 @@ class DatasetLabelMutationMixin:
87
68
  name = input.name
88
69
  description = input.description
89
70
  color = input.color
71
+ dataset_rowids: dict[
72
+ int, None
73
+ ] = {} # use dictionary to de-duplicate while preserving order
74
+ if input.dataset_ids:
75
+ for dataset_id in input.dataset_ids:
76
+ try:
77
+ dataset_rowid = from_global_id_with_expected_type(dataset_id, Dataset.__name__)
78
+ except ValueError:
79
+ raise BadRequest(f"Invalid dataset ID: {dataset_id}")
80
+ dataset_rowids[dataset_rowid] = None
81
+
90
82
  async with info.context.db() as session:
91
83
  dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
92
84
  session.add(dataset_label_orm)
93
85
  try:
94
- await session.commit()
86
+ await session.flush()
95
87
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
96
88
  raise Conflict(f"A dataset label named '{name}' already exists")
97
89
  except sqlalchemy.exc.StatementError as error:
98
90
  raise BadRequest(str(error.orig))
99
- return CreateDatasetLabelMutationPayload(
100
- dataset_label=to_gql_dataset_label(dataset_label_orm)
101
- )
102
91
 
103
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
104
- async def update_dataset_label(
105
- self, info: Info[Context, None], input: UpdateDatasetLabelInput
106
- ) -> UpdateDatasetLabelMutationPayload:
107
- if not input.name or not input.name.strip():
108
- raise BadRequest("Dataset label name cannot be empty")
109
-
110
- try:
111
- dataset_label_id = from_global_id_with_expected_type(
112
- input.dataset_label_id, DatasetLabel.__name__
113
- )
114
- except ValueError:
115
- raise BadRequest(f"Invalid dataset label ID: {input.dataset_label_id}")
116
-
117
- async with info.context.db() as session:
118
- dataset_label_orm = await session.get(models.DatasetLabel, dataset_label_id)
119
- if not dataset_label_orm:
120
- raise NotFound(f"DatasetLabel with ID {input.dataset_label_id} not found")
121
-
122
- dataset_label_orm.name = input.name.strip()
123
- dataset_label_orm.description = input.description
124
- dataset_label_orm.color = input.color.strip()
125
-
126
- try:
92
+ datasets_by_id: dict[int, models.Dataset] = {}
93
+ if dataset_rowids:
94
+ datasets_by_id = {
95
+ dataset.id: dataset
96
+ for dataset in await session.scalars(
97
+ select(models.Dataset).where(models.Dataset.id.in_(dataset_rowids.keys()))
98
+ )
99
+ }
100
+ if len(datasets_by_id) < len(dataset_rowids):
101
+ raise NotFound("One or more datasets not found")
102
+ session.add_all(
103
+ [
104
+ models.DatasetsDatasetLabel(
105
+ dataset_id=dataset_rowid,
106
+ dataset_label_id=dataset_label_orm.id,
107
+ )
108
+ for dataset_rowid in dataset_rowids
109
+ ]
110
+ )
127
111
  await session.commit()
128
- except (PostgreSQLIntegrityError, SQLiteIntegrityError):
129
- raise Conflict(f"A dataset label named '{input.name}' already exists")
130
- except sqlalchemy.exc.StatementError as error:
131
- raise BadRequest(str(error.orig))
132
- return UpdateDatasetLabelMutationPayload(
133
- dataset_label=to_gql_dataset_label(dataset_label_orm)
112
+
113
+ return CreateDatasetLabelMutationPayload(
114
+ dataset_label=DatasetLabel(id=dataset_label_orm.id, db_record=dataset_label_orm),
115
+ datasets=[
116
+ Dataset(
117
+ id=datasets_by_id[dataset_rowid].id, db_record=datasets_by_id[dataset_rowid]
118
+ )
119
+ for dataset_rowid in dataset_rowids
120
+ ],
134
121
  )
135
122
 
136
123
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
@@ -161,7 +148,10 @@ class DatasetLabelMutationMixin:
161
148
  }
162
149
  return DeleteDatasetLabelsMutationPayload(
163
150
  dataset_labels=[
164
- to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
151
+ DatasetLabel(
152
+ id=deleted_dataset_labels_by_id[dataset_label_row_id].id,
153
+ db_record=deleted_dataset_labels_by_id[dataset_label_row_id],
154
+ )
165
155
  for dataset_label_row_id in dataset_label_row_ids
166
156
  ]
167
157
  )
@@ -170,122 +160,84 @@ class DatasetLabelMutationMixin:
170
160
  async def set_dataset_labels(
171
161
  self, info: Info[Context, None], input: SetDatasetLabelsInput
172
162
  ) -> SetDatasetLabelsMutationPayload:
173
- if not input.dataset_ids:
174
- raise BadRequest("No datasets provided.")
175
- if not input.dataset_label_ids:
176
- raise BadRequest("No dataset labels provided.")
177
-
178
- unique_dataset_rowids: set[int] = set()
179
- for dataset_gid in input.dataset_ids:
180
- try:
181
- dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
182
- except ValueError:
183
- raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
184
- unique_dataset_rowids.add(dataset_rowid)
185
- dataset_rowids = list(unique_dataset_rowids)
163
+ try:
164
+ dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
165
+ except ValueError:
166
+ raise BadRequest(f"Invalid dataset ID: {input.dataset_id}")
186
167
 
187
- unique_dataset_label_rowids: set[int] = set()
168
+ dataset_label_ids: dict[
169
+ int, None
170
+ ] = {} # use dictionary to de-duplicate while preserving order
188
171
  for dataset_label_gid in input.dataset_label_ids:
189
172
  try:
190
- dataset_label_rowid = from_global_id_with_expected_type(
173
+ dataset_label_id = from_global_id_with_expected_type(
191
174
  dataset_label_gid, DatasetLabel.__name__
192
175
  )
193
176
  except ValueError:
194
177
  raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
195
- unique_dataset_label_rowids.add(dataset_label_rowid)
196
- dataset_label_rowids = list(unique_dataset_label_rowids)
178
+ dataset_label_ids[dataset_label_id] = None
197
179
 
198
180
  async with info.context.db() as session:
199
- existing_dataset_ids = (
200
- await session.scalars(
201
- select(models.Dataset.id).where(models.Dataset.id.in_(dataset_rowids))
202
- )
203
- ).all()
204
- if len(existing_dataset_ids) != len(dataset_rowids):
205
- raise NotFound("One or more datasets not found")
181
+ dataset = await session.scalar(
182
+ select(models.Dataset)
183
+ .where(models.Dataset.id == dataset_id)
184
+ .options(joinedload(models.Dataset.datasets_dataset_labels))
185
+ )
206
186
 
207
- existing_dataset_label_ids = (
187
+ if not dataset:
188
+ raise NotFound(f"Dataset with ID {input.dataset_id} not found")
189
+
190
+ existing_label_ids = (
208
191
  await session.scalars(
209
192
  select(models.DatasetLabel.id).where(
210
- models.DatasetLabel.id.in_(dataset_label_rowids)
193
+ models.DatasetLabel.id.in_(dataset_label_ids.keys())
211
194
  )
212
195
  )
213
196
  ).all()
214
- if len(existing_dataset_label_ids) != len(dataset_label_rowids):
197
+ if len(existing_label_ids) != len(dataset_label_ids):
215
198
  raise NotFound("One or more dataset labels not found")
216
199
 
217
- existing_dataset_label_keys = await session.execute(
218
- select(
219
- models.DatasetsDatasetLabel.dataset_id,
220
- models.DatasetsDatasetLabel.dataset_label_id,
221
- ).where(
222
- models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
223
- & models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
224
- )
225
- )
226
- unique_dataset_label_keys = set(existing_dataset_label_keys.all())
200
+ previously_applied_dataset_label_ids = {
201
+ dataset_dataset_label.dataset_label_id
202
+ for dataset_dataset_label in dataset.datasets_dataset_labels
203
+ }
227
204
 
228
- datasets_dataset_labels = []
229
- for dataset_rowid in dataset_rowids:
230
- for dataset_label_rowid in dataset_label_rowids:
231
- if (dataset_rowid, dataset_label_rowid) in unique_dataset_label_keys:
232
- continue
233
- datasets_dataset_labels.append(
234
- models.DatasetsDatasetLabel(
235
- dataset_id=dataset_rowid,
236
- dataset_label_id=dataset_label_rowid,
205
+ datasets_dataset_labels_to_add = [
206
+ models.DatasetsDatasetLabel(
207
+ dataset_id=dataset_id,
208
+ dataset_label_id=dataset_label_id,
209
+ )
210
+ for dataset_label_id in dataset_label_ids
211
+ if dataset_label_id not in previously_applied_dataset_label_ids
212
+ ]
213
+ if datasets_dataset_labels_to_add:
214
+ session.add_all(datasets_dataset_labels_to_add)
215
+ await session.flush()
216
+
217
+ datasets_dataset_labels_to_delete = [
218
+ dataset_dataset_label
219
+ for dataset_dataset_label in dataset.datasets_dataset_labels
220
+ if dataset_dataset_label.dataset_label_id not in dataset_label_ids
221
+ ]
222
+ if datasets_dataset_labels_to_delete:
223
+ await session.execute(
224
+ delete(models.DatasetsDatasetLabel).where(
225
+ tuple_(
226
+ models.DatasetsDatasetLabel.dataset_id,
227
+ models.DatasetsDatasetLabel.dataset_label_id,
228
+ ).in_(
229
+ [
230
+ (
231
+ datasets_dataset_labels.dataset_id,
232
+ datasets_dataset_labels.dataset_label_id,
233
+ )
234
+ for datasets_dataset_labels in datasets_dataset_labels_to_delete
235
+ ]
237
236
  )
238
237
  )
239
- session.add_all(datasets_dataset_labels)
240
-
241
- if datasets_dataset_labels:
242
- try:
243
- await session.commit()
244
- except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
245
- raise Conflict("Failed to add dataset labels to datasets.") from e
246
-
247
- return SetDatasetLabelsMutationPayload(
248
- query=Query(),
249
- )
250
-
251
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
252
- async def unset_dataset_labels(
253
- self, info: Info[Context, None], input: UnsetDatasetLabelsInput
254
- ) -> UnsetDatasetLabelsMutationPayload:
255
- if not input.dataset_ids:
256
- raise BadRequest("No datasets provided.")
257
- if not input.dataset_label_ids:
258
- raise BadRequest("No dataset labels provided.")
259
-
260
- unique_dataset_rowids: set[int] = set()
261
- for dataset_gid in input.dataset_ids:
262
- try:
263
- dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
264
- except ValueError:
265
- raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
266
- unique_dataset_rowids.add(dataset_rowid)
267
- dataset_rowids = list(unique_dataset_rowids)
268
-
269
- unique_dataset_label_rowids: set[int] = set()
270
- for dataset_label_gid in input.dataset_label_ids:
271
- try:
272
- dataset_label_rowid = from_global_id_with_expected_type(
273
- dataset_label_gid, DatasetLabel.__name__
274
- )
275
- except ValueError:
276
- raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
277
- unique_dataset_label_rowids.add(dataset_label_rowid)
278
- dataset_label_rowids = list(unique_dataset_label_rowids)
279
-
280
- async with info.context.db() as session:
281
- await session.execute(
282
- delete(models.DatasetsDatasetLabel).where(
283
- models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
284
- & models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
285
238
  )
286
- )
287
- await session.commit()
288
239
 
289
- return UnsetDatasetLabelsMutationPayload(
240
+ return SetDatasetLabelsMutationPayload(
241
+ dataset=Dataset(id=dataset.id, db_record=dataset),
290
242
  query=Query(),
291
243
  )
@@ -35,7 +35,7 @@ from phoenix.server.api.input_types.PatchDatasetExamplesInput import (
35
35
  PatchDatasetExamplesInput,
36
36
  )
37
37
  from phoenix.server.api.input_types.PatchDatasetInput import PatchDatasetInput
38
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
38
+ from phoenix.server.api.types.Dataset import Dataset
39
39
  from phoenix.server.api.types.DatasetExample import DatasetExample
40
40
  from phoenix.server.api.types.node import from_global_id_with_expected_type
41
41
  from phoenix.server.api.types.Span import Span
@@ -72,7 +72,7 @@ class DatasetMutationMixin:
72
72
  )
73
73
  assert dataset is not None
74
74
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
75
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
75
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
76
76
 
77
77
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
78
78
  async def patch_dataset(
@@ -101,7 +101,7 @@ class DatasetMutationMixin:
101
101
  )
102
102
  assert dataset is not None
103
103
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
104
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
104
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
105
105
 
106
106
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
107
107
  async def add_spans_to_dataset(
@@ -221,7 +221,7 @@ class DatasetMutationMixin:
221
221
  ],
222
222
  )
223
223
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
224
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
224
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
225
225
 
226
226
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
227
227
  async def add_examples_to_dataset(
@@ -348,7 +348,7 @@ class DatasetMutationMixin:
348
348
  dataset_example_revisions,
349
349
  )
350
350
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
351
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
351
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
352
352
 
353
353
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
354
354
  async def delete_dataset(
@@ -379,7 +379,7 @@ class DatasetMutationMixin:
379
379
  return_exceptions=True,
380
380
  )
381
381
  info.context.event_queue.put(DatasetDeleteEvent((dataset.id,)))
382
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
382
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
383
383
 
384
384
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
385
385
  async def patch_dataset_examples(
@@ -472,7 +472,7 @@ class DatasetMutationMixin:
472
472
  ],
473
473
  )
474
474
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
475
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
475
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
476
476
 
477
477
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
478
478
  async def delete_dataset_examples(
@@ -556,7 +556,7 @@ class DatasetMutationMixin:
556
556
  ],
557
557
  )
558
558
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
559
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
559
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
560
560
 
561
561
 
562
562
  def _span_attribute(semconv: str) -> Any:
@@ -15,8 +15,8 @@ from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
16
  from phoenix.server.api.helpers.playground_users import get_user
17
17
  from phoenix.server.api.queries import Query
18
- from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
19
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
18
+ from phoenix.server.api.types.DatasetExample import DatasetExample
19
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
20
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
21
 
22
22
 
@@ -116,7 +116,8 @@ class DatasetSplitMutationMixin:
116
116
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
117
117
  raise Conflict(f"A dataset split named '{input.name}' already exists.")
118
118
  return DatasetSplitMutationPayload(
119
- dataset_split=to_gql_dataset_split(dataset_split_orm), query=Query()
119
+ dataset_split=DatasetSplit(id=dataset_split_orm.id, db_record=dataset_split_orm),
120
+ query=Query(),
120
121
  )
121
122
 
122
123
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
@@ -141,7 +142,7 @@ class DatasetSplitMutationMixin:
141
142
  if isinstance(input.metadata, dict):
142
143
  dataset_split_orm.metadata_ = input.metadata
143
144
 
144
- gql_dataset_split = to_gql_dataset_split(dataset_split_orm)
145
+ gql_dataset_split = DatasetSplit(id=dataset_split_orm.id, db_record=dataset_split_orm)
145
146
  try:
146
147
  await session.commit()
147
148
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
@@ -185,7 +186,10 @@ class DatasetSplitMutationMixin:
185
186
 
186
187
  return DeleteDatasetSplitsMutationPayload(
187
188
  dataset_splits=[
188
- to_gql_dataset_split(deleted_splits_by_id[dataset_split_rowid])
189
+ DatasetSplit(
190
+ id=deleted_splits_by_id[dataset_split_rowid].id,
191
+ db_record=deleted_splits_by_id[dataset_split_rowid],
192
+ )
189
193
  for dataset_split_rowid in dataset_split_rowids
190
194
  ],
191
195
  query=Query(),
@@ -281,7 +285,7 @@ class DatasetSplitMutationMixin:
281
285
  ).all()
282
286
  return AddDatasetExamplesToDatasetSplitsMutationPayload(
283
287
  query=Query(),
284
- examples=[to_gql_dataset_example(example) for example in examples],
288
+ examples=[DatasetExample(id=example.id, db_record=example) for example in examples],
285
289
  )
286
290
 
287
291
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
@@ -342,7 +346,7 @@ class DatasetSplitMutationMixin:
342
346
 
343
347
  return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
344
348
  query=Query(),
345
- examples=[to_gql_dataset_example(example) for example in examples],
349
+ examples=[DatasetExample(id=example.id, db_record=example) for example in examples],
346
350
  )
347
351
 
348
352
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
@@ -410,9 +414,9 @@ class DatasetSplitMutationMixin:
410
414
  ).all()
411
415
 
412
416
  return DatasetSplitMutationPayloadWithExamples(
413
- dataset_split=to_gql_dataset_split(dataset_split_orm),
417
+ dataset_split=DatasetSplit(id=dataset_split_orm.id, db_record=dataset_split_orm),
414
418
  query=Query(),
415
- examples=[to_gql_dataset_example(example) for example in examples],
419
+ examples=[DatasetExample(id=example.id, db_record=example) for example in examples],
416
420
  )
417
421
 
418
422
 
@@ -16,7 +16,7 @@ from phoenix.server.api.auth import IsNotReadOnly, IsNotViewer
16
16
  from phoenix.server.api.context import Context
17
17
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
18
  from phoenix.server.api.queries import Query
19
- from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
19
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
20
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
21
  from phoenix.server.api.types.TokenPrice import TokenKind
22
22
 
@@ -110,7 +110,7 @@ class ModelMutationMixin:
110
110
  raise Conflict(f"Model with name '{input.name}' already exists")
111
111
 
112
112
  return CreateModelMutationPayload(
113
- model=to_gql_generative_model(model),
113
+ model=GenerativeModel(id=model.id, db_record=model),
114
114
  query=Query(),
115
115
  )
116
116
 
@@ -163,7 +163,7 @@ class ModelMutationMixin:
163
163
  await session.refresh(model)
164
164
 
165
165
  return UpdateModelMutationPayload(
166
- model=to_gql_generative_model(model),
166
+ model=GenerativeModel(id=model.id, db_record=model),
167
167
  query=Query(),
168
168
  )
169
169
 
@@ -192,7 +192,7 @@ class ModelMutationMixin:
192
192
  await session.rollback()
193
193
  raise BadRequest("Cannot delete built-in model")
194
194
  return DeleteModelMutationPayload(
195
- model=to_gql_generative_model(model),
195
+ model=GenerativeModel(id=model.id, db_record=model),
196
196
  query=Query(),
197
197
  )
198
198
 
@@ -19,10 +19,7 @@ from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotatio
19
19
  from phoenix.server.api.queries import Query
20
20
  from phoenix.server.api.types.AnnotationSource import AnnotationSource
21
21
  from phoenix.server.api.types.node import from_global_id_with_expected_type
22
- from phoenix.server.api.types.ProjectSessionAnnotation import (
23
- ProjectSessionAnnotation,
24
- to_gql_project_session_annotation,
25
- )
22
+ from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
26
23
  from phoenix.server.bearer_auth import PhoenixUser
27
24
  from phoenix.server.dml_event import (
28
25
  ProjectSessionAnnotationDeleteEvent,
@@ -81,7 +78,7 @@ class ProjectSessionAnnotationMutationMixin:
81
78
  info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
82
79
 
83
80
  return ProjectSessionAnnotationMutationPayload(
84
- project_session_annotation=to_gql_project_session_annotation(anno),
81
+ project_session_annotation=ProjectSessionAnnotation(id=anno.id, db_record=anno),
85
82
  query=Query(),
86
83
  )
87
84
 
@@ -122,7 +119,7 @@ class ProjectSessionAnnotationMutationMixin:
122
119
 
123
120
  info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
124
121
  return ProjectSessionAnnotationMutationPayload(
125
- project_session_annotation=to_gql_project_session_annotation(anno),
122
+ project_session_annotation=ProjectSessionAnnotation(id=anno.id, db_record=anno),
126
123
  query=Query(),
127
124
  )
128
125
 
@@ -154,7 +151,7 @@ class ProjectSessionAnnotationMutationMixin:
154
151
 
155
152
  await session.delete(anno)
156
153
 
157
- deleted_gql_annotation = to_gql_project_session_annotation(anno)
154
+ deleted_gql_annotation = ProjectSessionAnnotation(id=anno.id, db_record=anno)
158
155
  info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
159
156
  return ProjectSessionAnnotationMutationPayload(
160
157
  project_session_annotation=deleted_gql_annotation, query=Query()
@@ -16,7 +16,7 @@ from phoenix.server.api.exceptions import Conflict, NotFound
16
16
  from phoenix.server.api.queries import Query
17
17
  from phoenix.server.api.types.node import from_global_id_with_expected_type
18
18
  from phoenix.server.api.types.Prompt import Prompt
19
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
19
+ from phoenix.server.api.types.PromptLabel import PromptLabel
20
20
 
21
21
 
22
22
  @strawberry.input
@@ -85,7 +85,7 @@ class PromptLabelMutationMixin:
85
85
  raise Conflict(f"A prompt label named '{input.name}' already exists.")
86
86
 
87
87
  return PromptLabelMutationPayload(
88
- prompt_labels=[to_gql_prompt_label(label_orm)],
88
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
89
89
  query=Query(),
90
90
  )
91
91
 
@@ -113,7 +113,7 @@ class PromptLabelMutationMixin:
113
113
  raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
114
114
 
115
115
  return PromptLabelMutationPayload(
116
- prompt_labels=[to_gql_prompt_label(label_orm)],
116
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
117
117
  query=Query(),
118
118
  )
119
119