arize-phoenix 12.4.0__py3-none-any.whl → 12.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (65) hide show
  1. {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/METADATA +1 -1
  2. {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/RECORD +62 -60
  3. phoenix/auth.py +8 -2
  4. phoenix/db/models.py +3 -3
  5. phoenix/server/api/auth.py +9 -0
  6. phoenix/server/api/context.py +2 -0
  7. phoenix/server/api/dataloaders/__init__.py +2 -0
  8. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  9. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  10. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  11. phoenix/server/api/input_types/SpanSort.py +2 -1
  12. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  13. phoenix/server/api/mutations/annotation_config_mutations.py +6 -6
  14. phoenix/server/api/mutations/api_key_mutations.py +13 -5
  15. phoenix/server/api/mutations/chat_mutations.py +3 -3
  16. phoenix/server/api/mutations/dataset_label_mutations.py +6 -6
  17. phoenix/server/api/mutations/dataset_mutations.py +8 -8
  18. phoenix/server/api/mutations/dataset_split_mutations.py +7 -7
  19. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  20. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  21. phoenix/server/api/mutations/model_mutations.py +4 -4
  22. phoenix/server/api/mutations/project_mutations.py +4 -4
  23. phoenix/server/api/mutations/project_session_annotations_mutations.py +4 -4
  24. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  25. phoenix/server/api/mutations/prompt_label_mutations.py +7 -7
  26. phoenix/server/api/mutations/prompt_mutations.py +7 -7
  27. phoenix/server/api/mutations/prompt_version_tag_mutations.py +3 -3
  28. phoenix/server/api/mutations/span_annotations_mutations.py +5 -5
  29. phoenix/server/api/mutations/trace_annotations_mutations.py +4 -4
  30. phoenix/server/api/mutations/trace_mutations.py +3 -3
  31. phoenix/server/api/mutations/user_mutations.py +8 -5
  32. phoenix/server/api/routers/auth.py +2 -2
  33. phoenix/server/api/routers/v1/__init__.py +16 -1
  34. phoenix/server/api/routers/v1/annotation_configs.py +7 -1
  35. phoenix/server/api/routers/v1/datasets.py +48 -8
  36. phoenix/server/api/routers/v1/experiment_runs.py +7 -1
  37. phoenix/server/api/routers/v1/experiments.py +41 -5
  38. phoenix/server/api/routers/v1/projects.py +3 -31
  39. phoenix/server/api/routers/v1/users.py +0 -7
  40. phoenix/server/api/subscriptions.py +3 -3
  41. phoenix/server/api/types/Dataset.py +95 -6
  42. phoenix/server/api/types/Experiment.py +60 -25
  43. phoenix/server/api/types/Project.py +24 -68
  44. phoenix/server/app.py +2 -0
  45. phoenix/server/authorization.py +3 -1
  46. phoenix/server/bearer_auth.py +9 -0
  47. phoenix/server/jwt_store.py +8 -6
  48. phoenix/server/static/.vite/manifest.json +44 -44
  49. phoenix/server/static/assets/{components-BvsExS75.js → components-CboqzKQ9.js} +520 -397
  50. phoenix/server/static/assets/{index-iq8WDxat.js → index-CYYGI5-x.js} +2 -2
  51. phoenix/server/static/assets/{pages-Ckg4SLQ9.js → pages-DdlUeKi2.js} +616 -604
  52. phoenix/server/static/assets/vendor-CQ4tN9P7.js +918 -0
  53. phoenix/server/static/assets/vendor-arizeai-Cb1ncvYH.js +106 -0
  54. phoenix/server/static/assets/{vendor-codemirror-1bq_t1Ec.js → vendor-codemirror-CckmKopH.js} +3 -3
  55. phoenix/server/static/assets/{vendor-recharts-DQ4xfrf4.js → vendor-recharts-BC1ysIKu.js} +1 -1
  56. phoenix/server/static/assets/{vendor-shiki-GGmcIQxA.js → vendor-shiki-B45T-YxN.js} +1 -1
  57. phoenix/server/static/assets/vendor-three-BtCyLs1w.js +3840 -0
  58. phoenix/version.py +1 -1
  59. phoenix/server/static/assets/vendor-D2eEI-6h.js +0 -914
  60. phoenix/server/static/assets/vendor-arizeai-kfOei7nf.js +0 -156
  61. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  62. {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/WHEEL +0 -0
  63. {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/entry_points.txt +0 -0
  64. {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/licenses/IP_NOTICE +0 -0
  65. {arize_phoenix-12.4.0.dist-info → arize_phoenix-12.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -161,7 +161,7 @@ async def refresh_tokens(request: Request) -> Response:
161
161
  or (expiration_time := refresh_token_claims.expiration_time) is None
162
162
  ):
163
163
  raise HTTPException(status_code=401, detail="Invalid refresh token")
164
- if expiration_time.timestamp() < datetime.now().timestamp():
164
+ if expiration_time.timestamp() <= datetime.now(timezone.utc).timestamp():
165
165
  raise HTTPException(status_code=401, detail="Expired refresh token")
166
166
  await token_store.revoke(refresh_token_id)
167
167
 
@@ -253,7 +253,7 @@ async def reset_password(request: Request) -> Response:
253
253
  not (token := data.get("token"))
254
254
  or not isinstance((claims := await token_store.read(token)), PasswordResetTokenClaims)
255
255
  or not claims.expiration_time
256
- or claims.expiration_time < datetime.now(timezone.utc)
256
+ or claims.expiration_time <= datetime.now(timezone.utc)
257
257
  ):
258
258
  raise INVALID_TOKEN
259
259
  assert (user_id := claims.subject)
@@ -1,7 +1,7 @@
1
1
  from fastapi import APIRouter, Depends, HTTPException, Request
2
2
  from fastapi.security import APIKeyHeader
3
3
 
4
- from phoenix.server.bearer_auth import is_authenticated
4
+ from phoenix.server.bearer_auth import PhoenixUser, is_authenticated
5
5
 
6
6
  from .annotation_configs import router as annotation_configs_router
7
7
  from .annotations import router as annotations_router
@@ -33,6 +33,20 @@ async def prevent_access_in_read_only_mode(request: Request) -> None:
33
33
  )
34
34
 
35
35
 
36
+ async def restrict_access_by_viewers(request: Request) -> None:
37
+ """
38
+ Prevents access to the REST API for viewers, except for GET requests
39
+ and specific allowed POST routes.
40
+ """
41
+ if request.method == "GET":
42
+ return
43
+ if isinstance(request.user, PhoenixUser) and request.user.is_viewer:
44
+ raise HTTPException(
45
+ status_code=403,
46
+ detail="Viewers cannot perform this action.",
47
+ )
48
+
49
+
36
50
  def create_v1_router(authentication_enabled: bool) -> APIRouter:
37
51
  """
38
52
  Instantiates the v1 REST API router.
@@ -50,6 +64,7 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
50
64
  )
51
65
  )
52
66
  dependencies.append(Depends(is_authenticated))
67
+ dependencies.append(Depends(restrict_access_by_viewers))
53
68
 
54
69
  router = APIRouter(
55
70
  prefix="/v1",
@@ -349,7 +349,13 @@ async def delete_annotation_config(
349
349
  request: Request,
350
350
  config_id: str = Path(..., description="ID of the annotation configuration"),
351
351
  ) -> DeleteAnnotationConfigResponseBody:
352
- config_gid = GlobalID.from_id(config_id)
352
+ try:
353
+ config_gid = GlobalID.from_id(config_id)
354
+ except Exception:
355
+ raise HTTPException(
356
+ status_code=422,
357
+ detail=f"Invalid annotation configuration ID format: {config_id}",
358
+ )
353
359
  if config_gid.type_name not in (
354
360
  CategoricalAnnotationConfigType.__name__,
355
361
  ContinuousAnnotationConfigType.__name__,
@@ -209,7 +209,13 @@ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
209
209
  async def get_dataset(
210
210
  request: Request, id: str = Path(description="The ID of the dataset")
211
211
  ) -> GetDatasetResponseBody:
212
- dataset_id = GlobalID.from_id(id)
212
+ try:
213
+ dataset_id = GlobalID.from_id(id)
214
+ except Exception as e:
215
+ raise HTTPException(
216
+ detail=f"Invalid dataset ID format: {id}",
217
+ status_code=422,
218
+ ) from e
213
219
 
214
220
  if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
215
221
  raise HTTPException(detail=f"ID {dataset_id} refers to a f{type_name}", status_code=404)
@@ -400,7 +406,12 @@ async def upload_dataset(
400
406
  description="If true, fulfill request synchronously and return JSON containing dataset_id.",
401
407
  ),
402
408
  ) -> Optional[UploadDatasetResponseBody]:
403
- request_content_type = request.headers["content-type"]
409
+ request_content_type = request.headers.get("content-type")
410
+ if not request_content_type:
411
+ raise HTTPException(
412
+ detail="Missing content-type header",
413
+ status_code=400,
414
+ )
404
415
  examples: Union[Examples, Awaitable[Examples]]
405
416
  if request_content_type.startswith("application/json"):
406
417
  try:
@@ -709,8 +720,24 @@ async def get_dataset_examples(
709
720
  ),
710
721
  ),
711
722
  ) -> ListDatasetExamplesResponseBody:
712
- dataset_gid = GlobalID.from_id(id)
713
- version_gid = GlobalID.from_id(version_id) if version_id else None
723
+ try:
724
+ dataset_gid = GlobalID.from_id(id)
725
+ except Exception as e:
726
+ raise HTTPException(
727
+ detail=f"Invalid dataset ID format: {id}",
728
+ status_code=422,
729
+ ) from e
730
+
731
+ if version_id:
732
+ try:
733
+ version_gid = GlobalID.from_id(version_id)
734
+ except Exception as e:
735
+ raise HTTPException(
736
+ detail=f"Invalid dataset version ID format: {version_id}",
737
+ status_code=422,
738
+ ) from e
739
+ else:
740
+ version_gid = None
714
741
 
715
742
  if (dataset_type := dataset_gid.type_name) != "Dataset":
716
743
  raise HTTPException(detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=404)
@@ -992,12 +1019,25 @@ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision
992
1019
  async def _get_db_examples(
993
1020
  *, session: Any, id: str, version_id: Optional[str]
994
1021
  ) -> tuple[str, list[models.DatasetExampleRevision]]:
995
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
1022
+ try:
1023
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
1024
+ except Exception as e:
1025
+ raise HTTPException(
1026
+ detail=f"Invalid dataset ID format: {id}",
1027
+ status_code=422,
1028
+ ) from e
1029
+
996
1030
  dataset_version_id: Optional[int] = None
997
1031
  if version_id:
998
- dataset_version_id = from_global_id_with_expected_type(
999
- GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
1000
- )
1032
+ try:
1033
+ dataset_version_id = from_global_id_with_expected_type(
1034
+ GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
1035
+ )
1036
+ except Exception as e:
1037
+ raise HTTPException(
1038
+ detail=f"Invalid dataset version ID format: {version_id}",
1039
+ status_code=422,
1040
+ ) from e
1001
1041
  latest_version = (
1002
1042
  select(
1003
1043
  models.DatasetExampleRevision.dataset_example_id,
@@ -159,7 +159,13 @@ async def list_experiment_runs(
159
159
  gt=0,
160
160
  ),
161
161
  ) -> ListExperimentRunsResponseBody:
162
- experiment_gid = GlobalID.from_id(experiment_id)
162
+ try:
163
+ experiment_gid = GlobalID.from_id(experiment_id)
164
+ except Exception as e:
165
+ raise HTTPException(
166
+ detail=f"Invalid experiment ID format: {experiment_id}",
167
+ status_code=422,
168
+ ) from e
163
169
  try:
164
170
  experiment_rowid = from_global_id_with_expected_type(experiment_gid, "Experiment")
165
171
  except ValueError:
@@ -104,7 +104,13 @@ async def create_experiment(
104
104
  request_body: CreateExperimentRequestBody,
105
105
  dataset_id: str = Path(..., title="Dataset ID"),
106
106
  ) -> CreateExperimentResponseBody:
107
- dataset_globalid = GlobalID.from_id(dataset_id)
107
+ try:
108
+ dataset_globalid = GlobalID.from_id(dataset_id)
109
+ except Exception as e:
110
+ raise HTTPException(
111
+ detail=f"Invalid dataset ID format: {dataset_id}",
112
+ status_code=422,
113
+ ) from e
108
114
  try:
109
115
  dataset_rowid = from_global_id_with_expected_type(dataset_globalid, "Dataset")
110
116
  except ValueError:
@@ -117,6 +123,12 @@ async def create_experiment(
117
123
  if dataset_version_globalid_str is not None:
118
124
  try:
119
125
  dataset_version_globalid = GlobalID.from_id(dataset_version_globalid_str)
126
+ except Exception as e:
127
+ raise HTTPException(
128
+ detail=f"Invalid dataset version ID format: {dataset_version_globalid_str}",
129
+ status_code=422,
130
+ ) from e
131
+ try:
120
132
  dataset_version_id = from_global_id_with_expected_type(
121
133
  dataset_version_globalid, "DatasetVersion"
122
134
  )
@@ -232,7 +244,13 @@ class GetExperimentResponseBody(ResponseBody[Experiment]):
232
244
  response_description="Experiment retrieved successfully",
233
245
  )
234
246
  async def get_experiment(request: Request, experiment_id: str) -> GetExperimentResponseBody:
235
- experiment_globalid = GlobalID.from_id(experiment_id)
247
+ try:
248
+ experiment_globalid = GlobalID.from_id(experiment_id)
249
+ except Exception as e:
250
+ raise HTTPException(
251
+ detail=f"Invalid experiment ID format: {experiment_id}",
252
+ status_code=422,
253
+ ) from e
236
254
  try:
237
255
  experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
238
256
  except ValueError:
@@ -282,7 +300,13 @@ async def list_experiments(
282
300
  request: Request,
283
301
  dataset_id: str = Path(..., title="Dataset ID"),
284
302
  ) -> ListExperimentsResponseBody:
285
- dataset_gid = GlobalID.from_id(dataset_id)
303
+ try:
304
+ dataset_gid = GlobalID.from_id(dataset_id)
305
+ except Exception as e:
306
+ raise HTTPException(
307
+ detail=f"Invalid dataset ID format: {dataset_id}",
308
+ status_code=422,
309
+ ) from e
286
310
  try:
287
311
  dataset_rowid = from_global_id_with_expected_type(dataset_gid, "Dataset")
288
312
  except ValueError:
@@ -397,7 +421,13 @@ async def get_experiment_json(
397
421
  request: Request,
398
422
  experiment_id: str = Path(..., title="Experiment ID"),
399
423
  ) -> Response:
400
- experiment_globalid = GlobalID.from_id(experiment_id)
424
+ try:
425
+ experiment_globalid = GlobalID.from_id(experiment_id)
426
+ except Exception as e:
427
+ raise HTTPException(
428
+ detail=f"Invalid experiment ID format: {experiment_id}",
429
+ status_code=422,
430
+ ) from e
401
431
  try:
402
432
  experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
403
433
  except ValueError:
@@ -464,7 +494,13 @@ async def get_experiment_csv(
464
494
  request: Request,
465
495
  experiment_id: str = Path(..., title="Experiment ID"),
466
496
  ) -> Response:
467
- experiment_globalid = GlobalID.from_id(experiment_id)
497
+ try:
498
+ experiment_globalid = GlobalID.from_id(experiment_id)
499
+ except Exception as e:
500
+ raise HTTPException(
501
+ detail=f"Invalid experiment ID format: {experiment_id}",
502
+ status_code=422,
503
+ ) from e
468
504
  try:
469
505
  experiment_rowid = from_global_id_with_expected_type(experiment_globalid, "Experiment")
470
506
  except ValueError:
@@ -9,7 +9,6 @@ from strawberry.relay import GlobalID
9
9
  from phoenix.config import DEFAULT_PROJECT_NAME
10
10
  from phoenix.db import models
11
11
  from phoenix.db.helpers import exclude_experiment_projects
12
- from phoenix.db.models import UserRoleName
13
12
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
14
13
  from phoenix.server.api.routers.v1.utils import (
15
14
  PaginatedResponseBody,
@@ -18,7 +17,7 @@ from phoenix.server.api.routers.v1.utils import (
18
17
  add_errors_to_responses,
19
18
  )
20
19
  from phoenix.server.api.types.Project import Project as ProjectNodeType
21
- from phoenix.server.authorization import is_not_locked
20
+ from phoenix.server.authorization import is_not_locked, require_admin
22
21
 
23
22
  router = APIRouter(tags=["projects"])
24
23
 
@@ -210,7 +209,7 @@ async def create_project(
210
209
 
211
210
  @router.put(
212
211
  "/projects/{project_identifier}",
213
- dependencies=[Depends(is_not_locked)],
212
+ dependencies=[Depends(require_admin), Depends(is_not_locked)],
214
213
  operation_id="updateProject",
215
214
  summary="Update a project by ID or name", # noqa: E501
216
215
  description="Update an existing project with new configuration. Project names cannot be changed. The project identifier is either project ID or project name. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
@@ -245,20 +244,6 @@ async def update_project(
245
244
  Raises:
246
245
  HTTPException: If the project identifier format is invalid or the project is not found.
247
246
  """ # noqa: E501
248
- if request.app.state.authentication_enabled:
249
- async with request.app.state.db() as session:
250
- # Check if the user is an admin
251
- stmt = (
252
- select(models.UserRole.name)
253
- .join(models.User)
254
- .where(models.User.id == int(request.user.identity))
255
- )
256
- role_name: UserRoleName = await session.scalar(stmt)
257
- if role_name != "ADMIN" and role_name != "SYSTEM":
258
- raise HTTPException(
259
- status_code=403,
260
- detail="Only admins can update projects",
261
- )
262
247
  async with request.app.state.db() as session:
263
248
  project = await _get_project_by_identifier(session, project_identifier)
264
249
 
@@ -272,6 +257,7 @@ async def update_project(
272
257
 
273
258
  @router.delete(
274
259
  "/projects/{project_identifier}",
260
+ dependencies=[Depends(require_admin)],
275
261
  operation_id="deleteProject",
276
262
  summary="Delete a project by ID or name", # noqa: E501
277
263
  description="Delete an existing project and all its associated data. The project identifier is either project ID or project name. The default project cannot be deleted. Note: When using a project name as the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) characters.", # noqa: E501
@@ -305,20 +291,6 @@ async def delete_project(
305
291
  Raises:
306
292
  HTTPException: If the project identifier format is invalid, the project is not found, or it's the default project.
307
293
  """ # noqa: E501
308
- if request.app.state.authentication_enabled:
309
- async with request.app.state.db() as session:
310
- # Check if the user is an admin
311
- stmt = (
312
- select(models.UserRole.name)
313
- .join(models.User)
314
- .where(models.User.id == int(request.user.identity))
315
- )
316
- role_name: UserRoleName = await session.scalar(stmt)
317
- if role_name != "ADMIN" and role_name != "SYSTEM":
318
- raise HTTPException(
319
- status_code=403,
320
- detail="Only admins can delete projects",
321
- )
322
294
  async with request.app.state.db() as session:
323
295
  project = await _get_project_by_identifier(session, project_identifier)
324
296
 
@@ -208,13 +208,6 @@ async def create_user(
208
208
  detail="Cannot create users with SYSTEM role",
209
209
  )
210
210
 
211
- # TODO: Implement VIEWER role
212
- if role == "VIEWER":
213
- raise HTTPException(
214
- status_code=400,
215
- detail="VIEWER role not yet implemented",
216
- )
217
-
218
211
  user: models.User
219
212
  if isinstance(user_data, LocalUserData):
220
213
  password = (user_data.password or secrets.token_hex()).strip()
@@ -27,7 +27,7 @@ from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
29
  from phoenix.db.helpers import insert_experiment_with_examples_snapshot
30
- from phoenix.server.api.auth import IsLocked, IsNotReadOnly
30
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly, IsNotViewer
31
31
  from phoenix.server.api.context import Context
32
32
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
33
33
  from phoenix.server.api.helpers.playground_clients import (
@@ -94,7 +94,7 @@ ChatStream: TypeAlias = AsyncGenerator[ChatCompletionSubscriptionPayload, None]
94
94
 
95
95
  @strawberry.type
96
96
  class Subscription:
97
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
97
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
98
98
  async def chat_completion(
99
99
  self, info: Info[Context, None], input: ChatCompletionInput
100
100
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -193,7 +193,7 @@ class Subscription:
193
193
  info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
194
194
  yield ChatCompletionSubscriptionResult(span=Span(span_rowid=db_span.id, db_span=db_span))
195
195
 
196
- @strawberry.subscription(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
196
+ @strawberry.subscription(permission_classes=[IsNotReadOnly, IsNotViewer, IsLocked]) # type: ignore
197
197
  async def chat_completion_over_dataset(
198
198
  self, info: Info[Context, None], input: ChatCompletionOverDatasetInput
199
199
  ) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
@@ -3,7 +3,7 @@ from datetime import datetime
3
3
  from typing import ClassVar, Optional, cast
4
4
 
5
5
  import strawberry
6
- from sqlalchemy import and_, func, or_, select
6
+ from sqlalchemy import Text, and_, func, or_, select
7
7
  from sqlalchemy.sql.functions import count
8
8
  from strawberry import UNSET
9
9
  from strawberry.relay import Connection, GlobalID, Node, NodeID
@@ -19,6 +19,7 @@ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
19
  DatasetExperimentAnnotationSummary,
20
20
  )
21
21
  from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
22
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
22
23
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
23
24
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
24
25
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -87,6 +88,7 @@ class Dataset(Node):
87
88
  self,
88
89
  info: Info[Context, None],
89
90
  dataset_version_id: Optional[GlobalID] = UNSET,
91
+ split_ids: Optional[list[GlobalID]] = UNSET,
90
92
  ) -> int:
91
93
  dataset_id = self.id_attr
92
94
  version_id = (
@@ -97,6 +99,20 @@ class Dataset(Node):
97
99
  if dataset_version_id
98
100
  else None
99
101
  )
102
+
103
+ # Parse split IDs if provided
104
+ split_rowids: Optional[list[int]] = None
105
+ if split_ids:
106
+ split_rowids = []
107
+ for split_id in split_ids:
108
+ try:
109
+ split_rowid = from_global_id_with_expected_type(
110
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
111
+ )
112
+ split_rowids.append(split_rowid)
113
+ except Exception:
114
+ raise BadRequest(f"Invalid split ID: {split_id}")
115
+
100
116
  revision_ids = (
101
117
  select(func.max(models.DatasetExampleRevision.id))
102
118
  .join(models.DatasetExample)
@@ -113,11 +129,36 @@ class Dataset(Node):
113
129
  revision_ids = revision_ids.where(
114
130
  models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
115
131
  )
116
- stmt = (
117
- select(count(models.DatasetExampleRevision.id))
118
- .where(models.DatasetExampleRevision.id.in_(revision_ids))
119
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
120
- )
132
+
133
+ # Build the count query
134
+ if split_rowids:
135
+ # When filtering by splits, count distinct examples that belong to those splits
136
+ stmt = (
137
+ select(count(models.DatasetExample.id.distinct()))
138
+ .join(
139
+ models.DatasetExampleRevision,
140
+ onclause=(
141
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
142
+ ),
143
+ )
144
+ .join(
145
+ models.DatasetSplitDatasetExample,
146
+ onclause=(
147
+ models.DatasetExample.id
148
+ == models.DatasetSplitDatasetExample.dataset_example_id
149
+ ),
150
+ )
151
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
152
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
153
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
154
+ )
155
+ else:
156
+ stmt = (
157
+ select(count(models.DatasetExampleRevision.id))
158
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
159
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
160
+ )
161
+
121
162
  async with info.context.db() as session:
122
163
  return (await session.scalar(stmt)) or 0
123
164
 
@@ -126,10 +167,12 @@ class Dataset(Node):
126
167
  self,
127
168
  info: Info[Context, None],
128
169
  dataset_version_id: Optional[GlobalID] = UNSET,
170
+ split_ids: Optional[list[GlobalID]] = UNSET,
129
171
  first: Optional[int] = 50,
130
172
  last: Optional[int] = UNSET,
131
173
  after: Optional[CursorString] = UNSET,
132
174
  before: Optional[CursorString] = UNSET,
175
+ filter: Optional[str] = UNSET,
133
176
  ) -> Connection[DatasetExample]:
134
177
  args = ConnectionArgs(
135
178
  first=first,
@@ -145,6 +188,20 @@ class Dataset(Node):
145
188
  if dataset_version_id
146
189
  else None
147
190
  )
191
+
192
+ # Parse split IDs if provided
193
+ split_rowids: Optional[list[int]] = None
194
+ if split_ids:
195
+ split_rowids = []
196
+ for split_id in split_ids:
197
+ try:
198
+ split_rowid = from_global_id_with_expected_type(
199
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
200
+ )
201
+ split_rowids.append(split_rowid)
202
+ except Exception:
203
+ raise BadRequest(f"Invalid split ID: {split_id}")
204
+
148
205
  revision_ids = (
149
206
  select(func.max(models.DatasetExampleRevision.id))
150
207
  .join(models.DatasetExample)
@@ -176,6 +233,31 @@ class Dataset(Node):
176
233
  )
177
234
  .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
178
235
  )
236
+
237
+ # Filter by split IDs if provided
238
+ if split_rowids:
239
+ query = (
240
+ query.join(
241
+ models.DatasetSplitDatasetExample,
242
+ onclause=(
243
+ models.DatasetExample.id
244
+ == models.DatasetSplitDatasetExample.dataset_example_id
245
+ ),
246
+ )
247
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
248
+ .distinct()
249
+ )
250
+ # Apply filter if provided - search through JSON fields (input, output, metadata)
251
+ if filter is not UNSET and filter:
252
+ # Create a filter that searches for the filter string in JSON fields
253
+ # Using PostgreSQL's JSON operators for case-insensitive text search
254
+ filter_condition = or_(
255
+ func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
256
+ func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
257
+ func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
258
+ )
259
+ query = query.where(filter_condition)
260
+
179
261
  async with info.context.db() as session:
180
262
  dataset_examples = [
181
263
  DatasetExample(
@@ -187,6 +269,13 @@ class Dataset(Node):
187
269
  ]
188
270
  return connection_from_list(data=dataset_examples, args=args)
189
271
 
272
+ @strawberry.field
273
+ async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
274
+ return [
275
+ to_gql_dataset_split(split)
276
+ for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id_attr)
277
+ ]
278
+
190
279
  @strawberry.field(
191
280
  description="Number of experiments for a specific version if version is specified, "
192
281
  "or for all versions if version is not specified."
@@ -11,19 +11,27 @@ from strawberry.types import Info
11
11
 
12
12
  from phoenix.db import models
13
13
  from phoenix.server.api.context import Context
14
+ from phoenix.server.api.exceptions import BadRequest
15
+ from phoenix.server.api.input_types.ExperimentRunSort import (
16
+ ExperimentRunSort,
17
+ add_order_by_and_page_start_to_query,
18
+ get_experiment_run_cursor,
19
+ )
14
20
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
15
21
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
16
22
  from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
17
23
  from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
18
24
  from phoenix.server.api.types.pagination import (
19
- ConnectionArgs,
25
+ Cursor,
20
26
  CursorString,
21
- connection_from_list,
27
+ connection_from_cursors_and_nodes,
22
28
  )
23
29
  from phoenix.server.api.types.Project import Project
24
30
  from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
25
31
  from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
26
32
 
33
+ _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
34
+
27
35
 
28
36
  @strawberry.type
29
37
  class Experiment(Node):
@@ -57,33 +65,60 @@ class Experiment(Node):
57
65
  async def runs(
58
66
  self,
59
67
  info: Info[Context, None],
60
- first: Optional[int] = 50,
61
- last: Optional[int] = UNSET,
68
+ first: Optional[int] = _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE,
62
69
  after: Optional[CursorString] = UNSET,
63
- before: Optional[CursorString] = UNSET,
70
+ sort: Optional[ExperimentRunSort] = UNSET,
64
71
  ) -> Connection[ExperimentRun]:
65
- args = ConnectionArgs(
66
- first=first,
67
- after=after if isinstance(after, CursorString) else None,
68
- last=last,
69
- before=before if isinstance(before, CursorString) else None,
72
+ if first is not None and first <= 0:
73
+ raise BadRequest("first must be a positive integer if set")
74
+ experiment_rowid = self.id_attr
75
+ page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
76
+ experiment_runs_query = (
77
+ select(models.ExperimentRun)
78
+ .where(models.ExperimentRun.experiment_id == experiment_rowid)
79
+ .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
80
+ .limit(page_size + 1)
70
81
  )
71
- experiment_id = self.id_attr
82
+
83
+ after_experiment_run_rowid = None
84
+ after_sort_column_value = None
85
+ if after:
86
+ cursor = Cursor.from_string(after)
87
+ after_experiment_run_rowid = cursor.rowid
88
+ if cursor.sort_column is not None:
89
+ after_sort_column_value = cursor.sort_column.value
90
+
91
+ experiment_runs_query = add_order_by_and_page_start_to_query(
92
+ query=experiment_runs_query,
93
+ sort=sort,
94
+ experiment_rowid=experiment_rowid,
95
+ after_experiment_run_rowid=after_experiment_run_rowid,
96
+ after_sort_column_value=after_sort_column_value,
97
+ )
98
+
72
99
  async with info.context.db() as session:
73
- runs = (
74
- await session.scalars(
75
- select(models.ExperimentRun)
76
- .where(models.ExperimentRun.experiment_id == experiment_id)
77
- .order_by(
78
- models.ExperimentRun.dataset_example_id.asc(),
79
- models.ExperimentRun.repetition_number.asc(),
80
- )
81
- .options(
82
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
83
- )
84
- )
85
- ).all()
86
- return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
100
+ results = (await session.execute(experiment_runs_query)).all()
101
+
102
+ has_next_page = False
103
+ if len(results) > page_size:
104
+ results = results[:page_size]
105
+ has_next_page = True
106
+
107
+ cursors_and_nodes = []
108
+ for result in results:
109
+ run = result[0]
110
+ annotation_score = result[1] if len(result) > 1 else None
111
+ gql_run = to_gql_experiment_run(run)
112
+ cursor = get_experiment_run_cursor(
113
+ run=run, annotation_score=annotation_score, sort=sort
114
+ )
115
+ cursors_and_nodes.append((cursor, gql_run))
116
+
117
+ return connection_from_cursors_and_nodes(
118
+ cursors_and_nodes=cursors_and_nodes,
119
+ has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
120
+ has_next_page=has_next_page,
121
+ )
87
122
 
88
123
  @strawberry.field
89
124
  async def run_count(self, info: Info[Context, None]) -> int: