arize-phoenix 11.37.0__py3-none-any.whl → 12.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (75) hide show
  1. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/RECORD +74 -53
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/helpers.py +2 -2
  9. phoenix/db/insertion/session_annotation.py +176 -0
  10. phoenix/db/insertion/types.py +30 -0
  11. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  12. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  13. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  14. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  15. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  16. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  17. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  18. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  19. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  20. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  21. phoenix/db/models.py +285 -46
  22. phoenix/server/api/context.py +13 -2
  23. phoenix/server/api/dataloaders/__init__.py +6 -2
  24. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  25. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  26. phoenix/server/api/dataloaders/table_fields.py +2 -2
  27. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  28. phoenix/server/api/helpers/playground_clients.py +65 -35
  29. phoenix/server/api/helpers/playground_spans.py +2 -1
  30. phoenix/server/api/helpers/playground_users.py +26 -0
  31. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  32. phoenix/server/api/input_types/ChatCompletionInput.py +2 -0
  33. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  34. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  35. phoenix/server/api/mutations/__init__.py +6 -0
  36. phoenix/server/api/mutations/chat_mutations.py +24 -9
  37. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  38. phoenix/server/api/mutations/dataset_split_mutations.py +387 -0
  39. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  40. phoenix/server/api/queries.py +32 -0
  41. phoenix/server/api/routers/v1/__init__.py +2 -0
  42. phoenix/server/api/routers/v1/annotations.py +320 -0
  43. phoenix/server/api/routers/v1/datasets.py +5 -0
  44. phoenix/server/api/routers/v1/experiments.py +10 -3
  45. phoenix/server/api/routers/v1/sessions.py +111 -0
  46. phoenix/server/api/routers/v1/traces.py +1 -2
  47. phoenix/server/api/routers/v1/users.py +7 -0
  48. phoenix/server/api/subscriptions.py +25 -7
  49. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  50. phoenix/server/api/types/DatasetExample.py +11 -0
  51. phoenix/server/api/types/DatasetSplit.py +32 -0
  52. phoenix/server/api/types/Experiment.py +0 -4
  53. phoenix/server/api/types/Project.py +16 -0
  54. phoenix/server/api/types/ProjectSession.py +88 -3
  55. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  56. phoenix/server/api/types/Span.py +5 -5
  57. phoenix/server/api/types/Trace.py +61 -0
  58. phoenix/server/app.py +6 -2
  59. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  60. phoenix/server/dml_event.py +13 -0
  61. phoenix/server/static/.vite/manifest.json +39 -39
  62. phoenix/server/static/assets/{components-CFzdBkk_.js → components-Dl9SUw1U.js} +371 -327
  63. phoenix/server/static/assets/{index-DayUA9lQ.js → index-CqQS0dTo.js} +2 -2
  64. phoenix/server/static/assets/{pages-CvUhOO9h.js → pages-DKSjVA_E.js} +771 -518
  65. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-CtbHQYl8.js} +1 -1
  66. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-D-lWOwIS.js} +1 -1
  67. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-BRBpy3_z.js} +3 -3
  68. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts--KdSwB3m.js} +1 -1
  69. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-CvRzZnIo.js} +1 -1
  70. phoenix/version.py +1 -1
  71. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  72. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/WHEEL +0 -0
  73. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/entry_points.txt +0 -0
  74. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/IP_NOTICE +0 -0
  75. {arize_phoenix-11.37.0.dist-info → arize_phoenix-12.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import Optional
5
+
6
+ from fastapi import APIRouter, Depends, HTTPException, Query
7
+ from pydantic import Field
8
+ from sqlalchemy import select
9
+ from starlette.requests import Request
10
+ from starlette.status import HTTP_404_NOT_FOUND
11
+
12
+ from phoenix.db import models
13
+ from phoenix.db.helpers import SupportedSQLDialect
14
+ from phoenix.db.insertion.helpers import as_kv, insert_on_conflict
15
+ from phoenix.server.api.routers.v1.models import V1RoutesBaseModel
16
+ from phoenix.server.authorization import is_not_locked
17
+ from phoenix.server.bearer_auth import PhoenixUser
18
+
19
+ from .annotations import SessionAnnotationData
20
+ from .utils import RequestBody, ResponseBody, add_errors_to_responses
21
+
22
+ router = APIRouter(tags=["sessions"])
23
+
24
+
25
+ class InsertedSessionAnnotation(V1RoutesBaseModel):
26
+ id: str = Field(description="The ID of the inserted session annotation")
27
+
28
+
29
+ class AnnotateSessionsRequestBody(RequestBody[list[SessionAnnotationData]]):
30
+ pass
31
+
32
+
33
+ class AnnotateSessionsResponseBody(ResponseBody[list[InsertedSessionAnnotation]]):
34
+ pass
35
+
36
+
37
+ @router.post(
38
+ "/session_annotations",
39
+ dependencies=[Depends(is_not_locked)],
40
+ operation_id="annotateSessions",
41
+ summary="Create session annotations",
42
+ responses=add_errors_to_responses(
43
+ [{"status_code": HTTP_404_NOT_FOUND, "description": "Session not found"}]
44
+ ),
45
+ response_description="Session annotations inserted successfully",
46
+ include_in_schema=True,
47
+ )
48
+ async def annotate_sessions(
49
+ request: Request,
50
+ request_body: AnnotateSessionsRequestBody,
51
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
52
+ ) -> AnnotateSessionsResponseBody:
53
+ if not request_body.data:
54
+ return AnnotateSessionsResponseBody(data=[])
55
+
56
+ user_id: Optional[int] = None
57
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
58
+ user_id = int(request.user.identity)
59
+
60
+ session_annotations = request_body.data
61
+ filtered_session_annotations = list(filter(lambda d: d.name != "note", session_annotations))
62
+ if len(filtered_session_annotations) != len(session_annotations):
63
+ warnings.warn(
64
+ (
65
+ "Session annotations with the name 'note' are not supported in this endpoint. "
66
+ "They will be ignored."
67
+ ),
68
+ UserWarning,
69
+ )
70
+ precursors = [d.as_precursor(user_id=user_id) for d in filtered_session_annotations]
71
+ if not sync:
72
+ await request.state.enqueue_annotations(*precursors)
73
+ return AnnotateSessionsResponseBody(data=[])
74
+
75
+ session_ids = {p.session_id for p in precursors}
76
+ async with request.app.state.db() as session:
77
+ existing_sessions = {
78
+ session_id: rowid
79
+ async for session_id, rowid in await session.stream(
80
+ select(models.ProjectSession.session_id, models.ProjectSession.id).filter(
81
+ models.ProjectSession.session_id.in_(session_ids)
82
+ )
83
+ )
84
+ }
85
+
86
+ missing_session_ids = session_ids - set(existing_sessions.keys())
87
+ # We prefer to fail the entire operation if there are missing sessions in sync mode
88
+ if missing_session_ids:
89
+ raise HTTPException(
90
+ detail=f"Sessions with IDs {', '.join(missing_session_ids)} do not exist.",
91
+ status_code=HTTP_404_NOT_FOUND,
92
+ )
93
+
94
+ async with request.app.state.db() as session:
95
+ inserted_ids = []
96
+ dialect = SupportedSQLDialect(session.bind.dialect.name)
97
+ for p in precursors:
98
+ values = dict(as_kv(p.as_insertable(existing_sessions[p.session_id]).row))
99
+ session_annotation_id = await session.scalar(
100
+ insert_on_conflict(
101
+ values,
102
+ dialect=dialect,
103
+ table=models.ProjectSessionAnnotation,
104
+ unique_by=("name", "project_session_id", "identifier"),
105
+ ).returning(models.ProjectSessionAnnotation.id)
106
+ )
107
+ inserted_ids.append(session_annotation_id)
108
+
109
+ return AnnotateSessionsResponseBody(
110
+ data=[InsertedSessionAnnotation(id=str(inserted_id)) for inserted_id in inserted_ids]
111
+ )
@@ -144,12 +144,11 @@ class AnnotateTracesResponseBody(ResponseBody[list[InsertedTraceAnnotation]]):
144
144
  responses=add_errors_to_responses(
145
145
  [{"status_code": HTTP_404_NOT_FOUND, "description": "Trace not found"}]
146
146
  ),
147
- include_in_schema=False,
148
147
  )
149
148
  async def annotate_traces(
150
149
  request: Request,
151
150
  request_body: AnnotateTracesRequestBody,
152
- sync: bool = Query(default=True, description="If true, fulfill request synchronously."),
151
+ sync: bool = Query(default=False, description="If true, fulfill request synchronously."),
153
152
  ) -> AnnotateTracesResponseBody:
154
153
  if not request_body.data:
155
154
  return AnnotateTracesResponseBody(data=[])
@@ -217,6 +217,13 @@ async def create_user(
217
217
  detail="Cannot create users with SYSTEM role",
218
218
  )
219
219
 
220
+ # TODO: Implement VIEWER role
221
+ if role == "VIEWER":
222
+ raise HTTPException(
223
+ status_code=HTTP_400_BAD_REQUEST,
224
+ detail="VIEWER role not yet implemented",
225
+ )
226
+
220
227
  user: models.User
221
228
  if isinstance(user_data, LocalUserData):
222
229
  password = (user_data.password or secrets.token_hex()).strip()
@@ -26,6 +26,7 @@ from typing_extensions import TypeAlias, assert_never
26
26
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
+ from phoenix.db.helpers import insert_experiment_with_examples_snapshot
29
30
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
30
31
  from phoenix.server.api.context import Context
31
32
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
@@ -43,6 +44,7 @@ from phoenix.server.api.helpers.playground_spans import (
43
44
  get_db_trace,
44
45
  streaming_llm_span,
45
46
  )
47
+ from phoenix.server.api.helpers.playground_users import get_user
46
48
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
47
49
  from phoenix.server.api.input_types.ChatCompletionInput import (
48
50
  ChatCompletionInput,
@@ -302,18 +304,19 @@ class Subscription:
302
304
  description="Traces from prompt playground",
303
305
  )
304
306
  )
307
+ user_id = get_user(info)
305
308
  experiment = models.Experiment(
306
309
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
307
310
  dataset_version_id=resolved_version_id,
308
311
  name=input.experiment_name
309
312
  or _default_playground_experiment_name(input.prompt_name),
310
313
  description=input.experiment_description,
311
- repetitions=1,
314
+ repetitions=input.repetitions,
312
315
  metadata_=input.experiment_metadata or dict(),
313
316
  project_name=project_name,
317
+ user_id=user_id,
314
318
  )
315
- session.add(experiment)
316
- await session.flush()
319
+ await insert_experiment_with_examples_snapshot(session, experiment)
317
320
  yield ChatCompletionSubscriptionExperiment(
318
321
  experiment=to_gql_experiment(experiment)
319
322
  ) # eagerly yields experiment so it can be linked by consumers of the subscription
@@ -327,11 +330,13 @@ class Subscription:
327
330
  llm_client=llm_client,
328
331
  revision=revision,
329
332
  results=results,
333
+ repetition_number=repetition_number,
330
334
  experiment_id=experiment.id,
331
335
  project_id=playground_project_id,
332
336
  ),
333
337
  )
334
338
  for revision in revisions
339
+ for repetition_number in range(1, input.repetitions + 1)
335
340
  ]
336
341
  in_progress: list[
337
342
  tuple[
@@ -409,6 +414,7 @@ async def _stream_chat_completion_over_dataset_example(
409
414
  input: ChatCompletionOverDatasetInput,
410
415
  llm_client: PlaygroundStreamingClient,
411
416
  revision: models.DatasetExampleRevision,
417
+ repetition_number: int,
412
418
  results: asyncio.Queue[ChatCompletionResult],
413
419
  experiment_id: int,
414
420
  project_id: int,
@@ -435,7 +441,11 @@ async def _stream_chat_completion_over_dataset_example(
435
441
  )
436
442
  except TemplateFormatterError as error:
437
443
  format_end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
438
- yield ChatCompletionSubscriptionError(message=str(error), dataset_example_id=example_id)
444
+ yield ChatCompletionSubscriptionError(
445
+ message=str(error),
446
+ dataset_example_id=example_id,
447
+ repetition_number=repetition_number,
448
+ )
439
449
  await results.put(
440
450
  (
441
451
  example_id,
@@ -445,7 +455,7 @@ async def _stream_chat_completion_over_dataset_example(
445
455
  dataset_example_id=revision.dataset_example_id,
446
456
  trace_id=None,
447
457
  output={},
448
- repetition_number=1,
458
+ repetition_number=repetition_number,
449
459
  start_time=format_start_time,
450
460
  end_time=format_end_time,
451
461
  error=str(error),
@@ -465,17 +475,24 @@ async def _stream_chat_completion_over_dataset_example(
465
475
  ):
466
476
  span.add_response_chunk(chunk)
467
477
  chunk.dataset_example_id = example_id
478
+ chunk.repetition_number = repetition_number
468
479
  yield chunk
469
480
  span.set_attributes(llm_client.attributes)
470
481
  db_trace = get_db_trace(span, project_id)
471
482
  db_span = get_db_span(span, db_trace)
472
483
  db_run = get_db_experiment_run(
473
- db_span, db_trace, experiment_id=experiment_id, example_id=revision.dataset_example_id
484
+ db_span,
485
+ db_trace,
486
+ experiment_id=experiment_id,
487
+ example_id=revision.dataset_example_id,
488
+ repetition_number=repetition_number,
474
489
  )
475
490
  await results.put((example_id, db_span, db_run))
476
491
  if span.status_message is not None:
477
492
  yield ChatCompletionSubscriptionError(
478
- message=span.status_message, dataset_example_id=example_id
493
+ message=span.status_message,
494
+ dataset_example_id=example_id,
495
+ repetition_number=repetition_number,
479
496
  )
480
497
 
481
498
 
@@ -511,6 +528,7 @@ async def _chat_completion_result_payloads(
511
528
  span=Span(span_rowid=span.id, db_span=span) if span else None,
512
529
  experiment_run=to_gql_experiment_run(run),
513
530
  dataset_example_id=example_id,
531
+ repetition_number=run.repetition_number,
514
532
  )
515
533
 
516
534
 
@@ -11,6 +11,7 @@ from .Span import Span
11
11
  @strawberry.interface
12
12
  class ChatCompletionSubscriptionPayload:
13
13
  dataset_example_id: Optional[GlobalID] = None
14
+ repetition_number: Optional[int] = None
14
15
 
15
16
 
16
17
  @strawberry.type
@@ -12,6 +12,7 @@ from phoenix.db import models
12
12
  from phoenix.server.api.context import Context
13
13
  from phoenix.server.api.exceptions import BadRequest
14
14
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
15
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
15
16
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
16
17
  from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
17
18
  ExperimentRepeatedRunGroup,
@@ -131,3 +132,13 @@ class DatasetExample(Node):
131
132
  )
132
133
  for group in repeated_run_groups
133
134
  ]
135
+
136
+ @strawberry.field
137
+ async def dataset_splits(
138
+ self,
139
+ info: Info[Context, None],
140
+ ) -> list[DatasetSplit]:
141
+ return [
142
+ to_gql_dataset_split(split)
143
+ for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
144
+ ]
@@ -0,0 +1,32 @@
1
+ from datetime import datetime
2
+ from typing import ClassVar, Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.scalars import JSON
7
+
8
+ from phoenix.db import models
9
+
10
+
11
+ @strawberry.type
12
+ class DatasetSplit(Node):
13
+ _table: ClassVar[type[models.Base]] = models.DatasetSplit
14
+ id_attr: NodeID[int]
15
+ name: str
16
+ description: Optional[str]
17
+ metadata: JSON
18
+ color: str
19
+ created_at: datetime
20
+ updated_at: datetime
21
+
22
+
23
+ def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
24
+ return DatasetSplit(
25
+ id_attr=dataset_split.id,
26
+ name=dataset_split.name,
27
+ description=dataset_split.description,
28
+ color=dataset_split.color or "#ffffff",
29
+ metadata=dataset_split.metadata_,
30
+ created_at=dataset_split.created_at,
31
+ updated_at=dataset_split.updated_at,
32
+ )
@@ -193,10 +193,6 @@ class Experiment(Node):
193
193
  async for token_type, is_prompt, cost, tokens in data
194
194
  ]
195
195
 
196
- @strawberry.field
197
- async def repetition_count(self, info: Info[Context, None]) -> int:
198
- return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
199
-
200
196
 
201
197
  def to_gql_experiment(
202
198
  experiment: models.Experiment,
@@ -588,6 +588,22 @@ class Project(Node):
588
588
  async with info.context.db() as session:
589
589
  return list(await session.scalars(stmt))
590
590
 
591
+ @strawberry.field(
592
+ description="Names of all available annotations for sessions. "
593
+ "(The list contains no duplicates.)"
594
+ ) # type: ignore
595
+ async def session_annotation_names(
596
+ self,
597
+ info: Info[Context, None],
598
+ ) -> list[str]:
599
+ stmt = (
600
+ select(distinct(models.ProjectSessionAnnotation.name))
601
+ .join(models.ProjectSession)
602
+ .where(models.ProjectSession.project_id == self.project_rowid)
603
+ )
604
+ async with info.context.db() as session:
605
+ return list(await session.scalars(stmt))
606
+
591
607
  @strawberry.field(
592
608
  description="Names of available document evaluations.",
593
609
  ) # type: ignore
@@ -1,14 +1,19 @@
1
+ from collections import defaultdict
2
+ from dataclasses import asdict, dataclass
1
3
  from datetime import datetime
2
4
  from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
3
5
 
6
+ import pandas as pd
4
7
  import strawberry
5
8
  from openinference.semconv.trace import SpanAttributes
6
9
  from sqlalchemy import select
7
10
  from strawberry import UNSET, Info, Private, lazy
8
- from strawberry.relay import Connection, GlobalID, Node, NodeID
11
+ from strawberry.relay import Connection, Node, NodeID
9
12
 
10
13
  from phoenix.db import models
11
14
  from phoenix.server.api.context import Context
15
+ from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
16
+ from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
12
17
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
13
18
  from phoenix.server.api.types.MimeType import MimeType
14
19
  from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
@@ -18,6 +23,8 @@ from phoenix.server.api.types.SpanIOValue import SpanIOValue
18
23
  from phoenix.server.api.types.TokenUsage import TokenUsage
19
24
 
20
25
  if TYPE_CHECKING:
26
+ from phoenix.server.api.types.Project import Project
27
+ from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
21
28
  from phoenix.server.api.types.Trace import Trace
22
29
 
23
30
 
@@ -31,10 +38,13 @@ class ProjectSession(Node):
31
38
  end_time: datetime
32
39
 
33
40
  @strawberry.field
34
- async def project_id(self) -> GlobalID:
41
+ async def project(
42
+ self,
43
+ info: Info[Context, None],
44
+ ) -> Annotated["Project", lazy(".Project")]:
35
45
  from phoenix.server.api.types.Project import Project
36
46
 
37
- return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
47
+ return Project(project_rowid=self.project_rowid)
38
48
 
39
49
  @strawberry.field
40
50
  async def num_traces(
@@ -165,6 +175,81 @@ class ProjectSession(Node):
165
175
  for entry in summary
166
176
  ]
167
177
 
178
+ @strawberry.field
179
+ async def session_annotations(
180
+ self,
181
+ info: Info[Context, None],
182
+ ) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
183
+ """Get all annotations for this session."""
184
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
185
+ to_gql_project_session_annotation,
186
+ )
187
+
188
+ stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id_attr)
189
+ async with info.context.db() as session:
190
+ annotations = await session.stream_scalars(stmt)
191
+ return [
192
+ to_gql_project_session_annotation(annotation) async for annotation in annotations
193
+ ]
194
+
195
+ @strawberry.field(
196
+ description="Summarizes each annotation (by name) associated with the session"
197
+ ) # type: ignore
198
+ async def session_annotation_summaries(
199
+ self,
200
+ info: Info[Context, None],
201
+ filter: Optional[AnnotationFilter] = None,
202
+ ) -> list[AnnotationSummary]:
203
+ """
204
+ Retrieves and summarizes annotations associated with this span.
205
+
206
+ This method aggregates annotation data by name and label, calculating metrics
207
+ such as count of occurrences and sum of scores. The results are organized
208
+ into a structured format that can be easily converted to a DataFrame.
209
+
210
+ Args:
211
+ info: GraphQL context information
212
+ filter: Optional filter to apply to annotations before processing
213
+
214
+ Returns:
215
+ A list of AnnotationSummary objects, each containing:
216
+ - name: The name of the annotation
217
+ - data: A list of dictionaries with label statistics
218
+ """
219
+ # Load all annotations for this span from the data loader
220
+ annotations = await info.context.data_loaders.session_annotations_by_session.load(
221
+ self.id_attr
222
+ )
223
+
224
+ # Apply filter if provided to narrow down the annotations
225
+ if filter:
226
+ annotations = [
227
+ annotation for annotation in annotations if satisfies_filter(annotation, filter)
228
+ ]
229
+
230
+ @dataclass
231
+ class Metrics:
232
+ record_count: int = 0
233
+ label_count: int = 0
234
+ score_sum: float = 0
235
+ score_count: int = 0
236
+
237
+ summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
238
+ lambda: defaultdict(Metrics)
239
+ )
240
+ for annotation in annotations:
241
+ metrics = summaries[annotation.name][annotation.label]
242
+ metrics.record_count += 1
243
+ metrics.label_count += int(annotation.label is not None)
244
+ metrics.score_sum += annotation.score or 0
245
+ metrics.score_count += int(annotation.score is not None)
246
+
247
+ result: list[AnnotationSummary] = []
248
+ for name, label_metrics in summaries.items():
249
+ rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
250
+ result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
251
+ return result
252
+
168
253
 
169
254
  def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
170
255
  return ProjectSession(
@@ -0,0 +1,68 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry import Private
5
+ from strawberry.relay import GlobalID, Node, NodeID
6
+ from strawberry.scalars import JSON
7
+ from strawberry.types import Info
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.api.context import Context
11
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
12
+
13
+ from .AnnotationSource import AnnotationSource
14
+ from .User import User, to_gql_user
15
+
16
+
17
+ @strawberry.type
18
+ class ProjectSessionAnnotation(Node):
19
+ id_attr: NodeID[int]
20
+ user_id: Private[Optional[int]]
21
+ name: str
22
+ annotator_kind: AnnotatorKind
23
+ label: Optional[str]
24
+ score: Optional[float]
25
+ explanation: Optional[str]
26
+ metadata: JSON
27
+ _project_session_id: Private[Optional[int]]
28
+ identifier: str
29
+ source: AnnotationSource
30
+
31
+ @strawberry.field
32
+ async def project_session_id(self) -> GlobalID:
33
+ from phoenix.server.api.types.ProjectSession import ProjectSession
34
+
35
+ return GlobalID(type_name=ProjectSession.__name__, node_id=str(self._project_session_id))
36
+
37
+ @strawberry.field
38
+ async def user(
39
+ self,
40
+ info: Info[Context, None],
41
+ ) -> Optional[User]:
42
+ if self.user_id is None:
43
+ return None
44
+ user = await info.context.data_loaders.users.load(self.user_id)
45
+ if user is None:
46
+ return None
47
+ return to_gql_user(user)
48
+
49
+
50
+ def to_gql_project_session_annotation(
51
+ annotation: models.ProjectSessionAnnotation,
52
+ ) -> ProjectSessionAnnotation:
53
+ """
54
+ Converts an ORM projectSession annotation to a GraphQL ProjectSessionAnnotation.
55
+ """
56
+ return ProjectSessionAnnotation(
57
+ id_attr=annotation.id,
58
+ user_id=annotation.user_id,
59
+ _project_session_id=annotation.project_session_id,
60
+ name=annotation.name,
61
+ annotator_kind=AnnotatorKind(annotation.annotator_kind),
62
+ label=annotation.label,
63
+ score=annotation.score,
64
+ explanation=annotation.explanation,
65
+ metadata=JSON(annotation.metadata_),
66
+ identifier=annotation.identifier,
67
+ source=AnnotationSource(annotation.source),
68
+ )
@@ -23,11 +23,11 @@ from phoenix.server.api.helpers.dataset_helpers import (
23
23
  get_dataset_example_input,
24
24
  get_dataset_example_output,
25
25
  )
26
- from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
27
- from phoenix.server.api.input_types.SpanAnnotationFilter import (
28
- SpanAnnotationFilter,
26
+ from phoenix.server.api.input_types.AnnotationFilter import (
27
+ AnnotationFilter,
29
28
  satisfies_filter,
30
29
  )
30
+ from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
31
31
  from phoenix.server.api.input_types.SpanAnnotationSort import (
32
32
  SpanAnnotationColumn,
33
33
  SpanAnnotationSort,
@@ -547,7 +547,7 @@ class Span(Node):
547
547
  self,
548
548
  info: Info[Context, None],
549
549
  sort: Optional[SpanAnnotationSort] = UNSET,
550
- filter: Optional[SpanAnnotationFilter] = None,
550
+ filter: Optional[AnnotationFilter] = None,
551
551
  ) -> list[SpanAnnotation]:
552
552
  span_id = self.span_rowid
553
553
  annotations = await info.context.data_loaders.span_annotations.load(span_id)
@@ -580,7 +580,7 @@ class Span(Node):
580
580
  async def span_annotation_summaries(
581
581
  self,
582
582
  info: Info[Context, None],
583
- filter: Optional[SpanAnnotationFilter] = None,
583
+ filter: Optional[AnnotationFilter] = None,
584
584
  ) -> list[AnnotationSummary]:
585
585
  """
586
586
  Retrieves and summarizes annotations associated with this span.
@@ -1,8 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import defaultdict
4
+ from dataclasses import asdict, dataclass
3
5
  from datetime import datetime
4
6
  from typing import TYPE_CHECKING, Annotated, Optional, Union
5
7
 
8
+ import pandas as pd
6
9
  import strawberry
7
10
  from openinference.semconv.trace import SpanAttributes
8
11
  from sqlalchemy import desc, select
@@ -13,7 +16,9 @@ from typing_extensions import TypeAlias
13
16
 
14
17
  from phoenix.db import models
15
18
  from phoenix.server.api.context import Context
19
+ from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
16
20
  from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
21
+ from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
17
22
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
18
23
  from phoenix.server.api.types.pagination import (
19
24
  ConnectionArgs,
@@ -229,6 +234,62 @@ class Trace(Node):
229
234
  annotations = await session.scalars(stmt)
230
235
  return [to_gql_trace_annotation(annotation) for annotation in annotations]
231
236
 
237
+ @strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
238
+ async def trace_annotation_summaries(
239
+ self,
240
+ info: Info[Context, None],
241
+ filter: Optional[AnnotationFilter] = None,
242
+ ) -> list[AnnotationSummary]:
243
+ """
244
+ Retrieves and summarizes annotations associated with this span.
245
+
246
+ This method aggregates annotation data by name and label, calculating metrics
247
+ such as count of occurrences and sum of scores. The results are organized
248
+ into a structured format that can be easily converted to a DataFrame.
249
+
250
+ Args:
251
+ info: GraphQL context information
252
+ filter: Optional filter to apply to annotations before processing
253
+
254
+ Returns:
255
+ A list of AnnotationSummary objects, each containing:
256
+ - name: The name of the annotation
257
+ - data: A list of dictionaries with label statistics
258
+ """
259
+ # Load all annotations for this span from the data loader
260
+ annotations = await info.context.data_loaders.trace_annotations_by_trace.load(
261
+ self.trace_rowid
262
+ )
263
+
264
+ # Apply filter if provided to narrow down the annotations
265
+ if filter:
266
+ annotations = [
267
+ annotation for annotation in annotations if satisfies_filter(annotation, filter)
268
+ ]
269
+
270
+ @dataclass
271
+ class Metrics:
272
+ record_count: int = 0
273
+ label_count: int = 0
274
+ score_sum: float = 0
275
+ score_count: int = 0
276
+
277
+ summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
278
+ lambda: defaultdict(Metrics)
279
+ )
280
+ for annotation in annotations:
281
+ metrics = summaries[annotation.name][annotation.label]
282
+ metrics.record_count += 1
283
+ metrics.label_count += int(annotation.label is not None)
284
+ metrics.score_sum += annotation.score or 0
285
+ metrics.score_count += int(annotation.score is not None)
286
+
287
+ result: list[AnnotationSummary] = []
288
+ for name, label_metrics in summaries.items():
289
+ rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
290
+ result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
291
+ return result
292
+
232
293
  @strawberry.field
233
294
  async def cost_summary(
234
295
  self,