arize-phoenix 12.4.0__py3-none-any.whl → 12.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/METADATA +1 -1
- {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/RECORD +62 -60
- phoenix/auth.py +8 -2
- phoenix/db/models.py +3 -3
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/context.py +2 -0
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/SpanSort.py +2 -1
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +6 -6
- phoenix/server/api/mutations/api_key_mutations.py +13 -5
- phoenix/server/api/mutations/chat_mutations.py +3 -3
- phoenix/server/api/mutations/dataset_label_mutations.py +6 -6
- phoenix/server/api/mutations/dataset_mutations.py +8 -8
- phoenix/server/api/mutations/dataset_split_mutations.py +7 -7
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +4 -4
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +7 -7
- phoenix/server/api/mutations/prompt_mutations.py +7 -7
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +3 -3
- phoenix/server/api/mutations/span_annotations_mutations.py +5 -5
- phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +8 -5
- phoenix/server/api/routers/auth.py +2 -2
- phoenix/server/api/routers/v1/__init__.py +16 -1
- phoenix/server/api/routers/v1/annotation_configs.py +7 -1
- phoenix/server/api/routers/v1/datasets.py +48 -8
- phoenix/server/api/routers/v1/experiment_runs.py +7 -1
- phoenix/server/api/routers/v1/experiments.py +41 -5
- phoenix/server/api/routers/v1/projects.py +3 -31
- phoenix/server/api/routers/v1/users.py +0 -7
- phoenix/server/api/subscriptions.py +3 -3
- phoenix/server/api/types/Dataset.py +95 -6
- phoenix/server/api/types/Experiment.py +60 -25
- phoenix/server/api/types/Project.py +24 -68
- phoenix/server/app.py +2 -0
- phoenix/server/authorization.py +3 -1
- phoenix/server/bearer_auth.py +9 -0
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/static/.vite/manifest.json +44 -44
- phoenix/server/static/assets/{components-BvsExS75.js → components-CboqzKQ9.js} +520 -397
- phoenix/server/static/assets/{index-iq8WDxat.js → index-CYYGI5-x.js} +2 -2
- phoenix/server/static/assets/{pages-Ckg4SLQ9.js → pages-DdlUeKi2.js} +616 -604
- phoenix/server/static/assets/vendor-CQ4tN9P7.js +918 -0
- phoenix/server/static/assets/vendor-arizeai-Cb1ncvYH.js +106 -0
- phoenix/server/static/assets/{vendor-codemirror-1bq_t1Ec.js → vendor-codemirror-CckmKopH.js} +3 -3
- phoenix/server/static/assets/{vendor-recharts-DQ4xfrf4.js → vendor-recharts-BC1ysIKu.js} +1 -1
- phoenix/server/static/assets/{vendor-shiki-GGmcIQxA.js → vendor-shiki-B45T-YxN.js} +1 -1
- phoenix/server/static/assets/vendor-three-BtCyLs1w.js +3840 -0
- phoenix/version.py +1 -1
- phoenix/server/static/assets/vendor-D2eEI-6h.js +0 -914
- phoenix/server/static/assets/vendor-arizeai-kfOei7nf.js +0 -156
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -161,7 +161,7 @@ async def refresh_tokens(request: Request) -> Response:
|
|
|
161
161
|
or (expiration_time := refresh_token_claims.expiration_time) is None
|
|
162
162
|
):
|
|
163
163
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
|
164
|
-
if expiration_time.timestamp()
|
|
164
|
+
if expiration_time.timestamp() <= datetime.now(timezone.utc).timestamp():
|
|
165
165
|
raise HTTPException(status_code=401, detail="Expired refresh token")
|
|
166
166
|
await token_store.revoke(refresh_token_id)
|
|
167
167
|
|
|
@@ -253,7 +253,7 @@ async def reset_password(request: Request) -> Response:
|
|
|
253
253
|
not (token := data.get("token"))
|
|
254
254
|
or not isinstance((claims := await token_store.read(token)), PasswordResetTokenClaims)
|
|
255
255
|
or not claims.expiration_time
|
|
256
|
-
or claims.expiration_time
|
|
256
|
+
or claims.expiration_time <= datetime.now(timezone.utc)
|
|
257
257
|
):
|
|
258
258
|
raise INVALID_TOKEN
|
|
259
259
|
assert (user_id := claims.subject)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
2
2
|
from fastapi.security import APIKeyHeader
|
|
3
3
|
|
|
4
|
-
from phoenix.server.bearer_auth import is_authenticated
|
|
4
|
+
from phoenix.server.bearer_auth import PhoenixUser, is_authenticated
|
|
5
5
|
|
|
6
6
|
from .annotation_configs import router as annotation_configs_router
|
|
7
7
|
from .annotations import router as annotations_router
|
|
@@ -33,6 +33,20 @@ async def prevent_access_in_read_only_mode(request: Request) -> None:
|
|
|
33
33
|
)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
+
async def restrict_access_by_viewers(request: Request) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Prevents access to the REST API for viewers, except for GET requests
|
|
39
|
+
and specific allowed POST routes.
|
|
40
|
+
"""
|
|
41
|
+
if request.method == "GET":
|
|
42
|
+
return
|
|
43
|
+
if isinstance(request.user, PhoenixUser) and request.user.is_viewer:
|
|
44
|
+
raise HTTPException(
|
|
45
|
+
status_code=403,
|
|
46
|
+
detail="Viewers cannot perform this action.",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
36
50
|
def create_v1_router(authentication_enabled: bool) -> APIRouter:
|
|
37
51
|
"""
|
|
38
52
|
Instantiates the v1 REST API router.
|
|
@@ -50,6 +64,7 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
|
|
|
50
64
|
)
|
|
51
65
|
)
|
|
52
66
|
dependencies.append(Depends(is_authenticated))
|
|
67
|
+
dependencies.append(Depends(restrict_access_by_viewers))
|
|
53
68
|
|
|
54
69
|
router = APIRouter(
|
|
55
70
|
prefix="/v1",
|
|
@@ -349,7 +349,13 @@ async def delete_annotation_config(
|
|
|
349
349
|
request: Request,
|
|
350
350
|
config_id: str = Path(..., description="ID of the annotation configuration"),
|
|
351
351
|
) -> DeleteAnnotationConfigResponseBody:
|
|
352
|
-
|
|
352
|
+
try:
|
|
353
|
+
config_gid = GlobalID.from_id(config_id)
|
|
354
|
+
except Exception:
|
|
355
|
+
raise HTTPException(
|
|
356
|
+
status_code=422,
|
|
357
|
+
detail=f"Invalid annotation configuration ID format: {config_id}",
|
|
358
|
+
)
|
|
353
359
|
if config_gid.type_name not in (
|
|
354
360
|
CategoricalAnnotationConfigType.__name__,
|
|
355
361
|
ContinuousAnnotationConfigType.__name__,
|
|
@@ -209,7 +209,13 @@ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
|
|
|
209
209
|
async def get_dataset(
|
|
210
210
|
request: Request, id: str = Path(description="The ID of the dataset")
|
|
211
211
|
) -> GetDatasetResponseBody:
|
|
212
|
-
|
|
212
|
+
try:
|
|
213
|
+
dataset_id = GlobalID.from_id(id)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
raise HTTPException(
|
|
216
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
217
|
+
status_code=422,
|
|
218
|
+
) from e
|
|
213
219
|
|
|
214
220
|
if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
|
|
215
221
|
raise HTTPException(detail=f"ID {dataset_id} refers to a f{type_name}", status_code=404)
|
|
@@ -400,7 +406,12 @@ async def upload_dataset(
|
|
|
400
406
|
description="If true, fulfill request synchronously and return JSON containing dataset_id.",
|
|
401
407
|
),
|
|
402
408
|
) -> Optional[UploadDatasetResponseBody]:
|
|
403
|
-
request_content_type = request.headers
|
|
409
|
+
request_content_type = request.headers.get("content-type")
|
|
410
|
+
if not request_content_type:
|
|
411
|
+
raise HTTPException(
|
|
412
|
+
detail="Missing content-type header",
|
|
413
|
+
status_code=400,
|
|
414
|
+
)
|
|
404
415
|
examples: Union[Examples, Awaitable[Examples]]
|
|
405
416
|
if request_content_type.startswith("application/json"):
|
|
406
417
|
try:
|
|
@@ -709,8 +720,24 @@ async def get_dataset_examples(
|
|
|
709
720
|
),
|
|
710
721
|
),
|
|
711
722
|
) -> ListDatasetExamplesResponseBody:
|
|
712
|
-
|
|
713
|
-
|
|
723
|
+
try:
|
|
724
|
+
dataset_gid = GlobalID.from_id(id)
|
|
725
|
+
except Exception as e:
|
|
726
|
+
raise HTTPException(
|
|
727
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
728
|
+
status_code=422,
|
|
729
|
+
) from e
|
|
730
|
+
|
|
731
|
+
if version_id:
|
|
732
|
+
try:
|
|
733
|
+
version_gid = GlobalID.from_id(version_id)
|
|
734
|
+
except Exception as e:
|
|
735
|
+
raise HTTPException(
|
|
736
|
+
detail=f"Invalid dataset version ID format: {version_id}",
|
|
737
|
+
status_code=422,
|
|
738
|
+
) from e
|
|
739
|
+
else:
|
|
740
|
+
version_gid = None
|
|
714
741
|
|
|
715
742
|
if (dataset_type := dataset_gid.type_name) != "Dataset":
|
|
716
743
|
raise HTTPException(detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=404)
|
|
@@ -992,12 +1019,25 @@ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision
|
|
|
992
1019
|
async def _get_db_examples(
|
|
993
1020
|
*, session: Any, id: str, version_id: Optional[str]
|
|
994
1021
|
) -> tuple[str, list[models.DatasetExampleRevision]]:
|
|
995
|
-
|
|
1022
|
+
try:
|
|
1023
|
+
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
|
|
1024
|
+
except Exception as e:
|
|
1025
|
+
raise HTTPException(
|
|
1026
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
1027
|
+
status_code=422,
|
|
1028
|
+
) from e
|
|
1029
|
+
|
|
996
1030
|
dataset_version_id: Optional[int] = None
|
|
997
1031
|
if version_id:
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1032
|
+
try:
|
|
1033
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
1034
|
+
GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
|
|
1035
|
+
)
|
|
1036
|
+
except Exception as e:
|
|
1037
|
+
raise HTTPException(
|
|
1038
|
+
detail=f"Invalid dataset version ID format: {version_id}",
|
|
1039
|
+
status_code=422,
|
|
1040
|
+
) from e
|
|
1001
1041
|
latest_version = (
|
|
1002
1042
|
select(
|
|
1003
1043
|
models.DatasetExampleRevision.dataset_example_id,
|
|
@@ -159,7 +159,13 @@ async def list_experiment_runs(
|
|
|
159
159
|
gt=0,
|
|
160
160
|
),
|
|
161
161
|
) -> ListExperimentRunsResponseBody:
|
|
162
|
-
|
|
162
|
+
try:
|
|
163
|
+
experiment_gid = GlobalID.from_id(experiment_id)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
raise HTTPException(
|
|
166
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
167
|
+
status_code=422,
|
|
168
|
+
) from e
|
|
163
169
|
try:
|
|
164
170
|
experiment_rowid = from_global_id_with_expected_type(experiment_gid, "Experiment")
|
|
165
171
|
except ValueError:
|
|
@@ -104,7 +104,13 @@ async def create_experiment(
|
|
|
104
104
|
request_body: CreateExperimentRequestBody,
|
|
105
105
|
dataset_id: str = Path(..., title="Dataset ID"),
|
|
106
106
|
) -> CreateExperimentResponseBody:
|
|
107
|
-
|
|
107
|
+
try:
|
|
108
|
+
dataset_globalid = GlobalID.from_id(dataset_id)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise HTTPException(
|
|
111
|
+
detail=f"Invalid dataset ID format: {dataset_id}",
|
|
112
|
+
status_code=422,
|
|
113
|
+
) from e
|
|
108
114
|
try:
|
|
109
115
|
dataset_rowid = from_global_id_with_expected_type(dataset_globalid, "Dataset")
|
|
110
116
|
except ValueError:
|
|
@@ -117,6 +123,12 @@ async def create_experiment(
|
|
|
117
123
|
if dataset_version_globalid_str is not None:
|
|
118
124
|
try:
|
|
119
125
|
dataset_version_globalid = GlobalID.from_id(dataset_version_globalid_str)
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise HTTPException(
|
|
128
|
+
detail=f"Invalid dataset version ID format: {dataset_version_globalid_str}",
|
|
129
|
+
status_code=422,
|
|
130
|
+
) from e
|
|
131
|
+
try:
|
|
120
132
|
dataset_version_id = from_global_id_with_expected_type(
|
|
121
133
|
dataset_version_globalid, "DatasetVersion"
|
|
122
134
|
)
|
|
@@ -232,7 +244,13 @@ class GetExperimentResponseBody(ResponseBody[Experiment]):
|
|
|
232
244
|
response_description="Experiment retrieved successfully",
|
|
233
245
|
)
|
|
234
246
|
async def get_experiment(request: Request, experiment_id: str) -> GetExperimentResponseBody:
|
|
235
|
-
|
|
247
|
+
try:
|
|
248
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
raise HTTPException(
|
|
251
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
252
|
+
status_code=422,
|
|
253
|
+
) from e
|
|
236
254
|
try:
|
|
237
255
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
238
256
|
except ValueError:
|
|
@@ -282,7 +300,13 @@ async def list_experiments(
|
|
|
282
300
|
request: Request,
|
|
283
301
|
dataset_id: str = Path(..., title="Dataset ID"),
|
|
284
302
|
) -> ListExperimentsResponseBody:
|
|
285
|
-
|
|
303
|
+
try:
|
|
304
|
+
dataset_gid = GlobalID.from_id(dataset_id)
|
|
305
|
+
except Exception as e:
|
|
306
|
+
raise HTTPException(
|
|
307
|
+
detail=f"Invalid dataset ID format: {dataset_id}",
|
|
308
|
+
status_code=422,
|
|
309
|
+
) from e
|
|
286
310
|
try:
|
|
287
311
|
dataset_rowid = from_global_id_with_expected_type(dataset_gid, "Dataset")
|
|
288
312
|
except ValueError:
|
|
@@ -397,7 +421,13 @@ async def get_experiment_json(
|
|
|
397
421
|
request: Request,
|
|
398
422
|
experiment_id: str = Path(..., title="Experiment ID"),
|
|
399
423
|
) -> Response:
|
|
400
|
-
|
|
424
|
+
try:
|
|
425
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
426
|
+
except Exception as e:
|
|
427
|
+
raise HTTPException(
|
|
428
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
429
|
+
status_code=422,
|
|
430
|
+
) from e
|
|
401
431
|
try:
|
|
402
432
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
403
433
|
except ValueError:
|
|
@@ -464,7 +494,13 @@ async def get_experiment_csv(
|
|
|
464
494
|
request: Request,
|
|
465
495
|
experiment_id: str = Path(..., title="Experiment ID"),
|
|
466
496
|
) -> Response:
|
|
467
|
-
|
|
497
|
+
try:
|
|
498
|
+
experiment_globalid = GlobalID.from_id(experiment_id)
|
|
499
|
+
except Exception as e:
|
|
500
|
+
raise HTTPException(
|
|
501
|
+
detail=f"Invalid experiment ID format: {experiment_id}",
|
|
502
|
+
status_code=422,
|
|
503
|
+
) from e
|
|
468
504
|
try:
|
|
469
505
|
experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
|
|
470
506
|
except ValueError:
|
|
@@ -9,7 +9,6 @@ from strawberry.relay import GlobalID
|
|
|
9
9
|
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
10
10
|
from phoenix.db import models
|
|
11
11
|
from phoenix.db.helpers import exclude_experiment_projects
|
|
12
|
-
from phoenix.db.models import UserRoleName
|
|
13
12
|
from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
|
|
14
13
|
from phoenix.server.api.routers.v1.utils import (
|
|
15
14
|
PaginatedResponseBody,
|
|
@@ -18,7 +17,7 @@ from phoenix.server.api.routers.v1.utils import (
|
|
|
18
17
|
add_errors_to_responses,
|
|
19
18
|
)
|
|
20
19
|
from phoenix.server.api.types.Project import Project as ProjectNodeType
|
|
21
|
-
from phoenix.server.authorization import is_not_locked
|
|
20
|
+
from phoenix.server.authorization import is_not_locked, require_admin
|
|
22
21
|
|
|
23
22
|
router = APIRouter(tags=["projects"])
|
|
24
23
|
|
|
@@ -210,7 +209,7 @@ async def create_project(
|
|
|
210
209
|
|
|
211
210
|
@router.put(
|
|
212
211
|
"/projects/{project_identifier}",
|
|
213
|
-
dependencies=[Depends(is_not_locked)],
|
|
212
|
+
dependencies=[Depends(require_admin), Depends(is_not_locked)],
|
|
214
213
|
operation_id="updateProject",
|
|
215
214
|
summary="Update a project by ID or name", # noqa: E501
|
|
216
215
|
description="Update an existing project with new configuration. Project names cannot be changed. The project identifier is either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
|
|
@@ -245,20 +244,6 @@ async def update_project(
|
|
|
245
244
|
Raises:
|
|
246
245
|
HTTPException: If the project identifier format is invalid or the project is not found.
|
|
247
246
|
""" # noqa: E501
|
|
248
|
-
if request.app.state.authentication_enabled:
|
|
249
|
-
async with request.app.state.db() as session:
|
|
250
|
-
# Check if the user is an admin
|
|
251
|
-
stmt = (
|
|
252
|
-
select(models.UserRole.name)
|
|
253
|
-
.join(models.User)
|
|
254
|
-
.where(models.User.id == int(request.user.identity))
|
|
255
|
-
)
|
|
256
|
-
role_name: UserRoleName = await session.scalar(stmt)
|
|
257
|
-
if role_name != "ADMIN" and role_name != "SYSTEM":
|
|
258
|
-
raise HTTPException(
|
|
259
|
-
status_code=403,
|
|
260
|
-
detail="Only admins can update projects",
|
|
261
|
-
)
|
|
262
247
|
async with request.app.state.db() as session:
|
|
263
248
|
project = await _get_project_by_identifier(session, project_identifier)
|
|
264
249
|
|
|
@@ -272,6 +257,7 @@ async def update_project(
|
|
|
272
257
|
|
|
273
258
|
@router.delete(
|
|
274
259
|
"/projects/{project_identifier}",
|
|
260
|
+
dependencies=[Depends(require_admin)],
|
|
275
261
|
operation_id="deleteProject",
|
|
276
262
|
summary="Delete a project by ID or name", # noqa: E501
|
|
277
263
|
description="Delete an existing project and all its associated data. The project identifier is either project ID or project name. The default project cannot be deleted. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
|
|
@@ -305,20 +291,6 @@ async def delete_project(
|
|
|
305
291
|
Raises:
|
|
306
292
|
HTTPException: If the project identifier format is invalid, the project is not found, or it's the default project.
|
|
307
293
|
""" # noqa: E501
|
|
308
|
-
if request.app.state.authentication_enabled:
|
|
309
|
-
async with request.app.state.db() as session:
|
|
310
|
-
# Check if the user is an admin
|
|
311
|
-
stmt = (
|
|
312
|
-
select(models.UserRole.name)
|
|
313
|
-
.join(models.User)
|
|
314
|
-
.where(models.User.id == int(request.user.identity))
|
|
315
|
-
)
|
|
316
|
-
role_name: UserRoleName = await session.scalar(stmt)
|
|
317
|
-
if role_name != "ADMIN" and role_name != "SYSTEM":
|
|
318
|
-
raise HTTPException(
|
|
319
|
-
status_code=403,
|
|
320
|
-
detail="Only admins can delete projects",
|
|
321
|
-
)
|
|
322
294
|
async with request.app.state.db() as session:
|
|
323
295
|
project = await _get_project_by_identifier(session, project_identifier)
|
|
324
296
|
|
|
@@ -208,13 +208,6 @@ async def create_user(
|
|
|
208
208
|
detail="Cannot create users with SYSTEM role",
|
|
209
209
|
)
|
|
210
210
|
|
|
211
|
-
# TODO: Implement VIEWER role
|
|
212
|
-
if role == "VIEWER":
|
|
213
|
-
raise HTTPException(
|
|
214
|
-
status_code=400,
|
|
215
|
-
detail="VIEWER role not yet implemented",
|
|
216
|
-
)
|
|
217
|
-
|
|
218
211
|
user: models.User
|
|
219
212
|
if isinstance(user_data, LocalUserData):
|
|
220
213
|
password = (user_data.password or secrets.token_hex()).strip()
|
|
@@ -27,7 +27,7 @@ from phoenix.config import PLAYGROUND_PROJECT_NAME
|
|
|
27
27
|
from phoenix.datetime_utils import local_now, normalize_datetime
|
|
28
28
|
from phoenix.db import models
|
|
29
29
|
from phoenix.db.helpers import insert_experiment_with_examples_snapshot
|
|
30
|
-
from phoenix.server.api.auth import IsLocked, IsNotReadOnly
|
|
30
|
+
from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
|
|
31
31
|
from phoenix.server.api.context import Context
|
|
32
32
|
from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
|
|
33
33
|
from phoenix.server.api.helpers.playground_clients import (
|
|
@@ -94,7 +94,7 @@ ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
|
|
|
94
94
|
|
|
95
95
|
@strawberry.type
|
|
96
96
|
class Subscription:
|
|
97
|
-
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
97
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
98
98
|
async def chat_completion(
|
|
99
99
|
self, info: Info[Context, None], input: ChatCompletionInput
|
|
100
100
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -193,7 +193,7 @@ class Subscription:
|
|
|
193
193
|
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
|
|
194
194
|
yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
|
|
195
195
|
|
|
196
|
-
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
|
|
196
|
+
@strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
|
|
197
197
|
async def chat_completion_over_dataset(
|
|
198
198
|
self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
|
|
199
199
|
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
|
|
@@ -3,7 +3,7 @@ from datetime import datetime
|
|
|
3
3
|
from typing import ClassVar, Optional, cast
|
|
4
4
|
|
|
5
5
|
import strawberry
|
|
6
|
-
from sqlalchemy import and_, func, or_, select
|
|
6
|
+
from sqlalchemy import Text, and_, func, or_, select
|
|
7
7
|
from sqlalchemy.sql.functions import count
|
|
8
8
|
from strawberry import UNSET
|
|
9
9
|
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
@@ -19,6 +19,7 @@ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
|
|
|
19
19
|
DatasetExperimentAnnotationSummary,
|
|
20
20
|
)
|
|
21
21
|
from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
|
|
22
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
|
|
22
23
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
23
24
|
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
24
25
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
@@ -87,6 +88,7 @@ class Dataset(Node):
|
|
|
87
88
|
self,
|
|
88
89
|
info: Info[Context, None],
|
|
89
90
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
91
|
+
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
90
92
|
) -> int:
|
|
91
93
|
dataset_id = self.id_attr
|
|
92
94
|
version_id = (
|
|
@@ -97,6 +99,20 @@ class Dataset(Node):
|
|
|
97
99
|
if dataset_version_id
|
|
98
100
|
else None
|
|
99
101
|
)
|
|
102
|
+
|
|
103
|
+
# Parse split IDs if provided
|
|
104
|
+
split_rowids: Optional[list[int]] = None
|
|
105
|
+
if split_ids:
|
|
106
|
+
split_rowids = []
|
|
107
|
+
for split_id in split_ids:
|
|
108
|
+
try:
|
|
109
|
+
split_rowid = from_global_id_with_expected_type(
|
|
110
|
+
global_id=split_id, expected_type_name=models.DatasetSplit.__name__
|
|
111
|
+
)
|
|
112
|
+
split_rowids.append(split_rowid)
|
|
113
|
+
except Exception:
|
|
114
|
+
raise BadRequest(f"Invalid split ID: {split_id}")
|
|
115
|
+
|
|
100
116
|
revision_ids = (
|
|
101
117
|
select(func.max(models.DatasetExampleRevision.id))
|
|
102
118
|
.join(models.DatasetExample)
|
|
@@ -113,11 +129,36 @@ class Dataset(Node):
|
|
|
113
129
|
revision_ids = revision_ids.where(
|
|
114
130
|
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
115
131
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
132
|
+
|
|
133
|
+
# Build the count query
|
|
134
|
+
if split_rowids:
|
|
135
|
+
# When filtering by splits, count distinct examples that belong to those splits
|
|
136
|
+
stmt = (
|
|
137
|
+
select(count(models.DatasetExample.id.distinct()))
|
|
138
|
+
.join(
|
|
139
|
+
models.DatasetExampleRevision,
|
|
140
|
+
onclause=(
|
|
141
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
.join(
|
|
145
|
+
models.DatasetSplitDatasetExample,
|
|
146
|
+
onclause=(
|
|
147
|
+
models.DatasetExample.id
|
|
148
|
+
== models.DatasetSplitDatasetExample.dataset_example_id
|
|
149
|
+
),
|
|
150
|
+
)
|
|
151
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
152
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
153
|
+
.where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
stmt = (
|
|
157
|
+
select(count(models.DatasetExampleRevision.id))
|
|
158
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
159
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
160
|
+
)
|
|
161
|
+
|
|
121
162
|
async with info.context.db() as session:
|
|
122
163
|
return (await session.scalar(stmt)) or 0
|
|
123
164
|
|
|
@@ -126,10 +167,12 @@ class Dataset(Node):
|
|
|
126
167
|
self,
|
|
127
168
|
info: Info[Context, None],
|
|
128
169
|
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
170
|
+
split_ids: Optional[list[GlobalID]] = UNSET,
|
|
129
171
|
first: Optional[int] = 50,
|
|
130
172
|
last: Optional[int] = UNSET,
|
|
131
173
|
after: Optional[CursorString] = UNSET,
|
|
132
174
|
before: Optional[CursorString] = UNSET,
|
|
175
|
+
filter: Optional[str] = UNSET,
|
|
133
176
|
) -> Connection[DatasetExample]:
|
|
134
177
|
args = ConnectionArgs(
|
|
135
178
|
first=first,
|
|
@@ -145,6 +188,20 @@ class Dataset(Node):
|
|
|
145
188
|
if dataset_version_id
|
|
146
189
|
else None
|
|
147
190
|
)
|
|
191
|
+
|
|
192
|
+
# Parse split IDs if provided
|
|
193
|
+
split_rowids: Optional[list[int]] = None
|
|
194
|
+
if split_ids:
|
|
195
|
+
split_rowids = []
|
|
196
|
+
for split_id in split_ids:
|
|
197
|
+
try:
|
|
198
|
+
split_rowid = from_global_id_with_expected_type(
|
|
199
|
+
global_id=split_id, expected_type_name=models.DatasetSplit.__name__
|
|
200
|
+
)
|
|
201
|
+
split_rowids.append(split_rowid)
|
|
202
|
+
except Exception:
|
|
203
|
+
raise BadRequest(f"Invalid split ID: {split_id}")
|
|
204
|
+
|
|
148
205
|
revision_ids = (
|
|
149
206
|
select(func.max(models.DatasetExampleRevision.id))
|
|
150
207
|
.join(models.DatasetExample)
|
|
@@ -176,6 +233,31 @@ class Dataset(Node):
|
|
|
176
233
|
)
|
|
177
234
|
.order_by(models.DatasetExampleRevision.dataset_example_id.desc())
|
|
178
235
|
)
|
|
236
|
+
|
|
237
|
+
# Filter by split IDs if provided
|
|
238
|
+
if split_rowids:
|
|
239
|
+
query = (
|
|
240
|
+
query.join(
|
|
241
|
+
models.DatasetSplitDatasetExample,
|
|
242
|
+
onclause=(
|
|
243
|
+
models.DatasetExample.id
|
|
244
|
+
== models.DatasetSplitDatasetExample.dataset_example_id
|
|
245
|
+
),
|
|
246
|
+
)
|
|
247
|
+
.where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
|
|
248
|
+
.distinct()
|
|
249
|
+
)
|
|
250
|
+
# Apply filter if provided - search through JSON fields (input, output, metadata)
|
|
251
|
+
if filter is not UNSET and filter:
|
|
252
|
+
# Create a filter that searches for the filter string in JSON fields
|
|
253
|
+
# Using PostgreSQL's JSON operators for case-insensitive text search
|
|
254
|
+
filter_condition = or_(
|
|
255
|
+
func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
|
|
256
|
+
func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
|
|
257
|
+
func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
|
|
258
|
+
)
|
|
259
|
+
query = query.where(filter_condition)
|
|
260
|
+
|
|
179
261
|
async with info.context.db() as session:
|
|
180
262
|
dataset_examples = [
|
|
181
263
|
DatasetExample(
|
|
@@ -187,6 +269,13 @@ class Dataset(Node):
|
|
|
187
269
|
]
|
|
188
270
|
return connection_from_list(data=dataset_examples, args=args)
|
|
189
271
|
|
|
272
|
+
@strawberry.field
|
|
273
|
+
async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
|
|
274
|
+
return [
|
|
275
|
+
to_gql_dataset_split(split)
|
|
276
|
+
for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id_attr)
|
|
277
|
+
]
|
|
278
|
+
|
|
190
279
|
@strawberry.field(
|
|
191
280
|
description="Number of experiments for a specific version if version is specified, "
|
|
192
281
|
"or for all versions if version is not specified."
|
|
@@ -11,19 +11,27 @@ from strawberry.types import Info
|
|
|
11
11
|
|
|
12
12
|
from phoenix.db import models
|
|
13
13
|
from phoenix.server.api.context import Context
|
|
14
|
+
from phoenix.server.api.exceptions import BadRequest
|
|
15
|
+
from phoenix.server.api.input_types.ExperimentRunSort import (
|
|
16
|
+
ExperimentRunSort,
|
|
17
|
+
add_order_by_and_page_start_to_query,
|
|
18
|
+
get_experiment_run_cursor,
|
|
19
|
+
)
|
|
14
20
|
from phoenix.server.api.types.CostBreakdown import CostBreakdown
|
|
15
21
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
16
22
|
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
17
23
|
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
18
24
|
from phoenix.server.api.types.pagination import (
|
|
19
|
-
|
|
25
|
+
Cursor,
|
|
20
26
|
CursorString,
|
|
21
|
-
|
|
27
|
+
connection_from_cursors_and_nodes,
|
|
22
28
|
)
|
|
23
29
|
from phoenix.server.api.types.Project import Project
|
|
24
30
|
from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
|
|
25
31
|
from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
|
|
26
32
|
|
|
33
|
+
_DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
|
|
34
|
+
|
|
27
35
|
|
|
28
36
|
@strawberry.type
|
|
29
37
|
class Experiment(Node):
|
|
@@ -57,33 +65,60 @@ class Experiment(Node):
|
|
|
57
65
|
async def runs(
|
|
58
66
|
self,
|
|
59
67
|
info: Info[Context, None],
|
|
60
|
-
first: Optional[int] =
|
|
61
|
-
last: Optional[int] = UNSET,
|
|
68
|
+
first: Optional[int] = _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE,
|
|
62
69
|
after: Optional[CursorString] = UNSET,
|
|
63
|
-
|
|
70
|
+
sort: Optional[ExperimentRunSort] = UNSET,
|
|
64
71
|
) -> Connection[ExperimentRun]:
|
|
65
|
-
|
|
66
|
-
first
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
72
|
+
if first is not None and first <= 0:
|
|
73
|
+
raise BadRequest("first must be a positive integer if set")
|
|
74
|
+
experiment_rowid = self.id_attr
|
|
75
|
+
page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
|
|
76
|
+
experiment_runs_query = (
|
|
77
|
+
select(models.ExperimentRun)
|
|
78
|
+
.where(models.ExperimentRun.experiment_id == experiment_rowid)
|
|
79
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
80
|
+
.limit(page_size + 1)
|
|
70
81
|
)
|
|
71
|
-
|
|
82
|
+
|
|
83
|
+
after_experiment_run_rowid = None
|
|
84
|
+
after_sort_column_value = None
|
|
85
|
+
if after:
|
|
86
|
+
cursor = Cursor.from_string(after)
|
|
87
|
+
after_experiment_run_rowid = cursor.rowid
|
|
88
|
+
if cursor.sort_column is not None:
|
|
89
|
+
after_sort_column_value = cursor.sort_column.value
|
|
90
|
+
|
|
91
|
+
experiment_runs_query = add_order_by_and_page_start_to_query(
|
|
92
|
+
query=experiment_runs_query,
|
|
93
|
+
sort=sort,
|
|
94
|
+
experiment_rowid=experiment_rowid,
|
|
95
|
+
after_experiment_run_rowid=after_experiment_run_rowid,
|
|
96
|
+
after_sort_column_value=after_sort_column_value,
|
|
97
|
+
)
|
|
98
|
+
|
|
72
99
|
async with info.context.db() as session:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
100
|
+
results = (await session.execute(experiment_runs_query)).all()
|
|
101
|
+
|
|
102
|
+
has_next_page = False
|
|
103
|
+
if len(results) > page_size:
|
|
104
|
+
results = results[:page_size]
|
|
105
|
+
has_next_page = True
|
|
106
|
+
|
|
107
|
+
cursors_and_nodes = []
|
|
108
|
+
for result in results:
|
|
109
|
+
run = result[0]
|
|
110
|
+
annotation_score = result[1] if len(result) > 1 else None
|
|
111
|
+
gql_run = to_gql_experiment_run(run)
|
|
112
|
+
cursor = get_experiment_run_cursor(
|
|
113
|
+
run=run, annotation_score=annotation_score, sort=sort
|
|
114
|
+
)
|
|
115
|
+
cursors_and_nodes.append((cursor, gql_run))
|
|
116
|
+
|
|
117
|
+
return connection_from_cursors_and_nodes(
|
|
118
|
+
cursors_and_nodes=cursors_and_nodes,
|
|
119
|
+
has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
|
|
120
|
+
has_next_page=has_next_page,
|
|
121
|
+
)
|
|
87
122
|
|
|
88
123
|
@strawberry.field
|
|
89
124
|
async def run_count(self, info: Info[Context, None]) -> int:
|