arize-phoenix 12.0.0__py3-none-any.whl → 12.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +1 -1
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +39 -35
- phoenix/db/insertion/document_annotation.py +1 -1
- phoenix/db/insertion/session_annotation.py +1 -1
- phoenix/db/insertion/span_annotation.py +1 -1
- phoenix/db/insertion/trace_annotation.py +1 -1
- phoenix/db/insertion/types.py +0 -4
- phoenix/db/models.py +22 -1
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/helpers/playground_clients.py +1 -0
- phoenix/server/api/mutations/__init__.py +2 -0
- phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +38 -2
- phoenix/server/api/queries.py +21 -0
- phoenix/server/api/routers/auth.py +5 -5
- phoenix/server/api/routers/oauth2.py +5 -23
- phoenix/server/api/types/Dataset.py +8 -0
- phoenix/server/api/types/DatasetExample.py +7 -0
- phoenix/server/api/types/DatasetLabel.py +23 -0
- phoenix/server/api/types/Prompt.py +18 -1
- phoenix/server/app.py +7 -12
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-Dl9SUw1U.js → components-BG6v0EM8.js} +665 -389
- phoenix/server/static/assets/{index-CqQS0dTo.js → index-CSVcULw1.js} +13 -13
- phoenix/server/static/assets/{pages-DKSjVA_E.js → pages-DgaM7kpM.js} +1135 -1182
- phoenix/server/static/assets/{vendor-CtbHQYl8.js → vendor-BqTEkGQU.js} +183 -183
- phoenix/server/static/assets/{vendor-arizeai-D-lWOwIS.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
- phoenix/server/static/assets/{vendor-codemirror-BRBpy3_z.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts--KdSwB3m.js → vendor-recharts-CKsi4IjN.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-CvRzZnIo.js → vendor-shiki-DN26BkKE.js} +1 -1
- phoenix/server/utils.py +74 -0
- phoenix/session/session.py +25 -5
- phoenix/version.py +1 -1
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import sqlalchemy
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import delete, select
|
|
6
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
7
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
8
|
+
from strawberry import UNSET
|
|
9
|
+
from strawberry.relay.types import GlobalID
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
14
|
+
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
|
+
from phoenix.server.api.queries import Query
|
|
17
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
18
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
19
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@strawberry.input
|
|
23
|
+
class CreateDatasetLabelInput:
|
|
24
|
+
name: str
|
|
25
|
+
description: Optional[str] = UNSET
|
|
26
|
+
color: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@strawberry.type
|
|
30
|
+
class CreateDatasetLabelMutationPayload:
|
|
31
|
+
dataset_label: DatasetLabel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@strawberry.input
|
|
35
|
+
class DeleteDatasetLabelsInput:
|
|
36
|
+
dataset_label_ids: list[GlobalID]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@strawberry.type
|
|
40
|
+
class DeleteDatasetLabelsMutationPayload:
|
|
41
|
+
dataset_labels: list[DatasetLabel]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@strawberry.input
|
|
45
|
+
class UpdateDatasetLabelInput:
|
|
46
|
+
dataset_label_id: GlobalID
|
|
47
|
+
name: str
|
|
48
|
+
description: Optional[str] = None
|
|
49
|
+
color: str
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@strawberry.type
|
|
53
|
+
class UpdateDatasetLabelMutationPayload:
|
|
54
|
+
dataset_label: DatasetLabel
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.input
|
|
58
|
+
class SetDatasetLabelsInput:
|
|
59
|
+
dataset_label_ids: list[GlobalID]
|
|
60
|
+
dataset_ids: list[GlobalID]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@strawberry.type
|
|
64
|
+
class SetDatasetLabelsMutationPayload:
|
|
65
|
+
query: "Query"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@strawberry.input
|
|
69
|
+
class UnsetDatasetLabelsInput:
|
|
70
|
+
dataset_label_ids: list[GlobalID]
|
|
71
|
+
dataset_ids: list[GlobalID]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@strawberry.type
|
|
75
|
+
class UnsetDatasetLabelsMutationPayload:
|
|
76
|
+
query: "Query"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class DatasetLabelMutationMixin:
|
|
81
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
82
|
+
async def create_dataset_label(
|
|
83
|
+
self,
|
|
84
|
+
info: Info[Context, None],
|
|
85
|
+
input: CreateDatasetLabelInput,
|
|
86
|
+
) -> CreateDatasetLabelMutationPayload:
|
|
87
|
+
name = input.name
|
|
88
|
+
description = input.description
|
|
89
|
+
color = input.color
|
|
90
|
+
async with info.context.db() as session:
|
|
91
|
+
dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
|
|
92
|
+
session.add(dataset_label_orm)
|
|
93
|
+
try:
|
|
94
|
+
await session.commit()
|
|
95
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
96
|
+
raise Conflict(f"A dataset label named '{name}' already exists")
|
|
97
|
+
except sqlalchemy.exc.StatementError as error:
|
|
98
|
+
raise BadRequest(str(error.orig))
|
|
99
|
+
return CreateDatasetLabelMutationPayload(
|
|
100
|
+
dataset_label=to_gql_dataset_label(dataset_label_orm)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
104
|
+
async def update_dataset_label(
|
|
105
|
+
self, info: Info[Context, None], input: UpdateDatasetLabelInput
|
|
106
|
+
) -> UpdateDatasetLabelMutationPayload:
|
|
107
|
+
if not input.name or not input.name.strip():
|
|
108
|
+
raise BadRequest("Dataset label name cannot be empty")
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
dataset_label_id = from_global_id_with_expected_type(
|
|
112
|
+
input.dataset_label_id, DatasetLabel.__name__
|
|
113
|
+
)
|
|
114
|
+
except ValueError:
|
|
115
|
+
raise BadRequest(f"Invalid dataset label ID: {input.dataset_label_id}")
|
|
116
|
+
|
|
117
|
+
async with info.context.db() as session:
|
|
118
|
+
dataset_label_orm = await session.get(models.DatasetLabel, dataset_label_id)
|
|
119
|
+
if not dataset_label_orm:
|
|
120
|
+
raise NotFound(f"DatasetLabel with ID {input.dataset_label_id} not found")
|
|
121
|
+
|
|
122
|
+
dataset_label_orm.name = input.name.strip()
|
|
123
|
+
dataset_label_orm.description = input.description
|
|
124
|
+
dataset_label_orm.color = input.color.strip()
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
await session.commit()
|
|
128
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
129
|
+
raise Conflict(f"A dataset label named '{input.name}' already exists")
|
|
130
|
+
except sqlalchemy.exc.StatementError as error:
|
|
131
|
+
raise BadRequest(str(error.orig))
|
|
132
|
+
return UpdateDatasetLabelMutationPayload(
|
|
133
|
+
dataset_label=to_gql_dataset_label(dataset_label_orm)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
137
|
+
async def delete_dataset_labels(
|
|
138
|
+
self, info: Info[Context, None], input: DeleteDatasetLabelsInput
|
|
139
|
+
) -> DeleteDatasetLabelsMutationPayload:
|
|
140
|
+
dataset_label_row_ids: dict[int, None] = {}
|
|
141
|
+
for dataset_label_node_id in input.dataset_label_ids:
|
|
142
|
+
try:
|
|
143
|
+
dataset_label_row_id = from_global_id_with_expected_type(
|
|
144
|
+
dataset_label_node_id, DatasetLabel.__name__
|
|
145
|
+
)
|
|
146
|
+
except ValueError:
|
|
147
|
+
raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
|
|
148
|
+
dataset_label_row_ids[dataset_label_row_id] = None
|
|
149
|
+
async with info.context.db() as session:
|
|
150
|
+
stmt = (
|
|
151
|
+
delete(models.DatasetLabel)
|
|
152
|
+
.where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
|
|
153
|
+
.returning(models.DatasetLabel)
|
|
154
|
+
)
|
|
155
|
+
deleted_dataset_labels = (await session.scalars(stmt)).all()
|
|
156
|
+
if len(deleted_dataset_labels) < len(dataset_label_row_ids):
|
|
157
|
+
await session.rollback()
|
|
158
|
+
raise NotFound("Could not find one or more dataset labels with given IDs")
|
|
159
|
+
deleted_dataset_labels_by_id = {
|
|
160
|
+
dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
|
|
161
|
+
}
|
|
162
|
+
return DeleteDatasetLabelsMutationPayload(
|
|
163
|
+
dataset_labels=[
|
|
164
|
+
to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
|
|
165
|
+
for dataset_label_row_id in dataset_label_row_ids
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
170
|
+
async def set_dataset_labels(
|
|
171
|
+
self, info: Info[Context, None], input: SetDatasetLabelsInput
|
|
172
|
+
) -> SetDatasetLabelsMutationPayload:
|
|
173
|
+
if not input.dataset_ids:
|
|
174
|
+
raise BadRequest("No datasets provided.")
|
|
175
|
+
if not input.dataset_label_ids:
|
|
176
|
+
raise BadRequest("No dataset labels provided.")
|
|
177
|
+
|
|
178
|
+
unique_dataset_rowids: set[int] = set()
|
|
179
|
+
for dataset_gid in input.dataset_ids:
|
|
180
|
+
try:
|
|
181
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
|
|
182
|
+
except ValueError:
|
|
183
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
|
|
184
|
+
unique_dataset_rowids.add(dataset_rowid)
|
|
185
|
+
dataset_rowids = list(unique_dataset_rowids)
|
|
186
|
+
|
|
187
|
+
unique_dataset_label_rowids: set[int] = set()
|
|
188
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
189
|
+
try:
|
|
190
|
+
dataset_label_rowid = from_global_id_with_expected_type(
|
|
191
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
192
|
+
)
|
|
193
|
+
except ValueError:
|
|
194
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
195
|
+
unique_dataset_label_rowids.add(dataset_label_rowid)
|
|
196
|
+
dataset_label_rowids = list(unique_dataset_label_rowids)
|
|
197
|
+
|
|
198
|
+
async with info.context.db() as session:
|
|
199
|
+
existing_dataset_ids = (
|
|
200
|
+
await session.scalars(
|
|
201
|
+
select(models.Dataset.id).where(models.Dataset.id.in_(dataset_rowids))
|
|
202
|
+
)
|
|
203
|
+
).all()
|
|
204
|
+
if len(existing_dataset_ids) != len(dataset_rowids):
|
|
205
|
+
raise NotFound("One or more datasets not found")
|
|
206
|
+
|
|
207
|
+
existing_dataset_label_ids = (
|
|
208
|
+
await session.scalars(
|
|
209
|
+
select(models.DatasetLabel.id).where(
|
|
210
|
+
models.DatasetLabel.id.in_(dataset_label_rowids)
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
).all()
|
|
214
|
+
if len(existing_dataset_label_ids) != len(dataset_label_rowids):
|
|
215
|
+
raise NotFound("One or more dataset labels not found")
|
|
216
|
+
|
|
217
|
+
existing_dataset_label_keys = await session.execute(
|
|
218
|
+
select(
|
|
219
|
+
models.DatasetsDatasetLabel.dataset_id,
|
|
220
|
+
models.DatasetsDatasetLabel.dataset_label_id,
|
|
221
|
+
).where(
|
|
222
|
+
models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
|
|
223
|
+
& models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
unique_dataset_label_keys = set(existing_dataset_label_keys.all())
|
|
227
|
+
|
|
228
|
+
datasets_dataset_labels = []
|
|
229
|
+
for dataset_rowid in dataset_rowids:
|
|
230
|
+
for dataset_label_rowid in dataset_label_rowids:
|
|
231
|
+
if (dataset_rowid, dataset_label_rowid) in unique_dataset_label_keys:
|
|
232
|
+
continue
|
|
233
|
+
datasets_dataset_labels.append(
|
|
234
|
+
models.DatasetsDatasetLabel(
|
|
235
|
+
dataset_id=dataset_rowid,
|
|
236
|
+
dataset_label_id=dataset_label_rowid,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
session.add_all(datasets_dataset_labels)
|
|
240
|
+
|
|
241
|
+
if datasets_dataset_labels:
|
|
242
|
+
try:
|
|
243
|
+
await session.commit()
|
|
244
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
245
|
+
raise Conflict("Failed to add dataset labels to datasets.") from e
|
|
246
|
+
|
|
247
|
+
return SetDatasetLabelsMutationPayload(
|
|
248
|
+
query=Query(),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
252
|
+
async def unset_dataset_labels(
|
|
253
|
+
self, info: Info[Context, None], input: UnsetDatasetLabelsInput
|
|
254
|
+
) -> UnsetDatasetLabelsMutationPayload:
|
|
255
|
+
if not input.dataset_ids:
|
|
256
|
+
raise BadRequest("No datasets provided.")
|
|
257
|
+
if not input.dataset_label_ids:
|
|
258
|
+
raise BadRequest("No dataset labels provided.")
|
|
259
|
+
|
|
260
|
+
unique_dataset_rowids: set[int] = set()
|
|
261
|
+
for dataset_gid in input.dataset_ids:
|
|
262
|
+
try:
|
|
263
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
|
|
264
|
+
except ValueError:
|
|
265
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
|
|
266
|
+
unique_dataset_rowids.add(dataset_rowid)
|
|
267
|
+
dataset_rowids = list(unique_dataset_rowids)
|
|
268
|
+
|
|
269
|
+
unique_dataset_label_rowids: set[int] = set()
|
|
270
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
271
|
+
try:
|
|
272
|
+
dataset_label_rowid = from_global_id_with_expected_type(
|
|
273
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
274
|
+
)
|
|
275
|
+
except ValueError:
|
|
276
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
277
|
+
unique_dataset_label_rowids.add(dataset_label_rowid)
|
|
278
|
+
dataset_label_rowids = list(unique_dataset_label_rowids)
|
|
279
|
+
|
|
280
|
+
async with info.context.db() as session:
|
|
281
|
+
await session.execute(
|
|
282
|
+
delete(models.DatasetsDatasetLabel).where(
|
|
283
|
+
models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
|
|
284
|
+
& models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
await session.commit()
|
|
288
|
+
|
|
289
|
+
return UnsetDatasetLabelsMutationPayload(
|
|
290
|
+
query=Query(),
|
|
291
|
+
)
|
|
@@ -15,6 +15,7 @@ from phoenix.server.api.context import Context
|
|
|
15
15
|
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
16
|
from phoenix.server.api.helpers.playground_users import get_user
|
|
17
17
|
from phoenix.server.api.queries import Query
|
|
18
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
|
|
18
19
|
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
19
20
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
21
|
|
|
@@ -68,6 +69,13 @@ class DatasetSplitMutationPayload:
|
|
|
68
69
|
query: "Query"
|
|
69
70
|
|
|
70
71
|
|
|
72
|
+
@strawberry.type
|
|
73
|
+
class DatasetSplitMutationPayloadWithExamples:
|
|
74
|
+
dataset_split: DatasetSplit
|
|
75
|
+
query: "Query"
|
|
76
|
+
examples: list[DatasetExample]
|
|
77
|
+
|
|
78
|
+
|
|
71
79
|
@strawberry.type
|
|
72
80
|
class DeleteDatasetSplitsMutationPayload:
|
|
73
81
|
dataset_splits: list[DatasetSplit]
|
|
@@ -77,11 +85,13 @@ class DeleteDatasetSplitsMutationPayload:
|
|
|
77
85
|
@strawberry.type
|
|
78
86
|
class AddDatasetExamplesToDatasetSplitsMutationPayload:
|
|
79
87
|
query: "Query"
|
|
88
|
+
examples: list[DatasetExample]
|
|
80
89
|
|
|
81
90
|
|
|
82
91
|
@strawberry.type
|
|
83
92
|
class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
|
|
84
93
|
query: "Query"
|
|
94
|
+
examples: list[DatasetExample]
|
|
85
95
|
|
|
86
96
|
|
|
87
97
|
@strawberry.type
|
|
@@ -262,8 +272,16 @@ class DatasetSplitMutationMixin:
|
|
|
262
272
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
263
273
|
raise Conflict("Failed to add examples to dataset splits.") from e
|
|
264
274
|
|
|
275
|
+
examples = (
|
|
276
|
+
await session.scalars(
|
|
277
|
+
select(models.DatasetExample).where(
|
|
278
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
).all()
|
|
265
282
|
return AddDatasetExamplesToDatasetSplitsMutationPayload(
|
|
266
283
|
query=Query(),
|
|
284
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
267
285
|
)
|
|
268
286
|
|
|
269
287
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
@@ -314,14 +332,23 @@ class DatasetSplitMutationMixin:
|
|
|
314
332
|
|
|
315
333
|
await session.execute(stmt)
|
|
316
334
|
|
|
335
|
+
examples = (
|
|
336
|
+
await session.scalars(
|
|
337
|
+
select(models.DatasetExample).where(
|
|
338
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
339
|
+
)
|
|
340
|
+
)
|
|
341
|
+
).all()
|
|
342
|
+
|
|
317
343
|
return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
|
|
318
344
|
query=Query(),
|
|
345
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
319
346
|
)
|
|
320
347
|
|
|
321
348
|
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
322
349
|
async def create_dataset_split_with_examples(
|
|
323
350
|
self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
|
|
324
|
-
) ->
|
|
351
|
+
) -> DatasetSplitMutationPayloadWithExamples:
|
|
325
352
|
user_id = get_user(info)
|
|
326
353
|
validated_name = _validated_name(input.name)
|
|
327
354
|
unique_example_rowids: set[int] = set()
|
|
@@ -374,9 +401,18 @@ class DatasetSplitMutationMixin:
|
|
|
374
401
|
"Failed to associate examples with the new dataset split."
|
|
375
402
|
) from e
|
|
376
403
|
|
|
377
|
-
|
|
404
|
+
examples = (
|
|
405
|
+
await session.scalars(
|
|
406
|
+
select(models.DatasetExample).where(
|
|
407
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
).all()
|
|
411
|
+
|
|
412
|
+
return DatasetSplitMutationPayloadWithExamples(
|
|
378
413
|
dataset_split=to_gql_dataset_split(dataset_split_orm),
|
|
379
414
|
query=Query(),
|
|
415
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
380
416
|
)
|
|
381
417
|
|
|
382
418
|
|
phoenix/server/api/queries.py
CHANGED
|
@@ -48,6 +48,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
|
|
|
48
48
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
49
49
|
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
50
50
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
51
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
51
52
|
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
52
53
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
53
54
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
@@ -1149,6 +1150,26 @@ class Query:
|
|
|
1149
1150
|
args=args,
|
|
1150
1151
|
)
|
|
1151
1152
|
|
|
1153
|
+
@strawberry.field
|
|
1154
|
+
async def dataset_labels(
|
|
1155
|
+
self,
|
|
1156
|
+
info: Info[Context, None],
|
|
1157
|
+
first: Optional[int] = 50,
|
|
1158
|
+
last: Optional[int] = UNSET,
|
|
1159
|
+
after: Optional[CursorString] = UNSET,
|
|
1160
|
+
before: Optional[CursorString] = UNSET,
|
|
1161
|
+
) -> Connection[DatasetLabel]:
|
|
1162
|
+
args = ConnectionArgs(
|
|
1163
|
+
first=first,
|
|
1164
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1165
|
+
last=last,
|
|
1166
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1167
|
+
)
|
|
1168
|
+
async with info.context.db() as session:
|
|
1169
|
+
dataset_labels = await session.scalars(select(models.DatasetLabel))
|
|
1170
|
+
data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
|
|
1171
|
+
return connection_from_list(data=data, args=args)
|
|
1172
|
+
|
|
1152
1173
|
@strawberry.field
|
|
1153
1174
|
async def dataset_splits(
|
|
1154
1175
|
self,
|
|
@@ -2,7 +2,6 @@ import asyncio
|
|
|
2
2
|
import secrets
|
|
3
3
|
from datetime import datetime, timedelta, timezone
|
|
4
4
|
from functools import partial
|
|
5
|
-
from pathlib import Path
|
|
6
5
|
from urllib.parse import urlencode, urlparse, urlunparse
|
|
7
6
|
|
|
8
7
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
|
@@ -38,7 +37,6 @@ from phoenix.config import (
|
|
|
38
37
|
get_base_url,
|
|
39
38
|
get_env_disable_basic_auth,
|
|
40
39
|
get_env_disable_rate_limit,
|
|
41
|
-
get_env_host_root_path,
|
|
42
40
|
)
|
|
43
41
|
from phoenix.db import models
|
|
44
42
|
from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
|
|
@@ -52,6 +50,7 @@ from phoenix.server.types import (
|
|
|
52
50
|
TokenStore,
|
|
53
51
|
UserId,
|
|
54
52
|
)
|
|
53
|
+
from phoenix.server.utils import prepend_root_path
|
|
55
54
|
|
|
56
55
|
rate_limiter = ServerRateLimiter(
|
|
57
56
|
per_second_rate_limit=0.2,
|
|
@@ -145,7 +144,8 @@ async def logout(
|
|
|
145
144
|
user_id = subject
|
|
146
145
|
if user_id:
|
|
147
146
|
await token_store.log_out(user_id)
|
|
148
|
-
|
|
147
|
+
redirect_path = "/logout" if get_env_disable_basic_auth() else "/login"
|
|
148
|
+
redirect_url = prepend_root_path(request.scope, redirect_path)
|
|
149
149
|
response = Response(status_code=HTTP_302_FOUND, headers={"Location": redirect_url})
|
|
150
150
|
response = delete_access_token_cookie(response)
|
|
151
151
|
response = delete_refresh_token_cookie(response)
|
|
@@ -242,9 +242,9 @@ async def initiate_password_reset(request: Request) -> Response:
|
|
|
242
242
|
)
|
|
243
243
|
token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
|
|
244
244
|
url = urlparse(request.headers.get("referer") or get_base_url())
|
|
245
|
-
path =
|
|
245
|
+
path = prepend_root_path(request.scope, "/reset-password-with-token")
|
|
246
246
|
query_string = urlencode(dict(token=token))
|
|
247
|
-
components = (url.scheme, url.netloc, path
|
|
247
|
+
components = (url.scheme, url.netloc, path, "", query_string, "")
|
|
248
248
|
reset_url = urlunparse(components)
|
|
249
249
|
await sender.send_password_reset_email(email, reset_url)
|
|
250
250
|
return Response(status_code=HTTP_204_NO_CONTENT)
|
|
@@ -46,6 +46,7 @@ from phoenix.server.rate_limiters import (
|
|
|
46
46
|
fastapi_route_rate_limiter,
|
|
47
47
|
)
|
|
48
48
|
from phoenix.server.types import TokenStore
|
|
49
|
+
from phoenix.server.utils import get_root_path, prepend_root_path
|
|
49
50
|
|
|
50
51
|
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
|
|
51
52
|
|
|
@@ -191,7 +192,7 @@ async def create_tokens(
|
|
|
191
192
|
access_token_expiry=access_token_expiry,
|
|
192
193
|
refresh_token_expiry=refresh_token_expiry,
|
|
193
194
|
)
|
|
194
|
-
redirect_path =
|
|
195
|
+
redirect_path = prepend_root_path(request.scope, return_url or "/")
|
|
195
196
|
response = RedirectResponse(
|
|
196
197
|
url=redirect_path,
|
|
197
198
|
status_code=HTTP_302_FOUND,
|
|
@@ -564,8 +565,8 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
|
|
|
564
565
|
Creates a RedirectResponse to the login page to display an error message.
|
|
565
566
|
"""
|
|
566
567
|
# TODO: this needs some cleanup
|
|
567
|
-
login_path =
|
|
568
|
-
request
|
|
568
|
+
login_path = prepend_root_path(
|
|
569
|
+
request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
|
|
569
570
|
)
|
|
570
571
|
url = URL(login_path).include_query_params(error=error)
|
|
571
572
|
response = RedirectResponse(url=url)
|
|
@@ -574,34 +575,15 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
|
|
|
574
575
|
return response
|
|
575
576
|
|
|
576
577
|
|
|
577
|
-
def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
|
|
578
|
-
"""
|
|
579
|
-
If a root path is configured, prepends it to the input path.
|
|
580
|
-
"""
|
|
581
|
-
if not path.startswith("/"):
|
|
582
|
-
raise ValueError("path must start with a forward slash")
|
|
583
|
-
root_path = _get_root_path(request=request)
|
|
584
|
-
if root_path.endswith("/"):
|
|
585
|
-
root_path = root_path.rstrip("/")
|
|
586
|
-
return root_path + path
|
|
587
|
-
|
|
588
|
-
|
|
589
578
|
def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
|
|
590
579
|
"""
|
|
591
580
|
If a root path is configured, appends it to the input base url.
|
|
592
581
|
"""
|
|
593
|
-
if not (root_path :=
|
|
582
|
+
if not (root_path := get_root_path(request.scope)):
|
|
594
583
|
return base_url
|
|
595
584
|
return str(URLPath(root_path).make_absolute_url(base_url=base_url))
|
|
596
585
|
|
|
597
586
|
|
|
598
|
-
def _get_root_path(*, request: Request) -> str:
|
|
599
|
-
"""
|
|
600
|
-
Gets the root path from the request.
|
|
601
|
-
"""
|
|
602
|
-
return str(request.scope.get("root_path", ""))
|
|
603
|
-
|
|
604
|
-
|
|
605
587
|
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
606
588
|
"""
|
|
607
589
|
Gets the endpoint for create tokens route.
|
|
@@ -18,6 +18,7 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
|
18
18
|
from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
|
|
19
19
|
DatasetExperimentAnnotationSummary,
|
|
20
20
|
)
|
|
21
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
21
22
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
22
23
|
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
23
24
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
@@ -303,6 +304,13 @@ class Dataset(Node):
|
|
|
303
304
|
async for scores_tuple in await session.stream(query)
|
|
304
305
|
]
|
|
305
306
|
|
|
307
|
+
@strawberry.field
|
|
308
|
+
async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
|
|
309
|
+
return [
|
|
310
|
+
to_gql_dataset_label(label)
|
|
311
|
+
for label in await info.context.data_loaders.dataset_labels.load(self.id_attr)
|
|
312
|
+
]
|
|
313
|
+
|
|
306
314
|
@strawberry.field
|
|
307
315
|
def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
|
|
308
316
|
return info.context.last_updated_at.get(self._table, self.id_attr)
|
|
@@ -142,3 +142,10 @@ class DatasetExample(Node):
|
|
|
142
142
|
to_gql_dataset_split(split)
|
|
143
143
|
for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
|
|
144
144
|
]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def to_gql_dataset_example(example: models.DatasetExample) -> DatasetExample:
|
|
148
|
+
return DatasetExample(
|
|
149
|
+
id_attr=example.id,
|
|
150
|
+
created_at=example.created_at,
|
|
151
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import strawberry
|
|
4
|
+
from strawberry.relay import Node, NodeID
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.type
|
|
10
|
+
class DatasetLabel(Node):
|
|
11
|
+
id_attr: NodeID[int]
|
|
12
|
+
name: str
|
|
13
|
+
description: Optional[str]
|
|
14
|
+
color: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_gql_dataset_label(dataset_label: models.DatasetLabel) -> DatasetLabel:
|
|
18
|
+
return DatasetLabel(
|
|
19
|
+
id_attr=dataset_label.id,
|
|
20
|
+
name=dataset_label.name,
|
|
21
|
+
description=dataset_label.description,
|
|
22
|
+
color=dataset_label.color,
|
|
23
|
+
)
|
|
@@ -9,6 +9,7 @@ from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
|
9
9
|
from strawberry.types import Info
|
|
10
10
|
|
|
11
11
|
from phoenix.db import models
|
|
12
|
+
from phoenix.db.types.identifier import Identifier as IdentifierModel
|
|
12
13
|
from phoenix.server.api.context import Context
|
|
13
14
|
from phoenix.server.api.exceptions import NotFound
|
|
14
15
|
from phoenix.server.api.types.Identifier import Identifier
|
|
@@ -37,7 +38,10 @@ class Prompt(Node):
|
|
|
37
38
|
|
|
38
39
|
@strawberry.field
|
|
39
40
|
async def version(
|
|
40
|
-
self,
|
|
41
|
+
self,
|
|
42
|
+
info: Info[Context, None],
|
|
43
|
+
version_id: Optional[GlobalID] = None,
|
|
44
|
+
tag_name: Optional[Identifier] = None,
|
|
41
45
|
) -> PromptVersion:
|
|
42
46
|
async with info.context.db() as session:
|
|
43
47
|
if version_id:
|
|
@@ -50,6 +54,19 @@ class Prompt(Node):
|
|
|
50
54
|
)
|
|
51
55
|
if not version:
|
|
52
56
|
raise NotFound(f"Prompt version not found: {version_id}")
|
|
57
|
+
elif tag_name:
|
|
58
|
+
try:
|
|
59
|
+
name = IdentifierModel(tag_name)
|
|
60
|
+
except ValueError:
|
|
61
|
+
raise NotFound(f"Prompt version tag not found: {tag_name}")
|
|
62
|
+
version = await session.scalar(
|
|
63
|
+
select(models.PromptVersion)
|
|
64
|
+
.where(models.PromptVersion.prompt_id == self.id_attr)
|
|
65
|
+
.join_from(models.PromptVersion, models.PromptVersionTag)
|
|
66
|
+
.where(models.PromptVersionTag.name == name)
|
|
67
|
+
)
|
|
68
|
+
if not version:
|
|
69
|
+
raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
|
|
53
70
|
else:
|
|
54
71
|
stmt = (
|
|
55
72
|
select(models.PromptVersion)
|