arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
- phoenix/config.py +1 -11
- phoenix/db/bulk_inserter.py +8 -0
- phoenix/db/facilitator.py +1 -1
- phoenix/db/helpers.py +202 -33
- phoenix/db/insertion/dataset.py +7 -0
- phoenix/db/insertion/document_annotation.py +1 -1
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +1 -1
- phoenix/db/insertion/trace_annotation.py +1 -1
- phoenix/db/insertion/types.py +29 -3
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +306 -46
- phoenix/server/api/context.py +15 -2
- phoenix/server/api/dataloaders/__init__.py +8 -2
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/helpers/playground_clients.py +66 -35
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/chat_mutations.py +8 -3
- phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
- phoenix/server/api/mutations/dataset_mutations.py +5 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
- phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
- phoenix/server/api/queries.py +53 -0
- phoenix/server/api/routers/auth.py +5 -5
- phoenix/server/api/routers/oauth2.py +5 -23
- phoenix/server/api/routers/v1/__init__.py +2 -0
- phoenix/server/api/routers/v1/annotations.py +320 -0
- phoenix/server/api/routers/v1/datasets.py +5 -0
- phoenix/server/api/routers/v1/experiments.py +10 -3
- phoenix/server/api/routers/v1/sessions.py +111 -0
- phoenix/server/api/routers/v1/traces.py +1 -2
- phoenix/server/api/routers/v1/users.py +7 -0
- phoenix/server/api/subscriptions.py +5 -2
- phoenix/server/api/types/Dataset.py +8 -0
- phoenix/server/api/types/DatasetExample.py +18 -0
- phoenix/server/api/types/DatasetLabel.py +23 -0
- phoenix/server/api/types/DatasetSplit.py +32 -0
- phoenix/server/api/types/Experiment.py +0 -4
- phoenix/server/api/types/Project.py +16 -0
- phoenix/server/api/types/ProjectSession.py +88 -3
- phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
- phoenix/server/api/types/Prompt.py +18 -1
- phoenix/server/api/types/Span.py +5 -5
- phoenix/server/api/types/Trace.py +61 -0
- phoenix/server/app.py +13 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
- phoenix/server/dml_event.py +13 -0
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
- phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
- phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
- phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
- phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
- phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
- phoenix/server/utils.py +74 -0
- phoenix/session/session.py +25 -5
- phoenix/version.py +1 -1
- phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy import delete, func, insert, select
|
|
5
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
6
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
9
|
+
from strawberry.scalars import JSON
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
14
|
+
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
|
+
from phoenix.server.api.helpers.playground_users import get_user
|
|
17
|
+
from phoenix.server.api.queries import Query
|
|
18
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
|
|
19
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
20
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@strawberry.input
|
|
24
|
+
class CreateDatasetSplitInput:
|
|
25
|
+
name: str
|
|
26
|
+
description: Optional[str] = UNSET
|
|
27
|
+
color: str
|
|
28
|
+
metadata: Optional[JSON] = UNSET
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@strawberry.input
|
|
32
|
+
class PatchDatasetSplitInput:
|
|
33
|
+
dataset_split_id: GlobalID
|
|
34
|
+
name: Optional[str] = UNSET
|
|
35
|
+
description: Optional[str] = UNSET
|
|
36
|
+
color: Optional[str] = UNSET
|
|
37
|
+
metadata: Optional[JSON] = UNSET
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@strawberry.input
|
|
41
|
+
class DeleteDatasetSplitInput:
|
|
42
|
+
dataset_split_ids: list[GlobalID]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@strawberry.input
|
|
46
|
+
class AddDatasetExamplesToDatasetSplitsInput:
|
|
47
|
+
dataset_split_ids: list[GlobalID]
|
|
48
|
+
example_ids: list[GlobalID]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@strawberry.input
|
|
52
|
+
class RemoveDatasetExamplesFromDatasetSplitsInput:
|
|
53
|
+
dataset_split_ids: list[GlobalID]
|
|
54
|
+
example_ids: list[GlobalID]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.input
|
|
58
|
+
class CreateDatasetSplitWithExamplesInput:
|
|
59
|
+
name: str
|
|
60
|
+
description: Optional[str] = UNSET
|
|
61
|
+
color: str
|
|
62
|
+
metadata: Optional[JSON] = UNSET
|
|
63
|
+
example_ids: list[GlobalID]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@strawberry.type
|
|
67
|
+
class DatasetSplitMutationPayload:
|
|
68
|
+
dataset_split: DatasetSplit
|
|
69
|
+
query: "Query"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@strawberry.type
|
|
73
|
+
class DatasetSplitMutationPayloadWithExamples:
|
|
74
|
+
dataset_split: DatasetSplit
|
|
75
|
+
query: "Query"
|
|
76
|
+
examples: list[DatasetExample]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class DeleteDatasetSplitsMutationPayload:
|
|
81
|
+
dataset_splits: list[DatasetSplit]
|
|
82
|
+
query: "Query"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@strawberry.type
|
|
86
|
+
class AddDatasetExamplesToDatasetSplitsMutationPayload:
|
|
87
|
+
query: "Query"
|
|
88
|
+
examples: list[DatasetExample]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@strawberry.type
|
|
92
|
+
class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
|
|
93
|
+
query: "Query"
|
|
94
|
+
examples: list[DatasetExample]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@strawberry.type
|
|
98
|
+
class DatasetSplitMutationMixin:
|
|
99
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
100
|
+
async def create_dataset_split(
|
|
101
|
+
self, info: Info[Context, None], input: CreateDatasetSplitInput
|
|
102
|
+
) -> DatasetSplitMutationPayload:
|
|
103
|
+
user_id = get_user(info)
|
|
104
|
+
validated_name = _validated_name(input.name)
|
|
105
|
+
async with info.context.db() as session:
|
|
106
|
+
dataset_split_orm = models.DatasetSplit(
|
|
107
|
+
name=validated_name,
|
|
108
|
+
description=input.description,
|
|
109
|
+
color=input.color,
|
|
110
|
+
metadata_=input.metadata or {},
|
|
111
|
+
user_id=user_id,
|
|
112
|
+
)
|
|
113
|
+
session.add(dataset_split_orm)
|
|
114
|
+
try:
|
|
115
|
+
await session.commit()
|
|
116
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
117
|
+
raise Conflict(f"A dataset split named '{input.name}' already exists.")
|
|
118
|
+
return DatasetSplitMutationPayload(
|
|
119
|
+
dataset_split=to_gql_dataset_split(dataset_split_orm), query=Query()
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
123
|
+
async def patch_dataset_split(
|
|
124
|
+
self, info: Info[Context, None], input: PatchDatasetSplitInput
|
|
125
|
+
) -> DatasetSplitMutationPayload:
|
|
126
|
+
validated_name = _validated_name(input.name) if input.name else None
|
|
127
|
+
async with info.context.db() as session:
|
|
128
|
+
dataset_split_id = from_global_id_with_expected_type(
|
|
129
|
+
input.dataset_split_id, DatasetSplit.__name__
|
|
130
|
+
)
|
|
131
|
+
dataset_split_orm = await session.get(models.DatasetSplit, dataset_split_id)
|
|
132
|
+
if not dataset_split_orm:
|
|
133
|
+
raise NotFound(f"Dataset split with ID {input.dataset_split_id} not found")
|
|
134
|
+
|
|
135
|
+
if validated_name:
|
|
136
|
+
dataset_split_orm.name = validated_name
|
|
137
|
+
if input.description:
|
|
138
|
+
dataset_split_orm.description = input.description
|
|
139
|
+
if input.color:
|
|
140
|
+
dataset_split_orm.color = input.color
|
|
141
|
+
if isinstance(input.metadata, dict):
|
|
142
|
+
dataset_split_orm.metadata_ = input.metadata
|
|
143
|
+
|
|
144
|
+
gql_dataset_split = to_gql_dataset_split(dataset_split_orm)
|
|
145
|
+
try:
|
|
146
|
+
await session.commit()
|
|
147
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
148
|
+
raise Conflict("A dataset split with this name already exists")
|
|
149
|
+
|
|
150
|
+
return DatasetSplitMutationPayload(
|
|
151
|
+
dataset_split=gql_dataset_split,
|
|
152
|
+
query=Query(),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
156
|
+
async def delete_dataset_splits(
|
|
157
|
+
self, info: Info[Context, None], input: DeleteDatasetSplitInput
|
|
158
|
+
) -> DeleteDatasetSplitsMutationPayload:
|
|
159
|
+
unique_dataset_split_rowids: dict[int, None] = {} # use a dict to preserve ordering
|
|
160
|
+
for dataset_split_gid in input.dataset_split_ids:
|
|
161
|
+
try:
|
|
162
|
+
dataset_split_rowid = from_global_id_with_expected_type(
|
|
163
|
+
dataset_split_gid, DatasetSplit.__name__
|
|
164
|
+
)
|
|
165
|
+
except ValueError:
|
|
166
|
+
raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
|
|
167
|
+
unique_dataset_split_rowids[dataset_split_rowid] = None
|
|
168
|
+
dataset_split_rowids = list(unique_dataset_split_rowids.keys())
|
|
169
|
+
|
|
170
|
+
async with info.context.db() as session:
|
|
171
|
+
deleted_splits_by_id = {
|
|
172
|
+
split.id: split
|
|
173
|
+
for split in (
|
|
174
|
+
await session.scalars(
|
|
175
|
+
delete(models.DatasetSplit)
|
|
176
|
+
.where(models.DatasetSplit.id.in_(dataset_split_rowids))
|
|
177
|
+
.returning(models.DatasetSplit)
|
|
178
|
+
)
|
|
179
|
+
).all()
|
|
180
|
+
}
|
|
181
|
+
if len(deleted_splits_by_id) < len(dataset_split_rowids):
|
|
182
|
+
await session.rollback()
|
|
183
|
+
raise NotFound("One or more dataset splits not found")
|
|
184
|
+
await session.commit()
|
|
185
|
+
|
|
186
|
+
return DeleteDatasetSplitsMutationPayload(
|
|
187
|
+
dataset_splits=[
|
|
188
|
+
to_gql_dataset_split(deleted_splits_by_id[dataset_split_rowid])
|
|
189
|
+
for dataset_split_rowid in dataset_split_rowids
|
|
190
|
+
],
|
|
191
|
+
query=Query(),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
195
|
+
async def add_dataset_examples_to_dataset_splits(
|
|
196
|
+
self, info: Info[Context, None], input: AddDatasetExamplesToDatasetSplitsInput
|
|
197
|
+
) -> AddDatasetExamplesToDatasetSplitsMutationPayload:
|
|
198
|
+
if not input.example_ids:
|
|
199
|
+
raise BadRequest("No examples provided.")
|
|
200
|
+
if not input.dataset_split_ids:
|
|
201
|
+
raise BadRequest("No dataset splits provided.")
|
|
202
|
+
|
|
203
|
+
unique_dataset_split_rowids: set[int] = set()
|
|
204
|
+
for dataset_split_gid in input.dataset_split_ids:
|
|
205
|
+
try:
|
|
206
|
+
dataset_split_rowid = from_global_id_with_expected_type(
|
|
207
|
+
dataset_split_gid, DatasetSplit.__name__
|
|
208
|
+
)
|
|
209
|
+
except ValueError:
|
|
210
|
+
raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
|
|
211
|
+
unique_dataset_split_rowids.add(dataset_split_rowid)
|
|
212
|
+
dataset_split_rowids = list(unique_dataset_split_rowids)
|
|
213
|
+
|
|
214
|
+
unique_example_rowids: set[int] = set()
|
|
215
|
+
for example_gid in input.example_ids:
|
|
216
|
+
try:
|
|
217
|
+
example_rowid = from_global_id_with_expected_type(
|
|
218
|
+
example_gid, models.DatasetExample.__name__
|
|
219
|
+
)
|
|
220
|
+
except ValueError:
|
|
221
|
+
raise BadRequest(f"Invalid example ID: {example_gid}")
|
|
222
|
+
unique_example_rowids.add(example_rowid)
|
|
223
|
+
example_rowids = list(unique_example_rowids)
|
|
224
|
+
|
|
225
|
+
async with info.context.db() as session:
|
|
226
|
+
existing_dataset_split_ids = (
|
|
227
|
+
await session.scalars(
|
|
228
|
+
select(models.DatasetSplit.id).where(
|
|
229
|
+
models.DatasetSplit.id.in_(dataset_split_rowids)
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
).all()
|
|
233
|
+
if len(existing_dataset_split_ids) != len(dataset_split_rowids):
|
|
234
|
+
raise NotFound("One or more dataset splits not found")
|
|
235
|
+
|
|
236
|
+
# Find existing (dataset_split_id, dataset_example_id) keys to avoid duplicates
|
|
237
|
+
# Users can submit multiple examples at once which can have
|
|
238
|
+
# indeterminate participation in multiple splits
|
|
239
|
+
existing_dataset_example_split_keys = await session.execute(
|
|
240
|
+
select(
|
|
241
|
+
models.DatasetSplitDatasetExample.dataset_split_id,
|
|
242
|
+
models.DatasetSplitDatasetExample.dataset_example_id,
|
|
243
|
+
).where(
|
|
244
|
+
models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
|
|
245
|
+
& models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
unique_dataset_example_split_keys = set(existing_dataset_example_split_keys.all())
|
|
249
|
+
|
|
250
|
+
# Compute all desired pairs and insert only missing
|
|
251
|
+
values = []
|
|
252
|
+
for dataset_split_rowid in dataset_split_rowids:
|
|
253
|
+
for example_rowid in example_rowids:
|
|
254
|
+
# if the keys already exists, skip
|
|
255
|
+
if (dataset_split_rowid, example_rowid) in unique_dataset_example_split_keys:
|
|
256
|
+
continue
|
|
257
|
+
dataset_split_id_key = models.DatasetSplitDatasetExample.dataset_split_id.key
|
|
258
|
+
dataset_example_id_key = (
|
|
259
|
+
models.DatasetSplitDatasetExample.dataset_example_id.key
|
|
260
|
+
)
|
|
261
|
+
values.append(
|
|
262
|
+
{
|
|
263
|
+
dataset_split_id_key: dataset_split_rowid,
|
|
264
|
+
dataset_example_id_key: example_rowid,
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if values:
|
|
269
|
+
try:
|
|
270
|
+
await session.execute(insert(models.DatasetSplitDatasetExample), values)
|
|
271
|
+
await session.flush()
|
|
272
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
273
|
+
raise Conflict("Failed to add examples to dataset splits.") from e
|
|
274
|
+
|
|
275
|
+
examples = (
|
|
276
|
+
await session.scalars(
|
|
277
|
+
select(models.DatasetExample).where(
|
|
278
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
).all()
|
|
282
|
+
return AddDatasetExamplesToDatasetSplitsMutationPayload(
|
|
283
|
+
query=Query(),
|
|
284
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
288
|
+
async def remove_dataset_examples_from_dataset_splits(
|
|
289
|
+
self, info: Info[Context, None], input: RemoveDatasetExamplesFromDatasetSplitsInput
|
|
290
|
+
) -> RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
|
|
291
|
+
if not input.dataset_split_ids:
|
|
292
|
+
raise BadRequest("No dataset splits provided.")
|
|
293
|
+
if not input.example_ids:
|
|
294
|
+
raise BadRequest("No examples provided.")
|
|
295
|
+
|
|
296
|
+
unique_dataset_split_rowids: set[int] = set()
|
|
297
|
+
for dataset_split_gid in input.dataset_split_ids:
|
|
298
|
+
try:
|
|
299
|
+
dataset_split_rowid = from_global_id_with_expected_type(
|
|
300
|
+
dataset_split_gid, DatasetSplit.__name__
|
|
301
|
+
)
|
|
302
|
+
except ValueError:
|
|
303
|
+
raise BadRequest(f"Invalid dataset split ID: {dataset_split_gid}")
|
|
304
|
+
unique_dataset_split_rowids.add(dataset_split_rowid)
|
|
305
|
+
dataset_split_rowids = list(unique_dataset_split_rowids)
|
|
306
|
+
|
|
307
|
+
unique_example_rowids: set[int] = set()
|
|
308
|
+
for example_gid in input.example_ids:
|
|
309
|
+
try:
|
|
310
|
+
example_rowid = from_global_id_with_expected_type(
|
|
311
|
+
example_gid, models.DatasetExample.__name__
|
|
312
|
+
)
|
|
313
|
+
except ValueError:
|
|
314
|
+
raise BadRequest(f"Invalid example ID: {example_gid}")
|
|
315
|
+
unique_example_rowids.add(example_rowid)
|
|
316
|
+
example_rowids = list(unique_example_rowids)
|
|
317
|
+
|
|
318
|
+
stmt = delete(models.DatasetSplitDatasetExample).where(
|
|
319
|
+
models.DatasetSplitDatasetExample.dataset_split_id.in_(dataset_split_rowids)
|
|
320
|
+
& models.DatasetSplitDatasetExample.dataset_example_id.in_(example_rowids)
|
|
321
|
+
)
|
|
322
|
+
async with info.context.db() as session:
|
|
323
|
+
existing_dataset_split_ids = (
|
|
324
|
+
await session.scalars(
|
|
325
|
+
select(models.DatasetSplit.id).where(
|
|
326
|
+
models.DatasetSplit.id.in_(dataset_split_rowids)
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
).all()
|
|
330
|
+
if len(existing_dataset_split_ids) != len(dataset_split_rowids):
|
|
331
|
+
raise NotFound("One or more dataset splits not found")
|
|
332
|
+
|
|
333
|
+
await session.execute(stmt)
|
|
334
|
+
|
|
335
|
+
examples = (
|
|
336
|
+
await session.scalars(
|
|
337
|
+
select(models.DatasetExample).where(
|
|
338
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
339
|
+
)
|
|
340
|
+
)
|
|
341
|
+
).all()
|
|
342
|
+
|
|
343
|
+
return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
|
|
344
|
+
query=Query(),
|
|
345
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
349
|
+
async def create_dataset_split_with_examples(
|
|
350
|
+
self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
|
|
351
|
+
) -> DatasetSplitMutationPayloadWithExamples:
|
|
352
|
+
user_id = get_user(info)
|
|
353
|
+
validated_name = _validated_name(input.name)
|
|
354
|
+
unique_example_rowids: set[int] = set()
|
|
355
|
+
for example_gid in input.example_ids:
|
|
356
|
+
try:
|
|
357
|
+
example_rowid = from_global_id_with_expected_type(
|
|
358
|
+
example_gid, models.DatasetExample.__name__
|
|
359
|
+
)
|
|
360
|
+
unique_example_rowids.add(example_rowid)
|
|
361
|
+
except ValueError:
|
|
362
|
+
raise BadRequest(f"Invalid example ID: {example_gid}")
|
|
363
|
+
example_rowids = list(unique_example_rowids)
|
|
364
|
+
async with info.context.db() as session:
|
|
365
|
+
if example_rowids:
|
|
366
|
+
found_count = await session.scalar(
|
|
367
|
+
select(func.count(models.DatasetExample.id)).where(
|
|
368
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
if found_count is None or found_count < len(example_rowids):
|
|
372
|
+
raise NotFound("One or more dataset examples were not found.")
|
|
373
|
+
|
|
374
|
+
dataset_split_orm = models.DatasetSplit(
|
|
375
|
+
name=validated_name,
|
|
376
|
+
description=input.description or None,
|
|
377
|
+
color=input.color,
|
|
378
|
+
metadata_=input.metadata or {},
|
|
379
|
+
user_id=user_id,
|
|
380
|
+
)
|
|
381
|
+
session.add(dataset_split_orm)
|
|
382
|
+
try:
|
|
383
|
+
await session.flush()
|
|
384
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
385
|
+
raise Conflict(f"A dataset split named '{validated_name}' already exists.")
|
|
386
|
+
|
|
387
|
+
if example_rowids:
|
|
388
|
+
values = [
|
|
389
|
+
{
|
|
390
|
+
models.DatasetSplitDatasetExample.dataset_split_id.key: dataset_split_orm.id, # noqa: E501
|
|
391
|
+
models.DatasetSplitDatasetExample.dataset_example_id.key: example_id,
|
|
392
|
+
}
|
|
393
|
+
for example_id in example_rowids
|
|
394
|
+
]
|
|
395
|
+
try:
|
|
396
|
+
await session.execute(insert(models.DatasetSplitDatasetExample), values)
|
|
397
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
398
|
+
# Roll back the transaction on association failure
|
|
399
|
+
await session.rollback()
|
|
400
|
+
raise Conflict(
|
|
401
|
+
"Failed to associate examples with the new dataset split."
|
|
402
|
+
) from e
|
|
403
|
+
|
|
404
|
+
examples = (
|
|
405
|
+
await session.scalars(
|
|
406
|
+
select(models.DatasetExample).where(
|
|
407
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
).all()
|
|
411
|
+
|
|
412
|
+
return DatasetSplitMutationPayloadWithExamples(
|
|
413
|
+
dataset_split=to_gql_dataset_split(dataset_split_orm),
|
|
414
|
+
query=Query(),
|
|
415
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _validated_name(name: str) -> str:
|
|
420
|
+
validated_name = name.strip()
|
|
421
|
+
if not validated_name:
|
|
422
|
+
raise BadRequest("Name cannot be empty")
|
|
423
|
+
return validated_name
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
5
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
6
|
+
from starlette.requests import Request
|
|
7
|
+
from strawberry import Info
|
|
8
|
+
from strawberry.relay import GlobalID
|
|
9
|
+
|
|
10
|
+
from phoenix.db import models
|
|
11
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized
|
|
14
|
+
from phoenix.server.api.helpers.annotations import get_user_identifier
|
|
15
|
+
from phoenix.server.api.input_types.CreateProjectSessionAnnotationInput import (
|
|
16
|
+
CreateProjectSessionAnnotationInput,
|
|
17
|
+
)
|
|
18
|
+
from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotationInput
|
|
19
|
+
from phoenix.server.api.queries import Query
|
|
20
|
+
from phoenix.server.api.types.AnnotationSource import AnnotationSource
|
|
21
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
22
|
+
from phoenix.server.api.types.ProjectSessionAnnotation import (
|
|
23
|
+
ProjectSessionAnnotation,
|
|
24
|
+
to_gql_project_session_annotation,
|
|
25
|
+
)
|
|
26
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
27
|
+
from phoenix.server.dml_event import (
|
|
28
|
+
ProjectSessionAnnotationDeleteEvent,
|
|
29
|
+
ProjectSessionAnnotationInsertEvent,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@strawberry.type
|
|
34
|
+
class ProjectSessionAnnotationMutationPayload:
|
|
35
|
+
project_session_annotation: ProjectSessionAnnotation
|
|
36
|
+
query: Query
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@strawberry.type
|
|
40
|
+
class ProjectSessionAnnotationMutationMixin:
|
|
41
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
42
|
+
async def create_project_session_annotations(
|
|
43
|
+
self, info: Info[Context, None], input: CreateProjectSessionAnnotationInput
|
|
44
|
+
) -> ProjectSessionAnnotationMutationPayload:
|
|
45
|
+
assert isinstance(request := info.context.request, Request)
|
|
46
|
+
user_id: Optional[int] = None
|
|
47
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
48
|
+
user_id = int(user.identity)
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
project_session_id = from_global_id_with_expected_type(
|
|
52
|
+
input.project_session_id, "ProjectSession"
|
|
53
|
+
)
|
|
54
|
+
except ValueError:
|
|
55
|
+
raise BadRequest(f"Invalid session ID: {input.project_session_id}")
|
|
56
|
+
|
|
57
|
+
identifier = ""
|
|
58
|
+
if isinstance(input.identifier, str):
|
|
59
|
+
identifier = input.identifier # Already trimmed in __post_init__
|
|
60
|
+
elif input.source == AnnotationSource.APP and user_id is not None:
|
|
61
|
+
identifier = get_user_identifier(user_id)
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
async with info.context.db() as session:
|
|
65
|
+
anno = models.ProjectSessionAnnotation(
|
|
66
|
+
project_session_id=project_session_id,
|
|
67
|
+
name=input.name,
|
|
68
|
+
label=input.label,
|
|
69
|
+
score=input.score,
|
|
70
|
+
explanation=input.explanation,
|
|
71
|
+
annotator_kind=input.annotator_kind.value,
|
|
72
|
+
metadata_=input.metadata,
|
|
73
|
+
identifier=identifier,
|
|
74
|
+
source=input.source.value,
|
|
75
|
+
user_id=user_id,
|
|
76
|
+
)
|
|
77
|
+
session.add(anno)
|
|
78
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
79
|
+
raise Conflict(f"Error creating annotation: {e}")
|
|
80
|
+
|
|
81
|
+
info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
|
|
82
|
+
|
|
83
|
+
return ProjectSessionAnnotationMutationPayload(
|
|
84
|
+
project_session_annotation=to_gql_project_session_annotation(anno),
|
|
85
|
+
query=Query(),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
89
|
+
async def update_project_session_annotations(
|
|
90
|
+
self, info: Info[Context, None], input: UpdateAnnotationInput
|
|
91
|
+
) -> ProjectSessionAnnotationMutationPayload:
|
|
92
|
+
assert isinstance(request := info.context.request, Request)
|
|
93
|
+
user_id: Optional[int] = None
|
|
94
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
95
|
+
user_id = int(user.identity)
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
id_ = from_global_id_with_expected_type(input.id, "ProjectSessionAnnotation")
|
|
99
|
+
except ValueError:
|
|
100
|
+
raise BadRequest(f"Invalid session annotation ID: {input.id}")
|
|
101
|
+
|
|
102
|
+
async with info.context.db() as session:
|
|
103
|
+
if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
|
|
104
|
+
raise NotFound(f"Could not find session annotation with ID: {input.id}")
|
|
105
|
+
if anno.user_id != user_id:
|
|
106
|
+
raise Unauthorized("Session annotation is not associated with the current user.")
|
|
107
|
+
|
|
108
|
+
# Update the annotation fields
|
|
109
|
+
anno.name = input.name
|
|
110
|
+
anno.label = input.label
|
|
111
|
+
anno.score = input.score
|
|
112
|
+
anno.explanation = input.explanation
|
|
113
|
+
anno.annotator_kind = input.annotator_kind.value
|
|
114
|
+
anno.metadata_ = input.metadata
|
|
115
|
+
anno.source = input.source.value
|
|
116
|
+
|
|
117
|
+
session.add(anno)
|
|
118
|
+
try:
|
|
119
|
+
await session.flush()
|
|
120
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
121
|
+
raise Conflict(f"Error updating annotation: {e}")
|
|
122
|
+
|
|
123
|
+
info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
|
|
124
|
+
return ProjectSessionAnnotationMutationPayload(
|
|
125
|
+
project_session_annotation=to_gql_project_session_annotation(anno),
|
|
126
|
+
query=Query(),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
130
|
+
async def delete_project_session_annotation(
|
|
131
|
+
self, info: Info[Context, None], id: GlobalID
|
|
132
|
+
) -> ProjectSessionAnnotationMutationPayload:
|
|
133
|
+
try:
|
|
134
|
+
id_ = from_global_id_with_expected_type(id, "ProjectSessionAnnotation")
|
|
135
|
+
except ValueError:
|
|
136
|
+
raise BadRequest(f"Invalid session annotation ID: {id}")
|
|
137
|
+
|
|
138
|
+
assert isinstance(request := info.context.request, Request)
|
|
139
|
+
user_id: Optional[int] = None
|
|
140
|
+
user_is_admin = False
|
|
141
|
+
if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
|
|
142
|
+
user_id = int(user.identity)
|
|
143
|
+
user_is_admin = user.is_admin
|
|
144
|
+
|
|
145
|
+
async with info.context.db() as session:
|
|
146
|
+
if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
|
|
147
|
+
raise NotFound(f"Could not find session annotation with ID: {id}")
|
|
148
|
+
|
|
149
|
+
if not user_is_admin and anno.user_id != user_id:
|
|
150
|
+
raise Unauthorized(
|
|
151
|
+
"Session annotation is not associated with the current user and "
|
|
152
|
+
"the current user is not an admin."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
await session.delete(anno)
|
|
156
|
+
|
|
157
|
+
deleted_gql_annotation = to_gql_project_session_annotation(anno)
|
|
158
|
+
info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
|
|
159
|
+
return ProjectSessionAnnotationMutationPayload(
|
|
160
|
+
project_session_annotation=deleted_gql_annotation, query=Query()
|
|
161
|
+
)
|
phoenix/server/api/queries.py
CHANGED
|
@@ -48,6 +48,8 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
|
|
|
48
48
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
49
49
|
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
50
50
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
51
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
52
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
51
53
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
52
54
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
53
55
|
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
@@ -959,6 +961,14 @@ class Query:
|
|
|
959
961
|
id_attr=example.id,
|
|
960
962
|
created_at=example.created_at,
|
|
961
963
|
)
|
|
964
|
+
elif type_name == DatasetSplit.__name__:
|
|
965
|
+
async with info.context.db() as session:
|
|
966
|
+
dataset_split = await session.scalar(
|
|
967
|
+
select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
|
|
968
|
+
)
|
|
969
|
+
if not dataset_split:
|
|
970
|
+
raise NotFound(f"Unknown dataset split: {id}")
|
|
971
|
+
return to_gql_dataset_split(dataset_split)
|
|
962
972
|
elif type_name == Experiment.__name__:
|
|
963
973
|
async with info.context.db() as session:
|
|
964
974
|
experiment = await session.scalar(
|
|
@@ -1140,6 +1150,49 @@ class Query:
|
|
|
1140
1150
|
args=args,
|
|
1141
1151
|
)
|
|
1142
1152
|
|
|
1153
|
+
@strawberry.field
|
|
1154
|
+
async def dataset_labels(
|
|
1155
|
+
self,
|
|
1156
|
+
info: Info[Context, None],
|
|
1157
|
+
first: Optional[int] = 50,
|
|
1158
|
+
last: Optional[int] = UNSET,
|
|
1159
|
+
after: Optional[CursorString] = UNSET,
|
|
1160
|
+
before: Optional[CursorString] = UNSET,
|
|
1161
|
+
) -> Connection[DatasetLabel]:
|
|
1162
|
+
args = ConnectionArgs(
|
|
1163
|
+
first=first,
|
|
1164
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1165
|
+
last=last,
|
|
1166
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1167
|
+
)
|
|
1168
|
+
async with info.context.db() as session:
|
|
1169
|
+
dataset_labels = await session.scalars(select(models.DatasetLabel))
|
|
1170
|
+
data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
|
|
1171
|
+
return connection_from_list(data=data, args=args)
|
|
1172
|
+
|
|
1173
|
+
@strawberry.field
|
|
1174
|
+
async def dataset_splits(
|
|
1175
|
+
self,
|
|
1176
|
+
info: Info[Context, None],
|
|
1177
|
+
first: Optional[int] = 50,
|
|
1178
|
+
last: Optional[int] = UNSET,
|
|
1179
|
+
after: Optional[CursorString] = UNSET,
|
|
1180
|
+
before: Optional[CursorString] = UNSET,
|
|
1181
|
+
) -> Connection[DatasetSplit]:
|
|
1182
|
+
args = ConnectionArgs(
|
|
1183
|
+
first=first,
|
|
1184
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1185
|
+
last=last,
|
|
1186
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1187
|
+
)
|
|
1188
|
+
async with info.context.db() as session:
|
|
1189
|
+
splits = await session.stream_scalars(select(models.DatasetSplit))
|
|
1190
|
+
data = [to_gql_dataset_split(split) async for split in splits]
|
|
1191
|
+
return connection_from_list(
|
|
1192
|
+
data=data,
|
|
1193
|
+
args=args,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1143
1196
|
@strawberry.field
|
|
1144
1197
|
async def annotation_configs(
|
|
1145
1198
|
self,
|