arize-phoenix 12.0.0__py3-none-any.whl → 12.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/METADATA +1 -1
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/RECORD +43 -38
- phoenix/db/insertion/document_annotation.py +1 -1
- phoenix/db/insertion/session_annotation.py +1 -1
- phoenix/db/insertion/span_annotation.py +1 -1
- phoenix/db/insertion/trace_annotation.py +1 -1
- phoenix/db/insertion/types.py +0 -4
- phoenix/db/models.py +22 -1
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/helpers/playground_clients.py +1 -0
- phoenix/server/api/mutations/__init__.py +2 -0
- phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
- phoenix/server/api/mutations/dataset_split_mutations.py +38 -2
- phoenix/server/api/queries.py +21 -0
- phoenix/server/api/routers/auth.py +5 -5
- phoenix/server/api/routers/oauth2.py +53 -51
- phoenix/server/api/types/Dataset.py +8 -0
- phoenix/server/api/types/DatasetExample.py +7 -0
- phoenix/server/api/types/DatasetLabel.py +23 -0
- phoenix/server/api/types/Prompt.py +18 -1
- phoenix/server/app.py +12 -12
- phoenix/server/cost_tracking/model_cost_manifest.json +54 -54
- phoenix/server/oauth2.py +2 -4
- phoenix/server/static/.vite/manifest.json +39 -39
- phoenix/server/static/assets/{components-Dl9SUw1U.js → components-Bs8eJEpU.js} +699 -378
- phoenix/server/static/assets/{index-CqQS0dTo.js → index-C6WEu5UP.js} +3 -3
- phoenix/server/static/assets/{pages-DKSjVA_E.js → pages-D-n2pkoG.js} +1149 -1142
- phoenix/server/static/assets/vendor-D2eEI-6h.js +914 -0
- phoenix/server/static/assets/{vendor-arizeai-D-lWOwIS.js → vendor-arizeai-kfOei7nf.js} +15 -24
- phoenix/server/static/assets/{vendor-codemirror-BRBpy3_z.js → vendor-codemirror-1bq_t1Ec.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts--KdSwB3m.js → vendor-recharts-DQ4xfrf4.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-CvRzZnIo.js → vendor-shiki-GGmcIQxA.js} +1 -1
- phoenix/server/templates/index.html +1 -0
- phoenix/server/utils.py +74 -0
- phoenix/session/session.py +25 -5
- phoenix/version.py +1 -1
- phoenix/server/static/assets/vendor-CtbHQYl8.js +0 -903
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.0.0.dist-info → arize_phoenix-12.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import sqlalchemy
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import delete, select
|
|
6
|
+
from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
|
|
7
|
+
from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
|
|
8
|
+
from strawberry import UNSET
|
|
9
|
+
from strawberry.relay.types import GlobalID
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
14
|
+
from phoenix.server.api.context import Context
|
|
15
|
+
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
|
+
from phoenix.server.api.queries import Query
|
|
17
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
18
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
19
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@strawberry.input
|
|
23
|
+
class CreateDatasetLabelInput:
|
|
24
|
+
name: str
|
|
25
|
+
description: Optional[str] = UNSET
|
|
26
|
+
color: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@strawberry.type
|
|
30
|
+
class CreateDatasetLabelMutationPayload:
|
|
31
|
+
dataset_label: DatasetLabel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@strawberry.input
|
|
35
|
+
class DeleteDatasetLabelsInput:
|
|
36
|
+
dataset_label_ids: list[GlobalID]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@strawberry.type
|
|
40
|
+
class DeleteDatasetLabelsMutationPayload:
|
|
41
|
+
dataset_labels: list[DatasetLabel]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@strawberry.input
|
|
45
|
+
class UpdateDatasetLabelInput:
|
|
46
|
+
dataset_label_id: GlobalID
|
|
47
|
+
name: str
|
|
48
|
+
description: Optional[str] = None
|
|
49
|
+
color: str
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@strawberry.type
|
|
53
|
+
class UpdateDatasetLabelMutationPayload:
|
|
54
|
+
dataset_label: DatasetLabel
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@strawberry.input
|
|
58
|
+
class SetDatasetLabelsInput:
|
|
59
|
+
dataset_label_ids: list[GlobalID]
|
|
60
|
+
dataset_ids: list[GlobalID]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@strawberry.type
|
|
64
|
+
class SetDatasetLabelsMutationPayload:
|
|
65
|
+
query: "Query"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@strawberry.input
|
|
69
|
+
class UnsetDatasetLabelsInput:
|
|
70
|
+
dataset_label_ids: list[GlobalID]
|
|
71
|
+
dataset_ids: list[GlobalID]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@strawberry.type
|
|
75
|
+
class UnsetDatasetLabelsMutationPayload:
|
|
76
|
+
query: "Query"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@strawberry.type
|
|
80
|
+
class DatasetLabelMutationMixin:
|
|
81
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
82
|
+
async def create_dataset_label(
|
|
83
|
+
self,
|
|
84
|
+
info: Info[Context, None],
|
|
85
|
+
input: CreateDatasetLabelInput,
|
|
86
|
+
) -> CreateDatasetLabelMutationPayload:
|
|
87
|
+
name = input.name
|
|
88
|
+
description = input.description
|
|
89
|
+
color = input.color
|
|
90
|
+
async with info.context.db() as session:
|
|
91
|
+
dataset_label_orm = models.DatasetLabel(name=name, description=description, color=color)
|
|
92
|
+
session.add(dataset_label_orm)
|
|
93
|
+
try:
|
|
94
|
+
await session.commit()
|
|
95
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
96
|
+
raise Conflict(f"A dataset label named '{name}' already exists")
|
|
97
|
+
except sqlalchemy.exc.StatementError as error:
|
|
98
|
+
raise BadRequest(str(error.orig))
|
|
99
|
+
return CreateDatasetLabelMutationPayload(
|
|
100
|
+
dataset_label=to_gql_dataset_label(dataset_label_orm)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
104
|
+
async def update_dataset_label(
|
|
105
|
+
self, info: Info[Context, None], input: UpdateDatasetLabelInput
|
|
106
|
+
) -> UpdateDatasetLabelMutationPayload:
|
|
107
|
+
if not input.name or not input.name.strip():
|
|
108
|
+
raise BadRequest("Dataset label name cannot be empty")
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
dataset_label_id = from_global_id_with_expected_type(
|
|
112
|
+
input.dataset_label_id, DatasetLabel.__name__
|
|
113
|
+
)
|
|
114
|
+
except ValueError:
|
|
115
|
+
raise BadRequest(f"Invalid dataset label ID: {input.dataset_label_id}")
|
|
116
|
+
|
|
117
|
+
async with info.context.db() as session:
|
|
118
|
+
dataset_label_orm = await session.get(models.DatasetLabel, dataset_label_id)
|
|
119
|
+
if not dataset_label_orm:
|
|
120
|
+
raise NotFound(f"DatasetLabel with ID {input.dataset_label_id} not found")
|
|
121
|
+
|
|
122
|
+
dataset_label_orm.name = input.name.strip()
|
|
123
|
+
dataset_label_orm.description = input.description
|
|
124
|
+
dataset_label_orm.color = input.color.strip()
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
await session.commit()
|
|
128
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError):
|
|
129
|
+
raise Conflict(f"A dataset label named '{input.name}' already exists")
|
|
130
|
+
except sqlalchemy.exc.StatementError as error:
|
|
131
|
+
raise BadRequest(str(error.orig))
|
|
132
|
+
return UpdateDatasetLabelMutationPayload(
|
|
133
|
+
dataset_label=to_gql_dataset_label(dataset_label_orm)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
137
|
+
async def delete_dataset_labels(
|
|
138
|
+
self, info: Info[Context, None], input: DeleteDatasetLabelsInput
|
|
139
|
+
) -> DeleteDatasetLabelsMutationPayload:
|
|
140
|
+
dataset_label_row_ids: dict[int, None] = {}
|
|
141
|
+
for dataset_label_node_id in input.dataset_label_ids:
|
|
142
|
+
try:
|
|
143
|
+
dataset_label_row_id = from_global_id_with_expected_type(
|
|
144
|
+
dataset_label_node_id, DatasetLabel.__name__
|
|
145
|
+
)
|
|
146
|
+
except ValueError:
|
|
147
|
+
raise BadRequest(f"Unknown dataset label: {dataset_label_node_id}")
|
|
148
|
+
dataset_label_row_ids[dataset_label_row_id] = None
|
|
149
|
+
async with info.context.db() as session:
|
|
150
|
+
stmt = (
|
|
151
|
+
delete(models.DatasetLabel)
|
|
152
|
+
.where(models.DatasetLabel.id.in_(dataset_label_row_ids.keys()))
|
|
153
|
+
.returning(models.DatasetLabel)
|
|
154
|
+
)
|
|
155
|
+
deleted_dataset_labels = (await session.scalars(stmt)).all()
|
|
156
|
+
if len(deleted_dataset_labels) < len(dataset_label_row_ids):
|
|
157
|
+
await session.rollback()
|
|
158
|
+
raise NotFound("Could not find one or more dataset labels with given IDs")
|
|
159
|
+
deleted_dataset_labels_by_id = {
|
|
160
|
+
dataset_label.id: dataset_label for dataset_label in deleted_dataset_labels
|
|
161
|
+
}
|
|
162
|
+
return DeleteDatasetLabelsMutationPayload(
|
|
163
|
+
dataset_labels=[
|
|
164
|
+
to_gql_dataset_label(deleted_dataset_labels_by_id[dataset_label_row_id])
|
|
165
|
+
for dataset_label_row_id in dataset_label_row_ids
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
170
|
+
async def set_dataset_labels(
|
|
171
|
+
self, info: Info[Context, None], input: SetDatasetLabelsInput
|
|
172
|
+
) -> SetDatasetLabelsMutationPayload:
|
|
173
|
+
if not input.dataset_ids:
|
|
174
|
+
raise BadRequest("No datasets provided.")
|
|
175
|
+
if not input.dataset_label_ids:
|
|
176
|
+
raise BadRequest("No dataset labels provided.")
|
|
177
|
+
|
|
178
|
+
unique_dataset_rowids: set[int] = set()
|
|
179
|
+
for dataset_gid in input.dataset_ids:
|
|
180
|
+
try:
|
|
181
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
|
|
182
|
+
except ValueError:
|
|
183
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
|
|
184
|
+
unique_dataset_rowids.add(dataset_rowid)
|
|
185
|
+
dataset_rowids = list(unique_dataset_rowids)
|
|
186
|
+
|
|
187
|
+
unique_dataset_label_rowids: set[int] = set()
|
|
188
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
189
|
+
try:
|
|
190
|
+
dataset_label_rowid = from_global_id_with_expected_type(
|
|
191
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
192
|
+
)
|
|
193
|
+
except ValueError:
|
|
194
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
195
|
+
unique_dataset_label_rowids.add(dataset_label_rowid)
|
|
196
|
+
dataset_label_rowids = list(unique_dataset_label_rowids)
|
|
197
|
+
|
|
198
|
+
async with info.context.db() as session:
|
|
199
|
+
existing_dataset_ids = (
|
|
200
|
+
await session.scalars(
|
|
201
|
+
select(models.Dataset.id).where(models.Dataset.id.in_(dataset_rowids))
|
|
202
|
+
)
|
|
203
|
+
).all()
|
|
204
|
+
if len(existing_dataset_ids) != len(dataset_rowids):
|
|
205
|
+
raise NotFound("One or more datasets not found")
|
|
206
|
+
|
|
207
|
+
existing_dataset_label_ids = (
|
|
208
|
+
await session.scalars(
|
|
209
|
+
select(models.DatasetLabel.id).where(
|
|
210
|
+
models.DatasetLabel.id.in_(dataset_label_rowids)
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
).all()
|
|
214
|
+
if len(existing_dataset_label_ids) != len(dataset_label_rowids):
|
|
215
|
+
raise NotFound("One or more dataset labels not found")
|
|
216
|
+
|
|
217
|
+
existing_dataset_label_keys = await session.execute(
|
|
218
|
+
select(
|
|
219
|
+
models.DatasetsDatasetLabel.dataset_id,
|
|
220
|
+
models.DatasetsDatasetLabel.dataset_label_id,
|
|
221
|
+
).where(
|
|
222
|
+
models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
|
|
223
|
+
& models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
unique_dataset_label_keys = set(existing_dataset_label_keys.all())
|
|
227
|
+
|
|
228
|
+
datasets_dataset_labels = []
|
|
229
|
+
for dataset_rowid in dataset_rowids:
|
|
230
|
+
for dataset_label_rowid in dataset_label_rowids:
|
|
231
|
+
if (dataset_rowid, dataset_label_rowid) in unique_dataset_label_keys:
|
|
232
|
+
continue
|
|
233
|
+
datasets_dataset_labels.append(
|
|
234
|
+
models.DatasetsDatasetLabel(
|
|
235
|
+
dataset_id=dataset_rowid,
|
|
236
|
+
dataset_label_id=dataset_label_rowid,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
session.add_all(datasets_dataset_labels)
|
|
240
|
+
|
|
241
|
+
if datasets_dataset_labels:
|
|
242
|
+
try:
|
|
243
|
+
await session.commit()
|
|
244
|
+
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
245
|
+
raise Conflict("Failed to add dataset labels to datasets.") from e
|
|
246
|
+
|
|
247
|
+
return SetDatasetLabelsMutationPayload(
|
|
248
|
+
query=Query(),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
252
|
+
async def unset_dataset_labels(
|
|
253
|
+
self, info: Info[Context, None], input: UnsetDatasetLabelsInput
|
|
254
|
+
) -> UnsetDatasetLabelsMutationPayload:
|
|
255
|
+
if not input.dataset_ids:
|
|
256
|
+
raise BadRequest("No datasets provided.")
|
|
257
|
+
if not input.dataset_label_ids:
|
|
258
|
+
raise BadRequest("No dataset labels provided.")
|
|
259
|
+
|
|
260
|
+
unique_dataset_rowids: set[int] = set()
|
|
261
|
+
for dataset_gid in input.dataset_ids:
|
|
262
|
+
try:
|
|
263
|
+
dataset_rowid = from_global_id_with_expected_type(dataset_gid, Dataset.__name__)
|
|
264
|
+
except ValueError:
|
|
265
|
+
raise BadRequest(f"Invalid dataset ID: {dataset_gid}")
|
|
266
|
+
unique_dataset_rowids.add(dataset_rowid)
|
|
267
|
+
dataset_rowids = list(unique_dataset_rowids)
|
|
268
|
+
|
|
269
|
+
unique_dataset_label_rowids: set[int] = set()
|
|
270
|
+
for dataset_label_gid in input.dataset_label_ids:
|
|
271
|
+
try:
|
|
272
|
+
dataset_label_rowid = from_global_id_with_expected_type(
|
|
273
|
+
dataset_label_gid, DatasetLabel.__name__
|
|
274
|
+
)
|
|
275
|
+
except ValueError:
|
|
276
|
+
raise BadRequest(f"Invalid dataset label ID: {dataset_label_gid}")
|
|
277
|
+
unique_dataset_label_rowids.add(dataset_label_rowid)
|
|
278
|
+
dataset_label_rowids = list(unique_dataset_label_rowids)
|
|
279
|
+
|
|
280
|
+
async with info.context.db() as session:
|
|
281
|
+
await session.execute(
|
|
282
|
+
delete(models.DatasetsDatasetLabel).where(
|
|
283
|
+
models.DatasetsDatasetLabel.dataset_id.in_(dataset_rowids)
|
|
284
|
+
& models.DatasetsDatasetLabel.dataset_label_id.in_(dataset_label_rowids)
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
await session.commit()
|
|
288
|
+
|
|
289
|
+
return UnsetDatasetLabelsMutationPayload(
|
|
290
|
+
query=Query(),
|
|
291
|
+
)
|
|
@@ -15,6 +15,7 @@ from phoenix.server.api.context import Context
|
|
|
15
15
|
from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound
|
|
16
16
|
from phoenix.server.api.helpers.playground_users import get_user
|
|
17
17
|
from phoenix.server.api.queries import Query
|
|
18
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample, to_gql_dataset_example
|
|
18
19
|
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
19
20
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
21
|
|
|
@@ -68,6 +69,13 @@ class DatasetSplitMutationPayload:
|
|
|
68
69
|
query: "Query"
|
|
69
70
|
|
|
70
71
|
|
|
72
|
+
@strawberry.type
|
|
73
|
+
class DatasetSplitMutationPayloadWithExamples:
|
|
74
|
+
dataset_split: DatasetSplit
|
|
75
|
+
query: "Query"
|
|
76
|
+
examples: list[DatasetExample]
|
|
77
|
+
|
|
78
|
+
|
|
71
79
|
@strawberry.type
|
|
72
80
|
class DeleteDatasetSplitsMutationPayload:
|
|
73
81
|
dataset_splits: list[DatasetSplit]
|
|
@@ -77,11 +85,13 @@ class DeleteDatasetSplitsMutationPayload:
|
|
|
77
85
|
@strawberry.type
|
|
78
86
|
class AddDatasetExamplesToDatasetSplitsMutationPayload:
|
|
79
87
|
query: "Query"
|
|
88
|
+
examples: list[DatasetExample]
|
|
80
89
|
|
|
81
90
|
|
|
82
91
|
@strawberry.type
|
|
83
92
|
class RemoveDatasetExamplesFromDatasetSplitsMutationPayload:
|
|
84
93
|
query: "Query"
|
|
94
|
+
examples: list[DatasetExample]
|
|
85
95
|
|
|
86
96
|
|
|
87
97
|
@strawberry.type
|
|
@@ -262,8 +272,16 @@ class DatasetSplitMutationMixin:
|
|
|
262
272
|
except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
|
|
263
273
|
raise Conflict("Failed to add examples to dataset splits.") from e
|
|
264
274
|
|
|
275
|
+
examples = (
|
|
276
|
+
await session.scalars(
|
|
277
|
+
select(models.DatasetExample).where(
|
|
278
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
).all()
|
|
265
282
|
return AddDatasetExamplesToDatasetSplitsMutationPayload(
|
|
266
283
|
query=Query(),
|
|
284
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
267
285
|
)
|
|
268
286
|
|
|
269
287
|
@strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
|
|
@@ -314,14 +332,23 @@ class DatasetSplitMutationMixin:
|
|
|
314
332
|
|
|
315
333
|
await session.execute(stmt)
|
|
316
334
|
|
|
335
|
+
examples = (
|
|
336
|
+
await session.scalars(
|
|
337
|
+
select(models.DatasetExample).where(
|
|
338
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
339
|
+
)
|
|
340
|
+
)
|
|
341
|
+
).all()
|
|
342
|
+
|
|
317
343
|
return RemoveDatasetExamplesFromDatasetSplitsMutationPayload(
|
|
318
344
|
query=Query(),
|
|
345
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
319
346
|
)
|
|
320
347
|
|
|
321
348
|
@strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
322
349
|
async def create_dataset_split_with_examples(
|
|
323
350
|
self, info: Info[Context, None], input: CreateDatasetSplitWithExamplesInput
|
|
324
|
-
) ->
|
|
351
|
+
) -> DatasetSplitMutationPayloadWithExamples:
|
|
325
352
|
user_id = get_user(info)
|
|
326
353
|
validated_name = _validated_name(input.name)
|
|
327
354
|
unique_example_rowids: set[int] = set()
|
|
@@ -374,9 +401,18 @@ class DatasetSplitMutationMixin:
|
|
|
374
401
|
"Failed to associate examples with the new dataset split."
|
|
375
402
|
) from e
|
|
376
403
|
|
|
377
|
-
|
|
404
|
+
examples = (
|
|
405
|
+
await session.scalars(
|
|
406
|
+
select(models.DatasetExample).where(
|
|
407
|
+
models.DatasetExample.id.in_(example_rowids)
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
).all()
|
|
411
|
+
|
|
412
|
+
return DatasetSplitMutationPayloadWithExamples(
|
|
378
413
|
dataset_split=to_gql_dataset_split(dataset_split_orm),
|
|
379
414
|
query=Query(),
|
|
415
|
+
examples=[to_gql_dataset_example(example) for example in examples],
|
|
380
416
|
)
|
|
381
417
|
|
|
382
418
|
|
phoenix/server/api/queries.py
CHANGED
|
@@ -48,6 +48,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
|
|
|
48
48
|
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
49
49
|
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
50
50
|
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
51
|
+
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
51
52
|
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
52
53
|
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
53
54
|
from phoenix.server.api.types.EmbeddingDimension import (
|
|
@@ -1149,6 +1150,26 @@ class Query:
|
|
|
1149
1150
|
args=args,
|
|
1150
1151
|
)
|
|
1151
1152
|
|
|
1153
|
+
@strawberry.field
|
|
1154
|
+
async def dataset_labels(
|
|
1155
|
+
self,
|
|
1156
|
+
info: Info[Context, None],
|
|
1157
|
+
first: Optional[int] = 50,
|
|
1158
|
+
last: Optional[int] = UNSET,
|
|
1159
|
+
after: Optional[CursorString] = UNSET,
|
|
1160
|
+
before: Optional[CursorString] = UNSET,
|
|
1161
|
+
) -> Connection[DatasetLabel]:
|
|
1162
|
+
args = ConnectionArgs(
|
|
1163
|
+
first=first,
|
|
1164
|
+
after=after if isinstance(after, CursorString) else None,
|
|
1165
|
+
last=last,
|
|
1166
|
+
before=before if isinstance(before, CursorString) else None,
|
|
1167
|
+
)
|
|
1168
|
+
async with info.context.db() as session:
|
|
1169
|
+
dataset_labels = await session.scalars(select(models.DatasetLabel))
|
|
1170
|
+
data = [to_gql_dataset_label(dataset_label) for dataset_label in dataset_labels]
|
|
1171
|
+
return connection_from_list(data=data, args=args)
|
|
1172
|
+
|
|
1152
1173
|
@strawberry.field
|
|
1153
1174
|
async def dataset_splits(
|
|
1154
1175
|
self,
|
|
@@ -2,7 +2,6 @@ import asyncio
|
|
|
2
2
|
import secrets
|
|
3
3
|
from datetime import datetime, timedelta, timezone
|
|
4
4
|
from functools import partial
|
|
5
|
-
from pathlib import Path
|
|
6
5
|
from urllib.parse import urlencode, urlparse, urlunparse
|
|
7
6
|
|
|
8
7
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
|
@@ -38,7 +37,6 @@ from phoenix.config import (
|
|
|
38
37
|
get_base_url,
|
|
39
38
|
get_env_disable_basic_auth,
|
|
40
39
|
get_env_disable_rate_limit,
|
|
41
|
-
get_env_host_root_path,
|
|
42
40
|
)
|
|
43
41
|
from phoenix.db import models
|
|
44
42
|
from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
|
|
@@ -52,6 +50,7 @@ from phoenix.server.types import (
|
|
|
52
50
|
TokenStore,
|
|
53
51
|
UserId,
|
|
54
52
|
)
|
|
53
|
+
from phoenix.server.utils import prepend_root_path
|
|
55
54
|
|
|
56
55
|
rate_limiter = ServerRateLimiter(
|
|
57
56
|
per_second_rate_limit=0.2,
|
|
@@ -145,7 +144,8 @@ async def logout(
|
|
|
145
144
|
user_id = subject
|
|
146
145
|
if user_id:
|
|
147
146
|
await token_store.log_out(user_id)
|
|
148
|
-
|
|
147
|
+
redirect_path = "/logout" if get_env_disable_basic_auth() else "/login"
|
|
148
|
+
redirect_url = prepend_root_path(request.scope, redirect_path)
|
|
149
149
|
response = Response(status_code=HTTP_302_FOUND, headers={"Location": redirect_url})
|
|
150
150
|
response = delete_access_token_cookie(response)
|
|
151
151
|
response = delete_refresh_token_cookie(response)
|
|
@@ -242,9 +242,9 @@ async def initiate_password_reset(request: Request) -> Response:
|
|
|
242
242
|
)
|
|
243
243
|
token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
|
|
244
244
|
url = urlparse(request.headers.get("referer") or get_base_url())
|
|
245
|
-
path =
|
|
245
|
+
path = prepend_root_path(request.scope, "/reset-password-with-token")
|
|
246
246
|
query_string = urlencode(dict(token=token))
|
|
247
|
-
components = (url.scheme, url.netloc, path
|
|
247
|
+
components = (url.scheme, url.netloc, path, "", query_string, "")
|
|
248
248
|
reset_url = urlunparse(components)
|
|
249
249
|
await sender.send_password_reset_email(email, reset_url)
|
|
250
250
|
return Response(status_code=HTTP_204_NO_CONTENT)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import re
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from datetime import timedelta
|
|
@@ -38,17 +39,20 @@ from phoenix.config import (
|
|
|
38
39
|
get_env_disable_rate_limit,
|
|
39
40
|
)
|
|
40
41
|
from phoenix.db import models
|
|
42
|
+
from phoenix.server.api.auth_messages import AuthErrorCode
|
|
41
43
|
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
|
|
42
|
-
from phoenix.server.oauth2 import OAuth2Client
|
|
43
44
|
from phoenix.server.rate_limiters import (
|
|
44
45
|
ServerRateLimiter,
|
|
45
46
|
fastapi_ip_rate_limiter,
|
|
46
47
|
fastapi_route_rate_limiter,
|
|
47
48
|
)
|
|
48
49
|
from phoenix.server.types import TokenStore
|
|
50
|
+
from phoenix.server.utils import get_root_path, prepend_root_path
|
|
49
51
|
|
|
50
52
|
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
|
|
51
53
|
|
|
54
|
+
logger = logging.getLogger(__name__)
|
|
55
|
+
|
|
52
56
|
login_rate_limiter = fastapi_ip_rate_limiter(
|
|
53
57
|
ServerRateLimiter(
|
|
54
58
|
per_second_rate_limit=0.2,
|
|
@@ -87,11 +91,12 @@ async def login(
|
|
|
87
91
|
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
88
92
|
return_url: Optional[str] = Query(default=None, alias="returnUrl"),
|
|
89
93
|
) -> RedirectResponse:
|
|
94
|
+
# Security Note: Query parameters should be treated as untrusted user input. Never display
|
|
95
|
+
# these values directly to users as they could be manipulated for XSS, phishing, or social
|
|
96
|
+
# engineering attacks.
|
|
97
|
+
if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
|
|
98
|
+
return _redirect_to_login(request=request, error="unknown_idp")
|
|
90
99
|
secret = request.app.state.get_secret()
|
|
91
|
-
if not isinstance(
|
|
92
|
-
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
93
|
-
):
|
|
94
|
-
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
95
100
|
if (referer := request.headers.get("referer")) is not None:
|
|
96
101
|
# if the referer header is present, use it as the origin URL
|
|
97
102
|
parsed_url = urlparse(referer)
|
|
@@ -131,28 +136,42 @@ async def create_tokens(
|
|
|
131
136
|
request: Request,
|
|
132
137
|
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
|
|
133
138
|
state: str = Query(),
|
|
134
|
-
authorization_code: str = Query(alias="code"),
|
|
139
|
+
authorization_code: Optional[str] = Query(default=None, alias="code"),
|
|
140
|
+
error: Optional[str] = Query(default=None),
|
|
141
|
+
error_description: Optional[str] = Query(default=None),
|
|
135
142
|
stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
|
|
136
143
|
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
|
|
137
144
|
) -> RedirectResponse:
|
|
145
|
+
# Security Note: Query parameters should be treated as untrusted user input. Never display
|
|
146
|
+
# these values directly to users as they could be manipulated for XSS, phishing, or social
|
|
147
|
+
# engineering attacks.
|
|
148
|
+
if (oauth2_client := request.app.state.oauth2_clients.get_client(idp_name)) is None:
|
|
149
|
+
return _redirect_to_login(request=request, error="unknown_idp")
|
|
150
|
+
if error or error_description:
|
|
151
|
+
logger.error(
|
|
152
|
+
"OAuth2 authentication failed for IDP %s: error=%s, description=%s",
|
|
153
|
+
idp_name,
|
|
154
|
+
error,
|
|
155
|
+
error_description,
|
|
156
|
+
)
|
|
157
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
158
|
+
if authorization_code is None:
|
|
159
|
+
logger.error("OAuth2 callback missing authorization code for IDP %s", idp_name)
|
|
160
|
+
return _redirect_to_login(request=request, error="auth_failed")
|
|
138
161
|
secret = request.app.state.get_secret()
|
|
139
162
|
if state != stored_state:
|
|
140
|
-
return _redirect_to_login(request=request, error=
|
|
163
|
+
return _redirect_to_login(request=request, error="invalid_state")
|
|
141
164
|
try:
|
|
142
165
|
payload = _parse_state_payload(secret=secret, state=state)
|
|
143
166
|
except JoseError:
|
|
144
|
-
return _redirect_to_login(request=request, error=
|
|
167
|
+
return _redirect_to_login(request=request, error="invalid_state")
|
|
145
168
|
if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
|
|
146
169
|
unquote(return_url)
|
|
147
170
|
):
|
|
148
|
-
return _redirect_to_login(request=request, error="
|
|
171
|
+
return _redirect_to_login(request=request, error="unsafe_return_url")
|
|
149
172
|
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
|
|
150
173
|
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
|
|
151
174
|
token_store: TokenStore = request.app.state.get_token_store()
|
|
152
|
-
if not isinstance(
|
|
153
|
-
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
|
|
154
|
-
):
|
|
155
|
-
return _redirect_to_login(request=request, error=f"Unknown IDP: {idp_name}.")
|
|
156
175
|
try:
|
|
157
176
|
token_data = await oauth2_client.fetch_access_token(
|
|
158
177
|
state=state,
|
|
@@ -161,19 +180,19 @@ async def create_tokens(
|
|
|
161
180
|
request=request, origin_url=payload["origin_url"], idp_name=idp_name
|
|
162
181
|
),
|
|
163
182
|
)
|
|
164
|
-
except OAuthError as
|
|
165
|
-
|
|
183
|
+
except OAuthError as e:
|
|
184
|
+
logger.error("OAuth2 error for IDP %s: %s", idp_name, e)
|
|
185
|
+
return _redirect_to_login(request=request, error="oauth_error")
|
|
166
186
|
_validate_token_data(token_data)
|
|
167
187
|
if "id_token" not in token_data:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect.",
|
|
171
|
-
)
|
|
188
|
+
logger.error("OAuth2 IDP %s does not appear to support OpenID Connect", idp_name)
|
|
189
|
+
return _redirect_to_login(request=request, error="no_oidc_support")
|
|
172
190
|
user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
|
|
173
191
|
try:
|
|
174
192
|
user_info = _parse_user_info(user_info)
|
|
175
|
-
except MissingEmailScope as
|
|
176
|
-
|
|
193
|
+
except MissingEmailScope as e:
|
|
194
|
+
logger.error("Missing email scope for IDP %s: %s", idp_name, e)
|
|
195
|
+
return _redirect_to_login(request=request, error="missing_email_scope")
|
|
177
196
|
|
|
178
197
|
try:
|
|
179
198
|
async with request.app.state.db() as session:
|
|
@@ -183,15 +202,19 @@ async def create_tokens(
|
|
|
183
202
|
user_info=user_info,
|
|
184
203
|
allow_sign_up=oauth2_client.allow_sign_up,
|
|
185
204
|
)
|
|
186
|
-
except
|
|
187
|
-
|
|
205
|
+
except EmailAlreadyInUse as e:
|
|
206
|
+
logger.error("Email already in use for IDP %s: %s", idp_name, e)
|
|
207
|
+
return _redirect_to_login(request=request, error="email_in_use")
|
|
208
|
+
except SignInNotAllowed as e:
|
|
209
|
+
logger.error("Sign in not allowed for IDP %s: %s", idp_name, e)
|
|
210
|
+
return _redirect_to_login(request=request, error="sign_in_not_allowed")
|
|
188
211
|
access_token, refresh_token = await create_access_and_refresh_tokens(
|
|
189
212
|
user=user,
|
|
190
213
|
token_store=token_store,
|
|
191
214
|
access_token_expiry=access_token_expiry,
|
|
192
215
|
refresh_token_expiry=refresh_token_expiry,
|
|
193
216
|
)
|
|
194
|
-
redirect_path =
|
|
217
|
+
redirect_path = prepend_root_path(request.scope, return_url or "/")
|
|
195
218
|
response = RedirectResponse(
|
|
196
219
|
url=redirect_path,
|
|
197
220
|
status_code=HTTP_302_FOUND,
|
|
@@ -559,13 +582,14 @@ class MissingEmailScope(Exception):
|
|
|
559
582
|
pass
|
|
560
583
|
|
|
561
584
|
|
|
562
|
-
def _redirect_to_login(*, request: Request, error:
|
|
585
|
+
def _redirect_to_login(*, request: Request, error: AuthErrorCode) -> RedirectResponse:
|
|
563
586
|
"""
|
|
564
|
-
Creates a RedirectResponse to the login page to display an error
|
|
587
|
+
Creates a RedirectResponse to the login page to display an error code.
|
|
588
|
+
The error code will be validated and mapped to a user-friendly message on the frontend.
|
|
565
589
|
"""
|
|
566
590
|
# TODO: this needs some cleanup
|
|
567
|
-
login_path =
|
|
568
|
-
request
|
|
591
|
+
login_path = prepend_root_path(
|
|
592
|
+
request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
|
|
569
593
|
)
|
|
570
594
|
url = URL(login_path).include_query_params(error=error)
|
|
571
595
|
response = RedirectResponse(url=url)
|
|
@@ -574,34 +598,15 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
|
|
|
574
598
|
return response
|
|
575
599
|
|
|
576
600
|
|
|
577
|
-
def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
|
|
578
|
-
"""
|
|
579
|
-
If a root path is configured, prepends it to the input path.
|
|
580
|
-
"""
|
|
581
|
-
if not path.startswith("/"):
|
|
582
|
-
raise ValueError("path must start with a forward slash")
|
|
583
|
-
root_path = _get_root_path(request=request)
|
|
584
|
-
if root_path.endswith("/"):
|
|
585
|
-
root_path = root_path.rstrip("/")
|
|
586
|
-
return root_path + path
|
|
587
|
-
|
|
588
|
-
|
|
589
601
|
def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
|
|
590
602
|
"""
|
|
591
603
|
If a root path is configured, appends it to the input base url.
|
|
592
604
|
"""
|
|
593
|
-
if not (root_path :=
|
|
605
|
+
if not (root_path := get_root_path(request.scope)):
|
|
594
606
|
return base_url
|
|
595
607
|
return str(URLPath(root_path).make_absolute_url(base_url=base_url))
|
|
596
608
|
|
|
597
609
|
|
|
598
|
-
def _get_root_path(*, request: Request) -> str:
|
|
599
|
-
"""
|
|
600
|
-
Gets the root path from the request.
|
|
601
|
-
"""
|
|
602
|
-
return str(request.scope.get("root_path", ""))
|
|
603
|
-
|
|
604
|
-
|
|
605
610
|
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
|
|
606
611
|
"""
|
|
607
612
|
Gets the endpoint for create tokens route.
|
|
@@ -679,7 +684,4 @@ def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2State
|
|
|
679
684
|
|
|
680
685
|
|
|
681
686
|
_JWT_ALGORITHM = "HS256"
|
|
682
|
-
_INVALID_OAUTH2_STATE_MESSAGE = (
|
|
683
|
-
"Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
|
|
684
|
-
)
|
|
685
687
|
_RELATIVE_URL_PATTERN = re.compile(r"^/($|\w)")
|