arize-phoenix 11.38.0__py3-none-any.whl → 12.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (72) hide show
  1. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +71 -50
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/helpers.py +2 -2
  9. phoenix/db/insertion/session_annotation.py +176 -0
  10. phoenix/db/insertion/types.py +30 -0
  11. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  12. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  13. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  14. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  15. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  16. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  17. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  18. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  19. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  20. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  21. phoenix/db/models.py +285 -46
  22. phoenix/server/api/context.py +13 -2
  23. phoenix/server/api/dataloaders/__init__.py +6 -2
  24. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  25. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  26. phoenix/server/api/dataloaders/table_fields.py +2 -2
  27. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  28. phoenix/server/api/helpers/playground_clients.py +65 -35
  29. phoenix/server/api/helpers/playground_users.py +26 -0
  30. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  31. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  32. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  33. phoenix/server/api/mutations/__init__.py +6 -0
  34. phoenix/server/api/mutations/chat_mutations.py +8 -3
  35. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  36. phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
  37. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  38. phoenix/server/api/queries.py +32 -0
  39. phoenix/server/api/routers/v1/__init__.py +2 -0
  40. phoenix/server/api/routers/v1/annotations.py +320 -0
  41. phoenix/server/api/routers/v1/datasets.py +5 -0
  42. phoenix/server/api/routers/v1/experiments.py +10 -3
  43. phoenix/server/api/routers/v1/sessions.py +111 -0
  44. phoenix/server/api/routers/v1/traces.py +1 -2
  45. phoenix/server/api/routers/v1/users.py +7 -0
  46. phoenix/server/api/subscriptions.py +5 -2
  47. phoenix/server/api/types/DatasetExample.py +11 -0
  48. phoenix/server/api/types/DatasetSplit.py +32 -0
  49. phoenix/server/api/types/Experiment.py +0 -4
  50. phoenix/server/api/types/Project.py +16 -0
  51. phoenix/server/api/types/ProjectSession.py +88 -3
  52. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  53. phoenix/server/api/types/Span.py +5 -5
  54. phoenix/server/api/types/Trace.py +61 -0
  55. phoenix/server/app.py +6 -2
  56. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  57. phoenix/server/dml_event.py +13 -0
  58. phoenix/server/static/.vite/manifest.json +39 -39
  59. phoenix/server/static/assets/{components-BQPHTBfv.js → components-Dl9SUw1U.js} +371 -327
  60. phoenix/server/static/assets/{index-BL5BMgJU.js → index-CqQS0dTo.js} +2 -2
  61. phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DKSjVA_E.js} +762 -514
  62. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
  63. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
  64. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
  65. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
  66. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
  67. phoenix/version.py +1 -1
  68. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  69. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
  70. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
  71. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  72. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,9 @@ from strawberry.relay import GlobalID
14
14
  from phoenix.db import models
15
15
  from phoenix.db.insertion.types import Precursors
16
16
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
17
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
18
+ ProjectSessionAnnotation as SessionAnnotationNodeType,
19
+ )
17
20
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation as SpanAnnotationNodeType
18
21
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation as TraceAnnotationNodeType
19
22
  from phoenix.server.api.types.User import User as UserNodeType
@@ -24,6 +27,7 @@ logger = logging.getLogger(__name__)
24
27
 
25
28
  SPAN_ANNOTATION_NODE_NAME = SpanAnnotationNodeType.__name__
26
29
  TRACE_ANNOTATION_NODE_NAME = TraceAnnotationNodeType.__name__
30
+ SESSION_ANNOTATION_NODE_NAME = SessionAnnotationNodeType.__name__
27
31
  MAX_TRACE_IDS = 1_000
28
32
  USER_NODE_NAME = UserNodeType.__name__
29
33
  MAX_SPAN_IDS = 1_000
@@ -161,6 +165,35 @@ class TraceAnnotationsResponseBody(PaginatedResponseBody[TraceAnnotation]):
161
165
  pass
162
166
 
163
167
 
168
+ class SessionAnnotationData(AnnotationData):
169
+ session_id: str = Field(description="Session ID")
170
+
171
+ def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SessionAnnotation:
172
+ return Precursors.SessionAnnotation(
173
+ datetime.now(timezone.utc),
174
+ self.session_id,
175
+ models.ProjectSessionAnnotation(
176
+ name=self.name,
177
+ annotator_kind=self.annotator_kind,
178
+ score=self.result.score if self.result else None,
179
+ label=self.result.label if self.result else None,
180
+ explanation=self.result.explanation if self.result else None,
181
+ metadata_=self.metadata or {},
182
+ identifier=self.identifier,
183
+ source="API",
184
+ user_id=user_id,
185
+ ),
186
+ )
187
+
188
+
189
+ class SessionAnnotation(SessionAnnotationData, Annotation):
190
+ pass
191
+
192
+
193
+ class SessionAnnotationsResponseBody(PaginatedResponseBody[SessionAnnotation]):
194
+ pass
195
+
196
+
164
197
  @router.get(
165
198
  "/projects/{project_identifier}/span_annotations",
166
199
  operation_id="listSpanAnnotationsBySpanIds",
@@ -304,3 +337,290 @@ async def list_span_annotations(
304
337
  ]
305
338
 
306
339
  return SpanAnnotationsResponseBody(data=data, next_cursor=next_cursor)
340
+
341
+
342
+ @router.get(
343
+ "/projects/{project_identifier}/trace_annotations",
344
+ operation_id="listTraceAnnotationsByTraceIds",
345
+ summary="Get trace annotations for a list of trace_ids.",
346
+ status_code=HTTP_200_OK,
347
+ responses=add_errors_to_responses(
348
+ [
349
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Project or traces not found"},
350
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
351
+ ]
352
+ ),
353
+ )
354
+ async def list_trace_annotations(
355
+ request: Request,
356
+ project_identifier: str = Path(
357
+ description=(
358
+ "The project identifier: either project ID or project name. If using a project name as "
359
+ "the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
360
+ "characters."
361
+ )
362
+ ),
363
+ trace_ids: list[str] = Query(
364
+ ..., min_length=1, description="One or more trace id to fetch annotations for"
365
+ ),
366
+ include_annotation_names: Optional[list[str]] = Query(
367
+ default=None,
368
+ description=(
369
+ "Optional list of annotation names to include. If provided, only annotations with "
370
+ "these names will be returned. 'note' annotations are excluded by default unless "
371
+ "explicitly included in this list."
372
+ ),
373
+ ),
374
+ exclude_annotation_names: Optional[list[str]] = Query(
375
+ default=None, description="Optional list of annotation names to exclude from results."
376
+ ),
377
+ cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
378
+ limit: int = Query(
379
+ default=10,
380
+ gt=0,
381
+ le=10000,
382
+ description="The maximum number of annotations to return in a single request",
383
+ ),
384
+ ) -> TraceAnnotationsResponseBody:
385
+ trace_ids = list({*trace_ids})
386
+ if len(trace_ids) > MAX_TRACE_IDS:
387
+ raise HTTPException(
388
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
389
+ detail=f"Too many trace_ids supplied: {len(trace_ids)} (max {MAX_TRACE_IDS})",
390
+ )
391
+
392
+ async with request.app.state.db() as session:
393
+ project = await _get_project_by_identifier(session, project_identifier)
394
+ if not project:
395
+ raise HTTPException(
396
+ status_code=HTTP_404_NOT_FOUND,
397
+ detail=f"Project with identifier {project_identifier} not found",
398
+ )
399
+
400
+ # Build the base query
401
+ where_conditions = [
402
+ models.Project.id == project.id,
403
+ models.Trace.trace_id.in_(trace_ids),
404
+ ]
405
+
406
+ # Add annotation name filtering
407
+ if include_annotation_names:
408
+ where_conditions.append(models.TraceAnnotation.name.in_(include_annotation_names))
409
+
410
+ if exclude_annotation_names:
411
+ where_conditions.append(models.TraceAnnotation.name.not_in(exclude_annotation_names))
412
+
413
+ stmt = (
414
+ select(models.Trace.trace_id, models.TraceAnnotation)
415
+ .join(models.Project, models.Trace.project_rowid == models.Project.id)
416
+ .join(models.TraceAnnotation, models.TraceAnnotation.trace_rowid == models.Trace.id)
417
+ .where(*where_conditions)
418
+ .order_by(models.TraceAnnotation.id.desc())
419
+ .limit(limit + 1)
420
+ )
421
+
422
+ if cursor:
423
+ try:
424
+ cursor_id = int(GlobalID.from_id(cursor).node_id)
425
+ except ValueError:
426
+ raise HTTPException(
427
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
428
+ detail="Invalid cursor value",
429
+ )
430
+ stmt = stmt.where(models.TraceAnnotation.id <= cursor_id)
431
+
432
+ rows: list[tuple[str, models.TraceAnnotation]] = [
433
+ r async for r in (await session.stream(stmt))
434
+ ]
435
+
436
+ next_cursor: Optional[str] = None
437
+ if len(rows) == limit + 1:
438
+ *rows, extra = rows
439
+ next_cursor = str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(extra[1].id)))
440
+
441
+ if not rows:
442
+ traces_exist = await session.scalar(
443
+ select(
444
+ exists().where(
445
+ models.Trace.trace_id.in_(trace_ids),
446
+ models.Trace.project_rowid == project.id,
447
+ )
448
+ )
449
+ )
450
+ if not traces_exist:
451
+ raise HTTPException(
452
+ detail="None of the supplied trace_ids exist in this project",
453
+ status_code=HTTP_404_NOT_FOUND,
454
+ )
455
+
456
+ return TraceAnnotationsResponseBody(data=[], next_cursor=None)
457
+
458
+ data = [
459
+ TraceAnnotation(
460
+ id=str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(anno.id))),
461
+ trace_id=trace_id,
462
+ name=anno.name,
463
+ result=AnnotationResult(
464
+ label=anno.label,
465
+ score=anno.score,
466
+ explanation=anno.explanation,
467
+ ),
468
+ metadata=anno.metadata_,
469
+ annotator_kind=anno.annotator_kind,
470
+ created_at=anno.created_at,
471
+ updated_at=anno.updated_at,
472
+ identifier=anno.identifier,
473
+ source=anno.source,
474
+ user_id=str(GlobalID("User", str(anno.user_id))) if anno.user_id else None,
475
+ )
476
+ for trace_id, anno in rows
477
+ ]
478
+
479
+ return TraceAnnotationsResponseBody(data=data, next_cursor=next_cursor)
480
+
481
+
482
+ @router.get(
483
+ "/projects/{project_identifier}/session_annotations",
484
+ operation_id="listSessionAnnotationsBySessionIds",
485
+ summary="Get session annotations for a list of session_ids.",
486
+ status_code=HTTP_200_OK,
487
+ responses=add_errors_to_responses(
488
+ [
489
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Project or sessions not found"},
490
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
491
+ ]
492
+ ),
493
+ )
494
+ async def list_session_annotations(
495
+ request: Request,
496
+ project_identifier: str = Path(
497
+ description=(
498
+ "The project identifier: either project ID or project name. If using a project name as "
499
+ "the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
500
+ "characters."
501
+ )
502
+ ),
503
+ session_ids: list[str] = Query(
504
+ ..., min_length=1, description="One or more session id to fetch annotations for"
505
+ ),
506
+ include_annotation_names: Optional[list[str]] = Query(
507
+ default=None,
508
+ description=(
509
+ "Optional list of annotation names to include. If provided, only annotations with "
510
+ "these names will be returned. 'note' annotations are excluded by default unless "
511
+ "explicitly included in this list."
512
+ ),
513
+ ),
514
+ exclude_annotation_names: Optional[list[str]] = Query(
515
+ default=None, description="Optional list of annotation names to exclude from results."
516
+ ),
517
+ cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
518
+ limit: int = Query(
519
+ default=10,
520
+ gt=0,
521
+ le=10000,
522
+ description="The maximum number of annotations to return in a single request",
523
+ ),
524
+ ) -> SessionAnnotationsResponseBody:
525
+ session_ids = list({*session_ids})
526
+ if len(session_ids) > MAX_SESSION_IDS:
527
+ raise HTTPException(
528
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
529
+ detail=f"Too many session_ids supplied: {len(session_ids)} (max {MAX_SESSION_IDS})",
530
+ )
531
+
532
+ async with request.app.state.db() as session:
533
+ project = await _get_project_by_identifier(session, project_identifier)
534
+ if not project:
535
+ raise HTTPException(
536
+ status_code=HTTP_404_NOT_FOUND,
537
+ detail=f"Project with identifier {project_identifier} not found",
538
+ )
539
+
540
+ # Build the base query
541
+ where_conditions = [
542
+ models.Project.id == project.id,
543
+ models.ProjectSession.session_id.in_(session_ids),
544
+ ]
545
+
546
+ # Add annotation name filtering
547
+ if include_annotation_names:
548
+ where_conditions.append(
549
+ models.ProjectSessionAnnotation.name.in_(include_annotation_names)
550
+ )
551
+
552
+ if exclude_annotation_names:
553
+ where_conditions.append(
554
+ models.ProjectSessionAnnotation.name.not_in(exclude_annotation_names)
555
+ )
556
+
557
+ stmt = (
558
+ select(models.ProjectSession.session_id, models.ProjectSessionAnnotation)
559
+ .join(models.Project, models.ProjectSession.project_id == models.Project.id)
560
+ .join(
561
+ models.ProjectSessionAnnotation,
562
+ models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
563
+ )
564
+ .where(*where_conditions)
565
+ .order_by(models.ProjectSessionAnnotation.id.desc())
566
+ .limit(limit + 1)
567
+ )
568
+
569
+ if cursor:
570
+ try:
571
+ cursor_id = int(GlobalID.from_id(cursor).node_id)
572
+ except ValueError:
573
+ raise HTTPException(
574
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
575
+ detail="Invalid cursor value",
576
+ )
577
+ stmt = stmt.where(models.ProjectSessionAnnotation.id <= cursor_id)
578
+
579
+ rows: list[tuple[str, models.ProjectSessionAnnotation]] = [
580
+ r async for r in (await session.stream(stmt))
581
+ ]
582
+
583
+ next_cursor: Optional[str] = None
584
+ if len(rows) == limit + 1:
585
+ *rows, extra = rows
586
+ next_cursor = str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(extra[1].id)))
587
+
588
+ if not rows:
589
+ sessions_exist = await session.scalar(
590
+ select(
591
+ exists().where(
592
+ models.ProjectSession.session_id.in_(session_ids),
593
+ models.ProjectSession.project_id == project.id,
594
+ )
595
+ )
596
+ )
597
+ if not sessions_exist:
598
+ raise HTTPException(
599
+ detail="None of the supplied session_ids exist in this project",
600
+ status_code=HTTP_404_NOT_FOUND,
601
+ )
602
+
603
+ return SessionAnnotationsResponseBody(data=[], next_cursor=None)
604
+
605
+ data = [
606
+ SessionAnnotation(
607
+ id=str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(anno.id))),
608
+ session_id=session_id,
609
+ name=anno.name,
610
+ result=AnnotationResult(
611
+ label=anno.label,
612
+ score=anno.score,
613
+ explanation=anno.explanation,
614
+ ),
615
+ metadata=anno.metadata_,
616
+ annotator_kind=anno.annotator_kind,
617
+ created_at=anno.created_at,
618
+ updated_at=anno.updated_at,
619
+ identifier=anno.identifier,
620
+ source=anno.source,
621
+ user_id=str(GlobalID(USER_NODE_NAME, str(anno.user_id))) if anno.user_id else None,
622
+ )
623
+ for session_id, anno in rows
624
+ ]
625
+
626
+ return SessionAnnotationsResponseBody(data=data, next_cursor=next_cursor)
@@ -48,6 +48,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVer
48
48
  from phoenix.server.api.types.node import from_global_id_with_expected_type
49
49
  from phoenix.server.api.utils import delete_projects, delete_traces
50
50
  from phoenix.server.authorization import is_not_locked
51
+ from phoenix.server.bearer_auth import PhoenixUser
51
52
  from phoenix.server.dml_event import DatasetInsertEvent
52
53
 
53
54
  from .models import V1RoutesBaseModel
@@ -478,6 +479,9 @@ async def upload_dataset(
478
479
  detail="Invalid request Content-Type",
479
480
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
480
481
  )
482
+ user_id: Optional[int] = None
483
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
484
+ user_id = int(request.user.identity)
481
485
  operation = cast(
482
486
  Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
483
487
  partial(
@@ -486,6 +490,7 @@ async def upload_dataset(
486
490
  action=action,
487
491
  name=name,
488
492
  description=description,
493
+ user_id=user_id,
489
494
  ),
490
495
  )
491
496
  if sync:
@@ -15,10 +15,14 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESS
15
15
  from strawberry.relay import GlobalID
16
16
 
17
17
  from phoenix.db import models
18
- from phoenix.db.helpers import SupportedSQLDialect
18
+ from phoenix.db.helpers import (
19
+ SupportedSQLDialect,
20
+ insert_experiment_with_examples_snapshot,
21
+ )
19
22
  from phoenix.db.insertion.helpers import insert_on_conflict
20
23
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
24
  from phoenix.server.authorization import is_not_locked
25
+ from phoenix.server.bearer_auth import PhoenixUser
22
26
  from phoenix.server.dml_event import ExperimentInsertEvent
23
27
  from phoenix.server.experiments.utils import generate_experiment_project_name
24
28
 
@@ -157,6 +161,9 @@ async def create_experiment(
157
161
  detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
158
162
  status_code=HTTP_404_NOT_FOUND,
159
163
  )
164
+ user_id: Optional[int] = None
165
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
166
+ user_id = int(request.user.identity)
160
167
 
161
168
  # generate a semi-unique name for the experiment
162
169
  experiment_name = request_body.name or _generate_experiment_name(dataset_name)
@@ -172,9 +179,9 @@ async def create_experiment(
172
179
  repetitions=request_body.repetitions,
173
180
  metadata_=request_body.metadata or {},
174
181
  project_name=project_name,
182
+ user_id=user_id,
175
183
  )
176
- session.add(experiment)
177
- await session.flush()
184
+ await insert_experiment_with_examples_snapshot(session, experiment)
178
185
 
179
186
  dialect = SupportedSQLDialect(session.bind.dialect.name)
180
187
  project_rowid = await session.scalar(
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, Query
7
+ from pydantic import Field
8
+ from sqlalchemy import select
9
+ from starlette.requests import Request
10
+ from starlette.status import HTTP_404_NOT_FOUND
11
+
12
+ from phoenix.db import models
13
+ from phoenix.db.helpers import SupportedSQLDialect
14
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
15
+ from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
16
+ from phoenix.server.authorization import is_not_locked
17
+ from phoenix.server.bearer_auth import PhoenixUser
18
+
19
+ from .annotations import SessionAnnotationData
20
+ from .utils import RequestBody, ResponseBody, add_errors_to_responses
21
+
22
+ router = APIRouter(tags=["sessions"])
23
+
24
+
25
+ class InsertedSessionAnnotation(V1RoutesBaseModel):
26
+ id: str = Field(description="The ID of the inserted session annotation")
27
+
28
+
29
+ class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
30
+ pass
31
+
32
+
33
+ class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
34
+ pass
35
+
36
+
37
+ @router.post(
38
+ "/session_annotations",
39
+ dependencies=[Depends(is_not_locked)],
40
+ operation_id="annotateSessions",
41
+ summary="Create session annotations",
42
+ responses=add_errors_to_responses(
43
+ [{"status_code": HTTP_404_NOT_FOUND, "description": "Session not found"}]
44
+ ),
45
+ response_description="Session annotations inserted successfully",
46
+ include_in_schema=True,
47
+ )
48
+ async def annotate_sessions(
49
+ request: Request,
50
+ request_body: AnnotateSessionsRequestBody,
51
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
52
+ ) -> AnnotateSessionsResponseBody:
53
+ if not request_body.data:
54
+ return AnnotateSessionsResponseBody(data=[])
55
+
56
+ user_id: Optional[int] = None
57
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
58
+ user_id = int(request.user.identity)
59
+
60
+ session_annotations = request_body.data
61
+ filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
62
+ if len(filtered_session_annotations) != len(session_annotations):
63
+ warnings.warn(
64
+ (
65
+ "Session annotations with the name 'note' are not supported in this endpoint. "
66
+ "They will be ignored."
67
+ ),
68
+ UserWarning,
69
+ )
70
+ precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
71
+ if not sync:
72
+ await request.state.enqueue_annotations(*precursors)
73
+ return AnnotateSessionsResponseBody(data=[])
74
+
75
+ session_ids = {p.session_id for p in precursors}
76
+ async with request.app.state.db() as session:
77
+ existing_sessions = {
78
+ session_id: rowid
79
+ async for session_id, rowid in await session.stream(
80
+ select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
81
+ models.ProjectSession.session_id.in_(session_ids)
82
+ )
83
+ )
84
+ }
85
+
86
+ missing_session_ids = session_ids - set(existing_sessions.keys())
87
+ # We prefer to fail the entire operation if there are missing sessions in sync mode
88
+ if missing_session_ids:
89
+ raise HTTPException(
90
+ detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
91
+ status_code=HTTP_404_NOT_FOUND,
92
+ )
93
+
94
+ async with request.app.state.db() as session:
95
+ inserted_ids = []
96
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
97
+ for p in precursors:
98
+ values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
99
+ session_annotation_id = await session.scalar(
100
+ insert_on_conflict(
101
+ values,
102
+ dialect=dialect,
103
+ table=models.ProjectSessionAnnotation,
104
+ unique_by=("name", "project_session_id", "identifier"),
105
+ ).returning(models.ProjectSessionAnnotation.id)
106
+ )
107
+ inserted_ids.append(session_annotation_id)
108
+
109
+ return AnnotateSessionsResponseBody(
110
+ data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
111
+ )
@@ -144,12 +144,11 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
144
144
  responses=add_errors_to_responses(
145
145
  [{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
146
146
  ),
147
- include_in_schema=False,
148
147
  )
149
148
  async def annotate_traces(
150
149
  request: Request,
151
150
  request_body: AnnotateTracesRequestBody,
152
- sync: bool = Query(default=True, description="If true, fulfill request synchronously."),
151
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
153
152
  ) -> AnnotateTracesResponseBody:
154
153
  if not request_body.data:
155
154
  return AnnotateTracesResponseBody(data=[])
@@ -217,6 +217,13 @@ async def create_user(
217
217
  detail="Cannot create users with SYSTEM role",
218
218
  )
219
219
 
220
+ # TODO: Implement VIEWER role
221
+ if role == "VIEWER":
222
+ raise HTTPException(
223
+ status_code=HTTP_400_BAD_REQUEST,
224
+ detail="VIEWER role not yet implemented",
225
+ )
226
+
220
227
  user: models.User
221
228
  if isinstance(user_data, LocalUserData):
222
229
  password = (user_data.password or secrets.token_hex()).strip()
@@ -26,6 +26,7 @@ from typing_extensions import TypeAlias, assert_never
26
26
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
+ from phoenix.db.helpers import insert_experiment_with_examples_snapshot
29
30
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
30
31
  from phoenix.server.api.context import Context
31
32
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
@@ -43,6 +44,7 @@ from phoenix.server.api.helpers.playground_spans import (
43
44
  get_db_trace,
44
45
  streaming_llm_span,
45
46
  )
47
+ from phoenix.server.api.helpers.playground_users import get_user
46
48
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
47
49
  from phoenix.server.api.input_types.ChatCompletionInput import (
48
50
  ChatCompletionInput,
@@ -302,6 +304,7 @@ class Subscription:
302
304
  description="Traces from prompt playground",
303
305
  )
304
306
  )
307
+ user_id = get_user(info)
305
308
  experiment = models.Experiment(
306
309
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
307
310
  dataset_version_id=resolved_version_id,
@@ -311,9 +314,9 @@ class Subscription:
311
314
  repetitions=input.repetitions,
312
315
  metadata_=input.experiment_metadata or dict(),
313
316
  project_name=project_name,
317
+ user_id=user_id,
314
318
  )
315
- session.add(experiment)
316
- await session.flush()
319
+ await insert_experiment_with_examples_snapshot(session, experiment)
317
320
  yield ChatCompletionSubscriptionExperiment(
318
321
  experiment=to_gql_experiment(experiment)
319
322
  ) # eagerly yields experiment so it can be linked by consumers of the subscription
@@ -12,6 +12,7 @@ from phoenix.db import models
12
12
  from phoenix.server.api.context import Context
13
13
  from phoenix.server.api.exceptions import BadRequest
14
14
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
15
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
15
16
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
16
17
  from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
17
18
  ExperimentRepeatedRunGroup,
@@ -131,3 +132,13 @@ class DatasetExample(Node):
131
132
  )
132
133
  for group in repeated_run_groups
133
134
  ]
135
+
136
+ @strawberry.field
137
+ async def dataset_splits(
138
+ self,
139
+ info: Info[Context, None],
140
+ ) -> list[DatasetSplit]:
141
+ return [
142
+ to_gql_dataset_split(split)
143
+ for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
144
+ ]
@@ -0,0 +1,32 @@
1
+ from datetime import datetime
2
+ from typing import ClassVar, Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.scalars import JSON
7
+
8
+ from phoenix.db import models
9
+
10
+
11
+ @strawberry.type
12
+ class DatasetSplit(Node):
13
+ _table: ClassVar[type[models.Base]] = models.DatasetSplit
14
+ id_attr: NodeID[int]
15
+ name: str
16
+ description: Optional[str]
17
+ metadata: JSON
18
+ color: str
19
+ created_at: datetime
20
+ updated_at: datetime
21
+
22
+
23
+ def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
24
+ return DatasetSplit(
25
+ id_attr=dataset_split.id,
26
+ name=dataset_split.name,
27
+ description=dataset_split.description,
28
+ color=dataset_split.color or "#ffffff",
29
+ metadata=dataset_split.metadata_,
30
+ created_at=dataset_split.created_at,
31
+ updated_at=dataset_split.updated_at,
32
+ )
@@ -193,10 +193,6 @@ class Experiment(Node):
193
193
  async for token_type, is_prompt, cost, tokens in data
194
194
  ]
195
195
 
196
- @strawberry.field
197
- async def repetition_count(self, info: Info[Context, None]) -> int:
198
- return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
199
-
200
196
 
201
197
  def to_gql_experiment(
202
198
  experiment: models.Experiment,
@@ -588,6 +588,22 @@ class Project(Node):
588
588
  async with info.context.db() as session:
589
589
  return list(await session.scalars(stmt))
590
590
 
591
+ @strawberry.field(
592
+ description="Names of all available annotations for sessions. "
593
+ "(The list contains no duplicates.)"
594
+ ) # type: ignore
595
+ async def session_annotation_names(
596
+ self,
597
+ info: Info[Context, None],
598
+ ) -> list[str]:
599
+ stmt = (
600
+ select(distinct(models.ProjectSessionAnnotation.name))
601
+ .join(models.ProjectSession)
602
+ .where(models.ProjectSession.project_id == self.project_rowid)
603
+ )
604
+ async with info.context.db() as session:
605
+ return list(await session.scalars(stmt))
606
+
591
607
  @strawberry.field(
592
608
  description="Names of available document evaluations.",
593
609
  ) # type: ignore