arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -10,79 +10,89 @@ from strawberry.relay import GlobalID
|
|
|
10
10
|
from strawberry.types import Info
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
|
-
from phoenix.
|
|
14
|
-
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
15
14
|
from phoenix.server.api.context import Context
|
|
16
15
|
from phoenix.server.api.exceptions import Conflict, NotFound
|
|
17
16
|
from phoenix.server.api.queries import Query
|
|
18
|
-
from phoenix.server.api.types.Identifier import Identifier
|
|
19
17
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
18
|
from phoenix.server.api.types.Prompt import Prompt
|
|
21
|
-
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
19
|
+
from phoenix.server.api.types.PromptLabel import PromptLabel
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
@strawberry.input
|
|
25
23
|
class CreatePromptLabelInput:
|
|
26
|
-
name:
|
|
24
|
+
name: str
|
|
27
25
|
description: Optional[str] = None
|
|
26
|
+
color: str
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
@strawberry.input
|
|
31
30
|
class PatchPromptLabelInput:
|
|
32
31
|
prompt_label_id: GlobalID
|
|
33
|
-
name: Optional[
|
|
32
|
+
name: Optional[str] = None
|
|
34
33
|
description: Optional[str] = None
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
@strawberry.input
|
|
38
|
-
class
|
|
39
|
-
|
|
37
|
+
class DeletePromptLabelsInput:
|
|
38
|
+
prompt_label_ids: list[GlobalID]
|
|
40
39
|
|
|
41
40
|
|
|
42
41
|
@strawberry.input
|
|
43
|
-
class
|
|
42
|
+
class SetPromptLabelsInput:
|
|
44
43
|
prompt_id: GlobalID
|
|
45
|
-
|
|
44
|
+
prompt_label_ids: list[GlobalID]
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
@strawberry.input
|
|
49
|
-
class
|
|
48
|
+
class UnsetPromptLabelsInput:
|
|
50
49
|
prompt_id: GlobalID
|
|
51
|
-
|
|
50
|
+
prompt_label_ids: list[GlobalID]
|
|
52
51
|
|
|
53
52
|
|
|
54
53
|
@strawberry.type
|
|
55
54
|
class PromptLabelMutationPayload:
|
|
56
|
-
|
|
55
|
+
prompt_labels: list["PromptLabel"]
|
|
56
|
+
query: "Query"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@strawberry.type
|
|
60
|
+
class PromptLabelDeleteMutationPayload:
|
|
61
|
+
deleted_prompt_label_ids: list["GlobalID"]
|
|
62
|
+
query: "Query"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@strawberry.type
|
|
66
|
+
class PromptLabelAssociationMutationPayload:
|
|
57
67
|
query: "Query"
|
|
58
68
|
|
|
59
69
|
|
|
60
70
|
@strawberry.type
|
|
61
71
|
class PromptLabelMutationMixin:
|
|
62
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
72
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
63
73
|
async def create_prompt_label(
|
|
64
74
|
self, info: Info[Context, None], input: CreatePromptLabelInput
|
|
65
75
|
) -> PromptLabelMutationPayload:
|
|
66
76
|
async with info.context.db() as session:
|
|
67
|
-
|
|
68
|
-
|
|
77
|
+
label_orm = models.PromptLabel(
|
|
78
|
+
name=input.name, description=input.description, color=input.color
|
|
79
|
+
)
|
|
69
80
|
session.add(label_orm)
|
|
70
81
|
|
|
71
82
|
try:
|
|
72
83
|
await session.commit()
|
|
73
84
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
74
|
-
raise Conflict(f"A prompt label named '{name}' already exists.")
|
|
85
|
+
raise Conflict(f"A prompt label named '{input.name}' already exists.")
|
|
75
86
|
|
|
76
87
|
return PromptLabelMutationPayload(
|
|
77
|
-
|
|
88
|
+
prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
|
|
78
89
|
query=Query(),
|
|
79
90
|
)
|
|
80
91
|
|
|
81
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
92
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
82
93
|
async def patch_prompt_label(
|
|
83
94
|
self, info: Info[Context, None], input: PatchPromptLabelInput
|
|
84
95
|
) -> PromptLabelMutationPayload:
|
|
85
|
-
validated_name = IdentifierModel.model_validate(str(input.name)) if input.name else None
|
|
86
96
|
async with info.context.db() as session:
|
|
87
97
|
label_id = from_global_id_with_expected_type(
|
|
88
98
|
input.prompt_label_id, PromptLabel.__name__
|
|
@@ -92,8 +102,8 @@ class PromptLabelMutationMixin:
|
|
|
92
102
|
if not label_orm:
|
|
93
103
|
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
94
104
|
|
|
95
|
-
if
|
|
96
|
-
label_orm.name =
|
|
105
|
+
if input.name is not None:
|
|
106
|
+
label_orm.name = input.name
|
|
97
107
|
if input.description is not None:
|
|
98
108
|
label_orm.description = input.description
|
|
99
109
|
|
|
@@ -103,46 +113,48 @@ class PromptLabelMutationMixin:
|
|
|
103
113
|
raise Conflict("Error patching PromptLabel. Possibly a name conflict?")
|
|
104
114
|
|
|
105
115
|
return PromptLabelMutationPayload(
|
|
106
|
-
|
|
116
|
+
prompt_labels=[PromptLabel(id=label_orm.id, db_record=label_orm)],
|
|
107
117
|
query=Query(),
|
|
108
118
|
)
|
|
109
119
|
|
|
110
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
111
|
-
async def
|
|
112
|
-
self, info: Info[Context, None], input:
|
|
113
|
-
) ->
|
|
120
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
|
|
121
|
+
async def delete_prompt_labels(
|
|
122
|
+
self, info: Info[Context, None], input: DeletePromptLabelsInput
|
|
123
|
+
) -> PromptLabelDeleteMutationPayload:
|
|
114
124
|
"""
|
|
115
125
|
Deletes a PromptLabel (and any crosswalk references).
|
|
116
126
|
"""
|
|
117
127
|
async with info.context.db() as session:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
if result.rowcount == 0:
|
|
125
|
-
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
128
|
+
label_ids = [
|
|
129
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
130
|
+
for prompt_label_id in input.prompt_label_ids
|
|
131
|
+
]
|
|
132
|
+
stmt = delete(models.PromptLabel).where(models.PromptLabel.id.in_(label_ids))
|
|
133
|
+
await session.execute(stmt)
|
|
126
134
|
|
|
127
135
|
await session.commit()
|
|
128
136
|
|
|
129
|
-
return
|
|
130
|
-
|
|
137
|
+
return PromptLabelDeleteMutationPayload(
|
|
138
|
+
deleted_prompt_label_ids=input.prompt_label_ids,
|
|
131
139
|
query=Query(),
|
|
132
140
|
)
|
|
133
141
|
|
|
134
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
135
|
-
async def
|
|
136
|
-
self, info: Info[Context, None], input:
|
|
137
|
-
) ->
|
|
142
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
143
|
+
async def set_prompt_labels(
|
|
144
|
+
self, info: Info[Context, None], input: SetPromptLabelsInput
|
|
145
|
+
) -> PromptLabelAssociationMutationPayload:
|
|
138
146
|
async with info.context.db() as session:
|
|
139
147
|
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
148
|
+
label_ids = [
|
|
149
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
150
|
+
for prompt_label_id in input.prompt_label_ids
|
|
151
|
+
]
|
|
143
152
|
|
|
144
|
-
|
|
145
|
-
|
|
153
|
+
crosswalk_items = [
|
|
154
|
+
models.PromptPromptLabel(prompt_id=prompt_id, prompt_label_id=label_id)
|
|
155
|
+
for label_id in label_ids
|
|
156
|
+
]
|
|
157
|
+
session.add_all(crosswalk_items)
|
|
146
158
|
|
|
147
159
|
try:
|
|
148
160
|
await session.commit()
|
|
@@ -152,41 +164,38 @@ class PromptLabelMutationMixin:
|
|
|
152
164
|
# - Foreign key violation => prompt_id or label_id doesn't exist
|
|
153
165
|
raise Conflict("Failed to associate PromptLabel with Prompt.") from e
|
|
154
166
|
|
|
155
|
-
|
|
156
|
-
if not label_orm:
|
|
157
|
-
raise NotFound(f"PromptLabel with ID {input.prompt_label_id} not found")
|
|
158
|
-
|
|
159
|
-
return PromptLabelMutationPayload(
|
|
160
|
-
prompt_label=to_gql_prompt_label(label_orm),
|
|
167
|
+
return PromptLabelAssociationMutationPayload(
|
|
161
168
|
query=Query(),
|
|
162
169
|
)
|
|
163
170
|
|
|
164
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
165
|
-
async def
|
|
166
|
-
self, info: Info[Context, None], input:
|
|
167
|
-
) ->
|
|
171
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
|
|
172
|
+
async def unset_prompt_labels(
|
|
173
|
+
self, info: Info[Context, None], input: UnsetPromptLabelsInput
|
|
174
|
+
) -> PromptLabelAssociationMutationPayload:
|
|
168
175
|
"""
|
|
169
176
|
Unsets a PromptLabel from a Prompt by removing the row in the crosswalk.
|
|
170
177
|
"""
|
|
171
178
|
async with info.context.db() as session:
|
|
172
179
|
prompt_id = from_global_id_with_expected_type(input.prompt_id, Prompt.__name__)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
180
|
+
label_ids = [
|
|
181
|
+
from_global_id_with_expected_type(prompt_label_id, PromptLabel.__name__)
|
|
182
|
+
for prompt_label_id in input.prompt_label_ids
|
|
183
|
+
]
|
|
176
184
|
|
|
177
185
|
stmt = delete(models.PromptPromptLabel).where(
|
|
178
186
|
(models.PromptPromptLabel.prompt_id == prompt_id)
|
|
179
|
-
& (models.PromptPromptLabel.prompt_label_id
|
|
187
|
+
& (models.PromptPromptLabel.prompt_label_id.in_(label_ids))
|
|
180
188
|
)
|
|
181
189
|
result = await session.execute(stmt)
|
|
182
190
|
|
|
183
|
-
if result.rowcount
|
|
184
|
-
|
|
191
|
+
if result.rowcount != len(label_ids): # type: ignore[attr-defined]
|
|
192
|
+
label_ids_str = ", ".join(str(i) for i in label_ids)
|
|
193
|
+
raise NotFound(
|
|
194
|
+
f"No association between prompt={prompt_id} and labels={label_ids_str}."
|
|
195
|
+
)
|
|
185
196
|
|
|
186
197
|
await session.commit()
|
|
187
198
|
|
|
188
|
-
|
|
189
|
-
return PromptLabelMutationPayload(
|
|
190
|
-
prompt_label=to_gql_prompt_label(label_orm) if label_orm else None,
|
|
199
|
+
return PromptLabelAssociationMutationPayload(
|
|
191
200
|
query=Query(),
|
|
192
201
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
1
|
+
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
import strawberry
|
|
4
4
|
from fastapi import Request
|
|
@@ -7,23 +7,17 @@ from sqlalchemy import delete, select, update
|
|
|
7
7
|
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
8
8
|
from sqlalchemy.orm import joinedload
|
|
9
9
|
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
10
|
+
from strawberry import UNSET
|
|
10
11
|
from strawberry.relay.types import GlobalID
|
|
11
12
|
from strawberry.types import Info
|
|
12
13
|
|
|
13
14
|
from phoenix.db import models
|
|
14
15
|
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
15
|
-
from phoenix.
|
|
16
|
-
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
16
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
17
17
|
from phoenix.server.api.context import Context
|
|
18
18
|
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
19
|
-
from phoenix.server.api.helpers.prompts.models import (
|
|
20
|
-
normalize_response_format,
|
|
21
|
-
normalize_tools,
|
|
22
|
-
validate_invocation_parameters,
|
|
23
|
-
)
|
|
24
19
|
from phoenix.server.api.input_types.PromptVersionInput import (
|
|
25
20
|
ChatPromptVersionInput,
|
|
26
|
-
to_pydantic_prompt_chat_template_v1,
|
|
27
21
|
)
|
|
28
22
|
from phoenix.server.api.mutations.prompt_version_tag_mutations import (
|
|
29
23
|
SetPromptVersionTagInput,
|
|
@@ -32,7 +26,7 @@ from phoenix.server.api.mutations.prompt_version_tag_mutations import (
|
|
|
32
26
|
from phoenix.server.api.queries import Query
|
|
33
27
|
from phoenix.server.api.types.Identifier import Identifier
|
|
34
28
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
35
|
-
from phoenix.server.api.types.Prompt import Prompt
|
|
29
|
+
from phoenix.server.api.types.Prompt import Prompt
|
|
36
30
|
from phoenix.server.bearer_auth import PhoenixUser
|
|
37
31
|
|
|
38
32
|
|
|
@@ -41,6 +35,7 @@ class CreateChatPromptInput:
|
|
|
41
35
|
name: Identifier
|
|
42
36
|
description: Optional[str] = None
|
|
43
37
|
prompt_version: ChatPromptVersionInput
|
|
38
|
+
metadata: Optional[strawberry.scalars.JSON] = None
|
|
44
39
|
|
|
45
40
|
|
|
46
41
|
@strawberry.input
|
|
@@ -58,14 +53,16 @@ class DeletePromptInput:
|
|
|
58
53
|
@strawberry.input
|
|
59
54
|
class ClonePromptInput:
|
|
60
55
|
name: Identifier
|
|
61
|
-
description: Optional[str] = None
|
|
62
56
|
prompt_id: GlobalID
|
|
57
|
+
description: Optional[str] = UNSET
|
|
58
|
+
metadata: Optional[strawberry.scalars.JSON] = UNSET
|
|
63
59
|
|
|
64
60
|
|
|
65
61
|
@strawberry.input
|
|
66
62
|
class PatchPromptInput:
|
|
67
63
|
prompt_id: GlobalID
|
|
68
|
-
description: str
|
|
64
|
+
description: Optional[str] = UNSET
|
|
65
|
+
metadata: Optional[strawberry.scalars.JSON] = UNSET
|
|
69
66
|
|
|
70
67
|
|
|
71
68
|
@strawberry.type
|
|
@@ -75,7 +72,7 @@ class DeletePromptMutationPayload:
|
|
|
75
72
|
|
|
76
73
|
@strawberry.type
|
|
77
74
|
class PromptMutationMixin:
|
|
78
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
75
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
79
76
|
async def create_chat_prompt(
|
|
80
77
|
self, info: Info[Context, None], input: CreateChatPromptInput
|
|
81
78
|
) -> Prompt:
|
|
@@ -84,65 +81,26 @@ class PromptMutationMixin:
|
|
|
84
81
|
if "user" in request.scope:
|
|
85
82
|
assert isinstance(user := request.user, PhoenixUser)
|
|
86
83
|
user_id = int(user.identity)
|
|
87
|
-
|
|
88
|
-
input_prompt_version = input.prompt_version
|
|
89
|
-
tool_definitions = [tool.definition for tool in input_prompt_version.tools]
|
|
90
|
-
tool_choice = cast(
|
|
91
|
-
Optional[Union[str, dict[str, Any]]],
|
|
92
|
-
cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
|
|
93
|
-
"tool_choice", None
|
|
94
|
-
),
|
|
95
|
-
)
|
|
96
|
-
model_provider = ModelProvider(input_prompt_version.model_provider)
|
|
97
84
|
try:
|
|
98
|
-
|
|
99
|
-
normalize_tools(tool_definitions, model_provider, tool_choice)
|
|
100
|
-
if tool_definitions
|
|
101
|
-
else None
|
|
102
|
-
)
|
|
103
|
-
template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
|
|
104
|
-
response_format = (
|
|
105
|
-
normalize_response_format(
|
|
106
|
-
input_prompt_version.response_format.definition,
|
|
107
|
-
model_provider,
|
|
108
|
-
)
|
|
109
|
-
if input_prompt_version.response_format
|
|
110
|
-
else None
|
|
111
|
-
)
|
|
112
|
-
invocation_parameters = validate_invocation_parameters(
|
|
113
|
-
input_prompt_version.invocation_parameters,
|
|
114
|
-
model_provider,
|
|
115
|
-
)
|
|
85
|
+
prompt_version = input.prompt_version.to_orm_prompt_version(user_id)
|
|
116
86
|
except ValidationError as error:
|
|
117
87
|
raise BadRequest(str(error))
|
|
118
|
-
|
|
88
|
+
name = IdentifierModel.model_validate(str(input.name))
|
|
89
|
+
prompt = models.Prompt(
|
|
90
|
+
name=name,
|
|
91
|
+
description=input.description,
|
|
92
|
+
metadata_=input.metadata or {},
|
|
93
|
+
prompt_versions=[prompt_version],
|
|
94
|
+
)
|
|
119
95
|
async with info.context.db() as session:
|
|
120
|
-
prompt_version = models.PromptVersion(
|
|
121
|
-
description=input_prompt_version.description,
|
|
122
|
-
user_id=user_id,
|
|
123
|
-
template_type="CHAT",
|
|
124
|
-
template_format=input_prompt_version.template_format,
|
|
125
|
-
template=template,
|
|
126
|
-
invocation_parameters=invocation_parameters,
|
|
127
|
-
tools=tools,
|
|
128
|
-
response_format=response_format,
|
|
129
|
-
model_provider=input_prompt_version.model_provider,
|
|
130
|
-
model_name=input_prompt_version.model_name,
|
|
131
|
-
)
|
|
132
|
-
name = IdentifierModel.model_validate(str(input.name))
|
|
133
|
-
prompt = models.Prompt(
|
|
134
|
-
name=name,
|
|
135
|
-
description=input.description,
|
|
136
|
-
prompt_versions=[prompt_version],
|
|
137
|
-
)
|
|
138
96
|
session.add(prompt)
|
|
139
97
|
try:
|
|
140
98
|
await session.commit()
|
|
141
99
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
142
100
|
raise Conflict(f"A prompt named '{input.name}' already exists")
|
|
143
|
-
return
|
|
101
|
+
return Prompt(id=prompt.id, db_record=prompt)
|
|
144
102
|
|
|
145
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
103
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
146
104
|
async def create_chat_prompt_version(
|
|
147
105
|
self,
|
|
148
106
|
info: Info[Context, None],
|
|
@@ -153,74 +111,28 @@ class PromptMutationMixin:
|
|
|
153
111
|
if "user" in request.scope:
|
|
154
112
|
assert isinstance(user := request.user, PhoenixUser)
|
|
155
113
|
user_id = int(user.identity)
|
|
156
|
-
|
|
157
|
-
input_prompt_version = input.prompt_version
|
|
158
|
-
tool_definitions = [tool.definition for tool in input.prompt_version.tools]
|
|
159
|
-
tool_choice = cast(
|
|
160
|
-
Optional[Union[str, dict[str, Any]]],
|
|
161
|
-
cast(dict[str, Any], input.prompt_version.invocation_parameters).pop(
|
|
162
|
-
"tool_choice", None
|
|
163
|
-
),
|
|
164
|
-
)
|
|
165
|
-
model_provider = ModelProvider(input_prompt_version.model_provider)
|
|
166
114
|
try:
|
|
167
|
-
|
|
168
|
-
normalize_tools(tool_definitions, model_provider, tool_choice)
|
|
169
|
-
if tool_definitions
|
|
170
|
-
else None
|
|
171
|
-
)
|
|
172
|
-
template = to_pydantic_prompt_chat_template_v1(input_prompt_version.template)
|
|
173
|
-
response_format = (
|
|
174
|
-
normalize_response_format(
|
|
175
|
-
input_prompt_version.response_format.definition,
|
|
176
|
-
model_provider,
|
|
177
|
-
)
|
|
178
|
-
if input_prompt_version.response_format
|
|
179
|
-
else None
|
|
180
|
-
)
|
|
181
|
-
invocation_parameters = validate_invocation_parameters(
|
|
182
|
-
input_prompt_version.invocation_parameters,
|
|
183
|
-
model_provider,
|
|
184
|
-
)
|
|
115
|
+
prompt_version = input.prompt_version.to_orm_prompt_version(user_id)
|
|
185
116
|
except ValidationError as error:
|
|
186
117
|
raise BadRequest(str(error))
|
|
187
|
-
|
|
188
118
|
prompt_id = from_global_id_with_expected_type(
|
|
189
119
|
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
190
120
|
)
|
|
121
|
+
prompt_version.prompt_id = prompt_id
|
|
191
122
|
async with info.context.db() as session:
|
|
192
|
-
prompt = await session.get(models.Prompt, prompt_id)
|
|
193
|
-
if not prompt:
|
|
194
|
-
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
195
|
-
|
|
196
|
-
prompt_version = models.PromptVersion(
|
|
197
|
-
prompt_id=prompt_id,
|
|
198
|
-
description=input.prompt_version.description,
|
|
199
|
-
user_id=user_id,
|
|
200
|
-
template_type="CHAT",
|
|
201
|
-
template_format=input.prompt_version.template_format,
|
|
202
|
-
template=template,
|
|
203
|
-
invocation_parameters=invocation_parameters,
|
|
204
|
-
tools=tools,
|
|
205
|
-
response_format=response_format,
|
|
206
|
-
model_provider=input.prompt_version.model_provider,
|
|
207
|
-
model_name=input.prompt_version.model_name,
|
|
208
|
-
)
|
|
209
123
|
session.add(prompt_version)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
124
|
+
try:
|
|
125
|
+
await session.flush()
|
|
126
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
127
|
+
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
128
|
+
if input.tags:
|
|
129
|
+
for tag in input.tags:
|
|
130
|
+
await upsert_prompt_version_tag(
|
|
131
|
+
session, prompt_id, prompt_version.id, tag.name, tag.description
|
|
132
|
+
)
|
|
133
|
+
return Prompt(id=prompt_id)
|
|
134
|
+
|
|
135
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
|
|
224
136
|
async def delete_prompt(
|
|
225
137
|
self, info: Info[Context, None], input: DeletePromptInput
|
|
226
138
|
) -> DeletePromptMutationPayload:
|
|
@@ -231,13 +143,13 @@ class PromptMutationMixin:
|
|
|
231
143
|
stmt = delete(models.Prompt).where(models.Prompt.id == prompt_id)
|
|
232
144
|
result = await session.execute(stmt)
|
|
233
145
|
|
|
234
|
-
if result.rowcount == 0:
|
|
146
|
+
if result.rowcount == 0: # type: ignore[attr-defined]
|
|
235
147
|
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
236
148
|
|
|
237
149
|
await session.commit()
|
|
238
150
|
return DeletePromptMutationPayload(query=Query())
|
|
239
151
|
|
|
240
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
152
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
241
153
|
async def clone_prompt(self, info: Info[Context, None], input: ClonePromptInput) -> Prompt:
|
|
242
154
|
prompt_id = from_global_id_with_expected_type(
|
|
243
155
|
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
@@ -256,10 +168,23 @@ class PromptMutationMixin:
|
|
|
256
168
|
|
|
257
169
|
# Create new prompt
|
|
258
170
|
name = IdentifierModel.model_validate(str(input.name))
|
|
171
|
+
# Handle description: inherit if UNSET, otherwise use value (can be None)
|
|
172
|
+
if input.description is UNSET:
|
|
173
|
+
description = prompt.description
|
|
174
|
+
else:
|
|
175
|
+
description = input.description.strip() if input.description is not None else None
|
|
176
|
+
|
|
177
|
+
# Handle metadata: inherit if UNSET, clear to empty dict if None, or use value
|
|
178
|
+
if input.metadata is UNSET:
|
|
179
|
+
metadata = prompt.metadata_
|
|
180
|
+
else:
|
|
181
|
+
metadata = input.metadata or {}
|
|
182
|
+
|
|
259
183
|
new_prompt = models.Prompt(
|
|
260
184
|
name=name,
|
|
261
|
-
description=input.description,
|
|
262
185
|
source_prompt_id=prompt_id,
|
|
186
|
+
description=description,
|
|
187
|
+
metadata_=metadata,
|
|
263
188
|
)
|
|
264
189
|
|
|
265
190
|
# Create copies of all versions
|
|
@@ -288,19 +213,30 @@ class PromptMutationMixin:
|
|
|
288
213
|
await session.commit()
|
|
289
214
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
290
215
|
raise Conflict(f"A prompt named '{input.name}' already exists")
|
|
291
|
-
return
|
|
216
|
+
return Prompt(id=new_prompt.id, db_record=new_prompt)
|
|
292
217
|
|
|
293
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
218
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
294
219
|
async def patch_prompt(self, info: Info[Context, None], input: PatchPromptInput) -> Prompt:
|
|
295
220
|
prompt_id = from_global_id_with_expected_type(
|
|
296
221
|
global_id=input.prompt_id, expected_type_name=Prompt.__name__
|
|
297
222
|
)
|
|
298
223
|
|
|
224
|
+
values: dict[str, Any] = {}
|
|
225
|
+
if input.description is not UNSET:
|
|
226
|
+
values["description"] = (
|
|
227
|
+
input.description.strip() if input.description is not None else None
|
|
228
|
+
)
|
|
229
|
+
if input.metadata is not UNSET:
|
|
230
|
+
values["metadata_"] = input.metadata or {}
|
|
231
|
+
|
|
232
|
+
if not values:
|
|
233
|
+
raise BadRequest("No fields provided to update")
|
|
234
|
+
|
|
299
235
|
async with info.context.db() as session:
|
|
300
236
|
stmt = (
|
|
301
237
|
update(models.Prompt)
|
|
302
238
|
.where(models.Prompt.id == prompt_id)
|
|
303
|
-
.values(
|
|
239
|
+
.values(**values)
|
|
304
240
|
.returning(models.Prompt)
|
|
305
241
|
)
|
|
306
242
|
|
|
@@ -310,4 +246,4 @@ class PromptMutationMixin:
|
|
|
310
246
|
if prompt is None:
|
|
311
247
|
raise NotFound(f"Prompt with ID '{input.prompt_id}' not found")
|
|
312
248
|
|
|
313
|
-
return
|
|
249
|
+
return Prompt(id=prompt.id, db_record=prompt)
|
|
@@ -10,15 +10,15 @@ from strawberry.types import Info
|
|
|
10
10
|
|
|
11
11
|
from phoenix.db import models
|
|
12
12
|
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
13
|
-
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
14
14
|
from phoenix.server.api.context import Context
|
|
15
15
|
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
16
|
from phoenix.server.api.queries import Query
|
|
17
17
|
from phoenix.server.api.types.Identifier import Identifier
|
|
18
18
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
19
|
-
from phoenix.server.api.types.Prompt import Prompt
|
|
19
|
+
from phoenix.server.api.types.Prompt import Prompt
|
|
20
20
|
from phoenix.server.api.types.PromptVersion import PromptVersion
|
|
21
|
-
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
21
|
+
from phoenix.server.api.types.PromptVersionTag import PromptVersionTag
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
@strawberry.input
|
|
@@ -42,7 +42,7 @@ class PromptVersionTagMutationPayload:
|
|
|
42
42
|
|
|
43
43
|
@strawberry.type
|
|
44
44
|
class PromptVersionTagMutationMixin:
|
|
45
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
45
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer]) # type: ignore
|
|
46
46
|
async def delete_prompt_version_tag(
|
|
47
47
|
self, info: Info[Context, None], input: DeletePromptVersionTagInput
|
|
48
48
|
) -> PromptVersionTagMutationPayload:
|
|
@@ -75,10 +75,12 @@ class PromptVersionTagMutationMixin:
|
|
|
75
75
|
await session.delete(prompt_version_tag)
|
|
76
76
|
await session.commit()
|
|
77
77
|
return PromptVersionTagMutationPayload(
|
|
78
|
-
prompt_version_tag=None,
|
|
78
|
+
prompt_version_tag=None,
|
|
79
|
+
query=Query(),
|
|
80
|
+
prompt=Prompt(id=prompt.id, db_record=prompt),
|
|
79
81
|
)
|
|
80
82
|
|
|
81
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
83
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
82
84
|
async def set_prompt_version_tag(
|
|
83
85
|
self, info: Info[Context, None], input: SetPromptVersionTagInput
|
|
84
86
|
) -> PromptVersionTagMutationPayload:
|
|
@@ -111,9 +113,10 @@ class PromptVersionTagMutationMixin:
|
|
|
111
113
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
112
114
|
raise Conflict("Failed to update PromptVersionTag.")
|
|
113
115
|
|
|
114
|
-
version_tag = to_gql_prompt_version_tag(updated_tag)
|
|
115
116
|
return PromptVersionTagMutationPayload(
|
|
116
|
-
prompt_version_tag=
|
|
117
|
+
prompt_version_tag=PromptVersionTag(id=updated_tag.id, db_record=updated_tag),
|
|
118
|
+
prompt=Prompt(id=prompt.id, db_record=prompt),
|
|
119
|
+
query=Query(),
|
|
117
120
|
)
|
|
118
121
|
|
|
119
122
|
|