arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from datetime import datetime, timezone
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Literal, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
5
|
from sqlalchemy import select
|
|
@@ -9,7 +9,7 @@ from strawberry.types import Info
|
|
|
9
9
|
|
|
10
10
|
from phoenix.db import models
|
|
11
11
|
from phoenix.db.models import UserRoleName
|
|
12
|
-
from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly
|
|
12
|
+
from phoenix.server.api.auth import IsAdmin, IsLocked, IsNotReadOnly, IsNotViewer
|
|
13
13
|
from phoenix.server.api.context import Context
|
|
14
14
|
from phoenix.server.api.exceptions import Unauthorized
|
|
15
15
|
from phoenix.server.api.queries import Query
|
|
@@ -61,7 +61,7 @@ class DeleteApiKeyMutationPayload:
|
|
|
61
61
|
|
|
62
62
|
@strawberry.type
|
|
63
63
|
class ApiKeyMutationMixin:
|
|
64
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin, IsLocked]) # type: ignore
|
|
64
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdmin, IsLocked]) # type: ignore
|
|
65
65
|
async def create_system_api_key(
|
|
66
66
|
self, info: Info[Context, None], input: CreateApiKeyInput
|
|
67
67
|
) -> CreateSystemApiKeyMutationPayload:
|
|
@@ -92,13 +92,7 @@ class ApiKeyMutationMixin:
|
|
|
92
92
|
token, token_id = await token_store.create_api_key(claims)
|
|
93
93
|
return CreateSystemApiKeyMutationPayload(
|
|
94
94
|
jwt=token,
|
|
95
|
-
api_key=SystemApiKey(
|
|
96
|
-
id_attr=int(token_id),
|
|
97
|
-
name=input.name,
|
|
98
|
-
description=input.description or None,
|
|
99
|
-
created_at=issued_at,
|
|
100
|
-
expires_at=input.expires_at or None,
|
|
101
|
-
),
|
|
95
|
+
api_key=SystemApiKey(id=int(token_id)),
|
|
102
96
|
query=Query(),
|
|
103
97
|
)
|
|
104
98
|
|
|
@@ -113,12 +107,20 @@ class ApiKeyMutationMixin:
|
|
|
113
107
|
except AttributeError:
|
|
114
108
|
raise ValueError("User not found")
|
|
115
109
|
issued_at = datetime.now(timezone.utc)
|
|
110
|
+
# Determine user role for API key
|
|
111
|
+
user_role: Literal["ADMIN", "MEMBER", "VIEWER"]
|
|
112
|
+
if user.is_admin:
|
|
113
|
+
user_role = "ADMIN"
|
|
114
|
+
elif user.is_viewer:
|
|
115
|
+
user_role = "VIEWER"
|
|
116
|
+
else:
|
|
117
|
+
user_role = "MEMBER"
|
|
116
118
|
claims = ApiKeyClaims(
|
|
117
119
|
subject=user.identity,
|
|
118
120
|
issued_at=issued_at,
|
|
119
121
|
expiration_time=input.expires_at or None,
|
|
120
122
|
attributes=ApiKeyAttributes(
|
|
121
|
-
user_role=
|
|
123
|
+
user_role=user_role,
|
|
122
124
|
name=input.name,
|
|
123
125
|
description=input.description,
|
|
124
126
|
),
|
|
@@ -126,18 +128,11 @@ class ApiKeyMutationMixin:
|
|
|
126
128
|
token, token_id = await token_store.create_api_key(claims)
|
|
127
129
|
return CreateUserApiKeyMutationPayload(
|
|
128
130
|
jwt=token,
|
|
129
|
-
api_key=UserApiKey(
|
|
130
|
-
id_attr=int(token_id),
|
|
131
|
-
name=input.name,
|
|
132
|
-
description=input.description or None,
|
|
133
|
-
created_at=issued_at,
|
|
134
|
-
expires_at=input.expires_at or None,
|
|
135
|
-
user_id=int(user.identity),
|
|
136
|
-
),
|
|
131
|
+
api_key=UserApiKey(id=int(token_id)),
|
|
137
132
|
query=Query(),
|
|
138
133
|
)
|
|
139
134
|
|
|
140
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
|
|
135
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsAdmin]) # type: ignore
|
|
141
136
|
async def delete_system_api_key(
|
|
142
137
|
self, info: Info[Context, None], input: DeleteApiKeyInput
|
|
143
138
|
) -> DeleteApiKeyMutationPayload:
|
|
@@ -4,7 +4,7 @@ from dataclasses import asdict, field
|
|
|
4
4
|
from datetime import datetime, timezone
|
|
5
5
|
from itertools import chain, islice
|
|
6
6
|
from traceback import format_exc
|
|
7
|
-
from typing import Any, Iterable, Iterator,
|
|
7
|
+
from typing import Any, Iterable, Iterator, Optional, TypeVar, Union
|
|
8
8
|
|
|
9
9
|
import strawberry
|
|
10
10
|
from openinference.instrumentation import safe_json_dumps
|
|
@@ -26,8 +26,11 @@ from typing_extensions import assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
-
from phoenix.db.helpers import
|
|
30
|
-
|
|
29
|
+
from phoenix.db.helpers import (
|
|
30
|
+
get_dataset_example_revisions,
|
|
31
|
+
insert_experiment_with_examples_snapshot,
|
|
32
|
+
)
|
|
33
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
31
34
|
from phoenix.server.api.context import Context
|
|
32
35
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
33
36
|
from phoenix.server.api.helpers.dataset_helpers import get_dataset_example_output
|
|
@@ -46,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
46
49
|
llm_tools,
|
|
47
50
|
prompt_metadata,
|
|
48
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
49
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
50
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
51
55
|
ChatCompletionInput,
|
|
@@ -80,7 +84,7 @@ logger = logging.getLogger(__name__)
|
|
|
80
84
|
|
|
81
85
|
initialize_playground_clients()
|
|
82
86
|
|
|
83
|
-
ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[
|
|
87
|
+
ChatCompletionMessage = tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[Any]]]
|
|
84
88
|
|
|
85
89
|
|
|
86
90
|
@strawberry.type
|
|
@@ -96,24 +100,25 @@ class ChatCompletionToolCall:
|
|
|
96
100
|
|
|
97
101
|
|
|
98
102
|
@strawberry.type
|
|
99
|
-
class
|
|
100
|
-
|
|
103
|
+
class ChatCompletionRepetition:
|
|
104
|
+
repetition_number: int
|
|
101
105
|
content: Optional[str]
|
|
102
|
-
tool_calls:
|
|
103
|
-
span: Span
|
|
106
|
+
tool_calls: list[ChatCompletionToolCall]
|
|
107
|
+
span: Optional[Span]
|
|
104
108
|
error_message: Optional[str]
|
|
105
109
|
|
|
106
110
|
|
|
107
111
|
@strawberry.type
|
|
108
|
-
class
|
|
109
|
-
|
|
112
|
+
class ChatCompletionMutationPayload:
|
|
113
|
+
repetitions: list[ChatCompletionRepetition]
|
|
110
114
|
|
|
111
115
|
|
|
112
116
|
@strawberry.type
|
|
113
117
|
class ChatCompletionOverDatasetMutationExamplePayload:
|
|
114
118
|
dataset_example_id: GlobalID
|
|
119
|
+
repetition_number: int
|
|
115
120
|
experiment_run_id: GlobalID
|
|
116
|
-
|
|
121
|
+
repetition: ChatCompletionRepetition
|
|
117
122
|
|
|
118
123
|
|
|
119
124
|
@strawberry.type
|
|
@@ -126,7 +131,7 @@ class ChatCompletionOverDatasetMutationPayload:
|
|
|
126
131
|
|
|
127
132
|
@strawberry.type
|
|
128
133
|
class ChatCompletionMutationMixin:
|
|
129
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
134
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
130
135
|
@classmethod
|
|
131
136
|
async def chat_completion_over_dataset(
|
|
132
137
|
cls,
|
|
@@ -181,16 +186,26 @@ class ChatCompletionMutationMixin:
|
|
|
181
186
|
raise NotFound("No versions found for the given dataset")
|
|
182
187
|
else:
|
|
183
188
|
resolved_version_id = dataset_version_id
|
|
189
|
+
# Parse split IDs if provided
|
|
190
|
+
resolved_split_ids: Optional[list[int]] = None
|
|
191
|
+
if input.split_ids is not None and len(input.split_ids) > 0:
|
|
192
|
+
resolved_split_ids = [
|
|
193
|
+
from_global_id_with_expected_type(split_id, models.DatasetSplit.__name__)
|
|
194
|
+
for split_id in input.split_ids
|
|
195
|
+
]
|
|
196
|
+
|
|
184
197
|
revisions = [
|
|
185
198
|
revision
|
|
186
199
|
async for revision in await session.stream_scalars(
|
|
187
|
-
get_dataset_example_revisions(
|
|
188
|
-
|
|
189
|
-
|
|
200
|
+
get_dataset_example_revisions(
|
|
201
|
+
resolved_version_id,
|
|
202
|
+
split_ids=resolved_split_ids,
|
|
203
|
+
).order_by(models.DatasetExampleRevision.id)
|
|
190
204
|
)
|
|
191
205
|
]
|
|
192
206
|
if not revisions:
|
|
193
207
|
raise NotFound("No examples found for the given dataset and version")
|
|
208
|
+
user_id = get_user(info)
|
|
194
209
|
experiment = models.Experiment(
|
|
195
210
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
196
211
|
dataset_version_id=resolved_version_id,
|
|
@@ -200,14 +215,24 @@ class ChatCompletionMutationMixin:
|
|
|
200
215
|
repetitions=1,
|
|
201
216
|
metadata_=input.experiment_metadata or dict(),
|
|
202
217
|
project_name=project_name,
|
|
218
|
+
user_id=user_id,
|
|
203
219
|
)
|
|
204
|
-
|
|
205
|
-
|
|
220
|
+
if resolved_split_ids:
|
|
221
|
+
experiment.experiment_dataset_splits = [
|
|
222
|
+
models.ExperimentDatasetSplit(dataset_split_id=split_id)
|
|
223
|
+
for split_id in resolved_split_ids
|
|
224
|
+
]
|
|
225
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
206
226
|
|
|
207
|
-
results: list[Union[
|
|
227
|
+
results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
|
|
208
228
|
batch_size = 3
|
|
209
229
|
start_time = datetime.now(timezone.utc)
|
|
210
|
-
|
|
230
|
+
unbatched_items = [
|
|
231
|
+
(revision, repetition_number)
|
|
232
|
+
for revision in revisions
|
|
233
|
+
for repetition_number in range(1, input.repetitions + 1)
|
|
234
|
+
]
|
|
235
|
+
for batch in _get_batches(unbatched_items, batch_size):
|
|
211
236
|
batch_results = await asyncio.gather(
|
|
212
237
|
*(
|
|
213
238
|
cls._chat_completion(
|
|
@@ -224,10 +249,12 @@ class ChatCompletionMutationMixin:
|
|
|
224
249
|
variables=revision.input,
|
|
225
250
|
),
|
|
226
251
|
prompt_name=input.prompt_name,
|
|
252
|
+
repetitions=repetition_number,
|
|
227
253
|
),
|
|
254
|
+
repetition_number=repetition_number,
|
|
228
255
|
project_name=project_name,
|
|
229
256
|
)
|
|
230
|
-
for revision in batch
|
|
257
|
+
for revision, repetition_number in batch
|
|
231
258
|
),
|
|
232
259
|
return_exceptions=True,
|
|
233
260
|
)
|
|
@@ -239,19 +266,19 @@ class ChatCompletionMutationMixin:
|
|
|
239
266
|
experiment_id=GlobalID(models.Experiment.__name__, str(experiment.id)),
|
|
240
267
|
)
|
|
241
268
|
experiment_runs = []
|
|
242
|
-
for revision, result in zip(
|
|
269
|
+
for (revision, repetition_number), result in zip(unbatched_items, results):
|
|
243
270
|
if isinstance(result, BaseException):
|
|
244
271
|
experiment_run = models.ExperimentRun(
|
|
245
272
|
experiment_id=experiment.id,
|
|
246
273
|
dataset_example_id=revision.dataset_example_id,
|
|
247
274
|
output={},
|
|
248
|
-
repetition_number=
|
|
275
|
+
repetition_number=repetition_number,
|
|
249
276
|
start_time=start_time,
|
|
250
277
|
end_time=start_time,
|
|
251
278
|
error=str(result),
|
|
252
279
|
)
|
|
253
280
|
else:
|
|
254
|
-
db_span
|
|
281
|
+
repetition, db_span = result
|
|
255
282
|
experiment_run = models.ExperimentRun(
|
|
256
283
|
experiment_id=experiment.id,
|
|
257
284
|
dataset_example_id=revision.dataset_example_id,
|
|
@@ -261,10 +288,10 @@ class ChatCompletionMutationMixin:
|
|
|
261
288
|
),
|
|
262
289
|
prompt_token_count=db_span.cumulative_llm_token_count_prompt,
|
|
263
290
|
completion_token_count=db_span.cumulative_llm_token_count_completion,
|
|
264
|
-
repetition_number=
|
|
291
|
+
repetition_number=repetition_number,
|
|
265
292
|
start_time=db_span.start_time,
|
|
266
293
|
end_time=db_span.end_time,
|
|
267
|
-
error=str(
|
|
294
|
+
error=str(repetition.error_message) if repetition.error_message else None,
|
|
268
295
|
)
|
|
269
296
|
experiment_runs.append(experiment_run)
|
|
270
297
|
|
|
@@ -272,22 +299,31 @@ class ChatCompletionMutationMixin:
|
|
|
272
299
|
session.add_all(experiment_runs)
|
|
273
300
|
await session.flush()
|
|
274
301
|
|
|
275
|
-
for revision, experiment_run, result in zip(
|
|
302
|
+
for (revision, repetition_number), experiment_run, result in zip(
|
|
303
|
+
unbatched_items, experiment_runs, results
|
|
304
|
+
):
|
|
276
305
|
dataset_example_id = GlobalID(
|
|
277
306
|
models.DatasetExample.__name__, str(revision.dataset_example_id)
|
|
278
307
|
)
|
|
279
308
|
experiment_run_id = GlobalID(models.ExperimentRun.__name__, str(experiment_run.id))
|
|
280
309
|
example_payload = ChatCompletionOverDatasetMutationExamplePayload(
|
|
281
310
|
dataset_example_id=dataset_example_id,
|
|
311
|
+
repetition_number=repetition_number,
|
|
282
312
|
experiment_run_id=experiment_run_id,
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
313
|
+
repetition=ChatCompletionRepetition(
|
|
314
|
+
repetition_number=repetition_number,
|
|
315
|
+
content=None,
|
|
316
|
+
tool_calls=[],
|
|
317
|
+
span=None,
|
|
318
|
+
error_message=str(result),
|
|
319
|
+
)
|
|
320
|
+
if isinstance(result, BaseException)
|
|
321
|
+
else result[0],
|
|
286
322
|
)
|
|
287
323
|
payload.examples.append(example_payload)
|
|
288
324
|
return payload
|
|
289
325
|
|
|
290
|
-
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
326
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
291
327
|
@classmethod
|
|
292
328
|
async def chat_completion(
|
|
293
329
|
cls, info: Info[Context, None], input: ChatCompletionInput
|
|
@@ -316,7 +352,38 @@ class ChatCompletionMutationMixin:
|
|
|
316
352
|
f"Failed to connect to LLM API for {provider_key.value} {input.model.name}: "
|
|
317
353
|
f"{str(error)}"
|
|
318
354
|
)
|
|
319
|
-
|
|
355
|
+
|
|
356
|
+
results: list[Union[tuple[ChatCompletionRepetition, models.Span], BaseException]] = []
|
|
357
|
+
batch_size = 3
|
|
358
|
+
for batch in _get_batches(range(1, input.repetitions + 1), batch_size):
|
|
359
|
+
batch_results = await asyncio.gather(
|
|
360
|
+
*(
|
|
361
|
+
cls._chat_completion(
|
|
362
|
+
info, llm_client, input, repetition_number=repetition_number
|
|
363
|
+
)
|
|
364
|
+
for repetition_number in batch
|
|
365
|
+
),
|
|
366
|
+
return_exceptions=True,
|
|
367
|
+
)
|
|
368
|
+
results.extend(batch_results)
|
|
369
|
+
|
|
370
|
+
repetitions: list[ChatCompletionRepetition] = []
|
|
371
|
+
for repetition_number, result in enumerate(results, start=1):
|
|
372
|
+
if isinstance(result, BaseException):
|
|
373
|
+
repetitions.append(
|
|
374
|
+
ChatCompletionRepetition(
|
|
375
|
+
repetition_number=repetition_number,
|
|
376
|
+
content=None,
|
|
377
|
+
tool_calls=[],
|
|
378
|
+
span=None,
|
|
379
|
+
error_message=str(result),
|
|
380
|
+
)
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
repetition, _ = result
|
|
384
|
+
repetitions.append(repetition)
|
|
385
|
+
|
|
386
|
+
return ChatCompletionMutationPayload(repetitions=repetitions)
|
|
320
387
|
|
|
321
388
|
@classmethod
|
|
322
389
|
async def _chat_completion(
|
|
@@ -324,9 +391,10 @@ class ChatCompletionMutationMixin:
|
|
|
324
391
|
info: Info[Context, None],
|
|
325
392
|
llm_client: PlaygroundStreamingClient,
|
|
326
393
|
input: ChatCompletionInput,
|
|
394
|
+
repetition_number: int,
|
|
327
395
|
project_name: str = PLAYGROUND_PROJECT_NAME,
|
|
328
396
|
project_description: str = "Traces from prompt playground",
|
|
329
|
-
) ->
|
|
397
|
+
) -> tuple[ChatCompletionRepetition, models.Span]:
|
|
330
398
|
attributes: dict[str, Any] = {}
|
|
331
399
|
attributes.update(dict(prompt_metadata(input.prompt_name)))
|
|
332
400
|
|
|
@@ -473,26 +541,27 @@ class ChatCompletionMutationMixin:
|
|
|
473
541
|
session.add(span_cost)
|
|
474
542
|
await session.flush()
|
|
475
543
|
|
|
476
|
-
gql_span = Span(
|
|
544
|
+
gql_span = Span(id=span.id, db_record=span)
|
|
477
545
|
|
|
478
546
|
info.context.event_queue.put(SpanInsertEvent(ids=(project_id,)))
|
|
479
547
|
|
|
480
548
|
if status_code is StatusCode.ERROR:
|
|
481
|
-
|
|
482
|
-
|
|
549
|
+
repetition = ChatCompletionRepetition(
|
|
550
|
+
repetition_number=repetition_number,
|
|
483
551
|
content=None,
|
|
484
552
|
tool_calls=[],
|
|
485
553
|
span=gql_span,
|
|
486
554
|
error_message=status_message,
|
|
487
555
|
)
|
|
488
556
|
else:
|
|
489
|
-
|
|
490
|
-
|
|
557
|
+
repetition = ChatCompletionRepetition(
|
|
558
|
+
repetition_number=repetition_number,
|
|
491
559
|
content=text_content if text_content else None,
|
|
492
560
|
tool_calls=list(tool_calls.values()),
|
|
493
561
|
span=gql_span,
|
|
494
562
|
error_message=None,
|
|
495
563
|
)
|
|
564
|
+
return repetition, span
|
|
496
565
|
|
|
497
566
|
|
|
498
567
|
def _formatted_messages(
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import sqlalchemy
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import delete, select
|
|
6
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
7
|
+
from sqlalchemy.orm import joinedload
|
|
8
|
+
from sqlalchemy.sql import tuple_
|
|
9
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
10
|
+
from strawberry import UNSET
|
|
11
|
+
from strawberry.relay.types import GlobalID
|
|
12
|
+
from strawberry.types import Info
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
16
|
+
from phoenix.server.api.context import Context
|
|
17
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
18
|
+
from phoenix.server.api.queries import Query
|
|
19
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
20
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel
|
|
21
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@strawberry.input
|
|
25
|
+
class CreateDatasetLabelInput:
|
|
26
|
+
name: str
|
|
27
|
+
description: Optional[str] = UNSET
|
|
28
|
+
color: str
|
|
29
|
+
dataset_ids: Optional[list[GlobalID]] = UNSET
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@strawberry.type
|
|
33
|
+
class CreateDatasetLabelMutationPayload:
|
|
34
|
+
dataset_label: DatasetLabel
|
|
35
|
+
datasets: list[Dataset]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@strawberry.input
|
|
39
|
+
class DeleteDatasetLabelsInput:
|
|
40
|
+
dataset_label_ids: list[GlobalID]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@strawberry.type
|
|
44
|
+
class DeleteDatasetLabelsMutationPayload:
|
|
45
|
+
dataset_labels: list[DatasetLabel]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@strawberry.input
|
|
49
|
+
class SetDatasetLabelsInput:
|
|
50
|
+
dataset_id: GlobalID
|
|
51
|
+
dataset_label_ids: list[GlobalID]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@strawberry.type
|
|
55
|
+
class SetDatasetLabelsMutationPayload:
|
|
56
|
+
query: Query
|
|
57
|
+
dataset: Dataset
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@strawberry.type
|
|
61
|
+
class DatasetLabelMutationMixin:
|
|
62
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
63
|
+
async def create_dataset_label(
|
|
64
|
+
self,
|
|
65
|
+
info: Info[Context, None],
|
|
66
|
+
input: CreateDatasetLabelInput,
|
|
67
|
+
) -> CreateDatasetLabelMutationPayload:
|
|
68
|
+
name = input.name
|
|
69
|
+
description = input.description
|
|
70
|
+
color = input.color
|
|
71
|
+
dataset_rowids: dict[
|
|
72
|
+
int, None
|
|
73
|
+
] = {} # use dictionary to de-duplicate while preserving order
|
|
74
|
+
if input.dataset_ids:
|
|
75
|
+
for dataset_id in input.dataset_ids:
|
|
76
|
+
try:
|
|
77
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_id, Dataset.__name__)
|
|
78
|
+
except ValueError:
|
|
79
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_id}")
|
|
80
|
+
dataset_rowids[dataset_rowid] = None
|
|
81
|
+
|
|
82
|
+
async with info.context.db() as session:
|
|
83
|
+
dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
|
|
84
|
+
session.add(dataset_label_orm)
|
|
85
|
+
try:
|
|
86
|
+
await session.flush()
|
|
87
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
88
|
+
raise Conflict(f"A dataset label named '{name}' already exists")
|
|
89
|
+
except sqlalchemy.exc.StatementError as error:
|
|
90
|
+
raise BadRequest(str(error.orig))
|
|
91
|
+
|
|
92
|
+
datasets_by_id: dict[int, models.Dataset] = {}
|
|
93
|
+
if dataset_rowids:
|
|
94
|
+
datasets_by_id = {
|
|
95
|
+
dataset.id: dataset
|
|
96
|
+
for dataset in await session.scalars(
|
|
97
|
+
select(models.Dataset).where(models.Dataset.id.in_(dataset_rowids.keys()))
|
|
98
|
+
)
|
|
99
|
+
}
|
|
100
|
+
if len(datasets_by_id) < len(dataset_rowids):
|
|
101
|
+
raise NotFound("One or more datasets not found")
|
|
102
|
+
session.add_all(
|
|
103
|
+
[
|
|
104
|
+
models.DatasetsDatasetLabel(
|
|
105
|
+
dataset_id=dataset_rowid,
|
|
106
|
+
dataset_label_id=dataset_label_orm.id,
|
|
107
|
+
)
|
|
108
|
+
for dataset_rowid in dataset_rowids
|
|
109
|
+
]
|
|
110
|
+
)
|
|
111
|
+
await session.commit()
|
|
112
|
+
|
|
113
|
+
return CreateDatasetLabelMutationPayload(
|
|
114
|
+
dataset_label=DatasetLabel(id=dataset_label_orm.id, db_record=dataset_label_orm),
|
|
115
|
+
datasets=[
|
|
116
|
+
Dataset(
|
|
117
|
+
id=datasets_by_id[dataset_rowid].id, db_record=datasets_by_id[dataset_rowid]
|
|
118
|
+
)
|
|
119
|
+
for dataset_rowid in dataset_rowids
|
|
120
|
+
],
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
124
|
+
async def delete_dataset_labels(
|
|
125
|
+
self, info: Info[Context, None], input: DeleteDatasetLabelsInput
|
|
126
|
+
) -> DeleteDatasetLabelsMutationPayload:
|
|
127
|
+
dataset_label_row_ids: dict[int, None] = {}
|
|
128
|
+
for dataset_label_node_id in input.dataset_label_ids:
|
|
129
|
+
try:
|
|
130
|
+
dataset_label_row_id = from_global_id_with_expected_type(
|
|
131
|
+
dataset_label_node_id, DatasetLabel.__name__
|
|
132
|
+
)
|
|
133
|
+
except ValueError:
|
|
134
|
+
raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
|
|
135
|
+
dataset_label_row_ids[dataset_label_row_id] = None
|
|
136
|
+
async with info.context.db() as session:
|
|
137
|
+
stmt = (
|
|
138
|
+
delete(models.DatasetLabel)
|
|
139
|
+
.where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
|
|
140
|
+
.returning(models.DatasetLabel)
|
|
141
|
+
)
|
|
142
|
+
deleted_dataset_labels = (await session.scalars(stmt)).all()
|
|
143
|
+
if len(deleted_dataset_labels) < len(dataset_label_row_ids):
|
|
144
|
+
await session.rollback()
|
|
145
|
+
raise NotFound("Could not find one or more dataset labels with given IDs")
|
|
146
|
+
deleted_dataset_labels_by_id = {
|
|
147
|
+
dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
|
|
148
|
+
}
|
|
149
|
+
return DeleteDatasetLabelsMutationPayload(
|
|
150
|
+
dataset_labels=[
|
|
151
|
+
DatasetLabel(
|
|
152
|
+
id=deleted_dataset_labels_by_id[dataset_label_row_id].id,
|
|
153
|
+
db_record=deleted_dataset_labels_by_id[dataset_label_row_id],
|
|
154
|
+
)
|
|
155
|
+
for dataset_label_row_id in dataset_label_row_ids
|
|
156
|
+
]
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
160
|
+
async def set_dataset_labels(
|
|
161
|
+
self, info: Info[Context, None], input: SetDatasetLabelsInput
|
|
162
|
+
) -> SetDatasetLabelsMutationPayload:
|
|
163
|
+
try:
|
|
164
|
+
dataset_id = from_global_id_with_expected_type(input.dataset_id, Dataset.__name__)
|
|
165
|
+
except ValueError:
|
|
166
|
+
raise BadRequest(f"Invalid dataset ID: {input.dataset_id}")
|
|
167
|
+
|
|
168
|
+
dataset_label_ids: dict[
|
|
169
|
+
int, None
|
|
170
|
+
] = {} # use dictionary to de-duplicate while preserving order
|
|
171
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
172
|
+
try:
|
|
173
|
+
dataset_label_id = from_global_id_with_expected_type(
|
|
174
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
175
|
+
)
|
|
176
|
+
except ValueError:
|
|
177
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
178
|
+
dataset_label_ids[dataset_label_id] = None
|
|
179
|
+
|
|
180
|
+
async with info.context.db() as session:
|
|
181
|
+
dataset = await session.scalar(
|
|
182
|
+
select(models.Dataset)
|
|
183
|
+
.where(models.Dataset.id == dataset_id)
|
|
184
|
+
.options(joinedload(models.Dataset.datasets_dataset_labels))
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if not dataset:
|
|
188
|
+
raise NotFound(f"Dataset with ID {input.dataset_id} not found")
|
|
189
|
+
|
|
190
|
+
existing_label_ids = (
|
|
191
|
+
await session.scalars(
|
|
192
|
+
select(models.DatasetLabel.id).where(
|
|
193
|
+
models.DatasetLabel.id.in_(dataset_label_ids.keys())
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
).all()
|
|
197
|
+
if len(existing_label_ids) != len(dataset_label_ids):
|
|
198
|
+
raise NotFound("One or more dataset labels not found")
|
|
199
|
+
|
|
200
|
+
previously_applied_dataset_label_ids = {
|
|
201
|
+
dataset_dataset_label.dataset_label_id
|
|
202
|
+
for dataset_dataset_label in dataset.datasets_dataset_labels
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
datasets_dataset_labels_to_add = [
|
|
206
|
+
models.DatasetsDatasetLabel(
|
|
207
|
+
dataset_id=dataset_id,
|
|
208
|
+
dataset_label_id=dataset_label_id,
|
|
209
|
+
)
|
|
210
|
+
for dataset_label_id in dataset_label_ids
|
|
211
|
+
if dataset_label_id not in previously_applied_dataset_label_ids
|
|
212
|
+
]
|
|
213
|
+
if datasets_dataset_labels_to_add:
|
|
214
|
+
session.add_all(datasets_dataset_labels_to_add)
|
|
215
|
+
await session.flush()
|
|
216
|
+
|
|
217
|
+
datasets_dataset_labels_to_delete = [
|
|
218
|
+
dataset_dataset_label
|
|
219
|
+
for dataset_dataset_label in dataset.datasets_dataset_labels
|
|
220
|
+
if dataset_dataset_label.dataset_label_id not in dataset_label_ids
|
|
221
|
+
]
|
|
222
|
+
if datasets_dataset_labels_to_delete:
|
|
223
|
+
await session.execute(
|
|
224
|
+
delete(models.DatasetsDatasetLabel).where(
|
|
225
|
+
tuple_(
|
|
226
|
+
models.DatasetsDatasetLabel.dataset_id,
|
|
227
|
+
models.DatasetsDatasetLabel.dataset_label_id,
|
|
228
|
+
).in_(
|
|
229
|
+
[
|
|
230
|
+
(
|
|
231
|
+
datasets_dataset_labels.dataset_id,
|
|
232
|
+
datasets_dataset_labels.dataset_label_id,
|
|
233
|
+
)
|
|
234
|
+
for datasets_dataset_labels in datasets_dataset_labels_to_delete
|
|
235
|
+
]
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
return SetDatasetLabelsMutationPayload(
|
|
241
|
+
dataset=Dataset(id=dataset.id, db_record=dataset),
|
|
242
|
+
query=Query(),
|
|
243
|
+
)
|