arize-phoenix 12.8.0__py3-none-any.whl → 12.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (70) hide show
  1. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/METADATA +3 -1
  2. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/RECORD +70 -67
  3. phoenix/config.py +131 -9
  4. phoenix/db/engines.py +127 -14
  5. phoenix/db/iam_auth.py +64 -0
  6. phoenix/db/pg_config.py +10 -0
  7. phoenix/server/api/context.py +23 -0
  8. phoenix/server/api/dataloaders/__init__.py +6 -0
  9. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +0 -2
  10. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  11. phoenix/server/api/dataloaders/span_costs.py +3 -9
  12. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  13. phoenix/server/api/helpers/playground_clients.py +3 -3
  14. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  15. phoenix/server/api/mutations/annotation_config_mutations.py +2 -2
  16. phoenix/server/api/mutations/api_key_mutations.py +2 -15
  17. phoenix/server/api/mutations/chat_mutations.py +3 -2
  18. phoenix/server/api/mutations/dataset_label_mutations.py +12 -6
  19. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  20. phoenix/server/api/mutations/dataset_split_mutations.py +13 -9
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -7
  23. phoenix/server/api/mutations/prompt_label_mutations.py +3 -3
  24. phoenix/server/api/mutations/prompt_mutations.py +24 -117
  25. phoenix/server/api/mutations/prompt_version_tag_mutations.py +8 -5
  26. phoenix/server/api/mutations/span_annotations_mutations.py +10 -5
  27. phoenix/server/api/mutations/trace_annotations_mutations.py +9 -4
  28. phoenix/server/api/mutations/user_mutations.py +4 -4
  29. phoenix/server/api/queries.py +65 -210
  30. phoenix/server/api/subscriptions.py +4 -4
  31. phoenix/server/api/types/Annotation.py +90 -23
  32. phoenix/server/api/types/ApiKey.py +13 -17
  33. phoenix/server/api/types/Dataset.py +88 -48
  34. phoenix/server/api/types/DatasetExample.py +34 -30
  35. phoenix/server/api/types/DatasetLabel.py +47 -13
  36. phoenix/server/api/types/DatasetSplit.py +87 -21
  37. phoenix/server/api/types/DatasetVersion.py +49 -4
  38. phoenix/server/api/types/DocumentAnnotation.py +182 -62
  39. phoenix/server/api/types/Experiment.py +146 -55
  40. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +10 -1
  41. phoenix/server/api/types/ExperimentRun.py +118 -61
  42. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  43. phoenix/server/api/types/GenerativeModel.py +95 -42
  44. phoenix/server/api/types/ModelInterface.py +7 -2
  45. phoenix/server/api/types/PlaygroundModel.py +12 -2
  46. phoenix/server/api/types/Project.py +70 -75
  47. phoenix/server/api/types/ProjectSession.py +69 -37
  48. phoenix/server/api/types/ProjectSessionAnnotation.py +166 -47
  49. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  50. phoenix/server/api/types/Prompt.py +82 -44
  51. phoenix/server/api/types/PromptLabel.py +47 -13
  52. phoenix/server/api/types/PromptVersion.py +11 -8
  53. phoenix/server/api/types/PromptVersionTag.py +65 -25
  54. phoenix/server/api/types/Span.py +116 -115
  55. phoenix/server/api/types/SpanAnnotation.py +189 -42
  56. phoenix/server/api/types/SystemApiKey.py +65 -1
  57. phoenix/server/api/types/Trace.py +45 -44
  58. phoenix/server/api/types/TraceAnnotation.py +144 -48
  59. phoenix/server/api/types/User.py +103 -33
  60. phoenix/server/api/types/UserApiKey.py +73 -26
  61. phoenix/server/app.py +29 -0
  62. phoenix/server/static/.vite/manifest.json +9 -9
  63. phoenix/server/static/assets/{components-Bem6_7MW.js → components-v927s3NF.js} +427 -397
  64. phoenix/server/static/assets/{index-NdiXbuNL.js → index-DrD9eSrN.js} +9 -5
  65. phoenix/server/static/assets/{pages-CEJgMVKU.js → pages-GVybXa_W.js} +489 -486
  66. phoenix/version.py +1 -1
  67. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/WHEEL +0 -0
  68. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/entry_points.txt +0 -0
  69. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/IP_NOTICE +0 -0
  70. {arize_phoenix-12.8.0.dist-info → arize_phoenix-12.9.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,8 +16,8 @@ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
16
16
  from phoenix.server.api.context import Context
17
17
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
18
  from phoenix.server.api.queries import Query
19
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
20
- from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
19
+ from phoenix.server.api.types.Dataset import Dataset
20
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
21
21
  from phoenix.server.api.types.node import from_global_id_with_expected_type
22
22
 
23
23
 
@@ -111,9 +111,12 @@ class DatasetLabelMutationMixin:
111
111
  await session.commit()
112
112
 
113
113
  return CreateDatasetLabelMutationPayload(
114
- dataset_label=to_gql_dataset_label(dataset_label_orm),
114
+ dataset_label=DatasetLabel(id=dataset_label_orm.id, db_record=dataset_label_orm),
115
115
  datasets=[
116
- to_gql_dataset(datasets_by_id[dataset_rowid]) for dataset_rowid in dataset_rowids
116
+ Dataset(
117
+ id=datasets_by_id[dataset_rowid].id, db_record=datasets_by_id[dataset_rowid]
118
+ )
119
+ for dataset_rowid in dataset_rowids
117
120
  ],
118
121
  )
119
122
 
@@ -145,7 +148,10 @@ class DatasetLabelMutationMixin:
145
148
  }
146
149
  return DeleteDatasetLabelsMutationPayload(
147
150
  dataset_labels=[
148
- to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
151
+ DatasetLabel(
152
+ id=deleted_dataset_labels_by_id[dataset_label_row_id].id,
153
+ db_record=deleted_dataset_labels_by_id[dataset_label_row_id],
154
+ )
149
155
  for dataset_label_row_id in dataset_label_row_ids
150
156
  ]
151
157
  )
@@ -232,6 +238,6 @@ class DatasetLabelMutationMixin:
232
238
  )
233
239
 
234
240
  return SetDatasetLabelsMutationPayload(
235
- dataset=to_gql_dataset(dataset),
241
+ dataset=Dataset(id=dataset.id, db_record=dataset),
236
242
  query=Query(),
237
243
  )
@@ -35,7 +35,7 @@ from phoenix.server.api.input_types.PatchDatasetExamplesInput import (
35
35
  PatchDatasetExamplesInput,
36
36
  )
37
37
  from phoenix.server.api.input_types.PatchDatasetInput import PatchDatasetInput
38
- from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
38
+ from phoenix.server.api.types.Dataset import Dataset
39
39
  from phoenix.server.api.types.DatasetExample import DatasetExample
40
40
  from phoenix.server.api.types.node import from_global_id_with_expected_type
41
41
  from phoenix.server.api.types.Span import Span
@@ -72,7 +72,7 @@ class DatasetMutationMixin:
72
72
  )
73
73
  assert dataset is not None
74
74
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
75
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
75
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
76
76
 
77
77
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
78
78
  async def patch_dataset(
@@ -101,7 +101,7 @@ class DatasetMutationMixin:
101
101
  )
102
102
  assert dataset is not None
103
103
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
104
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
104
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
105
105
 
106
106
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
107
107
  async def add_spans_to_dataset(
@@ -221,7 +221,7 @@ class DatasetMutationMixin:
221
221
  ],
222
222
  )
223
223
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
224
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
224
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
225
225
 
226
226
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
227
227
  async def add_examples_to_dataset(
@@ -348,7 +348,7 @@ class DatasetMutationMixin:
348
348
  dataset_example_revisions,
349
349
  )
350
350
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
351
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
351
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
352
352
 
353
353
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
354
354
  async def delete_dataset(
@@ -379,7 +379,7 @@ class DatasetMutationMixin:
379
379
  return_exceptions=True,
380
380
  )
381
381
  info.context.event_queue.put(DatasetDeleteEvent((dataset.id,)))
382
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
382
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
383
383
 
384
384
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
385
385
  async def patch_dataset_examples(
@@ -472,7 +472,7 @@ class DatasetMutationMixin:
472
472
  ],
473
473
  )
474
474
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
475
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
475
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
476
476
 
477
477
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
478
478
  async def delete_dataset_examples(
@@ -556,7 +556,7 @@ class DatasetMutationMixin:
556
556
  ],
557
557
  )
558
558
  info.context.event_queue.put(DatasetInsertEvent((dataset.id,)))
559
- return DatasetMutationPayload(dataset=to_gql_dataset(dataset))
559
+ return DatasetMutationPayload(dataset=Dataset(id=dataset.id, db_record=dataset))
560
560
 
561
561
 
562
562
  def _span_attribute(semconv: str) -> Any:
@@ -15,8 +15,8 @@ from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
16
  from phoenix.server.api.helpers.playground_users import get_user
17
17
  from phoenix.server.api.queries import Query
18
- from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
19
- from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
18
+ from phoenix.server.api.types.DatasetExample import DatasetExample
19
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
20
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
21
 
22
22
 
@@ -116,7 +116,8 @@ class DatasetSplitMutationMixin:
116
116
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
117
117
  raise Conflict(f"A dataset split named '{input.name}' already exists.")
118
118
  return DatasetSplitMutationPayload(
119
- dataset_split=to_gql_dataset_split(dataset_split_orm), query=Query()
119
+ dataset_split=DatasetSplit(id=dataset_split_orm.id, db_record=dataset_split_orm),
120
+ query=Query(),
120
121
  )
121
122
 
122
123
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
@@ -141,7 +142,7 @@ class DatasetSplitMutationMixin:
141
142
  if isinstance(input.metadata, dict):
142
143
  dataset_split_orm.metadata_ = input.metadata
143
144
 
144
- gql_dataset_split = to_gql_dataset_split(dataset_split_orm)
145
+ gql_dataset_split = DatasetSplit(id=dataset_split_orm.id, db_record=dataset_split_orm)
145
146
  try:
146
147
  await session.commit()
147
148
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
@@ -185,7 +186,10 @@ class DatasetSplitMutationMixin:
185
186
 
186
187
  return DeleteDatasetSplitsMutationPayload(
187
188
  dataset_splits=[
188
- to_gql_dataset_split(deleted_splits_by_id[dataset_split_rowid])
189
+ DatasetSplit(
190
+ id=deleted_splits_by_id[dataset_split_rowid].id,
191
+ db_record=deleted_splits_by_id[dataset_split_rowid],
192
+ )
189
193
  for dataset_split_rowid in dataset_split_rowids
190
194
  ],
191
195
  query=Query(),
@@ -281,7 +285,7 @@ class DatasetSplitMutationMixin:
281
285
  ).all()
282
286
  return AddDatasetExamplesToDatasetSplitsMutationPayload(
283
287
  query=Query(),
284
- examples=[to_gql_dataset_example(example) for example in examples],
288
+ examples=[DatasetExample(id=example.id, db_record=example) for example in examples],
285
289
  )
286
290
 
287
291
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
@@ -342,7 +346,7 @@ class DatasetSplitMutationMixin:
342
346
 
343
347
  return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
344
348
  query=Query(),
345
- examples=[to_gql_dataset_example(example) for example in examples],
349
+ examples=[DatasetExample(id=example.id, db_record=example) for example in examples],
346
350
  )
347
351
 
348
352
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
@@ -410,9 +414,9 @@ class DatasetSplitMutationMixin:
410
414
  ).all()
411
415
 
412
416
  return DatasetSplitMutationPayloadWithExamples(
413
- dataset_split=to_gql_dataset_split(dataset_split_orm),
417
+ dataset_split=DatasetSplit(id=dataset_split_orm.id, db_record=dataset_split_orm),
414
418
  query=Query(),
415
- examples=[to_gql_dataset_example(example) for example in examples],
419
+ examples=[DatasetExample(id=example.id, db_record=example) for example in examples],
416
420
  )
417
421
 
418
422
 
@@ -16,7 +16,7 @@ from phoenix.server.api.auth import IsNotReadOnly, IsNotViewer
16
16
  from phoenix.server.api.context import Context
17
17
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
18
18
  from phoenix.server.api.queries import Query
19
- from phoenix.server.api.types.GenerativeModel import GenerativeModel, to_gql_generative_model
19
+ from phoenix.server.api.types.GenerativeModel import GenerativeModel
20
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
21
  from phoenix.server.api.types.TokenPrice import TokenKind
22
22
 
@@ -110,7 +110,7 @@ class ModelMutationMixin:
110
110
  raise Conflict(f"Model with name '{input.name}' already exists")
111
111
 
112
112
  return CreateModelMutationPayload(
113
- model=to_gql_generative_model(model),
113
+ model=GenerativeModel(id=model.id, db_record=model),
114
114
  query=Query(),
115
115
  )
116
116
 
@@ -163,7 +163,7 @@ class ModelMutationMixin:
163
163
  await session.refresh(model)
164
164
 
165
165
  return UpdateModelMutationPayload(
166
- model=to_gql_generative_model(model),
166
+ model=GenerativeModel(id=model.id, db_record=model),
167
167
  query=Query(),
168
168
  )
169
169
 
@@ -192,7 +192,7 @@ class ModelMutationMixin:
192
192
  await session.rollback()
193
193
  raise BadRequest("Cannot delete built-in model")
194
194
  return DeleteModelMutationPayload(
195
- model=to_gql_generative_model(model),
195
+ model=GenerativeModel(id=model.id, db_record=model),
196
196
  query=Query(),
197
197
  )
198
198
 
@@ -19,10 +19,7 @@ from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotatio
19
19
  from phoenix.server.api.queries import Query
20
20
  from phoenix.server.api.types.AnnotationSource import AnnotationSource
21
21
  from phoenix.server.api.types.node import from_global_id_with_expected_type
22
- from phoenix.server.api.types.ProjectSessionAnnotation import (
23
- ProjectSessionAnnotation,
24
- to_gql_project_session_annotation,
25
- )
22
+ from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
26
23
  from phoenix.server.bearer_auth import PhoenixUser
27
24
  from phoenix.server.dml_event import (
28
25
  ProjectSessionAnnotationDeleteEvent,
@@ -81,7 +78,7 @@ class ProjectSessionAnnotationMutationMixin:
81
78
  info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
82
79
 
83
80
  return ProjectSessionAnnotationMutationPayload(
84
- project_session_annotation=to_gql_project_session_annotation(anno),
81
+ project_session_annotation=ProjectSessionAnnotation(id=anno.id, db_record=anno),
85
82
  query=Query(),
86
83
  )
87
84
 
@@ -122,7 +119,7 @@ class ProjectSessionAnnotationMutationMixin:
122
119
 
123
120
  info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
124
121
  return ProjectSessionAnnotationMutationPayload(
125
- project_session_annotation=to_gql_project_session_annotation(anno),
122
+ project_session_annotation=ProjectSessionAnnotation(id=anno.id, db_record=anno),
126
123
  query=Query(),
127
124
  )
128
125
 
@@ -154,7 +151,7 @@ class ProjectSessionAnnotationMutationMixin:
154
151
 
155
152
  await session.delete(anno)
156
153
 
157
- deleted_gql_annotation = to_gql_project_session_annotation(anno)
154
+ deleted_gql_annotation = ProjectSessionAnnotation(id=anno.id, db_record=anno)
158
155
  info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
159
156
  return ProjectSessionAnnotationMutationPayload(
160
157
  project_session_annotation=deleted_gql_annotation, query=Query()
@@ -16,7 +16,7 @@ from phoenix.server.api.exceptions import Conflict, NotFound
16
16
  from phoenix.server.api.queries import Query
17
17
  from phoenix.server.api.types.node import from_global_id_with_expected_type
18
18
  from phoenix.server.api.types.Prompt import Prompt
19
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
19
+ from phoenix.server.api.types.PromptLabel import PromptLabel
20
20
 
21
21
 
22
22
  @strawberry.input
@@ -85,7 +85,7 @@ class PromptLabelMutationMixin:
85
85
  raise Conflict(f"A prompt label named '{input.name}' already exists.")
86
86
 
87
87
  return PromptLabelMutationPayload(
88
- prompt_labels=[to_gql_prompt_label(label_orm)],
88
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
89
89
  query=Query(),
90
90
  )
91
91
 
@@ -113,7 +113,7 @@ class PromptLabelMutationMixin:
113
113
  raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
114
114
 
115
115
  return PromptLabelMutationPayload(
116
- prompt_labels=[to_gql_prompt_label(label_orm)],
116
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
117
117
  query=Query(),
118
118
  )
119
119
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Union, cast
1
+ from typing import Optional
2
2
 
3
3
  import strawberry
4
4
  from fastapi import Request
@@ -12,18 +12,11 @@ from strawberry.types import Info
12
12
 
13
13
  from phoenix.db import models
14
14
  from phoenix.db.types.identifier import Identifier as IdentifierModel
15
- from phoenix.db.types.model_provider import ModelProvider
16
15
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
17
16
  from phoenix.server.api.context import Context
18
17
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
19
- from phoenix.server.api.helpers.prompts.models import (
20
- normalize_response_format,
21
- normalize_tools,
22
- validate_invocation_parameters,
23
- )
24
18
  from phoenix.server.api.input_types.PromptVersionInput import (
25
19
  ChatPromptVersionInput,
26
- to_pydantic_prompt_chat_template_v1,
27
20
  )
28
21
  from phoenix.server.api.mutations.prompt_version_tag_mutations import (
29
22
  SetPromptVersionTagInput,
@@ -32,7 +25,7 @@ from phoenix.server.api.mutations.prompt_version_tag_mutations import (
32
25
  from phoenix.server.api.queries import Query
33
26
  from phoenix.server.api.types.Identifier import Identifier
34
27
  from phoenix.server.api.types.node import from_global_id_with_expected_type
35
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
28
+ from phoenix.server.api.types.Prompt import Prompt
36
29
  from phoenix.server.bearer_auth import PhoenixUser
37
30
 
38
31
 
@@ -84,63 +77,23 @@ class PromptMutationMixin:
84
77
  if "user" in request.scope:
85
78
  assert isinstance(user := request.user, PhoenixUser)
86
79
  user_id = int(user.identity)
87
-
88
- input_prompt_version = input.prompt_version
89
- tool_definitions = [tool.definition for tool in input_prompt_version.tools]
90
- tool_choice = cast(
91
- Optional[Union[str, dict[str, Any]]],
92
- cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
93
- "tool_choice", None
94
- ),
95
- )
96
- model_provider = ModelProvider(input_prompt_version.model_provider)
97
80
  try:
98
- tools = (
99
- normalize_tools(tool_definitions, model_provider, tool_choice)
100
- if tool_definitions
101
- else None
102
- )
103
- template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
104
- response_format = (
105
- normalize_response_format(
106
- input_prompt_version.response_format.definition,
107
- model_provider,
108
- )
109
- if input_prompt_version.response_format
110
- else None
111
- )
112
- invocation_parameters = validate_invocation_parameters(
113
- input_prompt_version.invocation_parameters,
114
- model_provider,
115
- )
81
+ prompt_version = input.prompt_version.to_orm_prompt_version(user_id)
116
82
  except ValidationError as error:
117
83
  raise BadRequest(str(error))
118
-
84
+ name = IdentifierModel.model_validate(str(input.name))
85
+ prompt = models.Prompt(
86
+ name=name,
87
+ description=input.description,
88
+ prompt_versions=[prompt_version],
89
+ )
119
90
  async with info.context.db() as session:
120
- prompt_version = models.PromptVersion(
121
- description=input_prompt_version.description,
122
- user_id=user_id,
123
- template_type="CHAT",
124
- template_format=input_prompt_version.template_format,
125
- template=template,
126
- invocation_parameters=invocation_parameters,
127
- tools=tools,
128
- response_format=response_format,
129
- model_provider=input_prompt_version.model_provider,
130
- model_name=input_prompt_version.model_name,
131
- )
132
- name = IdentifierModel.model_validate(str(input.name))
133
- prompt = models.Prompt(
134
- name=name,
135
- description=input.description,
136
- prompt_versions=[prompt_version],
137
- )
138
91
  session.add(prompt)
139
92
  try:
140
93
  await session.commit()
141
94
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
142
95
  raise Conflict(f"A prompt named '{input.name}' already exists")
143
- return to_gql_prompt_from_orm(prompt)
96
+ return Prompt(id=prompt.id, db_record=prompt)
144
97
 
145
98
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
146
99
  async def create_chat_prompt_version(
@@ -153,72 +106,26 @@ class PromptMutationMixin:
153
106
  if "user" in request.scope:
154
107
  assert isinstance(user := request.user, PhoenixUser)
155
108
  user_id = int(user.identity)
156
-
157
- input_prompt_version = input.prompt_version
158
- tool_definitions = [tool.definition for tool in input.prompt_version.tools]
159
- tool_choice = cast(
160
- Optional[Union[str, dict[str, Any]]],
161
- cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
162
- "tool_choice", None
163
- ),
164
- )
165
- model_provider = ModelProvider(input_prompt_version.model_provider)
166
109
  try:
167
- tools = (
168
- normalize_tools(tool_definitions, model_provider, tool_choice)
169
- if tool_definitions
170
- else None
171
- )
172
- template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
173
- response_format = (
174
- normalize_response_format(
175
- input_prompt_version.response_format.definition,
176
- model_provider,
177
- )
178
- if input_prompt_version.response_format
179
- else None
180
- )
181
- invocation_parameters = validate_invocation_parameters(
182
- input_prompt_version.invocation_parameters,
183
- model_provider,
184
- )
110
+ prompt_version = input.prompt_version.to_orm_prompt_version(user_id)
185
111
  except ValidationError as error:
186
112
  raise BadRequest(str(error))
187
-
188
113
  prompt_id = from_global_id_with_expected_type(
189
114
  global_id=input.prompt_id, expected_type_name=Prompt.__name__
190
115
  )
116
+ prompt_version.prompt_id = prompt_id
191
117
  async with info.context.db() as session:
192
- prompt = await session.get(models.Prompt, prompt_id)
193
- if not prompt:
194
- raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
195
-
196
- prompt_version = models.PromptVersion(
197
- prompt_id=prompt_id,
198
- description=input.prompt_version.description,
199
- user_id=user_id,
200
- template_type="CHAT",
201
- template_format=input.prompt_version.template_format,
202
- template=template,
203
- invocation_parameters=invocation_parameters,
204
- tools=tools,
205
- response_format=response_format,
206
- model_provider=input.prompt_version.model_provider,
207
- model_name=input.prompt_version.model_name,
208
- )
209
118
  session.add(prompt_version)
210
-
211
- # ensure prompt_version is flushed to the database before creating tags against the
212
- # prompt_version id
213
- await session.flush()
214
-
215
- if input.tags:
216
- for tag in input.tags:
217
- await upsert_prompt_version_tag(
218
- session, prompt_id, prompt_version.id, tag.name, tag.description
219
- )
220
-
221
- return to_gql_prompt_from_orm(prompt)
119
+ try:
120
+ await session.flush()
121
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
122
+ raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
123
+ if input.tags:
124
+ for tag in input.tags:
125
+ await upsert_prompt_version_tag(
126
+ session, prompt_id, prompt_version.id, tag.name, tag.description
127
+ )
128
+ return Prompt(id=prompt_id)
222
129
 
223
130
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
224
131
  async def delete_prompt(
@@ -288,7 +195,7 @@ class PromptMutationMixin:
288
195
  await session.commit()
289
196
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
290
197
  raise Conflict(f"A prompt named '{input.name}' already exists")
291
- return to_gql_prompt_from_orm(new_prompt)
198
+ return Prompt(id=new_prompt.id, db_record=new_prompt)
292
199
 
293
200
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
294
201
  async def patch_prompt(self, info: Info[Context, None], input: PatchPromptInput) -> Prompt:
@@ -310,4 +217,4 @@ class PromptMutationMixin:
310
217
  if prompt is None:
311
218
  raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
312
219
 
313
- return to_gql_prompt_from_orm(prompt)
220
+ return Prompt(id=prompt.id, db_record=prompt)
@@ -16,9 +16,9 @@ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
16
  from phoenix.server.api.queries import Query
17
17
  from phoenix.server.api.types.Identifier import Identifier
18
18
  from phoenix.server.api.types.node import from_global_id_with_expected_type
19
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
19
+ from phoenix.server.api.types.Prompt import Prompt
20
20
  from phoenix.server.api.types.PromptVersion import PromptVersion
21
- from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
21
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
22
22
 
23
23
 
24
24
  @strawberry.input
@@ -75,7 +75,9 @@ class PromptVersionTagMutationMixin:
75
75
  await session.delete(prompt_version_tag)
76
76
  await session.commit()
77
77
  return PromptVersionTagMutationPayload(
78
- prompt_version_tag=None, query=Query(), prompt=to_gql_prompt_from_orm(prompt)
78
+ prompt_version_tag=None,
79
+ query=Query(),
80
+ prompt=Prompt(id=prompt.id, db_record=prompt),
79
81
  )
80
82
 
81
83
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
@@ -111,9 +113,10 @@ class PromptVersionTagMutationMixin:
111
113
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
112
114
  raise Conflict("Failed to update PromptVersionTag.")
113
115
 
114
- version_tag = to_gql_prompt_version_tag(updated_tag)
115
116
  return PromptVersionTagMutationPayload(
116
- prompt_version_tag=version_tag, prompt=to_gql_prompt_from_orm(prompt), query=Query()
117
+ prompt_version_tag=PromptVersionTag(id=updated_tag.id, db_record=updated_tag),
118
+ prompt=Prompt(id=prompt.id, db_record=prompt),
119
+ query=Query(),
117
120
  )
118
121
 
119
122
 
@@ -21,7 +21,7 @@ from phoenix.server.api.queries import Query
21
21
  from phoenix.server.api.types.AnnotationSource import AnnotationSource
22
22
  from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
23
23
  from phoenix.server.api.types.node import from_global_id_with_expected_type
24
- from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
24
+ from phoenix.server.api.types.SpanAnnotation import SpanAnnotation
25
25
  from phoenix.server.bearer_auth import PhoenixUser
26
26
  from phoenix.server.dml_event import SpanAnnotationDeleteEvent, SpanAnnotationInsertEvent
27
27
 
@@ -138,7 +138,7 @@ class SpanAnnotationMutationMixin:
138
138
 
139
139
  # Convert the fully loaded annotations to GQL types
140
140
  returned_annotations = [
141
- to_gql_span_annotation(anno) for anno in ordered_final_annotations
141
+ SpanAnnotation(id=anno.id, db_record=anno) for anno in ordered_final_annotations
142
142
  ]
143
143
 
144
144
  await session.commit()
@@ -184,7 +184,9 @@ class SpanAnnotationMutationMixin:
184
184
  processed_annotation = result.one()
185
185
 
186
186
  info.context.event_queue.put(SpanAnnotationInsertEvent((processed_annotation.id,)))
187
- returned_annotation = to_gql_span_annotation(processed_annotation)
187
+ returned_annotation = SpanAnnotation(
188
+ id=processed_annotation.id, db_record=processed_annotation
189
+ )
188
190
  await session.commit()
189
191
  return SpanAnnotationMutationPayload(
190
192
  span_annotations=[returned_annotation],
@@ -256,7 +258,7 @@ class SpanAnnotationMutationMixin:
256
258
  session.add(span_annotation)
257
259
 
258
260
  patched_annotations = [
259
- to_gql_span_annotation(span_annotation)
261
+ SpanAnnotation(id=span_annotation.id, db_record=span_annotation)
260
262
  for span_annotation in span_annotations_by_id.values()
261
263
  ]
262
264
 
@@ -320,7 +322,10 @@ class SpanAnnotationMutationMixin:
320
322
  )
321
323
 
322
324
  deleted_annotations_gql = [
323
- to_gql_span_annotation(deleted_annotations_by_id[id]) for id in span_annotation_ids
325
+ SpanAnnotation(
326
+ id=deleted_annotations_by_id[id].id, db_record=deleted_annotations_by_id[id]
327
+ )
328
+ for id in span_annotation_ids
324
329
  ]
325
330
  info.context.event_queue.put(
326
331
  SpanAnnotationDeleteEvent(tuple(deleted_annotations_by_id.keys()))
@@ -16,7 +16,7 @@ from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationI
16
16
  from phoenix.server.api.queries import Query
17
17
  from phoenix.server.api.types.AnnotationSource import AnnotationSource
18
18
  from phoenix.server.api.types.node import from_global_id_with_expected_type
19
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
19
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
20
20
  from phoenix.server.bearer_auth import PhoenixUser
21
21
  from phoenix.server.dml_event import TraceAnnotationDeleteEvent, TraceAnnotationInsertEvent
22
22
 
@@ -111,7 +111,9 @@ class TraceAnnotationMutationMixin:
111
111
  info.context.event_queue.put(TraceAnnotationInsertEvent(inserted_annotation_ids))
112
112
 
113
113
  returned_annotations = [
114
- to_gql_trace_annotation(processed_annotations_map[i])
114
+ TraceAnnotation(
115
+ id=processed_annotations_map[i].id, db_record=processed_annotations_map[i]
116
+ )
115
117
  for i in sorted(processed_annotations_map.keys())
116
118
  ]
117
119
 
@@ -186,7 +188,7 @@ class TraceAnnotationMutationMixin:
186
188
  await session.commit()
187
189
 
188
190
  patched_annotations = [
189
- to_gql_trace_annotation(trace_annotation)
191
+ TraceAnnotation(id=trace_annotation.id, db_record=trace_annotation)
190
192
  for trace_annotation in trace_annotations_by_id.values()
191
193
  ]
192
194
  info.context.event_queue.put(TraceAnnotationInsertEvent(tuple(patch_by_id.keys())))
@@ -245,7 +247,10 @@ class TraceAnnotationMutationMixin:
245
247
  )
246
248
 
247
249
  deleted_gql_annotations = [
248
- to_gql_trace_annotation(deleted_annotations_by_id[id]) for id in trace_annotation_ids
250
+ TraceAnnotation(
251
+ id=deleted_annotations_by_id[id].id, db_record=deleted_annotations_by_id[id]
252
+ )
253
+ for id in trace_annotation_ids
249
254
  ]
250
255
  info.context.event_queue.put(
251
256
  TraceAnnotationDeleteEvent(tuple(deleted_annotations_by_id.keys()))
@@ -33,7 +33,7 @@ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauth
33
33
  from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
34
34
  from phoenix.server.api.types.AuthMethod import AuthMethod
35
35
  from phoenix.server.api.types.node import from_global_id_with_expected_type
36
- from phoenix.server.api.types.User import User, to_gql_user
36
+ from phoenix.server.api.types.User import User
37
37
  from phoenix.server.bearer_auth import PhoenixUser
38
38
  from phoenix.server.types import AccessTokenId, ApiKeyId, PasswordResetTokenId, RefreshTokenId
39
39
 
@@ -155,7 +155,7 @@ class UserMutationMixin:
155
155
  except Exception as error:
156
156
  # Log the error but do not raise it
157
157
  logger.error(f"Failed to send welcome email: {error}")
158
- return UserMutationPayload(user=to_gql_user(user))
158
+ return UserMutationPayload(user=User(id=user.id, db_record=user))
159
159
 
160
160
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdmin]) # type: ignore
161
161
  async def patch_user(
@@ -204,7 +204,7 @@ class UserMutationMixin:
204
204
  assert user
205
205
  if should_log_out:
206
206
  await info.context.log_out(user.id)
207
- return UserMutationPayload(user=to_gql_user(user))
207
+ return UserMutationPayload(user=User(id=user.id, db_record=user))
208
208
 
209
209
  @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
210
210
  async def patch_viewer(
@@ -246,7 +246,7 @@ class UserMutationMixin:
246
246
  response = info.context.get_response()
247
247
  response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
248
248
  response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
249
- return UserMutationPayload(user=to_gql_user(user))
249
+ return UserMutationPayload(user=User(id=user.id, db_record=user))
250
250
 
251
251
  @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdmin, IsLocked]) # type: ignore
252
252
  async def delete_users(