arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -10,79 +10,89 @@ from strawberry.relay import GlobalID
10
10
  from strawberry.types import Info
11
11
 
12
12
  from phoenix.db import models
13
- from phoenix.db.types.identifier import Identifier as IdentifierModel
14
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
15
14
  from phoenix.server.api.context import Context
16
15
  from phoenix.server.api.exceptions import Conflict, NotFound
17
16
  from phoenix.server.api.queries import Query
18
- from phoenix.server.api.types.Identifier import Identifier
19
17
  from phoenix.server.api.types.node import from_global_id_with_expected_type
20
18
  from phoenix.server.api.types.Prompt import Prompt
21
- from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
19
+ from phoenix.server.api.types.PromptLabel import PromptLabel
22
20
 
23
21
 
24
22
  @strawberry.input
25
23
  class CreatePromptLabelInput:
26
- name: Identifier
24
+ name: str
27
25
  description: Optional[str] = None
26
+ color: str
28
27
 
29
28
 
30
29
  @strawberry.input
31
30
  class PatchPromptLabelInput:
32
31
  prompt_label_id: GlobalID
33
- name: Optional[Identifier] = None
32
+ name: Optional[str] = None
34
33
  description: Optional[str] = None
35
34
 
36
35
 
37
36
  @strawberry.input
38
- class DeletePromptLabelInput:
39
- prompt_label_id: GlobalID
37
+ class DeletePromptLabelsInput:
38
+ prompt_label_ids: list[GlobalID]
40
39
 
41
40
 
42
41
  @strawberry.input
43
- class SetPromptLabelInput:
42
+ class SetPromptLabelsInput:
44
43
  prompt_id: GlobalID
45
- prompt_label_id: GlobalID
44
+ prompt_label_ids: list[GlobalID]
46
45
 
47
46
 
48
47
  @strawberry.input
49
- class UnsetPromptLabelInput:
48
+ class UnsetPromptLabelsInput:
50
49
  prompt_id: GlobalID
51
- prompt_label_id: GlobalID
50
+ prompt_label_ids: list[GlobalID]
52
51
 
53
52
 
54
53
  @strawberry.type
55
54
  class PromptLabelMutationPayload:
56
- prompt_label: Optional["PromptLabel"]
55
+ prompt_labels: list["PromptLabel"]
56
+ query: "Query"
57
+
58
+
59
+ @strawberry.type
60
+ class PromptLabelDeleteMutationPayload:
61
+ deleted_prompt_label_ids: list["GlobalID"]
62
+ query: "Query"
63
+
64
+
65
+ @strawberry.type
66
+ class PromptLabelAssociationMutationPayload:
57
67
  query: "Query"
58
68
 
59
69
 
60
70
  @strawberry.type
61
71
  class PromptLabelMutationMixin:
62
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
72
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
63
73
  async def create_prompt_label(
64
74
  self, info: Info[Context, None], input: CreatePromptLabelInput
65
75
  ) -> PromptLabelMutationPayload:
66
76
  async with info.context.db() as session:
67
- name = IdentifierModel.model_validate(str(input.name))
68
- label_orm = models.PromptLabel(name=name, description=input.description)
77
+ label_orm = models.PromptLabel(
78
+ name=input.name, description=input.description, color=input.color
79
+ )
69
80
  session.add(label_orm)
70
81
 
71
82
  try:
72
83
  await session.commit()
73
84
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
74
- raise Conflict(f"A prompt label named '{name}' already exists.")
85
+ raise Conflict(f"A prompt label named '{input.name}' already exists.")
75
86
 
76
87
  return PromptLabelMutationPayload(
77
- prompt_label=to_gql_prompt_label(label_orm),
88
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
78
89
  query=Query(),
79
90
  )
80
91
 
81
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
92
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
82
93
  async def patch_prompt_label(
83
94
  self, info: Info[Context, None], input: PatchPromptLabelInput
84
95
  ) -> PromptLabelMutationPayload:
85
- validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
86
96
  async with info.context.db() as session:
87
97
  label_id = from_global_id_with_expected_type(
88
98
  input.prompt_label_id, PromptLabel.__name__
@@ -92,8 +102,8 @@ class PromptLabelMutationMixin:
92
102
  if not label_orm:
93
103
  raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
94
104
 
95
- if validated_name is not None:
96
- label_orm.name = validated_name.root
105
+ if input.name is not None:
106
+ label_orm.name = input.name
97
107
  if input.description is not None:
98
108
  label_orm.description = input.description
99
109
 
@@ -103,46 +113,48 @@ class PromptLabelMutationMixin:
103
113
  raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
104
114
 
105
115
  return PromptLabelMutationPayload(
106
- prompt_label=to_gql_prompt_label(label_orm),
116
+ prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
107
117
  query=Query(),
108
118
  )
109
119
 
110
- @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
111
- async def delete_prompt_label(
112
- self, info: Info[Context, None], input: DeletePromptLabelInput
113
- ) -> PromptLabelMutationPayload:
120
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
121
+ async def delete_prompt_labels(
122
+ self, info: Info[Context, None], input: DeletePromptLabelsInput
123
+ ) -> PromptLabelDeleteMutationPayload:
114
124
  """
115
125
  Deletes a PromptLabel (and any crosswalk references).
116
126
  """
117
127
  async with info.context.db() as session:
118
- label_id = from_global_id_with_expected_type(
119
- input.prompt_label_id, PromptLabel.__name__
120
- )
121
- stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id)
122
- result = await session.execute(stmt)
123
-
124
- if result.rowcount == 0:
125
- raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
128
+ label_ids = [
129
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
130
+ for prompt_label_id in input.prompt_label_ids
131
+ ]
132
+ stmt = delete(models.PromptLabel).where(models.PromptLabel.id.in_(label_ids))
133
+ await session.execute(stmt)
126
134
 
127
135
  await session.commit()
128
136
 
129
- return PromptLabelMutationPayload(
130
- prompt_label=None,
137
+ return PromptLabelDeleteMutationPayload(
138
+ deleted_prompt_label_ids=input.prompt_label_ids,
131
139
  query=Query(),
132
140
  )
133
141
 
134
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
135
- async def set_prompt_label(
136
- self, info: Info[Context, None], input: SetPromptLabelInput
137
- ) -> PromptLabelMutationPayload:
142
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
143
+ async def set_prompt_labels(
144
+ self, info: Info[Context, None], input: SetPromptLabelsInput
145
+ ) -> PromptLabelAssociationMutationPayload:
138
146
  async with info.context.db() as session:
139
147
  prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
140
- label_id = from_global_id_with_expected_type(
141
- input.prompt_label_id, PromptLabel.__name__
142
- )
148
+ label_ids = [
149
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
150
+ for prompt_label_id in input.prompt_label_ids
151
+ ]
143
152
 
144
- crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
145
- session.add(crosswalk)
153
+ crosswalk_items = [
154
+ models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
155
+ for label_id in label_ids
156
+ ]
157
+ session.add_all(crosswalk_items)
146
158
 
147
159
  try:
148
160
  await session.commit()
@@ -152,41 +164,38 @@ class PromptLabelMutationMixin:
152
164
  # - Foreign key violation => prompt_id or label_id doesn't exist
153
165
  raise Conflict("Failed to associate PromptLabel with Prompt.") from e
154
166
 
155
- label_orm = await session.get(models.PromptLabel, label_id)
156
- if not label_orm:
157
- raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
158
-
159
- return PromptLabelMutationPayload(
160
- prompt_label=to_gql_prompt_label(label_orm),
167
+ return PromptLabelAssociationMutationPayload(
161
168
  query=Query(),
162
169
  )
163
170
 
164
- @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
165
- async def unset_prompt_label(
166
- self, info: Info[Context, None], input: UnsetPromptLabelInput
167
- ) -> PromptLabelMutationPayload:
171
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
172
+ async def unset_prompt_labels(
173
+ self, info: Info[Context, None], input: UnsetPromptLabelsInput
174
+ ) -> PromptLabelAssociationMutationPayload:
168
175
  """
169
176
  Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
170
177
  """
171
178
  async with info.context.db() as session:
172
179
  prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
173
- label_id = from_global_id_with_expected_type(
174
- input.prompt_label_id, PromptLabel.__name__
175
- )
180
+ label_ids = [
181
+ from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
182
+ for prompt_label_id in input.prompt_label_ids
183
+ ]
176
184
 
177
185
  stmt = delete(models.PromptPromptLabel).where(
178
186
  (models.PromptPromptLabel.prompt_id == prompt_id)
179
- & (models.PromptPromptLabel.prompt_label_id == label_id)
187
+ & (models.PromptPromptLabel.prompt_label_id.in_(label_ids))
180
188
  )
181
189
  result = await session.execute(stmt)
182
190
 
183
- if result.rowcount == 0:
184
- raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.")
191
+ if result.rowcount != len(label_ids): # type: ignore[attr-defined]
192
+ label_ids_str = ", ".join(str(i) for i in label_ids)
193
+ raise NotFound(
194
+ f"No association between prompt={prompt_id} and labels={label_ids_str}."
195
+ )
185
196
 
186
197
  await session.commit()
187
198
 
188
- label_orm = await session.get(models.PromptLabel, label_id)
189
- return PromptLabelMutationPayload(
190
- prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
199
+ return PromptLabelAssociationMutationPayload(
191
200
  query=Query(),
192
201
  )
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Union, cast
1
+ from typing import Any, Optional
2
2
 
3
3
  import strawberry
4
4
  from fastapi import Request
@@ -7,23 +7,17 @@ from sqlalchemy import delete, select, update
7
7
  from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
8
8
  from sqlalchemy.orm import joinedload
9
9
  from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
10
+ from strawberry import UNSET
10
11
  from strawberry.relay.types import GlobalID
11
12
  from strawberry.types import Info
12
13
 
13
14
  from phoenix.db import models
14
15
  from phoenix.db.types.identifier import Identifier as IdentifierModel
15
- from phoenix.db.types.model_provider import ModelProvider
16
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
16
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
17
17
  from phoenix.server.api.context import Context
18
18
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
19
- from phoenix.server.api.helpers.prompts.models import (
20
- normalize_response_format,
21
- normalize_tools,
22
- validate_invocation_parameters,
23
- )
24
19
  from phoenix.server.api.input_types.PromptVersionInput import (
25
20
  ChatPromptVersionInput,
26
- to_pydantic_prompt_chat_template_v1,
27
21
  )
28
22
  from phoenix.server.api.mutations.prompt_version_tag_mutations import (
29
23
  SetPromptVersionTagInput,
@@ -32,7 +26,7 @@ from phoenix.server.api.mutations.prompt_version_tag_mutations import (
32
26
  from phoenix.server.api.queries import Query
33
27
  from phoenix.server.api.types.Identifier import Identifier
34
28
  from phoenix.server.api.types.node import from_global_id_with_expected_type
35
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
29
+ from phoenix.server.api.types.Prompt import Prompt
36
30
  from phoenix.server.bearer_auth import PhoenixUser
37
31
 
38
32
 
@@ -41,6 +35,7 @@ class CreateChatPromptInput:
41
35
  name: Identifier
42
36
  description: Optional[str] = None
43
37
  prompt_version: ChatPromptVersionInput
38
+ metadata: Optional[strawberry.scalars.JSON] = None
44
39
 
45
40
 
46
41
  @strawberry.input
@@ -58,14 +53,16 @@ class DeletePromptInput:
58
53
  @strawberry.input
59
54
  class ClonePromptInput:
60
55
  name: Identifier
61
- description: Optional[str] = None
62
56
  prompt_id: GlobalID
57
+ description: Optional[str] = UNSET
58
+ metadata: Optional[strawberry.scalars.JSON] = UNSET
63
59
 
64
60
 
65
61
  @strawberry.input
66
62
  class PatchPromptInput:
67
63
  prompt_id: GlobalID
68
- description: str
64
+ description: Optional[str] = UNSET
65
+ metadata: Optional[strawberry.scalars.JSON] = UNSET
69
66
 
70
67
 
71
68
  @strawberry.type
@@ -75,7 +72,7 @@ class DeletePromptMutationPayload:
75
72
 
76
73
  @strawberry.type
77
74
  class PromptMutationMixin:
78
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
75
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
79
76
  async def create_chat_prompt(
80
77
  self, info: Info[Context, None], input: CreateChatPromptInput
81
78
  ) -> Prompt:
@@ -84,65 +81,26 @@ class PromptMutationMixin:
84
81
  if "user" in request.scope:
85
82
  assert isinstance(user := request.user, PhoenixUser)
86
83
  user_id = int(user.identity)
87
-
88
- input_prompt_version = input.prompt_version
89
- tool_definitions = [tool.definition for tool in input_prompt_version.tools]
90
- tool_choice = cast(
91
- Optional[Union[str, dict[str, Any]]],
92
- cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
93
- "tool_choice", None
94
- ),
95
- )
96
- model_provider = ModelProvider(input_prompt_version.model_provider)
97
84
  try:
98
- tools = (
99
- normalize_tools(tool_definitions, model_provider, tool_choice)
100
- if tool_definitions
101
- else None
102
- )
103
- template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
104
- response_format = (
105
- normalize_response_format(
106
- input_prompt_version.response_format.definition,
107
- model_provider,
108
- )
109
- if input_prompt_version.response_format
110
- else None
111
- )
112
- invocation_parameters = validate_invocation_parameters(
113
- input_prompt_version.invocation_parameters,
114
- model_provider,
115
- )
85
+ prompt_version = input.prompt_version.to_orm_prompt_version(user_id)
116
86
  except ValidationError as error:
117
87
  raise BadRequest(str(error))
118
-
88
+ name = IdentifierModel.model_validate(str(input.name))
89
+ prompt = models.Prompt(
90
+ name=name,
91
+ description=input.description,
92
+ metadata_=input.metadata or {},
93
+ prompt_versions=[prompt_version],
94
+ )
119
95
  async with info.context.db() as session:
120
- prompt_version = models.PromptVersion(
121
- description=input_prompt_version.description,
122
- user_id=user_id,
123
- template_type="CHAT",
124
- template_format=input_prompt_version.template_format,
125
- template=template,
126
- invocation_parameters=invocation_parameters,
127
- tools=tools,
128
- response_format=response_format,
129
- model_provider=input_prompt_version.model_provider,
130
- model_name=input_prompt_version.model_name,
131
- )
132
- name = IdentifierModel.model_validate(str(input.name))
133
- prompt = models.Prompt(
134
- name=name,
135
- description=input.description,
136
- prompt_versions=[prompt_version],
137
- )
138
96
  session.add(prompt)
139
97
  try:
140
98
  await session.commit()
141
99
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
142
100
  raise Conflict(f"A prompt named '{input.name}' already exists")
143
- return to_gql_prompt_from_orm(prompt)
101
+ return Prompt(id=prompt.id, db_record=prompt)
144
102
 
145
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
103
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
146
104
  async def create_chat_prompt_version(
147
105
  self,
148
106
  info: Info[Context, None],
@@ -153,74 +111,28 @@ class PromptMutationMixin:
153
111
  if "user" in request.scope:
154
112
  assert isinstance(user := request.user, PhoenixUser)
155
113
  user_id = int(user.identity)
156
-
157
- input_prompt_version = input.prompt_version
158
- tool_definitions = [tool.definition for tool in input.prompt_version.tools]
159
- tool_choice = cast(
160
- Optional[Union[str, dict[str, Any]]],
161
- cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
162
- "tool_choice", None
163
- ),
164
- )
165
- model_provider = ModelProvider(input_prompt_version.model_provider)
166
114
  try:
167
- tools = (
168
- normalize_tools(tool_definitions, model_provider, tool_choice)
169
- if tool_definitions
170
- else None
171
- )
172
- template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
173
- response_format = (
174
- normalize_response_format(
175
- input_prompt_version.response_format.definition,
176
- model_provider,
177
- )
178
- if input_prompt_version.response_format
179
- else None
180
- )
181
- invocation_parameters = validate_invocation_parameters(
182
- input_prompt_version.invocation_parameters,
183
- model_provider,
184
- )
115
+ prompt_version = input.prompt_version.to_orm_prompt_version(user_id)
185
116
  except ValidationError as error:
186
117
  raise BadRequest(str(error))
187
-
188
118
  prompt_id = from_global_id_with_expected_type(
189
119
  global_id=input.prompt_id, expected_type_name=Prompt.__name__
190
120
  )
121
+ prompt_version.prompt_id = prompt_id
191
122
  async with info.context.db() as session:
192
- prompt = await session.get(models.Prompt, prompt_id)
193
- if not prompt:
194
- raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
195
-
196
- prompt_version = models.PromptVersion(
197
- prompt_id=prompt_id,
198
- description=input.prompt_version.description,
199
- user_id=user_id,
200
- template_type="CHAT",
201
- template_format=input.prompt_version.template_format,
202
- template=template,
203
- invocation_parameters=invocation_parameters,
204
- tools=tools,
205
- response_format=response_format,
206
- model_provider=input.prompt_version.model_provider,
207
- model_name=input.prompt_version.model_name,
208
- )
209
123
  session.add(prompt_version)
210
-
211
- # ensure prompt_version is flushed to the database before creating tags against the
212
- # prompt_version id
213
- await session.flush()
214
-
215
- if input.tags:
216
- for tag in input.tags:
217
- await upsert_prompt_version_tag(
218
- session, prompt_id, prompt_version.id, tag.name, tag.description
219
- )
220
-
221
- return to_gql_prompt_from_orm(prompt)
222
-
223
- @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
124
+ try:
125
+ await session.flush()
126
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
127
+ raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
128
+ if input.tags:
129
+ for tag in input.tags:
130
+ await upsert_prompt_version_tag(
131
+ session, prompt_id, prompt_version.id, tag.name, tag.description
132
+ )
133
+ return Prompt(id=prompt_id)
134
+
135
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
224
136
  async def delete_prompt(
225
137
  self, info: Info[Context, None], input: DeletePromptInput
226
138
  ) -> DeletePromptMutationPayload:
@@ -231,13 +143,13 @@ class PromptMutationMixin:
231
143
  stmt = delete(models.Prompt).where(models.Prompt.id == prompt_id)
232
144
  result = await session.execute(stmt)
233
145
 
234
- if result.rowcount == 0:
146
+ if result.rowcount == 0: # type: ignore[attr-defined]
235
147
  raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
236
148
 
237
149
  await session.commit()
238
150
  return DeletePromptMutationPayload(query=Query())
239
151
 
240
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
152
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
241
153
  async def clone_prompt(self, info: Info[Context, None], input: ClonePromptInput) -> Prompt:
242
154
  prompt_id = from_global_id_with_expected_type(
243
155
  global_id=input.prompt_id, expected_type_name=Prompt.__name__
@@ -256,10 +168,23 @@ class PromptMutationMixin:
256
168
 
257
169
  # Create new prompt
258
170
  name = IdentifierModel.model_validate(str(input.name))
171
+ # Handle description: inherit if UNSET, otherwise use value (can be None)
172
+ if input.description is UNSET:
173
+ description = prompt.description
174
+ else:
175
+ description = input.description.strip() if input.description is not None else None
176
+
177
+ # Handle metadata: inherit if UNSET, clear to empty dict if None, or use value
178
+ if input.metadata is UNSET:
179
+ metadata = prompt.metadata_
180
+ else:
181
+ metadata = input.metadata or {}
182
+
259
183
  new_prompt = models.Prompt(
260
184
  name=name,
261
- description=input.description,
262
185
  source_prompt_id=prompt_id,
186
+ description=description,
187
+ metadata_=metadata,
263
188
  )
264
189
 
265
190
  # Create copies of all versions
@@ -288,19 +213,30 @@ class PromptMutationMixin:
288
213
  await session.commit()
289
214
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
290
215
  raise Conflict(f"A prompt named '{input.name}' already exists")
291
- return to_gql_prompt_from_orm(new_prompt)
216
+ return Prompt(id=new_prompt.id, db_record=new_prompt)
292
217
 
293
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
218
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
294
219
  async def patch_prompt(self, info: Info[Context, None], input: PatchPromptInput) -> Prompt:
295
220
  prompt_id = from_global_id_with_expected_type(
296
221
  global_id=input.prompt_id, expected_type_name=Prompt.__name__
297
222
  )
298
223
 
224
+ values: dict[str, Any] = {}
225
+ if input.description is not UNSET:
226
+ values["description"] = (
227
+ input.description.strip() if input.description is not None else None
228
+ )
229
+ if input.metadata is not UNSET:
230
+ values["metadata_"] = input.metadata or {}
231
+
232
+ if not values:
233
+ raise BadRequest("No fields provided to update")
234
+
299
235
  async with info.context.db() as session:
300
236
  stmt = (
301
237
  update(models.Prompt)
302
238
  .where(models.Prompt.id == prompt_id)
303
- .values(description=input.description)
239
+ .values(**values)
304
240
  .returning(models.Prompt)
305
241
  )
306
242
 
@@ -310,4 +246,4 @@ class PromptMutationMixin:
310
246
  if prompt is None:
311
247
  raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
312
248
 
313
- return to_gql_prompt_from_orm(prompt)
249
+ return Prompt(id=prompt.id, db_record=prompt)
@@ -10,15 +10,15 @@ from strawberry.types import Info
10
10
 
11
11
  from phoenix.db import models
12
12
  from phoenix.db.types.identifier import Identifier as IdentifierModel
13
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
13
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
14
14
  from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
16
16
  from phoenix.server.api.queries import Query
17
17
  from phoenix.server.api.types.Identifier import Identifier
18
18
  from phoenix.server.api.types.node import from_global_id_with_expected_type
19
- from phoenix.server.api.types.Prompt import Prompt, to_gql_prompt_from_orm
19
+ from phoenix.server.api.types.Prompt import Prompt
20
20
  from phoenix.server.api.types.PromptVersion import PromptVersion
21
- from phoenix.server.api.types.PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
21
+ from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
22
22
 
23
23
 
24
24
  @strawberry.input
@@ -42,7 +42,7 @@ class PromptVersionTagMutationPayload:
42
42
 
43
43
  @strawberry.type
44
44
  class PromptVersionTagMutationMixin:
45
- @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
45
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
46
46
  async def delete_prompt_version_tag(
47
47
  self, info: Info[Context, None], input: DeletePromptVersionTagInput
48
48
  ) -> PromptVersionTagMutationPayload:
@@ -75,10 +75,12 @@ class PromptVersionTagMutationMixin:
75
75
  await session.delete(prompt_version_tag)
76
76
  await session.commit()
77
77
  return PromptVersionTagMutationPayload(
78
- prompt_version_tag=None, query=Query(), prompt=to_gql_prompt_from_orm(prompt)
78
+ prompt_version_tag=None,
79
+ query=Query(),
80
+ prompt=Prompt(id=prompt.id, db_record=prompt),
79
81
  )
80
82
 
81
- @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
83
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
82
84
  async def set_prompt_version_tag(
83
85
  self, info: Info[Context, None], input: SetPromptVersionTagInput
84
86
  ) -> PromptVersionTagMutationPayload:
@@ -111,9 +113,10 @@ class PromptVersionTagMutationMixin:
111
113
  except (PostgreSQLIntegrityError, SQLiteIntegrityError):
112
114
  raise Conflict("Failed to update PromptVersionTag.")
113
115
 
114
- version_tag = to_gql_prompt_version_tag(updated_tag)
115
116
  return PromptVersionTagMutationPayload(
116
- prompt_version_tag=version_tag, prompt=to_gql_prompt_from_orm(prompt), query=Query()
117
+ prompt_version_tag=PromptVersionTag(id=updated_tag.id, db_record=updated_tag),
118
+ prompt=Prompt(id=prompt.id, db_record=prompt),
119
+ query=Query(),
117
120
  )
118
121
 
119
122