arize-phoenix 7.12.3__py3-none-any.whl → 8.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (80) hide show
  1. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/METADATA +31 -28
  2. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/RECORD +70 -47
  3. phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
  4. phoenix/db/models.py +307 -0
  5. phoenix/db/types/__init__.py +0 -0
  6. phoenix/db/types/identifier.py +7 -0
  7. phoenix/db/types/model_provider.py +8 -0
  8. phoenix/server/api/context.py +2 -0
  9. phoenix/server/api/dataloaders/__init__.py +2 -0
  10. phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
  11. phoenix/server/api/helpers/jsonschema.py +135 -0
  12. phoenix/server/api/helpers/playground_clients.py +15 -15
  13. phoenix/server/api/helpers/playground_spans.py +9 -0
  14. phoenix/server/api/helpers/prompts/__init__.py +0 -0
  15. phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
  16. phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
  17. phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
  18. phoenix/server/api/helpers/prompts/models.py +575 -0
  19. phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
  20. phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
  21. phoenix/server/api/input_types/PromptVersionInput.py +133 -0
  22. phoenix/server/api/mutations/__init__.py +6 -0
  23. phoenix/server/api/mutations/chat_mutations.py +18 -16
  24. phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
  25. phoenix/server/api/mutations/prompt_mutations.py +312 -0
  26. phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
  27. phoenix/server/api/mutations/user_mutations.py +7 -6
  28. phoenix/server/api/openapi/schema.py +1 -0
  29. phoenix/server/api/queries.py +84 -31
  30. phoenix/server/api/routers/oauth2.py +3 -2
  31. phoenix/server/api/routers/v1/__init__.py +2 -0
  32. phoenix/server/api/routers/v1/datasets.py +1 -1
  33. phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
  34. phoenix/server/api/routers/v1/experiment_runs.py +1 -1
  35. phoenix/server/api/routers/v1/experiments.py +1 -1
  36. phoenix/server/api/routers/v1/models.py +45 -0
  37. phoenix/server/api/routers/v1/prompts.py +415 -0
  38. phoenix/server/api/routers/v1/spans.py +1 -1
  39. phoenix/server/api/routers/v1/traces.py +1 -1
  40. phoenix/server/api/routers/v1/utils.py +1 -1
  41. phoenix/server/api/subscriptions.py +21 -24
  42. phoenix/server/api/types/GenerativeProvider.py +4 -4
  43. phoenix/server/api/types/Identifier.py +15 -0
  44. phoenix/server/api/types/Project.py +5 -7
  45. phoenix/server/api/types/Prompt.py +134 -0
  46. phoenix/server/api/types/PromptLabel.py +41 -0
  47. phoenix/server/api/types/PromptVersion.py +148 -0
  48. phoenix/server/api/types/PromptVersionTag.py +27 -0
  49. phoenix/server/api/types/PromptVersionTemplate.py +148 -0
  50. phoenix/server/api/types/ResponseFormat.py +9 -0
  51. phoenix/server/api/types/ToolDefinition.py +9 -0
  52. phoenix/server/app.py +3 -0
  53. phoenix/server/static/.vite/manifest.json +45 -45
  54. phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
  55. phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
  56. phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
  57. phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
  58. phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
  59. phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
  60. phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
  61. phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
  62. phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
  63. phoenix/session/client.py +25 -21
  64. phoenix/utilities/client.py +6 -0
  65. phoenix/version.py +1 -1
  66. phoenix/server/api/input_types/TemplateOptions.py +0 -10
  67. phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
  68. phoenix/server/api/types/TemplateLanguage.py +0 -10
  69. phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
  70. phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
  71. phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
  72. phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
  73. phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
  74. phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
  75. phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
  76. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/WHEEL +0 -0
  77. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/entry_points.txt +0 -0
  78. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/IP_NOTICE +0 -0
  79. {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.1.dist-info}/licenses/LICENSE +0 -0
  80. /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
@@ -0,0 +1,133 @@
1
+ import json
2
+ from typing import Optional, cast
3
+
4
+ import strawberry
5
+ from strawberry import UNSET
6
+ from strawberry.scalars import JSON
7
+
8
+ from phoenix.db.types.model_provider import ModelProvider
9
+ from phoenix.server.api.helpers.prompts.models import (
10
+ ContentPart,
11
+ PromptChatTemplate,
12
+ PromptMessage,
13
+ PromptMessageRole,
14
+ PromptTemplateFormat,
15
+ RoleConversion,
16
+ TextContentPart,
17
+ ToolCallContentPart,
18
+ ToolCallFunction,
19
+ ToolResultContentPart,
20
+ )
21
+
22
+
23
+ @strawberry.input
24
+ class ToolDefinitionInput:
25
+ definition: JSON
26
+
27
+
28
+ @strawberry.input
29
+ class ResponseFormatInput:
30
+ definition: JSON
31
+
32
+
33
+ @strawberry.input
34
+ class TextContentValueInput:
35
+ text: str
36
+
37
+
38
+ @strawberry.input
39
+ class ToolResultContentValueInput:
40
+ tool_call_id: str
41
+ result: JSON
42
+
43
+
44
+ @strawberry.input
45
+ class ToolCallFunctionInput:
46
+ type: Optional[str] = strawberry.field(default="function")
47
+ name: str
48
+ arguments: str
49
+
50
+
51
+ @strawberry.input
52
+ class ToolCallContentValueInput:
53
+ tool_call_id: str
54
+ tool_call: ToolCallFunctionInput
55
+
56
+
57
+ @strawberry.input(one_of=True)
58
+ class ContentPartInput:
59
+ text: Optional[TextContentValueInput] = strawberry.UNSET
60
+ tool_call: Optional[ToolCallContentValueInput] = strawberry.UNSET
61
+ tool_result: Optional[ToolResultContentValueInput] = strawberry.UNSET
62
+
63
+
64
+ @strawberry.input
65
+ class PromptMessageInput:
66
+ role: str
67
+ content: list[ContentPartInput]
68
+
69
+
70
+ @strawberry.input
71
+ class PromptChatTemplateInput:
72
+ messages: list[PromptMessageInput]
73
+
74
+
75
+ @strawberry.input
76
+ class ChatPromptVersionInput:
77
+ description: Optional[str] = None
78
+ template_format: PromptTemplateFormat
79
+ template: PromptChatTemplateInput
80
+ invocation_parameters: JSON = strawberry.field(default_factory=dict)
81
+ tools: list[ToolDefinitionInput] = strawberry.field(default_factory=list)
82
+ response_format: Optional[ResponseFormatInput] = None
83
+ model_provider: ModelProvider
84
+ model_name: str
85
+
86
+
87
+ def to_pydantic_prompt_chat_template_v1(
88
+ prompt_chat_template_input: PromptChatTemplateInput,
89
+ ) -> PromptChatTemplate:
90
+ return PromptChatTemplate(
91
+ type="chat",
92
+ messages=[
93
+ to_pydantic_prompt_message(message) for message in prompt_chat_template_input.messages
94
+ ],
95
+ )
96
+
97
+
98
+ def to_pydantic_prompt_message(prompt_message_input: PromptMessageInput) -> PromptMessage:
99
+ content = [
100
+ to_pydantic_content_part(content_part) for content_part in prompt_message_input.content
101
+ ]
102
+ return PromptMessage(
103
+ role=RoleConversion.from_gql(PromptMessageRole(prompt_message_input.role)),
104
+ content=content,
105
+ )
106
+
107
+
108
+ def to_pydantic_content_part(content_part_input: ContentPartInput) -> ContentPart:
109
+ if content_part_input.text is not UNSET:
110
+ assert content_part_input.text is not None
111
+ return TextContentPart(
112
+ type="text",
113
+ text=content_part_input.text.text,
114
+ )
115
+ if content_part_input.tool_call is not UNSET:
116
+ assert content_part_input.tool_call is not None
117
+ return ToolCallContentPart(
118
+ type="tool_call",
119
+ tool_call_id=content_part_input.tool_call.tool_call_id,
120
+ tool_call=ToolCallFunction(
121
+ type="function",
122
+ name=content_part_input.tool_call.tool_call.name,
123
+ arguments=content_part_input.tool_call.tool_call.arguments,
124
+ ),
125
+ )
126
+ if content_part_input.tool_result is not UNSET:
127
+ assert content_part_input.tool_result is not None
128
+ return ToolResultContentPart(
129
+ type="tool_result",
130
+ tool_call_id=content_part_input.tool_result.tool_call_id,
131
+ tool_result=json.loads(cast(str, content_part_input.tool_result.result)),
132
+ )
133
+ raise ValueError("content part input has no content")
@@ -8,6 +8,9 @@ from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
8
8
  from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
9
9
  from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
10
10
  from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
11
+ from phoenix.server.api.mutations.prompt_label_mutations import PromptLabelMutationMixin
12
+ from phoenix.server.api.mutations.prompt_mutations import PromptMutationMixin
13
+ from phoenix.server.api.mutations.prompt_version_tag_mutations import PromptVersionTagMutationMixin
11
14
  from phoenix.server.api.mutations.span_annotations_mutations import SpanAnnotationMutationMixin
12
15
  from phoenix.server.api.mutations.trace_annotations_mutations import TraceAnnotationMutationMixin
13
16
  from phoenix.server.api.mutations.user_mutations import UserMutationMixin
@@ -20,6 +23,9 @@ class Mutation(
20
23
  ExperimentMutationMixin,
21
24
  ExportEventsMutationMixin,
22
25
  ProjectMutationMixin,
26
+ PromptMutationMixin,
27
+ PromptVersionTagMutationMixin,
28
+ PromptLabelMutationMixin,
23
29
  SpanAnnotationMutationMixin,
24
30
  TraceAnnotationMutationMixin,
25
31
  UserMutationMixin,
@@ -41,14 +41,15 @@ from phoenix.server.api.helpers.playground_spans import (
41
41
  llm_model_name,
42
42
  llm_span_kind,
43
43
  llm_tools,
44
+ prompt_metadata,
44
45
  )
46
+ from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
45
47
  from phoenix.server.api.input_types.ChatCompletionInput import (
46
48
  ChatCompletionInput,
47
49
  ChatCompletionOverDatasetInput,
48
50
  )
49
- from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
51
+ from phoenix.server.api.input_types.PromptTemplateOptions import PromptTemplateOptions
50
52
  from phoenix.server.api.subscriptions import (
51
- _default_playground_experiment_description,
52
53
  _default_playground_experiment_name,
53
54
  )
54
55
  from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
@@ -60,7 +61,6 @@ from phoenix.server.api.types.Dataset import Dataset
60
61
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
61
62
  from phoenix.server.api.types.node import from_global_id_with_expected_type
62
63
  from phoenix.server.api.types.Span import Span, to_gql_span
63
- from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
64
64
  from phoenix.server.dml_event import SpanInsertEvent
65
65
  from phoenix.trace.attributes import unflatten
66
66
  from phoenix.trace.schemas import SpanException
@@ -178,9 +178,9 @@ class ChatCompletionMutationMixin:
178
178
  experiment = models.Experiment(
179
179
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
180
180
  dataset_version_id=resolved_version_id,
181
- name=input.experiment_name or _default_playground_experiment_name(),
182
- description=input.experiment_description
183
- or _default_playground_experiment_description(dataset_name=dataset.name),
181
+ name=input.experiment_name
182
+ or _default_playground_experiment_name(input.prompt_name),
183
+ description=input.experiment_description,
184
184
  repetitions=1,
185
185
  metadata_=input.experiment_metadata or dict(),
186
186
  project_name=PLAYGROUND_PROJECT_NAME,
@@ -203,10 +203,11 @@ class ChatCompletionMutationMixin:
203
203
  messages=input.messages,
204
204
  tools=input.tools,
205
205
  invocation_parameters=input.invocation_parameters,
206
- template=TemplateOptions(
207
- language=input.template_language,
206
+ template=PromptTemplateOptions(
207
+ format=input.template_format,
208
208
  variables=revision.input,
209
209
  ),
210
+ prompt_name=input.prompt_name,
210
211
  ),
211
212
  )
212
213
  for revision in batch
@@ -300,6 +301,7 @@ class ChatCompletionMutationMixin:
300
301
  input: ChatCompletionInput,
301
302
  ) -> ChatCompletionMutationPayload:
302
303
  attributes: dict[str, Any] = {}
304
+ attributes.update(dict(prompt_metadata(input.prompt_name)))
303
305
 
304
306
  messages = [
305
307
  (
@@ -453,12 +455,12 @@ class ChatCompletionMutationMixin:
453
455
 
454
456
  def _formatted_messages(
455
457
  messages: Iterable[ChatCompletionMessage],
456
- template_options: TemplateOptions,
458
+ template_options: PromptTemplateOptions,
457
459
  ) -> Iterator[ChatCompletionMessage]:
458
460
  """
459
461
  Formats the messages using the given template options.
460
462
  """
461
- template_formatter = _template_formatter(template_language=template_options.language)
463
+ template_formatter = _template_formatter(template_format=template_options.format)
462
464
  (
463
465
  roles,
464
466
  templates,
@@ -473,17 +475,17 @@ def _formatted_messages(
473
475
  return formatted_messages
474
476
 
475
477
 
476
- def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatter:
478
+ def _template_formatter(template_format: PromptTemplateFormat) -> TemplateFormatter:
477
479
  """
478
- Instantiates the appropriate template formatter for the template language.
480
+ Instantiates the appropriate template formatter for the template format.
479
481
  """
480
- if template_language is TemplateLanguage.MUSTACHE:
482
+ if template_format is PromptTemplateFormat.MUSTACHE:
481
483
  return MustacheTemplateFormatter()
482
- if template_language is TemplateLanguage.F_STRING:
484
+ if template_format is PromptTemplateFormat.F_STRING:
483
485
  return FStringTemplateFormatter()
484
- if template_language is TemplateLanguage.NONE:
486
+ if template_format is PromptTemplateFormat.NONE:
485
487
  return NoOpFormatter()
486
- assert_never(template_language)
488
+ assert_never(template_format)
487
489
 
488
490
 
489
491
  def _output_value_and_mime_type(
@@ -0,0 +1,191 @@
1
+ # file: PromptLabelMutations.py
2
+
3
+ from typing import Optional
4
+
5
+ import strawberry
6
+ from sqlalchemy import delete
7
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
8
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
9
+ from strawberry.relay import GlobalID
10
+ from strawberry.types import Info
11
+
12
+ from phoenix.db import models
13
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
14
+ from phoenix.server.api.context import Context
15
+ from phoenix.server.api.exceptions import Conflict, NotFound
16
+ from phoenix.server.api.queries import Query
17
+ from phoenix.server.api.types.Identifier import Identifier
18
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
19
+ from phoenix.server.api.types.Prompt import Prompt
20
+ from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
21
+
22
+
23
+ @strawberry.input
24
+ class CreatePromptLabelInput:
25
+ name: Identifier
26
+ description: Optional[str] = None
27
+
28
+
29
+ @strawberry.input
30
+ class PatchPromptLabelInput:
31
+ prompt_label_id: GlobalID
32
+ name: Optional[Identifier] = None
33
+ description: Optional[str] = None
34
+
35
+
36
+ @strawberry.input
37
+ class DeletePromptLabelInput:
38
+ prompt_label_id: GlobalID
39
+
40
+
41
+ @strawberry.input
42
+ class SetPromptLabelInput:
43
+ prompt_id: GlobalID
44
+ prompt_label_id: GlobalID
45
+
46
+
47
+ @strawberry.input
48
+ class UnsetPromptLabelInput:
49
+ prompt_id: GlobalID
50
+ prompt_label_id: GlobalID
51
+
52
+
53
+ @strawberry.type
54
+ class PromptLabelMutationPayload:
55
+ prompt_label: Optional["PromptLabel"]
56
+ query: "Query"
57
+
58
+
59
+ @strawberry.type
60
+ class PromptLabelMutationMixin:
61
+ @strawberry.mutation
62
+ async def create_prompt_label(
63
+ self, info: Info[Context, None], input: CreatePromptLabelInput
64
+ ) -> PromptLabelMutationPayload:
65
+ async with info.context.db() as session:
66
+ name = IdentifierModel.model_validate(str(input.name))
67
+ label_orm = models.PromptLabel(name=name, description=input.description)
68
+ session.add(label_orm)
69
+
70
+ try:
71
+ await session.commit()
72
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
73
+ raise Conflict(f"A prompt label named '{name}' already exists.")
74
+
75
+ return PromptLabelMutationPayload(
76
+ prompt_label=to_gql_prompt_label(label_orm),
77
+ query=Query(),
78
+ )
79
+
80
+ @strawberry.mutation
81
+ async def patch_prompt_label(
82
+ self, info: Info[Context, None], input: PatchPromptLabelInput
83
+ ) -> PromptLabelMutationPayload:
84
+ validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
85
+ async with info.context.db() as session:
86
+ label_id = from_global_id_with_expected_type(
87
+ input.prompt_label_id, PromptLabel.__name__
88
+ )
89
+
90
+ label_orm = await session.get(models.PromptLabel, label_id)
91
+ if not label_orm:
92
+ raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
93
+
94
+ if validated_name is not None:
95
+ label_orm.name = validated_name.root
96
+ if input.description is not None:
97
+ label_orm.description = input.description
98
+
99
+ try:
100
+ await session.commit()
101
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError):
102
+ raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
103
+
104
+ return PromptLabelMutationPayload(
105
+ prompt_label=to_gql_prompt_label(label_orm),
106
+ query=Query(),
107
+ )
108
+
109
+ @strawberry.mutation
110
+ async def delete_prompt_label(
111
+ self, info: Info[Context, None], input: DeletePromptLabelInput
112
+ ) -> PromptLabelMutationPayload:
113
+ """
114
+ Deletes a PromptLabel (and any crosswalk references).
115
+ """
116
+ async with info.context.db() as session:
117
+ label_id = from_global_id_with_expected_type(
118
+ input.prompt_label_id, PromptLabel.__name__
119
+ )
120
+ stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id)
121
+ result = await session.execute(stmt)
122
+
123
+ if result.rowcount == 0:
124
+ raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
125
+
126
+ await session.commit()
127
+
128
+ return PromptLabelMutationPayload(
129
+ prompt_label=None,
130
+ query=Query(),
131
+ )
132
+
133
+ @strawberry.mutation
134
+ async def set_prompt_label(
135
+ self, info: Info[Context, None], input: SetPromptLabelInput
136
+ ) -> PromptLabelMutationPayload:
137
+ async with info.context.db() as session:
138
+ prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
139
+ label_id = from_global_id_with_expected_type(
140
+ input.prompt_label_id, PromptLabel.__name__
141
+ )
142
+
143
+ crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
144
+ session.add(crosswalk)
145
+
146
+ try:
147
+ await session.commit()
148
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
149
+ # The error could be:
150
+ # - Unique constraint violation => row already exists
151
+ # - Foreign key violation => prompt_id or label_id doesn't exist
152
+ raise Conflict("Failed to associate PromptLabel with Prompt.") from e
153
+
154
+ label_orm = await session.get(models.PromptLabel, label_id)
155
+ if not label_orm:
156
+ raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
157
+
158
+ return PromptLabelMutationPayload(
159
+ prompt_label=to_gql_prompt_label(label_orm),
160
+ query=Query(),
161
+ )
162
+
163
+ @strawberry.mutation
164
+ async def unset_prompt_label(
165
+ self, info: Info[Context, None], input: UnsetPromptLabelInput
166
+ ) -> PromptLabelMutationPayload:
167
+ """
168
+ Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
169
+ """
170
+ async with info.context.db() as session:
171
+ prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
172
+ label_id = from_global_id_with_expected_type(
173
+ input.prompt_label_id, PromptLabel.__name__
174
+ )
175
+
176
+ stmt = delete(models.PromptPromptLabel).where(
177
+ (models.PromptPromptLabel.prompt_id == prompt_id)
178
+ & (models.PromptPromptLabel.prompt_label_id == label_id)
179
+ )
180
+ result = await session.execute(stmt)
181
+
182
+ if result.rowcount == 0:
183
+ raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.")
184
+
185
+ await session.commit()
186
+
187
+ label_orm = await session.get(models.PromptLabel, label_id)
188
+ return PromptLabelMutationPayload(
189
+ prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
190
+ query=Query(),
191
+ )