arize-phoenix 11.37.0__py3-none-any.whl → 12.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (75) hide show
  1. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +74 -53
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/helpers.py +2 -2
  9. phoenix/db/insertion/session_annotation.py +176 -0
  10. phoenix/db/insertion/types.py +30 -0
  11. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  12. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  13. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  14. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  15. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  16. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  17. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  18. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  19. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  20. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  21. phoenix/db/models.py +285 -46
  22. phoenix/server/api/context.py +13 -2
  23. phoenix/server/api/dataloaders/__init__.py +6 -2
  24. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  25. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  26. phoenix/server/api/dataloaders/table_fields.py +2 -2
  27. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  28. phoenix/server/api/helpers/playground_clients.py +65 -35
  29. phoenix/server/api/helpers/playground_spans.py +2 -1
  30. phoenix/server/api/helpers/playground_users.py +26 -0
  31. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  32. phoenix/server/api/input_types/ChatCompletionInput.py +2 -0
  33. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  34. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  35. phoenix/server/api/mutations/__init__.py +6 -0
  36. phoenix/server/api/mutations/chat_mutations.py +24 -9
  37. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  38. phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
  39. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  40. phoenix/server/api/queries.py +32 -0
  41. phoenix/server/api/routers/v1/__init__.py +2 -0
  42. phoenix/server/api/routers/v1/annotations.py +320 -0
  43. phoenix/server/api/routers/v1/datasets.py +5 -0
  44. phoenix/server/api/routers/v1/experiments.py +10 -3
  45. phoenix/server/api/routers/v1/sessions.py +111 -0
  46. phoenix/server/api/routers/v1/traces.py +1 -2
  47. phoenix/server/api/routers/v1/users.py +7 -0
  48. phoenix/server/api/subscriptions.py +25 -7
  49. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  50. phoenix/server/api/types/DatasetExample.py +11 -0
  51. phoenix/server/api/types/DatasetSplit.py +32 -0
  52. phoenix/server/api/types/Experiment.py +0 -4
  53. phoenix/server/api/types/Project.py +16 -0
  54. phoenix/server/api/types/ProjectSession.py +88 -3
  55. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  56. phoenix/server/api/types/Span.py +5 -5
  57. phoenix/server/api/types/Trace.py +61 -0
  58. phoenix/server/app.py +6 -2
  59. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  60. phoenix/server/dml_event.py +13 -0
  61. phoenix/server/static/.vite/manifest.json +39 -39
  62. phoenix/server/static/assets/{components-CFzdBkk_.js → components-Dl9SUw1U.js} +371 -327
  63. phoenix/server/static/assets/{index-DayUA9lQ.js → index-CqQS0dTo.js} +2 -2
  64. phoenix/server/static/assets/{pages-CvUhOO9h.js → pages-DKSjVA_E.js} +771 -518
  65. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
  66. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
  67. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
  68. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
  69. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
  70. phoenix/version.py +1 -1
  71. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  72. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
  73. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
  74. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  75. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,161 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from sqlalchemy.exc import IntegrityError as PostgreSQLIntegrityError
5
+ from sqlean.dbapi2 import IntegrityError as SQLiteIntegrityError # type: ignore[import-untyped]
6
+ from starlette.requests import Request
7
+ from strawberry import Info
8
+ from strawberry.relay import GlobalID
9
+
10
+ from phoenix.db import models
11
+ from phoenix.server.api.auth import IsLocked, IsNotReadOnly
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.exceptions import BadRequest, Conflict, NotFound, Unauthorized
14
+ from phoenix.server.api.helpers.annotations import get_user_identifier
15
+ from phoenix.server.api.input_types.CreateProjectSessionAnnotationInput import (
16
+ CreateProjectSessionAnnotationInput,
17
+ )
18
+ from phoenix.server.api.input_types.UpdateAnnotationInput import UpdateAnnotationInput
19
+ from phoenix.server.api.queries import Query
20
+ from phoenix.server.api.types.AnnotationSource import AnnotationSource
21
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
22
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
23
+ ProjectSessionAnnotation,
24
+ to_gql_project_session_annotation,
25
+ )
26
+ from phoenix.server.bearer_auth import PhoenixUser
27
+ from phoenix.server.dml_event import (
28
+ ProjectSessionAnnotationDeleteEvent,
29
+ ProjectSessionAnnotationInsertEvent,
30
+ )
31
+
32
+
33
+ @strawberry.type
34
+ class ProjectSessionAnnotationMutationPayload:
35
+ project_session_annotation: ProjectSessionAnnotation
36
+ query: Query
37
+
38
+
39
+ @strawberry.type
40
+ class ProjectSessionAnnotationMutationMixin:
41
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
42
+ async def create_project_session_annotations(
43
+ self, info: Info[Context, None], input: CreateProjectSessionAnnotationInput
44
+ ) -> ProjectSessionAnnotationMutationPayload:
45
+ assert isinstance(request := info.context.request, Request)
46
+ user_id: Optional[int] = None
47
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
48
+ user_id = int(user.identity)
49
+
50
+ try:
51
+ project_session_id = from_global_id_with_expected_type(
52
+ input.project_session_id, "ProjectSession"
53
+ )
54
+ except ValueError:
55
+ raise BadRequest(f"Invalid session ID: {input.project_session_id}")
56
+
57
+ identifier = ""
58
+ if isinstance(input.identifier, str):
59
+ identifier = input.identifier # Already trimmed in __post_init__
60
+ elif input.source == AnnotationSource.APP and user_id is not None:
61
+ identifier = get_user_identifier(user_id)
62
+
63
+ try:
64
+ async with info.context.db() as session:
65
+ anno = models.ProjectSessionAnnotation(
66
+ project_session_id=project_session_id,
67
+ name=input.name,
68
+ label=input.label,
69
+ score=input.score,
70
+ explanation=input.explanation,
71
+ annotator_kind=input.annotator_kind.value,
72
+ metadata_=input.metadata,
73
+ identifier=identifier,
74
+ source=input.source.value,
75
+ user_id=user_id,
76
+ )
77
+ session.add(anno)
78
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
79
+ raise Conflict(f"Error creating annotation: {e}")
80
+
81
+ info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
82
+
83
+ return ProjectSessionAnnotationMutationPayload(
84
+ project_session_annotation=to_gql_project_session_annotation(anno),
85
+ query=Query(),
86
+ )
87
+
88
+ @strawberry.mutation(permission_classes=[IsNotReadOnly, IsLocked]) # type: ignore
89
+ async def update_project_session_annotations(
90
+ self, info: Info[Context, None], input: UpdateAnnotationInput
91
+ ) -> ProjectSessionAnnotationMutationPayload:
92
+ assert isinstance(request := info.context.request, Request)
93
+ user_id: Optional[int] = None
94
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
95
+ user_id = int(user.identity)
96
+
97
+ try:
98
+ id_ = from_global_id_with_expected_type(input.id, "ProjectSessionAnnotation")
99
+ except ValueError:
100
+ raise BadRequest(f"Invalid session annotation ID: {input.id}")
101
+
102
+ async with info.context.db() as session:
103
+ if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
104
+ raise NotFound(f"Could not find session annotation with ID: {input.id}")
105
+ if anno.user_id != user_id:
106
+ raise Unauthorized("Session annotation is not associated with the current user.")
107
+
108
+ # Update the annotation fields
109
+ anno.name = input.name
110
+ anno.label = input.label
111
+ anno.score = input.score
112
+ anno.explanation = input.explanation
113
+ anno.annotator_kind = input.annotator_kind.value
114
+ anno.metadata_ = input.metadata
115
+ anno.source = input.source.value
116
+
117
+ session.add(anno)
118
+ try:
119
+ await session.flush()
120
+ except (PostgreSQLIntegrityError, SQLiteIntegrityError) as e:
121
+ raise Conflict(f"Error updating annotation: {e}")
122
+
123
+ info.context.event_queue.put(ProjectSessionAnnotationInsertEvent((anno.id,)))
124
+ return ProjectSessionAnnotationMutationPayload(
125
+ project_session_annotation=to_gql_project_session_annotation(anno),
126
+ query=Query(),
127
+ )
128
+
129
+ @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore
130
+ async def delete_project_session_annotation(
131
+ self, info: Info[Context, None], id: GlobalID
132
+ ) -> ProjectSessionAnnotationMutationPayload:
133
+ try:
134
+ id_ = from_global_id_with_expected_type(id, "ProjectSessionAnnotation")
135
+ except ValueError:
136
+ raise BadRequest(f"Invalid session annotation ID: {id}")
137
+
138
+ assert isinstance(request := info.context.request, Request)
139
+ user_id: Optional[int] = None
140
+ user_is_admin = False
141
+ if "user" in request.scope and isinstance((user := info.context.user), PhoenixUser):
142
+ user_id = int(user.identity)
143
+ user_is_admin = user.is_admin
144
+
145
+ async with info.context.db() as session:
146
+ if not (anno := await session.get(models.ProjectSessionAnnotation, id_)):
147
+ raise NotFound(f"Could not find session annotation with ID: {id}")
148
+
149
+ if not user_is_admin and anno.user_id != user_id:
150
+ raise Unauthorized(
151
+ "Session annotation is not associated with the current user and "
152
+ "the current user is not an admin."
153
+ )
154
+
155
+ await session.delete(anno)
156
+
157
+ deleted_gql_annotation = to_gql_project_session_annotation(anno)
158
+ info.context.event_queue.put(ProjectSessionAnnotationDeleteEvent((id_,)))
159
+ return ProjectSessionAnnotationMutationPayload(
160
+ project_session_annotation=deleted_gql_annotation, query=Query()
161
+ )
@@ -48,6 +48,7 @@ from phoenix.server.api.types.AnnotationConfig import AnnotationConfig, to_gql_a
48
48
  from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
49
49
  from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
50
50
  from phoenix.server.api.types.DatasetExample import DatasetExample
51
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
51
52
  from phoenix.server.api.types.Dimension import to_gql_dimension
52
53
  from phoenix.server.api.types.EmbeddingDimension import (
53
54
  DEFAULT_CLUSTER_SELECTION_EPSILON,
@@ -959,6 +960,14 @@ class Query:
959
960
  id_attr=example.id,
960
961
  created_at=example.created_at,
961
962
  )
963
+ elif type_name == DatasetSplit.__name__:
964
+ async with info.context.db() as session:
965
+ dataset_split = await session.scalar(
966
+ select(models.DatasetSplit).where(models.DatasetSplit.id == node_id)
967
+ )
968
+ if not dataset_split:
969
+ raise NotFound(f"Unknown dataset split: {id}")
970
+ return to_gql_dataset_split(dataset_split)
962
971
  elif type_name == Experiment.__name__:
963
972
  async with info.context.db() as session:
964
973
  experiment = await session.scalar(
@@ -1140,6 +1149,29 @@ class Query:
1140
1149
  args=args,
1141
1150
  )
1142
1151
 
1152
+ @strawberry.field
1153
+ async def dataset_splits(
1154
+ self,
1155
+ info: Info[Context, None],
1156
+ first: Optional[int] = 50,
1157
+ last: Optional[int] = UNSET,
1158
+ after: Optional[CursorString] = UNSET,
1159
+ before: Optional[CursorString] = UNSET,
1160
+ ) -> Connection[DatasetSplit]:
1161
+ args = ConnectionArgs(
1162
+ first=first,
1163
+ after=after if isinstance(after, CursorString) else None,
1164
+ last=last,
1165
+ before=before if isinstance(before, CursorString) else None,
1166
+ )
1167
+ async with info.context.db() as session:
1168
+ splits = await session.stream_scalars(select(models.DatasetSplit))
1169
+ data = [to_gql_dataset_split(split) async for split in splits]
1170
+ return connection_from_list(
1171
+ data=data,
1172
+ args=args,
1173
+ )
1174
+
1143
1175
  @strawberry.field
1144
1176
  async def annotation_configs(
1145
1177
  self,
@@ -14,6 +14,7 @@ from .experiment_runs import router as experiment_runs_router
14
14
  from .experiments import router as experiments_router
15
15
  from .projects import router as projects_router
16
16
  from .prompts import router as prompts_router
17
+ from .sessions import router as sessions_router
17
18
  from .spans import router as spans_router
18
19
  from .traces import router as traces_router
19
20
  from .users import router as users_router
@@ -71,6 +72,7 @@ def create_v1_router(authentication_enabled: bool) -> APIRouter:
71
72
  router.include_router(evaluations_router)
72
73
  router.include_router(prompts_router)
73
74
  router.include_router(projects_router)
75
+ router.include_router(sessions_router)
74
76
  router.include_router(documents_router)
75
77
  router.include_router(users_router)
76
78
  return router
@@ -14,6 +14,9 @@ from strawberry.relay import GlobalID
14
14
  from phoenix.db import models
15
15
  from phoenix.db.insertion.types import Precursors
16
16
  from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
17
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
18
+ ProjectSessionAnnotation as SessionAnnotationNodeType,
19
+ )
17
20
  from phoenix.server.api.types.SpanAnnotation import SpanAnnotation as SpanAnnotationNodeType
18
21
  from phoenix.server.api.types.TraceAnnotation import TraceAnnotation as TraceAnnotationNodeType
19
22
  from phoenix.server.api.types.User import User as UserNodeType
@@ -24,6 +27,7 @@ logger = logging.getLogger(__name__)
24
27
 
25
28
  SPAN_ANNOTATION_NODE_NAME = SpanAnnotationNodeType.__name__
26
29
  TRACE_ANNOTATION_NODE_NAME = TraceAnnotationNodeType.__name__
30
+ SESSION_ANNOTATION_NODE_NAME = SessionAnnotationNodeType.__name__
27
31
  MAX_TRACE_IDS = 1_000
28
32
  USER_NODE_NAME = UserNodeType.__name__
29
33
  MAX_SPAN_IDS = 1_000
@@ -161,6 +165,35 @@ class TraceAnnotationsResponseBody(PaginatedResponseBody[TraceAnnotation]):
161
165
  pass
162
166
 
163
167
 
168
+ class SessionAnnotationData(AnnotationData):
169
+ session_id: str = Field(description="Session ID")
170
+
171
+ def as_precursor(self, *, user_id: Optional[int] = None) -> Precursors.SessionAnnotation:
172
+ return Precursors.SessionAnnotation(
173
+ datetime.now(timezone.utc),
174
+ self.session_id,
175
+ models.ProjectSessionAnnotation(
176
+ name=self.name,
177
+ annotator_kind=self.annotator_kind,
178
+ score=self.result.score if self.result else None,
179
+ label=self.result.label if self.result else None,
180
+ explanation=self.result.explanation if self.result else None,
181
+ metadata_=self.metadata or {},
182
+ identifier=self.identifier,
183
+ source="API",
184
+ user_id=user_id,
185
+ ),
186
+ )
187
+
188
+
189
+ class SessionAnnotation(SessionAnnotationData, Annotation):
190
+ pass
191
+
192
+
193
+ class SessionAnnotationsResponseBody(PaginatedResponseBody[SessionAnnotation]):
194
+ pass
195
+
196
+
164
197
  @router.get(
165
198
  "/projects/{project_identifier}/span_annotations",
166
199
  operation_id="listSpanAnnotationsBySpanIds",
@@ -304,3 +337,290 @@ async def list_span_annotations(
304
337
  ]
305
338
 
306
339
  return SpanAnnotationsResponseBody(data=data, next_cursor=next_cursor)
340
+
341
+
342
+ @router.get(
343
+ "/projects/{project_identifier}/trace_annotations",
344
+ operation_id="listTraceAnnotationsByTraceIds",
345
+ summary="Get trace annotations for a list of trace_ids.",
346
+ status_code=HTTP_200_OK,
347
+ responses=add_errors_to_responses(
348
+ [
349
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Project or traces not found"},
350
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
351
+ ]
352
+ ),
353
+ )
354
+ async def list_trace_annotations(
355
+ request: Request,
356
+ project_identifier: str = Path(
357
+ description=(
358
+ "The project identifier: either project ID or project name. If using a project name as "
359
+ "the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
360
+ "characters."
361
+ )
362
+ ),
363
+ trace_ids: list[str] = Query(
364
+ ..., min_length=1, description="One or more trace id to fetch annotations for"
365
+ ),
366
+ include_annotation_names: Optional[list[str]] = Query(
367
+ default=None,
368
+ description=(
369
+ "Optional list of annotation names to include. If provided, only annotations with "
370
+ "these names will be returned. 'note' annotations are excluded by default unless "
371
+ "explicitly included in this list."
372
+ ),
373
+ ),
374
+ exclude_annotation_names: Optional[list[str]] = Query(
375
+ default=None, description="Optional list of annotation names to exclude from results."
376
+ ),
377
+ cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
378
+ limit: int = Query(
379
+ default=10,
380
+ gt=0,
381
+ le=10000,
382
+ description="The maximum number of annotations to return in a single request",
383
+ ),
384
+ ) -> TraceAnnotationsResponseBody:
385
+ trace_ids = list({*trace_ids})
386
+ if len(trace_ids) > MAX_TRACE_IDS:
387
+ raise HTTPException(
388
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
389
+ detail=f"Too many trace_ids supplied: {len(trace_ids)} (max {MAX_TRACE_IDS})",
390
+ )
391
+
392
+ async with request.app.state.db() as session:
393
+ project = await _get_project_by_identifier(session, project_identifier)
394
+ if not project:
395
+ raise HTTPException(
396
+ status_code=HTTP_404_NOT_FOUND,
397
+ detail=f"Project with identifier {project_identifier} not found",
398
+ )
399
+
400
+ # Build the base query
401
+ where_conditions = [
402
+ models.Project.id == project.id,
403
+ models.Trace.trace_id.in_(trace_ids),
404
+ ]
405
+
406
+ # Add annotation name filtering
407
+ if include_annotation_names:
408
+ where_conditions.append(models.TraceAnnotation.name.in_(include_annotation_names))
409
+
410
+ if exclude_annotation_names:
411
+ where_conditions.append(models.TraceAnnotation.name.not_in(exclude_annotation_names))
412
+
413
+ stmt = (
414
+ select(models.Trace.trace_id, models.TraceAnnotation)
415
+ .join(models.Project, models.Trace.project_rowid == models.Project.id)
416
+ .join(models.TraceAnnotation, models.TraceAnnotation.trace_rowid == models.Trace.id)
417
+ .where(*where_conditions)
418
+ .order_by(models.TraceAnnotation.id.desc())
419
+ .limit(limit + 1)
420
+ )
421
+
422
+ if cursor:
423
+ try:
424
+ cursor_id = int(GlobalID.from_id(cursor).node_id)
425
+ except ValueError:
426
+ raise HTTPException(
427
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
428
+ detail="Invalid cursor value",
429
+ )
430
+ stmt = stmt.where(models.TraceAnnotation.id <= cursor_id)
431
+
432
+ rows: list[tuple[str, models.TraceAnnotation]] = [
433
+ r async for r in (await session.stream(stmt))
434
+ ]
435
+
436
+ next_cursor: Optional[str] = None
437
+ if len(rows) == limit + 1:
438
+ *rows, extra = rows
439
+ next_cursor = str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(extra[1].id)))
440
+
441
+ if not rows:
442
+ traces_exist = await session.scalar(
443
+ select(
444
+ exists().where(
445
+ models.Trace.trace_id.in_(trace_ids),
446
+ models.Trace.project_rowid == project.id,
447
+ )
448
+ )
449
+ )
450
+ if not traces_exist:
451
+ raise HTTPException(
452
+ detail="None of the supplied trace_ids exist in this project",
453
+ status_code=HTTP_404_NOT_FOUND,
454
+ )
455
+
456
+ return TraceAnnotationsResponseBody(data=[], next_cursor=None)
457
+
458
+ data = [
459
+ TraceAnnotation(
460
+ id=str(GlobalID(TRACE_ANNOTATION_NODE_NAME, str(anno.id))),
461
+ trace_id=trace_id,
462
+ name=anno.name,
463
+ result=AnnotationResult(
464
+ label=anno.label,
465
+ score=anno.score,
466
+ explanation=anno.explanation,
467
+ ),
468
+ metadata=anno.metadata_,
469
+ annotator_kind=anno.annotator_kind,
470
+ created_at=anno.created_at,
471
+ updated_at=anno.updated_at,
472
+ identifier=anno.identifier,
473
+ source=anno.source,
474
+ user_id=str(GlobalID("User", str(anno.user_id))) if anno.user_id else None,
475
+ )
476
+ for trace_id, anno in rows
477
+ ]
478
+
479
+ return TraceAnnotationsResponseBody(data=data, next_cursor=next_cursor)
480
+
481
+
482
+ @router.get(
483
+ "/projects/{project_identifier}/session_annotations",
484
+ operation_id="listSessionAnnotationsBySessionIds",
485
+ summary="Get session annotations for a list of session_ids.",
486
+ status_code=HTTP_200_OK,
487
+ responses=add_errors_to_responses(
488
+ [
489
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Project or sessions not found"},
490
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid parameters"},
491
+ ]
492
+ ),
493
+ )
494
+ async def list_session_annotations(
495
+ request: Request,
496
+ project_identifier: str = Path(
497
+ description=(
498
+ "The project identifier: either project ID or project name. If using a project name as "
499
+ "the identifier, it cannot contain slash (/), question mark (?), or pound sign (#) "
500
+ "characters."
501
+ )
502
+ ),
503
+ session_ids: list[str] = Query(
504
+ ..., min_length=1, description="One or more session id to fetch annotations for"
505
+ ),
506
+ include_annotation_names: Optional[list[str]] = Query(
507
+ default=None,
508
+ description=(
509
+ "Optional list of annotation names to include. If provided, only annotations with "
510
+ "these names will be returned. 'note' annotations are excluded by default unless "
511
+ "explicitly included in this list."
512
+ ),
513
+ ),
514
+ exclude_annotation_names: Optional[list[str]] = Query(
515
+ default=None, description="Optional list of annotation names to exclude from results."
516
+ ),
517
+ cursor: Optional[str] = Query(default=None, description="A cursor for pagination"),
518
+ limit: int = Query(
519
+ default=10,
520
+ gt=0,
521
+ le=10000,
522
+ description="The maximum number of annotations to return in a single request",
523
+ ),
524
+ ) -> SessionAnnotationsResponseBody:
525
+ session_ids = list({*session_ids})
526
+ if len(session_ids) > MAX_SESSION_IDS:
527
+ raise HTTPException(
528
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
529
+ detail=f"Too many session_ids supplied: {len(session_ids)} (max {MAX_SESSION_IDS})",
530
+ )
531
+
532
+ async with request.app.state.db() as session:
533
+ project = await _get_project_by_identifier(session, project_identifier)
534
+ if not project:
535
+ raise HTTPException(
536
+ status_code=HTTP_404_NOT_FOUND,
537
+ detail=f"Project with identifier {project_identifier} not found",
538
+ )
539
+
540
+ # Build the base query
541
+ where_conditions = [
542
+ models.Project.id == project.id,
543
+ models.ProjectSession.session_id.in_(session_ids),
544
+ ]
545
+
546
+ # Add annotation name filtering
547
+ if include_annotation_names:
548
+ where_conditions.append(
549
+ models.ProjectSessionAnnotation.name.in_(include_annotation_names)
550
+ )
551
+
552
+ if exclude_annotation_names:
553
+ where_conditions.append(
554
+ models.ProjectSessionAnnotation.name.not_in(exclude_annotation_names)
555
+ )
556
+
557
+ stmt = (
558
+ select(models.ProjectSession.session_id, models.ProjectSessionAnnotation)
559
+ .join(models.Project, models.ProjectSession.project_id == models.Project.id)
560
+ .join(
561
+ models.ProjectSessionAnnotation,
562
+ models.ProjectSessionAnnotation.project_session_id == models.ProjectSession.id,
563
+ )
564
+ .where(*where_conditions)
565
+ .order_by(models.ProjectSessionAnnotation.id.desc())
566
+ .limit(limit + 1)
567
+ )
568
+
569
+ if cursor:
570
+ try:
571
+ cursor_id = int(GlobalID.from_id(cursor).node_id)
572
+ except ValueError:
573
+ raise HTTPException(
574
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
575
+ detail="Invalid cursor value",
576
+ )
577
+ stmt = stmt.where(models.ProjectSessionAnnotation.id <= cursor_id)
578
+
579
+ rows: list[tuple[str, models.ProjectSessionAnnotation]] = [
580
+ r async for r in (await session.stream(stmt))
581
+ ]
582
+
583
+ next_cursor: Optional[str] = None
584
+ if len(rows) == limit + 1:
585
+ *rows, extra = rows
586
+ next_cursor = str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(extra[1].id)))
587
+
588
+ if not rows:
589
+ sessions_exist = await session.scalar(
590
+ select(
591
+ exists().where(
592
+ models.ProjectSession.session_id.in_(session_ids),
593
+ models.ProjectSession.project_id == project.id,
594
+ )
595
+ )
596
+ )
597
+ if not sessions_exist:
598
+ raise HTTPException(
599
+ detail="None of the supplied session_ids exist in this project",
600
+ status_code=HTTP_404_NOT_FOUND,
601
+ )
602
+
603
+ return SessionAnnotationsResponseBody(data=[], next_cursor=None)
604
+
605
+ data = [
606
+ SessionAnnotation(
607
+ id=str(GlobalID(SESSION_ANNOTATION_NODE_NAME, str(anno.id))),
608
+ session_id=session_id,
609
+ name=anno.name,
610
+ result=AnnotationResult(
611
+ label=anno.label,
612
+ score=anno.score,
613
+ explanation=anno.explanation,
614
+ ),
615
+ metadata=anno.metadata_,
616
+ annotator_kind=anno.annotator_kind,
617
+ created_at=anno.created_at,
618
+ updated_at=anno.updated_at,
619
+ identifier=anno.identifier,
620
+ source=anno.source,
621
+ user_id=str(GlobalID(USER_NODE_NAME, str(anno.user_id))) if anno.user_id else None,
622
+ )
623
+ for session_id, anno in rows
624
+ ]
625
+
626
+ return SessionAnnotationsResponseBody(data=data, next_cursor=next_cursor)
@@ -48,6 +48,7 @@ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVer
48
48
  from phoenix.server.api.types.node import from_global_id_with_expected_type
49
49
  from phoenix.server.api.utils import delete_projects, delete_traces
50
50
  from phoenix.server.authorization import is_not_locked
51
+ from phoenix.server.bearer_auth import PhoenixUser
51
52
  from phoenix.server.dml_event import DatasetInsertEvent
52
53
 
53
54
  from .models import V1RoutesBaseModel
@@ -478,6 +479,9 @@ async def upload_dataset(
478
479
  detail="Invalid request Content-Type",
479
480
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
480
481
  )
482
+ user_id: Optional[int] = None
483
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
484
+ user_id = int(request.user.identity)
481
485
  operation = cast(
482
486
  Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
483
487
  partial(
@@ -486,6 +490,7 @@ async def upload_dataset(
486
490
  action=action,
487
491
  name=name,
488
492
  description=description,
493
+ user_id=user_id,
489
494
  ),
490
495
  )
491
496
  if sync:
@@ -15,10 +15,14 @@ from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_422_UNPROCESS
15
15
  from strawberry.relay import GlobalID
16
16
 
17
17
  from phoenix.db import models
18
- from phoenix.db.helpers import SupportedSQLDialect
18
+ from phoenix.db.helpers import (
19
+ SupportedSQLDialect,
20
+ insert_experiment_with_examples_snapshot,
21
+ )
19
22
  from phoenix.db.insertion.helpers import insert_on_conflict
20
23
  from phoenix.server.api.types.node import from_global_id_with_expected_type
21
24
  from phoenix.server.authorization import is_not_locked
25
+ from phoenix.server.bearer_auth import PhoenixUser
22
26
  from phoenix.server.dml_event import ExperimentInsertEvent
23
27
  from phoenix.server.experiments.utils import generate_experiment_project_name
24
28
 
@@ -157,6 +161,9 @@ async def create_experiment(
157
161
  detail=f"DatasetVersion with ID {dataset_version_globalid} does not exist",
158
162
  status_code=HTTP_404_NOT_FOUND,
159
163
  )
164
+ user_id: Optional[int] = None
165
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
166
+ user_id = int(request.user.identity)
160
167
 
161
168
  # generate a semi-unique name for the experiment
162
169
  experiment_name = request_body.name or _generate_experiment_name(dataset_name)
@@ -172,9 +179,9 @@ async def create_experiment(
172
179
  repetitions=request_body.repetitions,
173
180
  metadata_=request_body.metadata or {},
174
181
  project_name=project_name,
182
+ user_id=user_id,
175
183
  )
176
- session.add(experiment)
177
- await session.flush()
184
+ await insert_experiment_with_examples_snapshot(session, experiment)
178
185
 
179
186
  dialect = SupportedSQLDialect(session.bind.dialect.name)
180
187
  project_rowid = await session.scalar(