arize-phoenix 7.12.3__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/METADATA +31 -28
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/RECORD +70 -47
- phoenix/db/migrations/versions/bc8fea3c2bc8_add_prompt_tables.py +197 -0
- phoenix/db/models.py +307 -0
- phoenix/db/types/__init__.py +0 -0
- phoenix/db/types/identifier.py +7 -0
- phoenix/db/types/model_provider.py +8 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/prompt_version_sequence_number.py +35 -0
- phoenix/server/api/helpers/jsonschema.py +135 -0
- phoenix/server/api/helpers/playground_clients.py +15 -15
- phoenix/server/api/helpers/playground_spans.py +9 -0
- phoenix/server/api/helpers/prompts/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/__init__.py +0 -0
- phoenix/server/api/helpers/prompts/conversions/anthropic.py +87 -0
- phoenix/server/api/helpers/prompts/conversions/openai.py +78 -0
- phoenix/server/api/helpers/prompts/models.py +575 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +9 -4
- phoenix/server/api/input_types/PromptTemplateOptions.py +10 -0
- phoenix/server/api/input_types/PromptVersionInput.py +133 -0
- phoenix/server/api/mutations/__init__.py +6 -0
- phoenix/server/api/mutations/chat_mutations.py +18 -16
- phoenix/server/api/mutations/prompt_label_mutations.py +191 -0
- phoenix/server/api/mutations/prompt_mutations.py +312 -0
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +148 -0
- phoenix/server/api/mutations/user_mutations.py +7 -6
- phoenix/server/api/openapi/schema.py +1 -0
- phoenix/server/api/queries.py +84 -31
- phoenix/server/api/routers/oauth2.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/datasets.py +1 -1
- phoenix/server/api/routers/v1/experiment_evaluations.py +1 -1
- phoenix/server/api/routers/v1/experiment_runs.py +1 -1
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/routers/v1/models.py +45 -0
- phoenix/server/api/routers/v1/prompts.py +412 -0
- phoenix/server/api/routers/v1/spans.py +1 -1
- phoenix/server/api/routers/v1/traces.py +1 -1
- phoenix/server/api/routers/v1/utils.py +1 -1
- phoenix/server/api/subscriptions.py +21 -24
- phoenix/server/api/types/GenerativeProvider.py +4 -4
- phoenix/server/api/types/Identifier.py +15 -0
- phoenix/server/api/types/Project.py +5 -7
- phoenix/server/api/types/Prompt.py +134 -0
- phoenix/server/api/types/PromptLabel.py +41 -0
- phoenix/server/api/types/PromptVersion.py +148 -0
- phoenix/server/api/types/PromptVersionTag.py +27 -0
- phoenix/server/api/types/PromptVersionTemplate.py +148 -0
- phoenix/server/api/types/ResponseFormat.py +9 -0
- phoenix/server/api/types/ToolDefinition.py +9 -0
- phoenix/server/app.py +3 -0
- phoenix/server/static/.vite/manifest.json +45 -45
- phoenix/server/static/assets/components-B-qgPyHv.js +2699 -0
- phoenix/server/static/assets/index-D4KO1IcF.js +1125 -0
- phoenix/server/static/assets/pages-DdcuL3Rh.js +5634 -0
- phoenix/server/static/assets/vendor-DQp7CrDA.js +894 -0
- phoenix/server/static/assets/vendor-arizeai-C1nEIEQq.js +657 -0
- phoenix/server/static/assets/vendor-codemirror-BZXYUIkP.js +24 -0
- phoenix/server/static/assets/vendor-recharts-BUFpwCVD.js +59 -0
- phoenix/server/static/assets/{vendor-shiki-Cl9QBraO.js → vendor-shiki-C8L-c9jT.js} +2 -2
- phoenix/server/static/assets/{vendor-three-DwGkEfCM.js → vendor-three-C-AGeJYv.js} +1 -1
- phoenix/session/client.py +25 -21
- phoenix/utilities/client.py +6 -0
- phoenix/version.py +1 -1
- phoenix/server/api/input_types/TemplateOptions.py +0 -10
- phoenix/server/api/routers/v1/pydantic_compat.py +0 -78
- phoenix/server/api/types/TemplateLanguage.py +0 -10
- phoenix/server/static/assets/components-DckIzNmE.js +0 -2125
- phoenix/server/static/assets/index-Bf25Ogon.js +0 -113
- phoenix/server/static/assets/pages-DL7J9q9w.js +0 -4463
- phoenix/server/static/assets/vendor-DvC8cT4X.js +0 -894
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +0 -662
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +0 -24
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +0 -59
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-7.12.3.dist-info → arize_phoenix-8.0.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/static/assets/{vendor-DxkFTwjz.css → vendor-Cg6lcjUC.css} +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Optional, cast
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry import UNSET
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
from phoenix.db.types.model_provider import ModelProvider
|
|
9
|
+
from phoenix.server.api.helpers.prompts.models import (
|
|
10
|
+
ContentPart,
|
|
11
|
+
PromptChatTemplate,
|
|
12
|
+
PromptMessage,
|
|
13
|
+
PromptMessageRole,
|
|
14
|
+
PromptTemplateFormat,
|
|
15
|
+
RoleConversion,
|
|
16
|
+
TextContentPart,
|
|
17
|
+
ToolCallContentPart,
|
|
18
|
+
ToolCallFunction,
|
|
19
|
+
ToolResultContentPart,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@strawberry.input
|
|
24
|
+
class ToolDefinitionInput:
|
|
25
|
+
definition: JSON
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@strawberry.input
|
|
29
|
+
class ResponseFormatInput:
|
|
30
|
+
definition: JSON
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@strawberry.input
|
|
34
|
+
class TextContentValueInput:
|
|
35
|
+
text: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@strawberry.input
|
|
39
|
+
class ToolResultContentValueInput:
|
|
40
|
+
tool_call_id: str
|
|
41
|
+
result: JSON
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@strawberry.input
|
|
45
|
+
class ToolCallFunctionInput:
|
|
46
|
+
type: Optional[str] = strawberry.field(default="function")
|
|
47
|
+
name: str
|
|
48
|
+
arguments: str
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@strawberry.input
|
|
52
|
+
class ToolCallContentValueInput:
|
|
53
|
+
tool_call_id: str
|
|
54
|
+
tool_call: ToolCallFunctionInput
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.input(one_of=True)
|
|
58
|
+
class ContentPartInput:
|
|
59
|
+
text: Optional[TextContentValueInput] = strawberry.UNSET
|
|
60
|
+
tool_call: Optional[ToolCallContentValueInput] = strawberry.UNSET
|
|
61
|
+
tool_result: Optional[ToolResultContentValueInput] = strawberry.UNSET
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@strawberry.input
|
|
65
|
+
class PromptMessageInput:
|
|
66
|
+
role: str
|
|
67
|
+
content: list[ContentPartInput]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@strawberry.input
|
|
71
|
+
class PromptChatTemplateInput:
|
|
72
|
+
messages: list[PromptMessageInput]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@strawberry.input
|
|
76
|
+
class ChatPromptVersionInput:
|
|
77
|
+
description: Optional[str] = None
|
|
78
|
+
template_format: PromptTemplateFormat
|
|
79
|
+
template: PromptChatTemplateInput
|
|
80
|
+
invocation_parameters: JSON = strawberry.field(default_factory=dict)
|
|
81
|
+
tools: list[ToolDefinitionInput] = strawberry.field(default_factory=list)
|
|
82
|
+
response_format: Optional[ResponseFormatInput] = None
|
|
83
|
+
model_provider: ModelProvider
|
|
84
|
+
model_name: str
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def to_pydantic_prompt_chat_template_v1(
|
|
88
|
+
prompt_chat_template_input: PromptChatTemplateInput,
|
|
89
|
+
) -> PromptChatTemplate:
|
|
90
|
+
return PromptChatTemplate(
|
|
91
|
+
type="chat",
|
|
92
|
+
messages=[
|
|
93
|
+
to_pydantic_prompt_message(message) for message in prompt_chat_template_input.messages
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def to_pydantic_prompt_message(prompt_message_input: PromptMessageInput) -> PromptMessage:
|
|
99
|
+
content = [
|
|
100
|
+
to_pydantic_content_part(content_part) for content_part in prompt_message_input.content
|
|
101
|
+
]
|
|
102
|
+
return PromptMessage(
|
|
103
|
+
role=RoleConversion.from_gql(PromptMessageRole(prompt_message_input.role)),
|
|
104
|
+
content=content,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def to_pydantic_content_part(content_part_input: ContentPartInput) -> ContentPart:
|
|
109
|
+
if content_part_input.text is not UNSET:
|
|
110
|
+
assert content_part_input.text is not None
|
|
111
|
+
return TextContentPart(
|
|
112
|
+
type="text",
|
|
113
|
+
text=content_part_input.text.text,
|
|
114
|
+
)
|
|
115
|
+
if content_part_input.tool_call is not UNSET:
|
|
116
|
+
assert content_part_input.tool_call is not None
|
|
117
|
+
return ToolCallContentPart(
|
|
118
|
+
type="tool_call",
|
|
119
|
+
tool_call_id=content_part_input.tool_call.tool_call_id,
|
|
120
|
+
tool_call=ToolCallFunction(
|
|
121
|
+
type="function",
|
|
122
|
+
name=content_part_input.tool_call.tool_call.name,
|
|
123
|
+
arguments=content_part_input.tool_call.tool_call.arguments,
|
|
124
|
+
),
|
|
125
|
+
)
|
|
126
|
+
if content_part_input.tool_result is not UNSET:
|
|
127
|
+
assert content_part_input.tool_result is not None
|
|
128
|
+
return ToolResultContentPart(
|
|
129
|
+
type="tool_result",
|
|
130
|
+
tool_call_id=content_part_input.tool_result.tool_call_id,
|
|
131
|
+
tool_result=json.loads(cast(str, content_part_input.tool_result.result)),
|
|
132
|
+
)
|
|
133
|
+
raise ValueError("content part input has no content")
|
|
@@ -8,6 +8,9 @@ from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
|
8
8
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
9
9
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
10
10
|
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
|
|
11
|
+
from phoenix.server.api.mutations.prompt_label_mutations import PromptLabelMutationMixin
|
|
12
|
+
from phoenix.server.api.mutations.prompt_mutations import PromptMutationMixin
|
|
13
|
+
from phoenix.server.api.mutations.prompt_version_tag_mutations import PromptVersionTagMutationMixin
|
|
11
14
|
from phoenix.server.api.mutations.span_annotations_mutations import SpanAnnotationMutationMixin
|
|
12
15
|
from phoenix.server.api.mutations.trace_annotations_mutations import TraceAnnotationMutationMixin
|
|
13
16
|
from phoenix.server.api.mutations.user_mutations import UserMutationMixin
|
|
@@ -20,6 +23,9 @@ class Mutation(
|
|
|
20
23
|
ExperimentMutationMixin,
|
|
21
24
|
ExportEventsMutationMixin,
|
|
22
25
|
ProjectMutationMixin,
|
|
26
|
+
PromptMutationMixin,
|
|
27
|
+
PromptVersionTagMutationMixin,
|
|
28
|
+
PromptLabelMutationMixin,
|
|
23
29
|
SpanAnnotationMutationMixin,
|
|
24
30
|
TraceAnnotationMutationMixin,
|
|
25
31
|
UserMutationMixin,
|
|
@@ -41,14 +41,15 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
41
41
|
llm_model_name,
|
|
42
42
|
llm_span_kind,
|
|
43
43
|
llm_tools,
|
|
44
|
+
prompt_metadata,
|
|
44
45
|
)
|
|
46
|
+
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
45
47
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
46
48
|
ChatCompletionInput,
|
|
47
49
|
ChatCompletionOverDatasetInput,
|
|
48
50
|
)
|
|
49
|
-
from phoenix.server.api.input_types.
|
|
51
|
+
from phoenix.server.api.input_types.PromptTemplateOptions import PromptTemplateOptions
|
|
50
52
|
from phoenix.server.api.subscriptions import (
|
|
51
|
-
_default_playground_experiment_description,
|
|
52
53
|
_default_playground_experiment_name,
|
|
53
54
|
)
|
|
54
55
|
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
|
|
@@ -60,7 +61,6 @@ from phoenix.server.api.types.Dataset import Dataset
|
|
|
60
61
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
61
62
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
62
63
|
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
63
|
-
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
|
|
64
64
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
65
65
|
from phoenix.trace.attributes import unflatten
|
|
66
66
|
from phoenix.trace.schemas import SpanException
|
|
@@ -178,9 +178,9 @@ class ChatCompletionMutationMixin:
|
|
|
178
178
|
experiment = models.Experiment(
|
|
179
179
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
180
180
|
dataset_version_id=resolved_version_id,
|
|
181
|
-
name=input.experiment_name
|
|
182
|
-
|
|
183
|
-
|
|
181
|
+
name=input.experiment_name
|
|
182
|
+
or _default_playground_experiment_name(input.prompt_name),
|
|
183
|
+
description=input.experiment_description,
|
|
184
184
|
repetitions=1,
|
|
185
185
|
metadata_=input.experiment_metadata or dict(),
|
|
186
186
|
project_name=PLAYGROUND_PROJECT_NAME,
|
|
@@ -203,10 +203,11 @@ class ChatCompletionMutationMixin:
|
|
|
203
203
|
messages=input.messages,
|
|
204
204
|
tools=input.tools,
|
|
205
205
|
invocation_parameters=input.invocation_parameters,
|
|
206
|
-
template=
|
|
207
|
-
|
|
206
|
+
template=PromptTemplateOptions(
|
|
207
|
+
format=input.template_format,
|
|
208
208
|
variables=revision.input,
|
|
209
209
|
),
|
|
210
|
+
prompt_name=input.prompt_name,
|
|
210
211
|
),
|
|
211
212
|
)
|
|
212
213
|
for revision in batch
|
|
@@ -300,6 +301,7 @@ class ChatCompletionMutationMixin:
|
|
|
300
301
|
input: ChatCompletionInput,
|
|
301
302
|
) -> ChatCompletionMutationPayload:
|
|
302
303
|
attributes: dict[str, Any] = {}
|
|
304
|
+
attributes.update(dict(prompt_metadata(input.prompt_name)))
|
|
303
305
|
|
|
304
306
|
messages = [
|
|
305
307
|
(
|
|
@@ -453,12 +455,12 @@ class ChatCompletionMutationMixin:
|
|
|
453
455
|
|
|
454
456
|
def _formatted_messages(
|
|
455
457
|
messages: Iterable[ChatCompletionMessage],
|
|
456
|
-
template_options:
|
|
458
|
+
template_options: PromptTemplateOptions,
|
|
457
459
|
) -> Iterator[ChatCompletionMessage]:
|
|
458
460
|
"""
|
|
459
461
|
Formats the messages using the given template options.
|
|
460
462
|
"""
|
|
461
|
-
template_formatter = _template_formatter(
|
|
463
|
+
template_formatter = _template_formatter(template_format=template_options.format)
|
|
462
464
|
(
|
|
463
465
|
roles,
|
|
464
466
|
templates,
|
|
@@ -473,17 +475,17 @@ def _formatted_messages(
|
|
|
473
475
|
return formatted_messages
|
|
474
476
|
|
|
475
477
|
|
|
476
|
-
def _template_formatter(
|
|
478
|
+
def _template_formatter(template_format: PromptTemplateFormat) -> TemplateFormatter:
|
|
477
479
|
"""
|
|
478
|
-
Instantiates the appropriate template formatter for the template
|
|
480
|
+
Instantiates the appropriate template formatter for the template format.
|
|
479
481
|
"""
|
|
480
|
-
if
|
|
482
|
+
if template_format is PromptTemplateFormat.MUSTACHE:
|
|
481
483
|
return MustacheTemplateFormatter()
|
|
482
|
-
if
|
|
484
|
+
if template_format is PromptTemplateFormat.F_STRING:
|
|
483
485
|
return FStringTemplateFormatter()
|
|
484
|
-
if
|
|
486
|
+
if template_format is PromptTemplateFormat.NONE:
|
|
485
487
|
return NoOpFormatter()
|
|
486
|
-
assert_never(
|
|
488
|
+
assert_never(template_format)
|
|
487
489
|
|
|
488
490
|
|
|
489
491
|
def _output_value_and_mime_type(
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# file: PromptLabelMutations.py
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import strawberry
|
|
6
|
+
from sqlalchemy import delete
|
|
7
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
8
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
9
|
+
from strawberry.relay import GlobalID
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
14
|
+
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.exceptions import Conflict, NotFound
|
|
16
|
+
from phoenix.server.api.queries import Query
|
|
17
|
+
from phoenix.server.api.types.Identifier import Identifier
|
|
18
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
19
|
+
from phoenix.server.api.types.Prompt import Prompt
|
|
20
|
+
from phoenix.server.api.types.PromptLabel import PromptLabel, to_gql_prompt_label
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@strawberry.input
|
|
24
|
+
class CreatePromptLabelInput:
|
|
25
|
+
name: Identifier
|
|
26
|
+
description: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@strawberry.input
|
|
30
|
+
class PatchPromptLabelInput:
|
|
31
|
+
prompt_label_id: GlobalID
|
|
32
|
+
name: Optional[Identifier] = None
|
|
33
|
+
description: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@strawberry.input
|
|
37
|
+
class DeletePromptLabelInput:
|
|
38
|
+
prompt_label_id: GlobalID
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@strawberry.input
|
|
42
|
+
class SetPromptLabelInput:
|
|
43
|
+
prompt_id: GlobalID
|
|
44
|
+
prompt_label_id: GlobalID
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@strawberry.input
|
|
48
|
+
class UnsetPromptLabelInput:
|
|
49
|
+
prompt_id: GlobalID
|
|
50
|
+
prompt_label_id: GlobalID
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@strawberry.type
|
|
54
|
+
class PromptLabelMutationPayload:
|
|
55
|
+
prompt_label: Optional["PromptLabel"]
|
|
56
|
+
query: "Query"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@strawberry.type
|
|
60
|
+
class PromptLabelMutationMixin:
|
|
61
|
+
@strawberry.mutation
|
|
62
|
+
async def create_prompt_label(
|
|
63
|
+
self, info: Info[Context, None], input: CreatePromptLabelInput
|
|
64
|
+
) -> PromptLabelMutationPayload:
|
|
65
|
+
async with info.context.db() as session:
|
|
66
|
+
name = IdentifierModel.model_validate(str(input.name))
|
|
67
|
+
label_orm = models.PromptLabel(name=name, description=input.description)
|
|
68
|
+
session.add(label_orm)
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
await session.commit()
|
|
72
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
73
|
+
raise Conflict(f"A prompt label named '{name}' already exists.")
|
|
74
|
+
|
|
75
|
+
return PromptLabelMutationPayload(
|
|
76
|
+
prompt_label=to_gql_prompt_label(label_orm),
|
|
77
|
+
query=Query(),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@strawberry.mutation
|
|
81
|
+
async def patch_prompt_label(
|
|
82
|
+
self, info: Info[Context, None], input: PatchPromptLabelInput
|
|
83
|
+
) -> PromptLabelMutationPayload:
|
|
84
|
+
validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
|
|
85
|
+
async with info.context.db() as session:
|
|
86
|
+
label_id = from_global_id_with_expected_type(
|
|
87
|
+
input.prompt_label_id, PromptLabel.__name__
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
label_orm = await session.get(models.PromptLabel, label_id)
|
|
91
|
+
if not label_orm:
|
|
92
|
+
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
93
|
+
|
|
94
|
+
if validated_name is not None:
|
|
95
|
+
label_orm.name = validated_name.root
|
|
96
|
+
if input.description is not None:
|
|
97
|
+
label_orm.description = input.description
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
await session.commit()
|
|
101
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
102
|
+
raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
|
|
103
|
+
|
|
104
|
+
return PromptLabelMutationPayload(
|
|
105
|
+
prompt_label=to_gql_prompt_label(label_orm),
|
|
106
|
+
query=Query(),
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
@strawberry.mutation
|
|
110
|
+
async def delete_prompt_label(
|
|
111
|
+
self, info: Info[Context, None], input: DeletePromptLabelInput
|
|
112
|
+
) -> PromptLabelMutationPayload:
|
|
113
|
+
"""
|
|
114
|
+
Deletes a PromptLabel (and any crosswalk references).
|
|
115
|
+
"""
|
|
116
|
+
async with info.context.db() as session:
|
|
117
|
+
label_id = from_global_id_with_expected_type(
|
|
118
|
+
input.prompt_label_id, PromptLabel.__name__
|
|
119
|
+
)
|
|
120
|
+
stmt = delete(models.PromptLabel).where(models.PromptLabel.id == label_id)
|
|
121
|
+
result = await session.execute(stmt)
|
|
122
|
+
|
|
123
|
+
if result.rowcount == 0:
|
|
124
|
+
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
125
|
+
|
|
126
|
+
await session.commit()
|
|
127
|
+
|
|
128
|
+
return PromptLabelMutationPayload(
|
|
129
|
+
prompt_label=None,
|
|
130
|
+
query=Query(),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@strawberry.mutation
|
|
134
|
+
async def set_prompt_label(
|
|
135
|
+
self, info: Info[Context, None], input: SetPromptLabelInput
|
|
136
|
+
) -> PromptLabelMutationPayload:
|
|
137
|
+
async with info.context.db() as session:
|
|
138
|
+
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
139
|
+
label_id = from_global_id_with_expected_type(
|
|
140
|
+
input.prompt_label_id, PromptLabel.__name__
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
crosswalk = models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
|
|
144
|
+
session.add(crosswalk)
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
await session.commit()
|
|
148
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
149
|
+
# The error could be:
|
|
150
|
+
# - Unique constraint violation => row already exists
|
|
151
|
+
# - Foreign key violation => prompt_id or label_id doesn't exist
|
|
152
|
+
raise Conflict("Failed to associate PromptLabel with Prompt.") from e
|
|
153
|
+
|
|
154
|
+
label_orm = await session.get(models.PromptLabel, label_id)
|
|
155
|
+
if not label_orm:
|
|
156
|
+
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
157
|
+
|
|
158
|
+
return PromptLabelMutationPayload(
|
|
159
|
+
prompt_label=to_gql_prompt_label(label_orm),
|
|
160
|
+
query=Query(),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
@strawberry.mutation
|
|
164
|
+
async def unset_prompt_label(
|
|
165
|
+
self, info: Info[Context, None], input: UnsetPromptLabelInput
|
|
166
|
+
) -> PromptLabelMutationPayload:
|
|
167
|
+
"""
|
|
168
|
+
Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
|
|
169
|
+
"""
|
|
170
|
+
async with info.context.db() as session:
|
|
171
|
+
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
172
|
+
label_id = from_global_id_with_expected_type(
|
|
173
|
+
input.prompt_label_id, PromptLabel.__name__
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
stmt = delete(models.PromptPromptLabel).where(
|
|
177
|
+
(models.PromptPromptLabel.prompt_id == prompt_id)
|
|
178
|
+
& (models.PromptPromptLabel.prompt_label_id == label_id)
|
|
179
|
+
)
|
|
180
|
+
result = await session.execute(stmt)
|
|
181
|
+
|
|
182
|
+
if result.rowcount == 0:
|
|
183
|
+
raise NotFound(f"No association between prompt={prompt_id} and label={label_id}.")
|
|
184
|
+
|
|
185
|
+
await session.commit()
|
|
186
|
+
|
|
187
|
+
label_orm = await session.get(models.PromptLabel, label_id)
|
|
188
|
+
return PromptLabelMutationPayload(
|
|
189
|
+
prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
|
|
190
|
+
query=Query(),
|
|
191
|
+
)
|