arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import logging
|
|
2
3
|
from dataclasses import asdict, field
|
|
3
4
|
from datetime import datetime, timezone
|
|
4
5
|
from itertools import chain, islice
|
|
5
6
|
from traceback import format_exc
|
|
6
|
-
from typing import Any, Iterable, Iterator,
|
|
7
|
+
from typing import Any, Iterable, Iterator, Optional, TypeVar, Union
|
|
7
8
|
|
|
8
9
|
import strawberry
|
|
9
10
|
from openinference.instrumentation import safe_json_dumps
|
|
@@ -22,14 +23,19 @@ from strawberry.relay import GlobalID
|
|
|
22
23
|
from strawberry.types import Info
|
|
23
24
|
from typing_extensions import assert_never
|
|
24
25
|
|
|
26
|
+
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
25
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
26
28
|
from phoenix.db import models
|
|
27
|
-
from phoenix.db.helpers import
|
|
28
|
-
|
|
29
|
+
from phoenix.db.helpers import (
|
|
30
|
+
get_dataset_example_revisions,
|
|
31
|
+
insert_experiment_with_examples_snapshot,
|
|
32
|
+
)
|
|
33
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
29
34
|
from phoenix.server.api.context import Context
|
|
30
35
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
31
36
|
from phoenix.server.api.helpers.dataset_helpers import get_dataset_example_output
|
|
32
37
|
from phoenix.server.api.helpers.playground_clients import (
|
|
38
|
+
PlaygroundClientCredential,
|
|
33
39
|
PlaygroundStreamingClient,
|
|
34
40
|
initialize_playground_clients,
|
|
35
41
|
)
|
|
@@ -43,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
43
49
|
llm_tools,
|
|
44
50
|
prompt_metadata,
|
|
45
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
46
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
47
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
48
55
|
ChatCompletionInput,
|
|
@@ -62,6 +69,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
|
62
69
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
63
70
|
from phoenix.server.api.types.Span import Span
|
|
64
71
|
from phoenix.server.dml_event import SpanInsertEvent
|
|
72
|
+
from phoenix.server.experiments.utils import generate_experiment_project_name
|
|
65
73
|
from phoenix.trace.attributes import unflatten
|
|
66
74
|
from phoenix.trace.schemas import SpanException
|
|
67
75
|
from phoenix.utilities.json import jsonify
|
|
@@ -72,9 +80,11 @@ from phoenix.utilities.template_formatters import (
|
|
|
72
80
|
TemplateFormatter,
|
|
73
81
|
)
|
|
74
82
|
|
|
83
|
+
logger = logging.getLogger(__name__)
|
|
84
|
+
|
|
75
85
|
initialize_playground_clients()
|
|
76
86
|
|
|
77
|
-
ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[
|
|
87
|
+
ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[Any]]]
|
|
78
88
|
|
|
79
89
|
|
|
80
90
|
@strawberry.type
|
|
@@ -90,24 +100,25 @@ class ChatCompletionToolCall:
|
|
|
90
100
|
|
|
91
101
|
|
|
92
102
|
@strawberry.type
|
|
93
|
-
class
|
|
94
|
-
|
|
103
|
+
class ChatCompletionRepetition:
|
|
104
|
+
repetition_number: int
|
|
95
105
|
content: Optional[str]
|
|
96
|
-
tool_calls:
|
|
97
|
-
span: Span
|
|
106
|
+
tool_calls: list[ChatCompletionToolCall]
|
|
107
|
+
span: Optional[Span]
|
|
98
108
|
error_message: Optional[str]
|
|
99
109
|
|
|
100
110
|
|
|
101
111
|
@strawberry.type
|
|
102
|
-
class
|
|
103
|
-
|
|
112
|
+
class ChatCompletionMutationPayload:
|
|
113
|
+
repetitions: list[ChatCompletionRepetition]
|
|
104
114
|
|
|
105
115
|
|
|
106
116
|
@strawberry.type
|
|
107
117
|
class ChatCompletionOverDatasetMutationExamplePayload:
|
|
108
118
|
dataset_example_id: GlobalID
|
|
119
|
+
repetition_number: int
|
|
109
120
|
experiment_run_id: GlobalID
|
|
110
|
-
|
|
121
|
+
repetition: ChatCompletionRepetition
|
|
111
122
|
|
|
112
123
|
|
|
113
124
|
@strawberry.type
|
|
@@ -120,7 +131,7 @@ class ChatCompletionOverDatasetMutationPayload:
|
|
|
120
131
|
|
|
121
132
|
@strawberry.type
|
|
122
133
|
class ChatCompletionMutationMixin:
|
|
123
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
134
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
124
135
|
@classmethod
|
|
125
136
|
async def chat_completion_over_dataset(
|
|
126
137
|
cls,
|
|
@@ -132,9 +143,17 @@ class ChatCompletionMutationMixin:
|
|
|
132
143
|
if llm_client_class is None:
|
|
133
144
|
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
134
145
|
try:
|
|
146
|
+
# Convert GraphQL credentials to PlaygroundCredential objects
|
|
147
|
+
credentials = None
|
|
148
|
+
if input.credentials:
|
|
149
|
+
credentials = [
|
|
150
|
+
PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
|
|
151
|
+
for cred in input.credentials
|
|
152
|
+
]
|
|
153
|
+
|
|
135
154
|
llm_client = llm_client_class(
|
|
136
155
|
model=input.model,
|
|
137
|
-
|
|
156
|
+
credentials=credentials,
|
|
138
157
|
)
|
|
139
158
|
except CustomGraphQLError:
|
|
140
159
|
raise
|
|
@@ -151,6 +170,7 @@ class ChatCompletionMutationMixin:
|
|
|
151
170
|
if input.dataset_version_id
|
|
152
171
|
else None
|
|
153
172
|
)
|
|
173
|
+
project_name = generate_experiment_project_name()
|
|
154
174
|
async with info.context.db() as session:
|
|
155
175
|
dataset = await session.scalar(select(models.Dataset).filter_by(id=dataset_id))
|
|
156
176
|
if dataset is None:
|
|
@@ -166,16 +186,26 @@ class ChatCompletionMutationMixin:
|
|
|
166
186
|
raise NotFound("No versions found for the given dataset")
|
|
167
187
|
else:
|
|
168
188
|
resolved_version_id = dataset_version_id
|
|
189
|
+
# Parse split IDs if provided
|
|
190
|
+
resolved_split_ids: Optional[list[int]] = None
|
|
191
|
+
if input.split_ids is not None and len(input.split_ids) > 0:
|
|
192
|
+
resolved_split_ids = [
|
|
193
|
+
from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
|
|
194
|
+
for split_id in input.split_ids
|
|
195
|
+
]
|
|
196
|
+
|
|
169
197
|
revisions = [
|
|
170
198
|
revision
|
|
171
199
|
async for revision in await session.stream_scalars(
|
|
172
|
-
get_dataset_example_revisions(
|
|
173
|
-
|
|
174
|
-
|
|
200
|
+
get_dataset_example_revisions(
|
|
201
|
+
resolved_version_id,
|
|
202
|
+
split_ids=resolved_split_ids,
|
|
203
|
+
).order_by(models.DatasetExampleRevision.id)
|
|
175
204
|
)
|
|
176
205
|
]
|
|
177
206
|
if not revisions:
|
|
178
207
|
raise NotFound("No examples found for the given dataset and version")
|
|
208
|
+
user_id = get_user(info)
|
|
179
209
|
experiment = models.Experiment(
|
|
180
210
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
181
211
|
dataset_version_id=resolved_version_id,
|
|
@@ -184,15 +214,25 @@ class ChatCompletionMutationMixin:
|
|
|
184
214
|
description=input.experiment_description,
|
|
185
215
|
repetitions=1,
|
|
186
216
|
metadata_=input.experiment_metadata or dict(),
|
|
187
|
-
project_name=
|
|
217
|
+
project_name=project_name,
|
|
218
|
+
user_id=user_id,
|
|
188
219
|
)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
220
|
+
if resolved_split_ids:
|
|
221
|
+
experiment.experiment_dataset_splits = [
|
|
222
|
+
models.ExperimentDatasetSplit(dataset_split_id=split_id)
|
|
223
|
+
for split_id in resolved_split_ids
|
|
224
|
+
]
|
|
225
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
226
|
+
|
|
227
|
+
results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
|
|
193
228
|
batch_size = 3
|
|
194
229
|
start_time = datetime.now(timezone.utc)
|
|
195
|
-
|
|
230
|
+
unbatched_items = [
|
|
231
|
+
(revision, repetition_number)
|
|
232
|
+
for revision in revisions
|
|
233
|
+
for repetition_number in range(1, input.repetitions + 1)
|
|
234
|
+
]
|
|
235
|
+
for batch in _get_batches(unbatched_items, batch_size):
|
|
196
236
|
batch_results = await asyncio.gather(
|
|
197
237
|
*(
|
|
198
238
|
cls._chat_completion(
|
|
@@ -200,7 +240,7 @@ class ChatCompletionMutationMixin:
|
|
|
200
240
|
llm_client,
|
|
201
241
|
ChatCompletionInput(
|
|
202
242
|
model=input.model,
|
|
203
|
-
|
|
243
|
+
credentials=input.credentials,
|
|
204
244
|
messages=input.messages,
|
|
205
245
|
tools=input.tools,
|
|
206
246
|
invocation_parameters=input.invocation_parameters,
|
|
@@ -209,9 +249,12 @@ class ChatCompletionMutationMixin:
|
|
|
209
249
|
variables=revision.input,
|
|
210
250
|
),
|
|
211
251
|
prompt_name=input.prompt_name,
|
|
252
|
+
repetitions=repetition_number,
|
|
212
253
|
),
|
|
254
|
+
repetition_number=repetition_number,
|
|
255
|
+
project_name=project_name,
|
|
213
256
|
)
|
|
214
|
-
for revision in batch
|
|
257
|
+
for revision, repetition_number in batch
|
|
215
258
|
),
|
|
216
259
|
return_exceptions=True,
|
|
217
260
|
)
|
|
@@ -223,19 +266,19 @@ class ChatCompletionMutationMixin:
|
|
|
223
266
|
experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
|
|
224
267
|
)
|
|
225
268
|
experiment_runs = []
|
|
226
|
-
for revision, result in zip(
|
|
269
|
+
for (revision, repetition_number), result in zip(unbatched_items, results):
|
|
227
270
|
if isinstance(result, BaseException):
|
|
228
271
|
experiment_run = models.ExperimentRun(
|
|
229
272
|
experiment_id=experiment.id,
|
|
230
273
|
dataset_example_id=revision.dataset_example_id,
|
|
231
274
|
output={},
|
|
232
|
-
repetition_number=
|
|
275
|
+
repetition_number=repetition_number,
|
|
233
276
|
start_time=start_time,
|
|
234
277
|
end_time=start_time,
|
|
235
278
|
error=str(result),
|
|
236
279
|
)
|
|
237
280
|
else:
|
|
238
|
-
db_span
|
|
281
|
+
repetition, db_span = result
|
|
239
282
|
experiment_run = models.ExperimentRun(
|
|
240
283
|
experiment_id=experiment.id,
|
|
241
284
|
dataset_example_id=revision.dataset_example_id,
|
|
@@ -245,10 +288,10 @@ class ChatCompletionMutationMixin:
|
|
|
245
288
|
),
|
|
246
289
|
prompt_token_count=db_span.cumulative_llm_token_count_prompt,
|
|
247
290
|
completion_token_count=db_span.cumulative_llm_token_count_completion,
|
|
248
|
-
repetition_number=
|
|
291
|
+
repetition_number=repetition_number,
|
|
249
292
|
start_time=db_span.start_time,
|
|
250
293
|
end_time=db_span.end_time,
|
|
251
|
-
error=str(
|
|
294
|
+
error=str(repetition.error_message) if repetition.error_message else None,
|
|
252
295
|
)
|
|
253
296
|
experiment_runs.append(experiment_run)
|
|
254
297
|
|
|
@@ -256,22 +299,31 @@ class ChatCompletionMutationMixin:
|
|
|
256
299
|
session.add_all(experiment_runs)
|
|
257
300
|
await session.flush()
|
|
258
301
|
|
|
259
|
-
for revision, experiment_run, result in zip(
|
|
302
|
+
for (revision, repetition_number), experiment_run, result in zip(
|
|
303
|
+
unbatched_items, experiment_runs, results
|
|
304
|
+
):
|
|
260
305
|
dataset_example_id = GlobalID(
|
|
261
306
|
models.DatasetExample.__name__, str(revision.dataset_example_id)
|
|
262
307
|
)
|
|
263
308
|
experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
|
|
264
309
|
example_payload = ChatCompletionOverDatasetMutationExamplePayload(
|
|
265
310
|
dataset_example_id=dataset_example_id,
|
|
311
|
+
repetition_number=repetition_number,
|
|
266
312
|
experiment_run_id=experiment_run_id,
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
313
|
+
repetition=ChatCompletionRepetition(
|
|
314
|
+
repetition_number=repetition_number,
|
|
315
|
+
content=None,
|
|
316
|
+
tool_calls=[],
|
|
317
|
+
span=None,
|
|
318
|
+
error_message=str(result),
|
|
319
|
+
)
|
|
320
|
+
if isinstance(result, BaseException)
|
|
321
|
+
else result[0],
|
|
270
322
|
)
|
|
271
323
|
payload.examples.append(example_payload)
|
|
272
324
|
return payload
|
|
273
325
|
|
|
274
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
326
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
275
327
|
@classmethod
|
|
276
328
|
async def chat_completion(
|
|
277
329
|
cls, info: Info[Context, None], input: ChatCompletionInput
|
|
@@ -281,9 +333,17 @@ class ChatCompletionMutationMixin:
|
|
|
281
333
|
if llm_client_class is None:
|
|
282
334
|
raise BadRequest(f"Unknown LLM provider: '{provider_key.value}'")
|
|
283
335
|
try:
|
|
336
|
+
# Convert GraphQL credentials to PlaygroundCredential objects
|
|
337
|
+
credentials = None
|
|
338
|
+
if input.credentials:
|
|
339
|
+
credentials = [
|
|
340
|
+
PlaygroundClientCredential(env_var_name=cred.env_var_name, value=cred.value)
|
|
341
|
+
for cred in input.credentials
|
|
342
|
+
]
|
|
343
|
+
|
|
284
344
|
llm_client = llm_client_class(
|
|
285
345
|
model=input.model,
|
|
286
|
-
|
|
346
|
+
credentials=credentials,
|
|
287
347
|
)
|
|
288
348
|
except CustomGraphQLError:
|
|
289
349
|
raise
|
|
@@ -292,7 +352,38 @@ class ChatCompletionMutationMixin:
|
|
|
292
352
|
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
293
353
|
f"{str(error)}"
|
|
294
354
|
)
|
|
295
|
-
|
|
355
|
+
|
|
356
|
+
results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
|
|
357
|
+
batch_size = 3
|
|
358
|
+
for batch in _get_batches(range(1, input.repetitions + 1), batch_size):
|
|
359
|
+
batch_results = await asyncio.gather(
|
|
360
|
+
*(
|
|
361
|
+
cls._chat_completion(
|
|
362
|
+
info, llm_client, input, repetition_number=repetition_number
|
|
363
|
+
)
|
|
364
|
+
for repetition_number in batch
|
|
365
|
+
),
|
|
366
|
+
return_exceptions=True,
|
|
367
|
+
)
|
|
368
|
+
results.extend(batch_results)
|
|
369
|
+
|
|
370
|
+
repetitions: list[ChatCompletionRepetition] = []
|
|
371
|
+
for repetition_number, result in enumerate(results, start=1):
|
|
372
|
+
if isinstance(result, BaseException):
|
|
373
|
+
repetitions.append(
|
|
374
|
+
ChatCompletionRepetition(
|
|
375
|
+
repetition_number=repetition_number,
|
|
376
|
+
content=None,
|
|
377
|
+
tool_calls=[],
|
|
378
|
+
span=None,
|
|
379
|
+
error_message=str(result),
|
|
380
|
+
)
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
repetition, _ = result
|
|
384
|
+
repetitions.append(repetition)
|
|
385
|
+
|
|
386
|
+
return ChatCompletionMutationPayload(repetitions=repetitions)
|
|
296
387
|
|
|
297
388
|
@classmethod
|
|
298
389
|
async def _chat_completion(
|
|
@@ -300,7 +391,10 @@ class ChatCompletionMutationMixin:
|
|
|
300
391
|
info: Info[Context, None],
|
|
301
392
|
llm_client: PlaygroundStreamingClient,
|
|
302
393
|
input: ChatCompletionInput,
|
|
303
|
-
|
|
394
|
+
repetition_number: int,
|
|
395
|
+
project_name: str = PLAYGROUND_PROJECT_NAME,
|
|
396
|
+
project_description: str = "Traces from prompt playground",
|
|
397
|
+
) -> tuple[ChatCompletionRepetition, models.Span]:
|
|
304
398
|
attributes: dict[str, Any] = {}
|
|
305
399
|
attributes.update(dict(prompt_metadata(input.prompt_name)))
|
|
306
400
|
|
|
@@ -394,15 +488,15 @@ class ChatCompletionMutationMixin:
|
|
|
394
488
|
# Get or create the project ID
|
|
395
489
|
if (
|
|
396
490
|
project_id := await session.scalar(
|
|
397
|
-
select(models.Project.id).where(models.Project.name ==
|
|
491
|
+
select(models.Project.id).where(models.Project.name == project_name)
|
|
398
492
|
)
|
|
399
493
|
) is None:
|
|
400
494
|
project_id = await session.scalar(
|
|
401
495
|
insert(models.Project)
|
|
402
496
|
.returning(models.Project.id)
|
|
403
497
|
.values(
|
|
404
|
-
name=
|
|
405
|
-
description=
|
|
498
|
+
name=project_name,
|
|
499
|
+
description=project_description,
|
|
406
500
|
)
|
|
407
501
|
)
|
|
408
502
|
trace = models.Trace(
|
|
@@ -433,27 +527,41 @@ class ChatCompletionMutationMixin:
|
|
|
433
527
|
session.add(trace)
|
|
434
528
|
session.add(span)
|
|
435
529
|
await session.flush()
|
|
530
|
+
try:
|
|
531
|
+
span_cost = info.context.span_cost_calculator.calculate_cost(
|
|
532
|
+
start_time=span.start_time,
|
|
533
|
+
attributes=span.attributes,
|
|
534
|
+
)
|
|
535
|
+
except Exception as e:
|
|
536
|
+
logger.exception(f"Failed to calculate cost for span {span.id}: {e}")
|
|
537
|
+
span_cost = None
|
|
538
|
+
if span_cost:
|
|
539
|
+
span_cost.span_rowid = span.id
|
|
540
|
+
span_cost.trace_rowid = trace.id
|
|
541
|
+
session.add(span_cost)
|
|
542
|
+
await session.flush()
|
|
436
543
|
|
|
437
|
-
gql_span = Span(
|
|
544
|
+
gql_span = Span(id=span.id, db_record=span)
|
|
438
545
|
|
|
439
546
|
info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
|
|
440
547
|
|
|
441
548
|
if status_code is StatusCode.ERROR:
|
|
442
|
-
|
|
443
|
-
|
|
549
|
+
repetition = ChatCompletionRepetition(
|
|
550
|
+
repetition_number=repetition_number,
|
|
444
551
|
content=None,
|
|
445
552
|
tool_calls=[],
|
|
446
553
|
span=gql_span,
|
|
447
554
|
error_message=status_message,
|
|
448
555
|
)
|
|
449
556
|
else:
|
|
450
|
-
|
|
451
|
-
|
|
557
|
+
repetition = ChatCompletionRepetition(
|
|
558
|
+
repetition_number=repetition_number,
|
|
452
559
|
content=text_content if text_content else None,
|
|
453
560
|
tool_calls=list(tool_calls.values()),
|
|
454
561
|
span=gql_span,
|
|
455
562
|
error_message=None,
|
|
456
563
|
)
|
|
564
|
+
return repetition, span
|
|
457
565
|
|
|
458
566
|
|
|
459
567
|
def _formatted_messages(
|
|
@@ -588,5 +696,4 @@ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUME
|
|
|
588
696
|
TOOL_JSON_SCHEMA = ToolAttributes.TOOL_JSON_SCHEMA
|
|
589
697
|
PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
|
|
590
698
|
|
|
591
|
-
|
|
592
|
-
PLAYGROUND_PROJECT_NAME = "playground"
|
|
699
|
+
LLM_PROVIDER = SpanAttributes.LLM_PROVIDER
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import sqlalchemy
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import delete, select
|
|
6
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
7
|
+
from sqlalchemy.orm import joinedload
|
|
8
|
+
from sqlalchemy.sql import tuple_
|
|
9
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
10
|
+
from strawberry import UNSET
|
|
11
|
+
from strawberry.relay.types import GlobalID
|
|
12
|
+
from strawberry.types import Info
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
16
|
+
from phoenix.server.api.context import Context
|
|
17
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
18
|
+
from phoenix.server.api.queries import Query
|
|
19
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
20
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
21
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@strawberry.input
|
|
25
|
+
class CreateDatasetLabelInput:
|
|
26
|
+
name: str
|
|
27
|
+
description: Optional[str] = UNSET
|
|
28
|
+
color: str
|
|
29
|
+
dataset_ids: Optional[list[GlobalID]] = UNSET
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@strawberry.type
|
|
33
|
+
class CreateDatasetLabelMutationPayload:
|
|
34
|
+
dataset_label: DatasetLabel
|
|
35
|
+
datasets: list[Dataset]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@strawberry.input
|
|
39
|
+
class DeleteDatasetLabelsInput:
|
|
40
|
+
dataset_label_ids: list[GlobalID]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@strawberry.type
|
|
44
|
+
class DeleteDatasetLabelsMutationPayload:
|
|
45
|
+
dataset_labels: list[DatasetLabel]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@strawberry.input
|
|
49
|
+
class SetDatasetLabelsInput:
|
|
50
|
+
dataset_id: GlobalID
|
|
51
|
+
dataset_label_ids: list[GlobalID]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@strawberry.type
|
|
55
|
+
class SetDatasetLabelsMutationPayload:
|
|
56
|
+
query: Query
|
|
57
|
+
dataset: Dataset
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@strawberry.type
|
|
61
|
+
class DatasetLabelMutationMixin:
|
|
62
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
63
|
+
async def create_dataset_label(
|
|
64
|
+
self,
|
|
65
|
+
info: Info[Context, None],
|
|
66
|
+
input: CreateDatasetLabelInput,
|
|
67
|
+
) -> CreateDatasetLabelMutationPayload:
|
|
68
|
+
name = input.name
|
|
69
|
+
description = input.description
|
|
70
|
+
color = input.color
|
|
71
|
+
dataset_rowids: dict[
|
|
72
|
+
int, None
|
|
73
|
+
] = {} # use dictionary to de-duplicate while preserving order
|
|
74
|
+
if input.dataset_ids:
|
|
75
|
+
for dataset_id in input.dataset_ids:
|
|
76
|
+
try:
|
|
77
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_id, Dataset.__name__)
|
|
78
|
+
except ValueError:
|
|
79
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_id}")
|
|
80
|
+
dataset_rowids[dataset_rowid] = None
|
|
81
|
+
|
|
82
|
+
async with info.context.db() as session:
|
|
83
|
+
dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
|
|
84
|
+
session.add(dataset_label_orm)
|
|
85
|
+
try:
|
|
86
|
+
await session.flush()
|
|
87
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
88
|
+
raise Conflict(f"A dataset label named '{name}' already exists")
|
|
89
|
+
except sqlalchemy.exc.StatementError as error:
|
|
90
|
+
raise BadRequest(str(error.orig))
|
|
91
|
+
|
|
92
|
+
datasets_by_id: dict[int, models.Dataset] = {}
|
|
93
|
+
if dataset_rowids:
|
|
94
|
+
datasets_by_id = {
|
|
95
|
+
dataset.id: dataset
|
|
96
|
+
for dataset in await session.scalars(
|
|
97
|
+
select(models.Dataset).where(models.Dataset.id.in_(dataset_rowids.keys()))
|
|
98
|
+
)
|
|
99
|
+
}
|
|
100
|
+
if len(datasets_by_id) < len(dataset_rowids):
|
|
101
|
+
raise NotFound("One or more datasets not found")
|
|
102
|
+
session.add_all(
|
|
103
|
+
[
|
|
104
|
+
models.DatasetsDatasetLabel(
|
|
105
|
+
dataset_id=dataset_rowid,
|
|
106
|
+
dataset_label_id=dataset_label_orm.id,
|
|
107
|
+
)
|
|
108
|
+
for dataset_rowid in dataset_rowids
|
|
109
|
+
]
|
|
110
|
+
)
|
|
111
|
+
await session.commit()
|
|
112
|
+
|
|
113
|
+
return CreateDatasetLabelMutationPayload(
|
|
114
|
+
dataset_label=DatasetLabel(id=dataset_label_orm.id, db_record=dataset_label_orm),
|
|
115
|
+
datasets=[
|
|
116
|
+
Dataset(
|
|
117
|
+
id=datasets_by_id[dataset_rowid].id, db_record=datasets_by_id[dataset_rowid]
|
|
118
|
+
)
|
|
119
|
+
for dataset_rowid in dataset_rowids
|
|
120
|
+
],
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
124
|
+
async def delete_dataset_labels(
|
|
125
|
+
self, info: Info[Context, None], input: DeleteDatasetLabelsInput
|
|
126
|
+
) -> DeleteDatasetLabelsMutationPayload:
|
|
127
|
+
dataset_label_row_ids: dict[int, None] = {}
|
|
128
|
+
for dataset_label_node_id in input.dataset_label_ids:
|
|
129
|
+
try:
|
|
130
|
+
dataset_label_row_id = from_global_id_with_expected_type(
|
|
131
|
+
dataset_label_node_id, DatasetLabel.__name__
|
|
132
|
+
)
|
|
133
|
+
except ValueError:
|
|
134
|
+
raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
|
|
135
|
+
dataset_label_row_ids[dataset_label_row_id] = None
|
|
136
|
+
async with info.context.db() as session:
|
|
137
|
+
stmt = (
|
|
138
|
+
delete(models.DatasetLabel)
|
|
139
|
+
.where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
|
|
140
|
+
.returning(models.DatasetLabel)
|
|
141
|
+
)
|
|
142
|
+
deleted_dataset_labels = (await session.scalars(stmt)).all()
|
|
143
|
+
if len(deleted_dataset_labels) < len(dataset_label_row_ids):
|
|
144
|
+
await session.rollback()
|
|
145
|
+
raise NotFound("Could not find one or more dataset labels with given IDs")
|
|
146
|
+
deleted_dataset_labels_by_id = {
|
|
147
|
+
dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
|
|
148
|
+
}
|
|
149
|
+
return DeleteDatasetLabelsMutationPayload(
|
|
150
|
+
dataset_labels=[
|
|
151
|
+
DatasetLabel(
|
|
152
|
+
id=deleted_dataset_labels_by_id[dataset_label_row_id].id,
|
|
153
|
+
db_record=deleted_dataset_labels_by_id[dataset_label_row_id],
|
|
154
|
+
)
|
|
155
|
+
for dataset_label_row_id in dataset_label_row_ids
|
|
156
|
+
]
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
160
|
+
async def set_dataset_labels(
|
|
161
|
+
self, info: Info[Context, None], input: SetDatasetLabelsInput
|
|
162
|
+
) -> SetDatasetLabelsMutationPayload:
|
|
163
|
+
try:
|
|
164
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
165
|
+
except ValueError:
|
|
166
|
+
raise BadRequest(f"Invalid dataset ID: {input.dataset_id}")
|
|
167
|
+
|
|
168
|
+
dataset_label_ids: dict[
|
|
169
|
+
int, None
|
|
170
|
+
] = {} # use dictionary to de-duplicate while preserving order
|
|
171
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
172
|
+
try:
|
|
173
|
+
dataset_label_id = from_global_id_with_expected_type(
|
|
174
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
175
|
+
)
|
|
176
|
+
except ValueError:
|
|
177
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
178
|
+
dataset_label_ids[dataset_label_id] = None
|
|
179
|
+
|
|
180
|
+
async with info.context.db() as session:
|
|
181
|
+
dataset = await session.scalar(
|
|
182
|
+
select(models.Dataset)
|
|
183
|
+
.where(models.Dataset.id == dataset_id)
|
|
184
|
+
.options(joinedload(models.Dataset.datasets_dataset_labels))
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if not dataset:
|
|
188
|
+
raise NotFound(f"Dataset with ID {input.dataset_id} not found")
|
|
189
|
+
|
|
190
|
+
existing_label_ids = (
|
|
191
|
+
await session.scalars(
|
|
192
|
+
select(models.DatasetLabel.id).where(
|
|
193
|
+
models.DatasetLabel.id.in_(dataset_label_ids.keys())
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
).all()
|
|
197
|
+
if len(existing_label_ids) != len(dataset_label_ids):
|
|
198
|
+
raise NotFound("One or more dataset labels not found")
|
|
199
|
+
|
|
200
|
+
previously_applied_dataset_label_ids = {
|
|
201
|
+
dataset_dataset_label.dataset_label_id
|
|
202
|
+
for dataset_dataset_label in dataset.datasets_dataset_labels
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
datasets_dataset_labels_to_add = [
|
|
206
|
+
models.DatasetsDatasetLabel(
|
|
207
|
+
dataset_id=dataset_id,
|
|
208
|
+
dataset_label_id=dataset_label_id,
|
|
209
|
+
)
|
|
210
|
+
for dataset_label_id in dataset_label_ids
|
|
211
|
+
if dataset_label_id not in previously_applied_dataset_label_ids
|
|
212
|
+
]
|
|
213
|
+
if datasets_dataset_labels_to_add:
|
|
214
|
+
session.add_all(datasets_dataset_labels_to_add)
|
|
215
|
+
await session.flush()
|
|
216
|
+
|
|
217
|
+
datasets_dataset_labels_to_delete = [
|
|
218
|
+
dataset_dataset_label
|
|
219
|
+
for dataset_dataset_label in dataset.datasets_dataset_labels
|
|
220
|
+
if dataset_dataset_label.dataset_label_id not in dataset_label_ids
|
|
221
|
+
]
|
|
222
|
+
if datasets_dataset_labels_to_delete:
|
|
223
|
+
await session.execute(
|
|
224
|
+
delete(models.DatasetsDatasetLabel).where(
|
|
225
|
+
tuple_(
|
|
226
|
+
models.DatasetsDatasetLabel.dataset_id,
|
|
227
|
+
models.DatasetsDatasetLabel.dataset_label_id,
|
|
228
|
+
).in_(
|
|
229
|
+
[
|
|
230
|
+
(
|
|
231
|
+
datasets_dataset_labels.dataset_id,
|
|
232
|
+
datasets_dataset_labels.dataset_label_id,
|
|
233
|
+
)
|
|
234
|
+
for datasets_dataset_labels in datasets_dataset_labels_to_delete
|
|
235
|
+
]
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return SetDatasetLabelsMutationPayload(
|
|
241
|
+
dataset=Dataset(id=dataset.id, db_record=dataset),
|
|
242
|
+
query=Query(),
|
|
243
|
+
)
|