arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/document_annotation.py +1 -1
  9. phoenix/db/insertion/helpers.py +2 -2
  10. phoenix/db/insertion/session_annotation.py +176 -0
  11. phoenix/db/insertion/span_annotation.py +1 -1
  12. phoenix/db/insertion/trace_annotation.py +1 -1
  13. phoenix/db/insertion/types.py +29 -3
  14. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  15. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  16. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  17. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  18. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  19. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  20. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  21. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  22. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  23. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  24. phoenix/db/models.py +306 -46
  25. phoenix/server/api/context.py +15 -2
  26. phoenix/server/api/dataloaders/__init__.py +8 -2
  27. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  28. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  29. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  30. phoenix/server/api/dataloaders/table_fields.py +2 -2
  31. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  32. phoenix/server/api/helpers/playground_clients.py +66 -35
  33. phoenix/server/api/helpers/playground_users.py +26 -0
  34. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  35. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  36. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  37. phoenix/server/api/mutations/__init__.py +8 -0
  38. phoenix/server/api/mutations/chat_mutations.py +8 -3
  39. phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
  40. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  41. phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
  42. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  43. phoenix/server/api/queries.py +53 -0
  44. phoenix/server/api/routers/auth.py +5 -5
  45. phoenix/server/api/routers/oauth2.py +5 -23
  46. phoenix/server/api/routers/v1/__init__.py +2 -0
  47. phoenix/server/api/routers/v1/annotations.py +320 -0
  48. phoenix/server/api/routers/v1/datasets.py +5 -0
  49. phoenix/server/api/routers/v1/experiments.py +10 -3
  50. phoenix/server/api/routers/v1/sessions.py +111 -0
  51. phoenix/server/api/routers/v1/traces.py +1 -2
  52. phoenix/server/api/routers/v1/users.py +7 -0
  53. phoenix/server/api/subscriptions.py +5 -2
  54. phoenix/server/api/types/Dataset.py +8 -0
  55. phoenix/server/api/types/DatasetExample.py +18 -0
  56. phoenix/server/api/types/DatasetLabel.py +23 -0
  57. phoenix/server/api/types/DatasetSplit.py +32 -0
  58. phoenix/server/api/types/Experiment.py +0 -4
  59. phoenix/server/api/types/Project.py +16 -0
  60. phoenix/server/api/types/ProjectSession.py +88 -3
  61. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  62. phoenix/server/api/types/Prompt.py +18 -1
  63. phoenix/server/api/types/Span.py +5 -5
  64. phoenix/server/api/types/Trace.py +61 -0
  65. phoenix/server/app.py +13 -14
  66. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  67. phoenix/server/dml_event.py +13 -0
  68. phoenix/server/static/.vite/manifest.json +39 -39
  69. phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
  70. phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
  71. phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
  72. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
  73. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
  74. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
  75. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
  76. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
  77. phoenix/server/utils.py +74 -0
  78. phoenix/session/session.py +25 -5
  79. phoenix/version.py +1 -1
  80. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  81. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,6 @@ import asyncio
2
2
  import secrets
3
3
  from datetime import datetime, timedelta, timezone
4
4
  from functools import partial
5
- from pathlib import Path
6
5
  from urllib.parse import urlencode, urlparse, urlunparse
7
6
 
8
7
  from fastapi import APIRouter, Depends, HTTPException, Request, Response
@@ -38,7 +37,6 @@ from phoenix.config import (
38
37
  get_base_url,
39
38
  get_env_disable_basic_auth,
40
39
  get_env_disable_rate_limit,
41
- get_env_host_root_path,
42
40
  )
43
41
  from phoenix.db import models
44
42
  from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens
@@ -52,6 +50,7 @@ from phoenix.server.types import (
52
50
  TokenStore,
53
51
  UserId,
54
52
  )
53
+ from phoenix.server.utils import prepend_root_path
55
54
 
56
55
  rate_limiter = ServerRateLimiter(
57
56
  per_second_rate_limit=0.2,
@@ -145,7 +144,8 @@ async def logout(
145
144
  user_id = subject
146
145
  if user_id:
147
146
  await token_store.log_out(user_id)
148
- redirect_url = "/logout" if get_env_disable_basic_auth() else "/login"
147
+ redirect_path = "/logout" if get_env_disable_basic_auth() else "/login"
148
+ redirect_url = prepend_root_path(request.scope, redirect_path)
149
149
  response = Response(status_code=HTTP_302_FOUND, headers={"Location": redirect_url})
150
150
  response = delete_access_token_cookie(response)
151
151
  response = delete_refresh_token_cookie(response)
@@ -242,9 +242,9 @@ async def initiate_password_reset(request: Request) -> Response:
242
242
  )
243
243
  token, _ = await token_store.create_password_reset_token(password_reset_token_claims)
244
244
  url = urlparse(request.headers.get("referer") or get_base_url())
245
- path = Path(get_env_host_root_path()) / "reset-password-with-token"
245
+ path = prepend_root_path(request.scope, "/reset-password-with-token")
246
246
  query_string = urlencode(dict(token=token))
247
- components = (url.scheme, url.netloc, path.as_posix(), "", query_string, "")
247
+ components = (url.scheme, url.netloc, path, "", query_string, "")
248
248
  reset_url = urlunparse(components)
249
249
  await sender.send_password_reset_email(email, reset_url)
250
250
  return Response(status_code=HTTP_204_NO_CONTENT)
@@ -46,6 +46,7 @@ from phoenix.server.rate_limiters import (
46
46
  fastapi_route_rate_limiter,
47
47
  )
48
48
  from phoenix.server.types import TokenStore
49
+ from phoenix.server.utils import get_root_path, prepend_root_path
49
50
 
50
51
  _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
51
52
 
@@ -191,7 +192,7 @@ async def create_tokens(
191
192
  access_token_expiry=access_token_expiry,
192
193
  refresh_token_expiry=refresh_token_expiry,
193
194
  )
194
- redirect_path = _prepend_root_path_if_exists(request=request, path=return_url or "/")
195
+ redirect_path = prepend_root_path(request.scope, return_url or "/")
195
196
  response = RedirectResponse(
196
197
  url=redirect_path,
197
198
  status_code=HTTP_302_FOUND,
@@ -564,8 +565,8 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
564
565
  Creates a RedirectResponse to the login page to display an error message.
565
566
  """
566
567
  # TODO: this needs some cleanup
567
- login_path = _prepend_root_path_if_exists(
568
- request=request, path="/login" if not get_env_disable_basic_auth() else "/logout"
568
+ login_path = prepend_root_path(
569
+ request.scope, "/login" if not get_env_disable_basic_auth() else "/logout"
569
570
  )
570
571
  url = URL(login_path).include_query_params(error=error)
571
572
  response = RedirectResponse(url=url)
@@ -574,34 +575,15 @@ def _redirect_to_login(*, request: Request, error: str) -> RedirectResponse:
574
575
  return response
575
576
 
576
577
 
577
- def _prepend_root_path_if_exists(*, request: Request, path: str) -> str:
578
- """
579
- If a root path is configured, prepends it to the input path.
580
- """
581
- if not path.startswith("/"):
582
- raise ValueError("path must start with a forward slash")
583
- root_path = _get_root_path(request=request)
584
- if root_path.endswith("/"):
585
- root_path = root_path.rstrip("/")
586
- return root_path + path
587
-
588
-
589
578
  def _append_root_path_if_exists(*, request: Request, base_url: str) -> str:
590
579
  """
591
580
  If a root path is configured, appends it to the input base url.
592
581
  """
593
- if not (root_path := _get_root_path(request=request)):
582
+ if not (root_path := get_root_path(request.scope)):
594
583
  return base_url
595
584
  return str(URLPath(root_path).make_absolute_url(base_url=base_url))
596
585
 
597
586
 
598
- def _get_root_path(*, request: Request) -> str:
599
- """
600
- Gets the root path from the request.
601
- """
602
- return str(request.scope.get("root_path", ""))
603
-
604
-
605
587
  def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
606
588
  """
607
589
  Gets the endpoint for create tokens route.
@@ -14,6 +14,7 @@ from .experiment_runs import router as experiment_runs_router
14
14
  from .experiments import router as experiments_router
15
15
  from .projects import router as projects_router
16
16
  from .prompts import router as prompts_router
17
+ from .sessions import router as sessions_router
17
18
  from .spans import router as spans_router
18
19
  from .traces import router as traces_router
19
20
  from .users import router as users_router
@@ -71,6 +72,7 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
71
72
  router.include_router(evaluations_router)
72
73
  router.include_router(prompts_router)
73
74
  router.include_router(projects_router)
75
+ router.include_router(sessions_router)
74
76
  router.include_router(documents_router)
75
77
  router.include_router(users_router)
76
78
  return router
@@ -14,6 +14,9 @@ from strawberry.relay import GlobalID
14
14
  from phoenix.db import models
15
15
  from phoenix.db.insertion.types import Precursors
16
16
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
17
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
18
+ ProjectSessionAnnotation as SessionAnnotationNodeType,
19
+ )
17
20
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation as SpanAnnotationNodeType
18
21
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation as TraceAnnotationNodeType
19
22
  from phoenix.server.api.types.User import User as UserNodeType
@@ -24,6 +27,7 @@ logger = logging.getLogger(__name__)
24
27
 
25
28
  SPAN_ANNOTATION_NODE_NAME = SpanAnnotationNodeType.__name__
26
29
  TRACE_ANNOTATION_NODE_NAME = TraceAnnotationNodeType.__name__
30
+ SESSION_ANNOTATION_NODE_NAME = SessionAnnotationNodeType.__name__
27
31
  MAX_TRACE_IDS = 1_000
28
32
  USER_NODE_NAME = UserNodeType.__name__
29
33
  MAX_SPAN_IDS = 1_000
@@ -161,6 +165,35 @@ class TraceAnnotationsResponseBody(PaginatedResponseBody[TraceAnnotation]):
161
165
  pass
162
166
 
163
167
 
168
+ class SessionAnnotationData(AnnotationData):
169
+ session_id: str = Field(description="Session ID")
170
+
171
+ def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SessionAnnotation:
172
+ return Precursors.SessionAnnotation(
173
+ datetime.now(timezone.utc),
174
+ self.session_id,
175
+ models.ProjectSessionAnnotation(
176
+ name=self.name,
177
+ annotator_kind=self.annotator_kind,
178
+ score=self.result.score if self.result else None,
179
+ label=self.result.label if self.result else None,
180
+ explanation=self.result.explanation if self.result else None,
181
+ metadata_=self.metadata or {},
182
+ identifier=self.identifier,
183
+ source="API",
184
+ user_id=user_id,
185
+ ),
186
+ )
187
+
188
+
189
+ class SessionAnnotation(SessionAnnotationData, Annotation):
190
+ pass
191
+
192
+
193
+ class SessionAnnotationsResponseBody(PaginatedResponseBody[SessionAnnotation]):
194
+ pass
195
+
196
+
164
197
  @router.get(
165
198
  "/projects/{project_identifier}/span_annotations",
166
199
  operation_id="listSpanAnnotationsBySpanIds",
@@ -304,3 +337,290 @@ async def list_span_annotations(
304
337
  ]
305
338
 
306
339
  return SpanAnnotationsResponseBody(data=data, next_cursor=next_cursor)
340
+
341
+
342
+ @router.get(
343
+ "/projects/{project_identifier}/trace_annotations",
344
+ operation_id="listTraceAnnotationsByTraceIds",
345
+ summary="Get trace annotations for a list of trace_ids.",
346
+ status_code=HTTP_200_OK,
347
+ responses=add_errors_to_responses(
348
+ [
349
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Project or traces not found"},
350
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
351
+ ]
352
+ ),
353
+ )
354
+ async def list_trace_annotations(
355
+ request: Request,
356
+ project_identifier: str = Path(
357
+ description=(
358
+ "The project identifier: either project ID or project name. If using a project name as "
359
+ "the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
360
+ "characters."
361
+ )
362
+ ),
363
+ trace_ids: list[str] = Query(
364
+ ..., min_length=1, description="One or more trace id to fetch annotations for"
365
+ ),
366
+ include_annotation_names: Optional[list[str]] = Query(
367
+ default=None,
368
+ description=(
369
+ "Optional list of annotation names to include. If provided, only annotations with "
370
+ "these names will be returned. 'note' annotations are excluded by default unless "
371
+ "explicitly included in this list."
372
+ ),
373
+ ),
374
+ exclude_annotation_names: Optional[list[str]] = Query(
375
+ default=None, description="Optional list of annotation names to exclude from results."
376
+ ),
377
+ cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
378
+ limit: int = Query(
379
+ default=10,
380
+ gt=0,
381
+ le=10000,
382
+ description="The maximum number of annotations to return in a single request",
383
+ ),
384
+ ) -> TraceAnnotationsResponseBody:
385
+ trace_ids = list({*trace_ids})
386
+ if len(trace_ids) > MAX_TRACE_IDS:
387
+ raise HTTPException(
388
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
389
+ detail=f"Too many trace_ids supplied: {len(trace_ids)} (max {MAX_TRACE_IDS})",
390
+ )
391
+
392
+ async with request.app.state.db() as session:
393
+ project = await _get_project_by_identifier(session, project_identifier)
394
+ if not project:
395
+ raise HTTPException(
396
+ status_code=HTTP_404_NOT_FOUND,
397
+ detail=f"Project with identifier {project_identifier} not found",
398
+ )
399
+
400
+ # Build the base query
401
+ where_conditions = [
402
+ models.Project.id == project.id,
403
+ models.Trace.trace_id.in_(trace_ids),
404
+ ]
405
+
406
+ # Add annotation name filtering
407
+ if include_annotation_names:
408
+ where_conditions.append(models.TraceAnnotation.name.in_(include_annotation_names))
409
+
410
+ if exclude_annotation_names:
411
+ where_conditions.append(models.TraceAnnotation.name.not_in(exclude_annotation_names))
412
+
413
+ stmt = (
414
+ select(models.Trace.trace_id, models.TraceAnnotation)
415
+ .join(models.Project, models.Trace.project_rowid == models.Project.id)
416
+ .join(models.TraceAnnotation, models.TraceAnnotation.trace_rowid == models.Trace.id)
417
+ .where(*where_conditions)
418
+ .order_by(models.TraceAnnotation.id.desc())
419
+ .limit(limit + 1)
420
+ )
421
+
422
+ if cursor:
423
+ try:
424
+ cursor_id = int(GlobalID.from_id(cursor).node_id)
425
+ except ValueError:
426
+ raise HTTPException(
427
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
428
+ detail="Invalid cursor value",
429
+ )
430
+ stmt = stmt.where(models.TraceAnnotation.id <= cursor_id)
431
+
432
+ rows: list[tuple[str, models.TraceAnnotation]] = [
433
+ r async for r in (await session.stream(stmt))
434
+ ]
435
+
436
+ next_cursor: Optional[str] = None
437
+ if len(rows) == limit + 1:
438
+ *rows, extra = rows
439
+ next_cursor = str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(extra[1].id)))
440
+
441
+ if not rows:
442
+ traces_exist = await session.scalar(
443
+ select(
444
+ exists().where(
445
+ models.Trace.trace_id.in_(trace_ids),
446
+ models.Trace.project_rowid == project.id,
447
+ )
448
+ )
449
+ )
450
+ if not traces_exist:
451
+ raise HTTPException(
452
+ detail="None of the supplied trace_ids exist in this project",
453
+ status_code=HTTP_404_NOT_FOUND,
454
+ )
455
+
456
+ return TraceAnnotationsResponseBody(data=[], next_cursor=None)
457
+
458
+ data = [
459
+ TraceAnnotation(
460
+ id=str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(anno.id))),
461
+ trace_id=trace_id,
462
+ name=anno.name,
463
+ result=AnnotationResult(
464
+ label=anno.label,
465
+ score=anno.score,
466
+ explanation=anno.explanation,
467
+ ),
468
+ metadata=anno.metadata_,
469
+ annotator_kind=anno.annotator_kind,
470
+ created_at=anno.created_at,
471
+ updated_at=anno.updated_at,
472
+ identifier=anno.identifier,
473
+ source=anno.source,
474
+ user_id=str(GlobalID("User", str(anno.user_id))) if anno.user_id else None,
475
+ )
476
+ for trace_id, anno in rows
477
+ ]
478
+
479
+ return TraceAnnotationsResponseBody(data=data, next_cursor=next_cursor)
480
+
481
+
482
+ @router.get(
483
+ "/projects/{project_identifier}/session_annotations",
484
+ operation_id="listSessionAnnotationsBySessionIds",
485
+ summary="Get session annotations for a list of session_ids.",
486
+ status_code=HTTP_200_OK,
487
+ responses=add_errors_to_responses(
488
+ [
489
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Project or sessions not found"},
490
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
491
+ ]
492
+ ),
493
+ )
494
+ async def list_session_annotations(
495
+ request: Request,
496
+ project_identifier: str = Path(
497
+ description=(
498
+ "The project identifier: either project ID or project name. If using a project name as "
499
+ "the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
500
+ "characters."
501
+ )
502
+ ),
503
+ session_ids: list[str] = Query(
504
+ ..., min_length=1, description="One or more session id to fetch annotations for"
505
+ ),
506
+ include_annotation_names: Optional[list[str]] = Query(
507
+ default=None,
508
+ description=(
509
+ "Optional list of annotation names to include. If provided, only annotations with "
510
+ "these names will be returned. 'note' annotations are excluded by default unless "
511
+ "explicitly included in this list."
512
+ ),
513
+ ),
514
+ exclude_annotation_names: Optional[list[str]] = Query(
515
+ default=None, description="Optional list of annotation names to exclude from results."
516
+ ),
517
+ cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
518
+ limit: int = Query(
519
+ default=10,
520
+ gt=0,
521
+ le=10000,
522
+ description="The maximum number of annotations to return in a single request",
523
+ ),
524
+ ) -> SessionAnnotationsResponseBody:
525
+ session_ids = list({*session_ids})
526
+ if len(session_ids) > MAX_SESSION_IDS:
527
+ raise HTTPException(
528
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
529
+ detail=f"Too many session_ids supplied: {len(session_ids)} (max {MAX_SESSION_IDS})",
530
+ )
531
+
532
+ async with request.app.state.db() as session:
533
+ project = await _get_project_by_identifier(session, project_identifier)
534
+ if not project:
535
+ raise HTTPException(
536
+ status_code=HTTP_404_NOT_FOUND,
537
+ detail=f"Project with identifier {project_identifier} not found",
538
+ )
539
+
540
+ # Build the base query
541
+ where_conditions = [
542
+ models.Project.id == project.id,
543
+ models.ProjectSession.session_id.in_(session_ids),
544
+ ]
545
+
546
+ # Add annotation name filtering
547
+ if include_annotation_names:
548
+ where_conditions.append(
549
+ models.ProjectSessionAnnotation.name.in_(include_annotation_names)
550
+ )
551
+
552
+ if exclude_annotation_names:
553
+ where_conditions.append(
554
+ models.ProjectSessionAnnotation.name.not_in(exclude_annotation_names)
555
+ )
556
+
557
+ stmt = (
558
+ select(models.ProjectSession.session_id, models.ProjectSessionAnnotation)
559
+ .join(models.Project, models.ProjectSession.project_id == models.Project.id)
560
+ .join(
561
+ models.ProjectSessionAnnotation,
562
+ models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
563
+ )
564
+ .where(*where_conditions)
565
+ .order_by(models.ProjectSessionAnnotation.id.desc())
566
+ .limit(limit + 1)
567
+ )
568
+
569
+ if cursor:
570
+ try:
571
+ cursor_id = int(GlobalID.from_id(cursor).node_id)
572
+ except ValueError:
573
+ raise HTTPException(
574
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
575
+ detail="Invalid cursor value",
576
+ )
577
+ stmt = stmt.where(models.ProjectSessionAnnotation.id <= cursor_id)
578
+
579
+ rows: list[tuple[str, models.ProjectSessionAnnotation]] = [
580
+ r async for r in (await session.stream(stmt))
581
+ ]
582
+
583
+ next_cursor: Optional[str] = None
584
+ if len(rows) == limit + 1:
585
+ *rows, extra = rows
586
+ next_cursor = str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(extra[1].id)))
587
+
588
+ if not rows:
589
+ sessions_exist = await session.scalar(
590
+ select(
591
+ exists().where(
592
+ models.ProjectSession.session_id.in_(session_ids),
593
+ models.ProjectSession.project_id == project.id,
594
+ )
595
+ )
596
+ )
597
+ if not sessions_exist:
598
+ raise HTTPException(
599
+ detail="None of the supplied session_ids exist in this project",
600
+ status_code=HTTP_404_NOT_FOUND,
601
+ )
602
+
603
+ return SessionAnnotationsResponseBody(data=[], next_cursor=None)
604
+
605
+ data = [
606
+ SessionAnnotation(
607
+ id=str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(anno.id))),
608
+ session_id=session_id,
609
+ name=anno.name,
610
+ result=AnnotationResult(
611
+ label=anno.label,
612
+ score=anno.score,
613
+ explanation=anno.explanation,
614
+ ),
615
+ metadata=anno.metadata_,
616
+ annotator_kind=anno.annotator_kind,
617
+ created_at=anno.created_at,
618
+ updated_at=anno.updated_at,
619
+ identifier=anno.identifier,
620
+ source=anno.source,
621
+ user_id=str(GlobalID(USER_NODE_NAME, str(anno.user_id))) if anno.user_id else None,
622
+ )
623
+ for session_id, anno in rows
624
+ ]
625
+
626
+ return SessionAnnotationsResponseBody(data=data, next_cursor=next_cursor)
@@ -48,6 +48,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVer
48
48
  from phoenix.server.api.types.node import from_global_id_with_expected_type
49
49
  from phoenix.server.api.utils import delete_projects, delete_traces
50
50
  from phoenix.server.authorization import is_not_locked
51
+ from phoenix.server.bearer_auth import PhoenixUser
51
52
  from phoenix.server.dml_event import DatasetInsertEvent
52
53
 
53
54
  from .models import V1RoutesBaseModel
@@ -478,6 +479,9 @@ async def upload_dataset(
478
479
  detail="Invalid request Content-Type",
479
480
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
480
481
  )
482
+ user_id: Optional[int] = None
483
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
484
+ user_id = int(request.user.identity)
481
485
  operation = cast(
482
486
  Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
483
487
  partial(
@@ -486,6 +490,7 @@ async def upload_dataset(
486
490
  action=action,
487
491
  name=name,
488
492
  description=description,
493
+ user_id=user_id,
489
494
  ),
490
495
  )
491
496
  if sync:
@@ -15,10 +15,14 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESS
15
15
  from strawberry.relay import GlobalID
16
16
 
17
17
  from phoenix.db import models
18
- from phoenix.db.helpers import SupportedSQLDialect
18
+ from phoenix.db.helpers import (
19
+ SupportedSQLDialect,
20
+ insert_experiment_with_examples_snapshot,
21
+ )
19
22
  from phoenix.db.insertion.helpers import insert_on_conflict
20
23
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
24
  from phoenix.server.authorization import is_not_locked
25
+ from phoenix.server.bearer_auth import PhoenixUser
22
26
  from phoenix.server.dml_event import ExperimentInsertEvent
23
27
  from phoenix.server.experiments.utils import generate_experiment_project_name
24
28
 
@@ -157,6 +161,9 @@ async def create_experiment(
157
161
  detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
158
162
  status_code=HTTP_404_NOT_FOUND,
159
163
  )
164
+ user_id: Optional[int] = None
165
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
166
+ user_id = int(request.user.identity)
160
167
 
161
168
  # generate a semi-unique name for the experiment
162
169
  experiment_name = request_body.name or _generate_experiment_name(dataset_name)
@@ -172,9 +179,9 @@ async def create_experiment(
172
179
  repetitions=request_body.repetitions,
173
180
  metadata_=request_body.metadata or {},
174
181
  project_name=project_name,
182
+ user_id=user_id,
175
183
  )
176
- session.add(experiment)
177
- await session.flush()
184
+ await insert_experiment_with_examples_snapshot(session, experiment)
178
185
 
179
186
  dialect = SupportedSQLDialect(session.bind.dialect.name)
180
187
  project_rowid = await session.scalar(
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, Query
7
+ from pydantic import Field
8
+ from sqlalchemy import select
9
+ from starlette.requests import Request
10
+ from starlette.status import HTTP_404_NOT_FOUND
11
+
12
+ from phoenix.db import models
13
+ from phoenix.db.helpers import SupportedSQLDialect
14
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
15
+ from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
16
+ from phoenix.server.authorization import is_not_locked
17
+ from phoenix.server.bearer_auth import PhoenixUser
18
+
19
+ from .annotations import SessionAnnotationData
20
+ from .utils import RequestBody, ResponseBody, add_errors_to_responses
21
+
22
+ router = APIRouter(tags=["sessions"])
23
+
24
+
25
+ class InsertedSessionAnnotation(V1RoutesBaseModel):
26
+ id: str = Field(description="The ID of the inserted session annotation")
27
+
28
+
29
+ class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
30
+ pass
31
+
32
+
33
+ class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
34
+ pass
35
+
36
+
37
+ @router.post(
38
+ "/session_annotations",
39
+ dependencies=[Depends(is_not_locked)],
40
+ operation_id="annotateSessions",
41
+ summary="Create session annotations",
42
+ responses=add_errors_to_responses(
43
+ [{"status_code": HTTP_404_NOT_FOUND, "description": "Session not found"}]
44
+ ),
45
+ response_description="Session annotations inserted successfully",
46
+ include_in_schema=True,
47
+ )
48
+ async def annotate_sessions(
49
+ request: Request,
50
+ request_body: AnnotateSessionsRequestBody,
51
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
52
+ ) -> AnnotateSessionsResponseBody:
53
+ if not request_body.data:
54
+ return AnnotateSessionsResponseBody(data=[])
55
+
56
+ user_id: Optional[int] = None
57
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
58
+ user_id = int(request.user.identity)
59
+
60
+ session_annotations = request_body.data
61
+ filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
62
+ if len(filtered_session_annotations) != len(session_annotations):
63
+ warnings.warn(
64
+ (
65
+ "Session annotations with the name 'note' are not supported in this endpoint. "
66
+ "They will be ignored."
67
+ ),
68
+ UserWarning,
69
+ )
70
+ precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
71
+ if not sync:
72
+ await request.state.enqueue_annotations(*precursors)
73
+ return AnnotateSessionsResponseBody(data=[])
74
+
75
+ session_ids = {p.session_id for p in precursors}
76
+ async with request.app.state.db() as session:
77
+ existing_sessions = {
78
+ session_id: rowid
79
+ async for session_id, rowid in await session.stream(
80
+ select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
81
+ models.ProjectSession.session_id.in_(session_ids)
82
+ )
83
+ )
84
+ }
85
+
86
+ missing_session_ids = session_ids - set(existing_sessions.keys())
87
+ # We prefer to fail the entire operation if there are missing sessions in sync mode
88
+ if missing_session_ids:
89
+ raise HTTPException(
90
+ detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
91
+ status_code=HTTP_404_NOT_FOUND,
92
+ )
93
+
94
+ async with request.app.state.db() as session:
95
+ inserted_ids = []
96
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
97
+ for p in precursors:
98
+ values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
99
+ session_annotation_id = await session.scalar(
100
+ insert_on_conflict(
101
+ values,
102
+ dialect=dialect,
103
+ table=models.ProjectSessionAnnotation,
104
+ unique_by=("name", "project_session_id", "identifier"),
105
+ ).returning(models.ProjectSessionAnnotation.id)
106
+ )
107
+ inserted_ids.append(session_annotation_id)
108
+
109
+ return AnnotateSessionsResponseBody(
110
+ data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
111
+ )
@@ -144,12 +144,11 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
144
144
  responses=add_errors_to_responses(
145
145
  [{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
146
146
  ),
147
- include_in_schema=False,
148
147
  )
149
148
  async def annotate_traces(
150
149
  request: Request,
151
150
  request_body: AnnotateTracesRequestBody,
152
- sync: bool = Query(default=True, description="If true, fulfill request synchronously."),
151
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
153
152
  ) -> AnnotateTracesResponseBody:
154
153
  if not request_body.data:
155
154
  return AnnotateTracesResponseBody(data=[])
@@ -217,6 +217,13 @@ async def create_user(
217
217
  detail="Cannot create users with SYSTEM role",
218
218
  )
219
219
 
220
+ # TODO: Implement VIEWER role
221
+ if role == "VIEWER":
222
+ raise HTTPException(
223
+ status_code=HTTP_400_BAD_REQUEST,
224
+ detail="VIEWER role not yet implemented",
225
+ )
226
+
220
227
  user: models.User
221
228
  if isinstance(user_data, LocalUserData):
222
229
  password = (user_data.password or secrets.token_hex()).strip()