arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/document_annotation.py +1 -1
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +1 -1
- phoenix/db/insertion/trace_annotation.py +1 -1
- phoenix/db/insertion/types.py +29 -3
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +306 -46
- phoenix/server/api/context.py +15 -2
- phoenix/server/api/dataloaders/__init__.py +8 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +66 -35
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/chat_mutations.py +8 -3
- phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +53 -0
- phoenix/server/api/routers/auth.py +5 -5
- phoenix/server/api/routers/oauth2.py +5 -23
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +5 -2
- phoenix/server/api/types/Dataset.py +8 -0
- phoenix/server/api/types/DatasetExample.py +18 -0
- phoenix/server/api/types/DatasetLabel.py +23 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Prompt.py +18 -1
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +13 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
- phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
- phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
- phoenix/server/utils.py +74 -0
- phoenix/session/session.py +25 -5
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,11 +5,16 @@ from phoenix.server.api.mutations.api_key_mutations import ApiKeyMutationMixin
|
|
|
5
5
|
from phoenix.server.api.mutations.chat_mutations import (
|
|
6
6
|
ChatCompletionMutationMixin,
|
|
7
7
|
)
|
|
8
|
+
from phoenix.server.api.mutations.dataset_label_mutations import DatasetLabelMutationMixin
|
|
8
9
|
from phoenix.server.api.mutations.dataset_mutations import DatasetMutationMixin
|
|
10
|
+
from phoenix.server.api.mutations.dataset_split_mutations import DatasetSplitMutationMixin
|
|
9
11
|
from phoenix.server.api.mutations.experiment_mutations import ExperimentMutationMixin
|
|
10
12
|
from phoenix.server.api.mutations.export_events_mutations import ExportEventsMutationMixin
|
|
11
13
|
from phoenix.server.api.mutations.model_mutations import ModelMutationMixin
|
|
12
14
|
from phoenix.server.api.mutations.project_mutations import ProjectMutationMixin
|
|
15
|
+
from phoenix.server.api.mutations.project_session_annotations_mutations import (
|
|
16
|
+
ProjectSessionAnnotationMutationMixin,
|
|
17
|
+
)
|
|
13
18
|
from phoenix.server.api.mutations.project_trace_retention_policy_mutations import (
|
|
14
19
|
ProjectTraceRetentionPolicyMutationMixin,
|
|
15
20
|
)
|
|
@@ -27,7 +32,9 @@ class Mutation(
|
|
|
27
32
|
AnnotationConfigMutationMixin,
|
|
28
33
|
ApiKeyMutationMixin,
|
|
29
34
|
ChatCompletionMutationMixin,
|
|
35
|
+
DatasetLabelMutationMixin,
|
|
30
36
|
DatasetMutationMixin,
|
|
37
|
+
DatasetSplitMutationMixin,
|
|
31
38
|
ExperimentMutationMixin,
|
|
32
39
|
ExportEventsMutationMixin,
|
|
33
40
|
ModelMutationMixin,
|
|
@@ -37,6 +44,7 @@ class Mutation(
|
|
|
37
44
|
PromptVersionTagMutationMixin,
|
|
38
45
|
PromptLabelMutationMixin,
|
|
39
46
|
SpanAnnotationMutationMixin,
|
|
47
|
+
ProjectSessionAnnotationMutationMixin,
|
|
40
48
|
TraceAnnotationMutationMixin,
|
|
41
49
|
TraceMutationMixin,
|
|
42
50
|
UserMutationMixin,
|
|
@@ -26,7 +26,10 @@ from typing_extensions import assert_never
|
|
|
26
26
|
from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
|
-
from phoenix.db.helpers import
|
|
29
|
+
from phoenix.db.helpers import (
|
|
30
|
+
get_dataset_example_revisions,
|
|
31
|
+
insert_experiment_with_examples_snapshot,
|
|
32
|
+
)
|
|
30
33
|
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
31
34
|
from phoenix.server.api.context import Context
|
|
32
35
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
@@ -46,6 +49,7 @@ from phoenix.server.api.helpers.playground_spans import (
|
|
|
46
49
|
llm_tools,
|
|
47
50
|
prompt_metadata,
|
|
48
51
|
)
|
|
52
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
49
53
|
from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
|
|
50
54
|
from phoenix.server.api.input_types.ChatCompletionInput import (
|
|
51
55
|
ChatCompletionInput,
|
|
@@ -192,6 +196,7 @@ class ChatCompletionMutationMixin:
|
|
|
192
196
|
]
|
|
193
197
|
if not revisions:
|
|
194
198
|
raise NotFound("No examples found for the given dataset and version")
|
|
199
|
+
user_id = get_user(info)
|
|
195
200
|
experiment = models.Experiment(
|
|
196
201
|
dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
|
|
197
202
|
dataset_version_id=resolved_version_id,
|
|
@@ -201,9 +206,9 @@ class ChatCompletionMutationMixin:
|
|
|
201
206
|
repetitions=1,
|
|
202
207
|
metadata_=input.experiment_metadata or dict(),
|
|
203
208
|
project_name=project_name,
|
|
209
|
+
user_id=user_id,
|
|
204
210
|
)
|
|
205
|
-
session
|
|
206
|
-
await session.flush()
|
|
211
|
+
await insert_experiment_with_examples_snapshot(session, experiment)
|
|
207
212
|
|
|
208
213
|
results: list[Union[ChatCompletionMutationPayload, BaseException]] = []
|
|
209
214
|
batch_size = 3
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import sqlalchemy
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import delete, select
|
|
6
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
7
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
8
|
+
from strawberry import UNSET
|
|
9
|
+
from strawberry.relay.types import GlobalID
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
14
|
+
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
|
+
from phoenix.server.api.queries import Query
|
|
17
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
18
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
19
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@strawberry.input
|
|
23
|
+
class CreateDatasetLabelInput:
|
|
24
|
+
name: str
|
|
25
|
+
description: Optional[str] = UNSET
|
|
26
|
+
color: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@strawberry.type
|
|
30
|
+
class CreateDatasetLabelMutationPayload:
|
|
31
|
+
dataset_label: DatasetLabel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@strawberry.input
|
|
35
|
+
class DeleteDatasetLabelsInput:
|
|
36
|
+
dataset_label_ids: list[GlobalID]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@strawberry.type
|
|
40
|
+
class DeleteDatasetLabelsMutationPayload:
|
|
41
|
+
dataset_labels: list[DatasetLabel]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@strawberry.input
|
|
45
|
+
class UpdateDatasetLabelInput:
|
|
46
|
+
dataset_label_id: GlobalID
|
|
47
|
+
name: str
|
|
48
|
+
description: Optional[str] = None
|
|
49
|
+
color: str
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@strawberry.type
|
|
53
|
+
class UpdateDatasetLabelMutationPayload:
|
|
54
|
+
dataset_label: DatasetLabel
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.input
|
|
58
|
+
class SetDatasetLabelsInput:
|
|
59
|
+
dataset_label_ids: list[GlobalID]
|
|
60
|
+
dataset_ids: list[GlobalID]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@strawberry.type
|
|
64
|
+
class SetDatasetLabelsMutationPayload:
|
|
65
|
+
query: "Query"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@strawberry.input
|
|
69
|
+
class UnsetDatasetLabelsInput:
|
|
70
|
+
dataset_label_ids: list[GlobalID]
|
|
71
|
+
dataset_ids: list[GlobalID]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@strawberry.type
|
|
75
|
+
class UnsetDatasetLabelsMutationPayload:
|
|
76
|
+
query: "Query"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class DatasetLabelMutationMixin:
|
|
81
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
82
|
+
async def create_dataset_label(
|
|
83
|
+
self,
|
|
84
|
+
info: Info[Context, None],
|
|
85
|
+
input: CreateDatasetLabelInput,
|
|
86
|
+
) -> CreateDatasetLabelMutationPayload:
|
|
87
|
+
name = input.name
|
|
88
|
+
description = input.description
|
|
89
|
+
color = input.color
|
|
90
|
+
async with info.context.db() as session:
|
|
91
|
+
dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
|
|
92
|
+
session.add(dataset_label_orm)
|
|
93
|
+
try:
|
|
94
|
+
await session.commit()
|
|
95
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
96
|
+
raise Conflict(f"A dataset label named '{name}' already exists")
|
|
97
|
+
except sqlalchemy.exc.StatementError as error:
|
|
98
|
+
raise BadRequest(str(error.orig))
|
|
99
|
+
return CreateDatasetLabelMutationPayload(
|
|
100
|
+
dataset_label=to_gql_dataset_label(dataset_label_orm)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
104
|
+
async def update_dataset_label(
|
|
105
|
+
self, info: Info[Context, None], input: UpdateDatasetLabelInput
|
|
106
|
+
) -> UpdateDatasetLabelMutationPayload:
|
|
107
|
+
if not input.name or not input.name.strip():
|
|
108
|
+
raise BadRequest("Dataset label name cannot be empty")
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
dataset_label_id = from_global_id_with_expected_type(
|
|
112
|
+
input.dataset_label_id, DatasetLabel.__name__
|
|
113
|
+
)
|
|
114
|
+
except ValueError:
|
|
115
|
+
raise BadRequest(f"Invalid dataset label ID: {input.dataset_label_id}")
|
|
116
|
+
|
|
117
|
+
async with info.context.db() as session:
|
|
118
|
+
dataset_label_orm = await session.get(models.DatasetLabel, dataset_label_id)
|
|
119
|
+
if not dataset_label_orm:
|
|
120
|
+
raise NotFound(f"DatasetLabel with ID {input.dataset_label_id} not found")
|
|
121
|
+
|
|
122
|
+
dataset_label_orm.name = input.name.strip()
|
|
123
|
+
dataset_label_orm.description = input.description
|
|
124
|
+
dataset_label_orm.color = input.color.strip()
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
await session.commit()
|
|
128
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
129
|
+
raise Conflict(f"A dataset label named '{input.name}' already exists")
|
|
130
|
+
except sqlalchemy.exc.StatementError as error:
|
|
131
|
+
raise BadRequest(str(error.orig))
|
|
132
|
+
return UpdateDatasetLabelMutationPayload(
|
|
133
|
+
dataset_label=to_gql_dataset_label(dataset_label_orm)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
137
|
+
async def delete_dataset_labels(
|
|
138
|
+
self, info: Info[Context, None], input: DeleteDatasetLabelsInput
|
|
139
|
+
) -> DeleteDatasetLabelsMutationPayload:
|
|
140
|
+
dataset_label_row_ids: dict[int, None] = {}
|
|
141
|
+
for dataset_label_node_id in input.dataset_label_ids:
|
|
142
|
+
try:
|
|
143
|
+
dataset_label_row_id = from_global_id_with_expected_type(
|
|
144
|
+
dataset_label_node_id, DatasetLabel.__name__
|
|
145
|
+
)
|
|
146
|
+
except ValueError:
|
|
147
|
+
raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
|
|
148
|
+
dataset_label_row_ids[dataset_label_row_id] = None
|
|
149
|
+
async with info.context.db() as session:
|
|
150
|
+
stmt = (
|
|
151
|
+
delete(models.DatasetLabel)
|
|
152
|
+
.where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
|
|
153
|
+
.returning(models.DatasetLabel)
|
|
154
|
+
)
|
|
155
|
+
deleted_dataset_labels = (await session.scalars(stmt)).all()
|
|
156
|
+
if len(deleted_dataset_labels) < len(dataset_label_row_ids):
|
|
157
|
+
await session.rollback()
|
|
158
|
+
raise NotFound("Could not find one or more dataset labels with given IDs")
|
|
159
|
+
deleted_dataset_labels_by_id = {
|
|
160
|
+
dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
|
|
161
|
+
}
|
|
162
|
+
return DeleteDatasetLabelsMutationPayload(
|
|
163
|
+
dataset_labels=[
|
|
164
|
+
to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
|
|
165
|
+
for dataset_label_row_id in dataset_label_row_ids
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
170
|
+
async def set_dataset_labels(
|
|
171
|
+
self, info: Info[Context, None], input: SetDatasetLabelsInput
|
|
172
|
+
) -> SetDatasetLabelsMutationPayload:
|
|
173
|
+
if not input.dataset_ids:
|
|
174
|
+
raise BadRequest("No datasets provided.")
|
|
175
|
+
if not input.dataset_label_ids:
|
|
176
|
+
raise BadRequest("No dataset labels provided.")
|
|
177
|
+
|
|
178
|
+
unique_dataset_rowids: set[int] = set()
|
|
179
|
+
for dataset_gid in input.dataset_ids:
|
|
180
|
+
try:
|
|
181
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
|
|
182
|
+
except ValueError:
|
|
183
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
|
|
184
|
+
unique_dataset_rowids.add(dataset_rowid)
|
|
185
|
+
dataset_rowids = list(unique_dataset_rowids)
|
|
186
|
+
|
|
187
|
+
unique_dataset_label_rowids: set[int] = set()
|
|
188
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
189
|
+
try:
|
|
190
|
+
dataset_label_rowid = from_global_id_with_expected_type(
|
|
191
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
192
|
+
)
|
|
193
|
+
except ValueError:
|
|
194
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
195
|
+
unique_dataset_label_rowids.add(dataset_label_rowid)
|
|
196
|
+
dataset_label_rowids = list(unique_dataset_label_rowids)
|
|
197
|
+
|
|
198
|
+
async with info.context.db() as session:
|
|
199
|
+
existing_dataset_ids = (
|
|
200
|
+
await session.scalars(
|
|
201
|
+
select(models.Dataset.id).where(models.Dataset.id.in_(dataset_rowids))
|
|
202
|
+
)
|
|
203
|
+
).all()
|
|
204
|
+
if len(existing_dataset_ids) != len(dataset_rowids):
|
|
205
|
+
raise NotFound("One or more datasets not found")
|
|
206
|
+
|
|
207
|
+
existing_dataset_label_ids = (
|
|
208
|
+
await session.scalars(
|
|
209
|
+
select(models.DatasetLabel.id).where(
|
|
210
|
+
models.DatasetLabel.id.in_(dataset_label_rowids)
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
).all()
|
|
214
|
+
if len(existing_dataset_label_ids) != len(dataset_label_rowids):
|
|
215
|
+
raise NotFound("One or more dataset labels not found")
|
|
216
|
+
|
|
217
|
+
existing_dataset_label_keys = await session.execute(
|
|
218
|
+
select(
|
|
219
|
+
models.DatasetsDatasetLabel.dataset_id,
|
|
220
|
+
models.DatasetsDatasetLabel.dataset_label_id,
|
|
221
|
+
).where(
|
|
222
|
+
models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
|
|
223
|
+
& models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
unique_dataset_label_keys = set(existing_dataset_label_keys.all())
|
|
227
|
+
|
|
228
|
+
datasets_dataset_labels = []
|
|
229
|
+
for dataset_rowid in dataset_rowids:
|
|
230
|
+
for dataset_label_rowid in dataset_label_rowids:
|
|
231
|
+
if (dataset_rowid, dataset_label_rowid) in unique_dataset_label_keys:
|
|
232
|
+
continue
|
|
233
|
+
datasets_dataset_labels.append(
|
|
234
|
+
models.DatasetsDatasetLabel(
|
|
235
|
+
dataset_id=dataset_rowid,
|
|
236
|
+
dataset_label_id=dataset_label_rowid,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
session.add_all(datasets_dataset_labels)
|
|
240
|
+
|
|
241
|
+
if datasets_dataset_labels:
|
|
242
|
+
try:
|
|
243
|
+
await session.commit()
|
|
244
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
245
|
+
raise Conflict("Failed to add dataset labels to datasets.") from e
|
|
246
|
+
|
|
247
|
+
return SetDatasetLabelsMutationPayload(
|
|
248
|
+
query=Query(),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
252
|
+
async def unset_dataset_labels(
|
|
253
|
+
self, info: Info[Context, None], input: UnsetDatasetLabelsInput
|
|
254
|
+
) -> UnsetDatasetLabelsMutationPayload:
|
|
255
|
+
if not input.dataset_ids:
|
|
256
|
+
raise BadRequest("No datasets provided.")
|
|
257
|
+
if not input.dataset_label_ids:
|
|
258
|
+
raise BadRequest("No dataset labels provided.")
|
|
259
|
+
|
|
260
|
+
unique_dataset_rowids: set[int] = set()
|
|
261
|
+
for dataset_gid in input.dataset_ids:
|
|
262
|
+
try:
|
|
263
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
|
|
264
|
+
except ValueError:
|
|
265
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
|
|
266
|
+
unique_dataset_rowids.add(dataset_rowid)
|
|
267
|
+
dataset_rowids = list(unique_dataset_rowids)
|
|
268
|
+
|
|
269
|
+
unique_dataset_label_rowids: set[int] = set()
|
|
270
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
271
|
+
try:
|
|
272
|
+
dataset_label_rowid = from_global_id_with_expected_type(
|
|
273
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
274
|
+
)
|
|
275
|
+
except ValueError:
|
|
276
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
277
|
+
unique_dataset_label_rowids.add(dataset_label_rowid)
|
|
278
|
+
dataset_label_rowids = list(unique_dataset_label_rowids)
|
|
279
|
+
|
|
280
|
+
async with info.context.db() as session:
|
|
281
|
+
await session.execute(
|
|
282
|
+
delete(models.DatasetsDatasetLabel).where(
|
|
283
|
+
models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
|
|
284
|
+
& models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
await session.commit()
|
|
288
|
+
|
|
289
|
+
return UnsetDatasetLabelsMutationPayload(
|
|
290
|
+
query=Query(),
|
|
291
|
+
)
|
|
@@ -66,6 +66,7 @@ class DatasetMutationMixin:
|
|
|
66
66
|
name=name,
|
|
67
67
|
description=description,
|
|
68
68
|
metadata_=metadata,
|
|
69
|
+
user_id=info.context.user_id,
|
|
69
70
|
)
|
|
70
71
|
.returning(models.Dataset)
|
|
71
72
|
)
|
|
@@ -136,6 +137,7 @@ class DatasetMutationMixin:
|
|
|
136
137
|
dataset_id=dataset_rowid,
|
|
137
138
|
description=dataset_version_description,
|
|
138
139
|
metadata_=dataset_version_metadata or {},
|
|
140
|
+
user_id=info.context.user_id,
|
|
139
141
|
)
|
|
140
142
|
session.add(dataset_version)
|
|
141
143
|
await session.flush()
|
|
@@ -254,6 +256,7 @@ class DatasetMutationMixin:
|
|
|
254
256
|
dataset_id=dataset_rowid,
|
|
255
257
|
description=dataset_version_description,
|
|
256
258
|
metadata_=dataset_version_metadata,
|
|
259
|
+
user_id=info.context.user_id,
|
|
257
260
|
)
|
|
258
261
|
.returning(models.DatasetVersion.id)
|
|
259
262
|
)
|
|
@@ -451,6 +454,7 @@ class DatasetMutationMixin:
|
|
|
451
454
|
dataset_id=dataset.id,
|
|
452
455
|
description=version_description,
|
|
453
456
|
metadata_=version_metadata,
|
|
457
|
+
user_id=info.context.user_id,
|
|
454
458
|
)
|
|
455
459
|
)
|
|
456
460
|
assert version_id is not None
|
|
@@ -514,6 +518,7 @@ class DatasetMutationMixin:
|
|
|
514
518
|
dataset_id=dataset.id,
|
|
515
519
|
description=dataset_version_description,
|
|
516
520
|
metadata_=dataset_version_metadata,
|
|
521
|
+
user_id=info.context.user_id,
|
|
517
522
|
created_at=timestamp,
|
|
518
523
|
)
|
|
519
524
|
.returning(models.DatasetVersion.id)
|