arize-phoenix 11.38.0__py3-none-any.whl → 12.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (84) hide show
  1. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/METADATA +3 -3
  2. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/RECORD +83 -58
  3. phoenix/config.py +1 -11
  4. phoenix/db/bulk_inserter.py +8 -0
  5. phoenix/db/facilitator.py +1 -1
  6. phoenix/db/helpers.py +202 -33
  7. phoenix/db/insertion/dataset.py +7 -0
  8. phoenix/db/insertion/document_annotation.py +1 -1
  9. phoenix/db/insertion/helpers.py +2 -2
  10. phoenix/db/insertion/session_annotation.py +176 -0
  11. phoenix/db/insertion/span_annotation.py +1 -1
  12. phoenix/db/insertion/trace_annotation.py +1 -1
  13. phoenix/db/insertion/types.py +29 -3
  14. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  15. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  16. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  17. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  18. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  19. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  20. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  21. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  22. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  23. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  24. phoenix/db/models.py +306 -46
  25. phoenix/server/api/context.py +15 -2
  26. phoenix/server/api/dataloaders/__init__.py +8 -2
  27. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  28. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  29. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  30. phoenix/server/api/dataloaders/table_fields.py +2 -2
  31. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  32. phoenix/server/api/helpers/playground_clients.py +66 -35
  33. phoenix/server/api/helpers/playground_users.py +26 -0
  34. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  35. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  36. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  37. phoenix/server/api/mutations/__init__.py +8 -0
  38. phoenix/server/api/mutations/chat_mutations.py +8 -3
  39. phoenix/server/api/mutations/dataset_label_mutations.py +291 -0
  40. phoenix/server/api/mutations/dataset_mutations.py +5 -0
  41. phoenix/server/api/mutations/dataset_split_mutations.py +423 -0
  42. phoenix/server/api/mutations/project_session_annotations_mutations.py +161 -0
  43. phoenix/server/api/queries.py +53 -0
  44. phoenix/server/api/routers/auth.py +5 -5
  45. phoenix/server/api/routers/oauth2.py +5 -23
  46. phoenix/server/api/routers/v1/__init__.py +2 -0
  47. phoenix/server/api/routers/v1/annotations.py +320 -0
  48. phoenix/server/api/routers/v1/datasets.py +5 -0
  49. phoenix/server/api/routers/v1/experiments.py +10 -3
  50. phoenix/server/api/routers/v1/sessions.py +111 -0
  51. phoenix/server/api/routers/v1/traces.py +1 -2
  52. phoenix/server/api/routers/v1/users.py +7 -0
  53. phoenix/server/api/subscriptions.py +5 -2
  54. phoenix/server/api/types/Dataset.py +8 -0
  55. phoenix/server/api/types/DatasetExample.py +18 -0
  56. phoenix/server/api/types/DatasetLabel.py +23 -0
  57. phoenix/server/api/types/DatasetSplit.py +32 -0
  58. phoenix/server/api/types/Experiment.py +0 -4
  59. phoenix/server/api/types/Project.py +16 -0
  60. phoenix/server/api/types/ProjectSession.py +88 -3
  61. phoenix/server/api/types/ProjectSessionAnnotation.py +68 -0
  62. phoenix/server/api/types/Prompt.py +18 -1
  63. phoenix/server/api/types/Span.py +5 -5
  64. phoenix/server/api/types/Trace.py +61 -0
  65. phoenix/server/app.py +13 -14
  66. phoenix/server/cost_tracking/model_cost_manifest.json +132 -2
  67. phoenix/server/dml_event.py +13 -0
  68. phoenix/server/static/.vite/manifest.json +39 -39
  69. phoenix/server/static/assets/{components-BQPHTBfv.js → components-BG6v0EM8.js} +705 -385
  70. phoenix/server/static/assets/{index-BL5BMgJU.js → index-CSVcULw1.js} +13 -13
  71. phoenix/server/static/assets/{pages-C0Y17J0T.js → pages-DgaM7kpM.js} +1356 -1155
  72. phoenix/server/static/assets/{vendor-BdjZxMii.js → vendor-BqTEkGQU.js} +183 -183
  73. phoenix/server/static/assets/{vendor-arizeai-CHYlS8jV.js → vendor-arizeai-DlOj0PQQ.js} +15 -24
  74. phoenix/server/static/assets/{vendor-codemirror-Di6t4HnH.js → vendor-codemirror-B2PHH5yZ.js} +3 -3
  75. phoenix/server/static/assets/{vendor-recharts-C9wCDYj3.js → vendor-recharts-CKsi4IjN.js} +1 -1
  76. phoenix/server/static/assets/{vendor-shiki-MNnmOotP.js → vendor-shiki-DN26BkKE.js} +1 -1
  77. phoenix/server/utils.py +74 -0
  78. phoenix/session/session.py +25 -5
  79. phoenix/version.py +1 -1
  80. phoenix/server/api/dataloaders/experiment_repetition_counts.py +0 -39
  81. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/WHEEL +0 -0
  82. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/entry_points.txt +0 -0
  83. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/IP_NOTICE +0 -0
  84. {arize_phoenix-11.38.0.dist-info → arize_phoenix-12.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -26,6 +26,7 @@ from typing_extensions import TypeAlias, assert_never
26
26
  from phoenix.config import PLAYGROUND_PROJECT_NAME
27
27
  from phoenix.datetime_utils import local_now, normalize_datetime
28
28
  from phoenix.db import models
29
+ from phoenix.db.helpers import insert_experiment_with_examples_snapshot
29
30
  from phoenix.server.api.auth import IsLocked, IsNotReadOnly
30
31
  from phoenix.server.api.context import Context
31
32
  from phoenix.server.api.exceptions import BadRequest, CustomGraphQLError, NotFound
@@ -43,6 +44,7 @@ from phoenix.server.api.helpers.playground_spans import (
43
44
  get_db_trace,
44
45
  streaming_llm_span,
45
46
  )
47
+ from phoenix.server.api.helpers.playground_users import get_user
46
48
  from phoenix.server.api.helpers.prompts.models import PromptTemplateFormat
47
49
  from phoenix.server.api.input_types.ChatCompletionInput import (
48
50
  ChatCompletionInput,
@@ -302,6 +304,7 @@ class Subscription:
302
304
  description="Traces from prompt playground",
303
305
  )
304
306
  )
307
+ user_id = get_user(info)
305
308
  experiment = models.Experiment(
306
309
  dataset_id=from_global_id_with_expected_type(input.dataset_id, Dataset.__name__),
307
310
  dataset_version_id=resolved_version_id,
@@ -311,9 +314,9 @@ class Subscription:
311
314
  repetitions=input.repetitions,
312
315
  metadata_=input.experiment_metadata or dict(),
313
316
  project_name=project_name,
317
+ user_id=user_id,
314
318
  )
315
- session.add(experiment)
316
- await session.flush()
319
+ await insert_experiment_with_examples_snapshot(session, experiment)
317
320
  yield ChatCompletionSubscriptionExperiment(
318
321
  experiment=to_gql_experiment(experiment)
319
322
  ) # eagerly yields experiment so it can be linked by consumers of the subscription
@@ -18,6 +18,7 @@ from phoenix.server.api.types.DatasetExample import DatasetExample
18
18
  from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
19
  DatasetExperimentAnnotationSummary,
20
20
  )
21
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel, to_gql_dataset_label
21
22
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
22
23
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
23
24
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -303,6 +304,13 @@ class Dataset(Node):
303
304
  async for scores_tuple in await session.stream(query)
304
305
  ]
305
306
 
307
+ @strawberry.field
308
+ async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
309
+ return [
310
+ to_gql_dataset_label(label)
311
+ for label in await info.context.data_loaders.dataset_labels.load(self.id_attr)
312
+ ]
313
+
306
314
  @strawberry.field
307
315
  def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
308
316
  return info.context.last_updated_at.get(self._table, self.id_attr)
@@ -12,6 +12,7 @@ from phoenix.db import models
12
12
  from phoenix.server.api.context import Context
13
13
  from phoenix.server.api.exceptions import BadRequest
14
14
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
15
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit, to_gql_dataset_split
15
16
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
16
17
  from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
17
18
  ExperimentRepeatedRunGroup,
@@ -131,3 +132,20 @@ class DatasetExample(Node):
131
132
  )
132
133
  for group in repeated_run_groups
133
134
  ]
135
+
136
+ @strawberry.field
137
+ async def dataset_splits(
138
+ self,
139
+ info: Info[Context, None],
140
+ ) -> list[DatasetSplit]:
141
+ return [
142
+ to_gql_dataset_split(split)
143
+ for split in await info.context.data_loaders.dataset_example_splits.load(self.id_attr)
144
+ ]
145
+
146
+
147
+ def to_gql_dataset_example(example: models.DatasetExample) -> DatasetExample:
148
+ return DatasetExample(
149
+ id_attr=example.id,
150
+ created_at=example.created_at,
151
+ )
@@ -0,0 +1,23 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import Node, NodeID
5
+
6
+ from phoenix.db import models
7
+
8
+
9
+ @strawberry.type
10
+ class DatasetLabel(Node):
11
+ id_attr: NodeID[int]
12
+ name: str
13
+ description: Optional[str]
14
+ color: str
15
+
16
+
17
+ def to_gql_dataset_label(dataset_label: models.DatasetLabel) -> DatasetLabel:
18
+ return DatasetLabel(
19
+ id_attr=dataset_label.id,
20
+ name=dataset_label.name,
21
+ description=dataset_label.description,
22
+ color=dataset_label.color,
23
+ )
@@ -0,0 +1,32 @@
1
+ from datetime import datetime
2
+ from typing import ClassVar, Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import Node, NodeID
6
+ from strawberry.scalars import JSON
7
+
8
+ from phoenix.db import models
9
+
10
+
11
+ @strawberry.type
12
+ class DatasetSplit(Node):
13
+ _table: ClassVar[type[models.Base]] = models.DatasetSplit
14
+ id_attr: NodeID[int]
15
+ name: str
16
+ description: Optional[str]
17
+ metadata: JSON
18
+ color: str
19
+ created_at: datetime
20
+ updated_at: datetime
21
+
22
+
23
+ def to_gql_dataset_split(dataset_split: models.DatasetSplit) -> DatasetSplit:
24
+ return DatasetSplit(
25
+ id_attr=dataset_split.id,
26
+ name=dataset_split.name,
27
+ description=dataset_split.description,
28
+ color=dataset_split.color or "#ffffff",
29
+ metadata=dataset_split.metadata_,
30
+ created_at=dataset_split.created_at,
31
+ updated_at=dataset_split.updated_at,
32
+ )
@@ -193,10 +193,6 @@ class Experiment(Node):
193
193
  async for token_type, is_prompt, cost, tokens in data
194
194
  ]
195
195
 
196
- @strawberry.field
197
- async def repetition_count(self, info: Info[Context, None]) -> int:
198
- return await info.context.data_loaders.experiment_repetition_counts.load(self.id_attr)
199
-
200
196
 
201
197
  def to_gql_experiment(
202
198
  experiment: models.Experiment,
@@ -588,6 +588,22 @@ class Project(Node):
588
588
  async with info.context.db() as session:
589
589
  return list(await session.scalars(stmt))
590
590
 
591
+ @strawberry.field(
592
+ description="Names of all available annotations for sessions. "
593
+ "(The list contains no duplicates.)"
594
+ ) # type: ignore
595
+ async def session_annotation_names(
596
+ self,
597
+ info: Info[Context, None],
598
+ ) -> list[str]:
599
+ stmt = (
600
+ select(distinct(models.ProjectSessionAnnotation.name))
601
+ .join(models.ProjectSession)
602
+ .where(models.ProjectSession.project_id == self.project_rowid)
603
+ )
604
+ async with info.context.db() as session:
605
+ return list(await session.scalars(stmt))
606
+
591
607
  @strawberry.field(
592
608
  description="Names of available document evaluations.",
593
609
  ) # type: ignore
@@ -1,14 +1,19 @@
1
+ from collections import defaultdict
2
+ from dataclasses import asdict, dataclass
1
3
  from datetime import datetime
2
4
  from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
3
5
 
6
+ import pandas as pd
4
7
  import strawberry
5
8
  from openinference.semconv.trace import SpanAttributes
6
9
  from sqlalchemy import select
7
10
  from strawberry import UNSET, Info, Private, lazy
8
- from strawberry.relay import Connection, GlobalID, Node, NodeID
11
+ from strawberry.relay import Connection, Node, NodeID
9
12
 
10
13
  from phoenix.db import models
11
14
  from phoenix.server.api.context import Context
15
+ from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
16
+ from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
12
17
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
13
18
  from phoenix.server.api.types.MimeType import MimeType
14
19
  from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
@@ -18,6 +23,8 @@ from phoenix.server.api.types.SpanIOValue import SpanIOValue
18
23
  from phoenix.server.api.types.TokenUsage import TokenUsage
19
24
 
20
25
  if TYPE_CHECKING:
26
+ from phoenix.server.api.types.Project import Project
27
+ from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
21
28
  from phoenix.server.api.types.Trace import Trace
22
29
 
23
30
 
@@ -31,10 +38,13 @@ class ProjectSession(Node):
31
38
  end_time: datetime
32
39
 
33
40
  @strawberry.field
34
- async def project_id(self) -> GlobalID:
41
+ async def project(
42
+ self,
43
+ info: Info[Context, None],
44
+ ) -> Annotated["Project", lazy(".Project")]:
35
45
  from phoenix.server.api.types.Project import Project
36
46
 
37
- return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
47
+ return Project(project_rowid=self.project_rowid)
38
48
 
39
49
  @strawberry.field
40
50
  async def num_traces(
@@ -165,6 +175,81 @@ class ProjectSession(Node):
165
175
  for entry in summary
166
176
  ]
167
177
 
178
+ @strawberry.field
179
+ async def session_annotations(
180
+ self,
181
+ info: Info[Context, None],
182
+ ) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
183
+ """Get all annotations for this session."""
184
+ from phoenix.server.api.types.ProjectSessionAnnotation import (
185
+ to_gql_project_session_annotation,
186
+ )
187
+
188
+ stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id_attr)
189
+ async with info.context.db() as session:
190
+ annotations = await session.stream_scalars(stmt)
191
+ return [
192
+ to_gql_project_session_annotation(annotation) async for annotation in annotations
193
+ ]
194
+
195
+ @strawberry.field(
196
+ description="Summarizes each annotation (by name) associated with the session"
197
+ ) # type: ignore
198
+ async def session_annotation_summaries(
199
+ self,
200
+ info: Info[Context, None],
201
+ filter: Optional[AnnotationFilter] = None,
202
+ ) -> list[AnnotationSummary]:
203
+ """
204
+ Retrieves and summarizes annotations associated with this span.
205
+
206
+ This method aggregates annotation data by name and label, calculating metrics
207
+ such as count of occurrences and sum of scores. The results are organized
208
+ into a structured format that can be easily converted to a DataFrame.
209
+
210
+ Args:
211
+ info: GraphQL context information
212
+ filter: Optional filter to apply to annotations before processing
213
+
214
+ Returns:
215
+ A list of AnnotationSummary objects, each containing:
216
+ - name: The name of the annotation
217
+ - data: A list of dictionaries with label statistics
218
+ """
219
+ # Load all annotations for this span from the data loader
220
+ annotations = await info.context.data_loaders.session_annotations_by_session.load(
221
+ self.id_attr
222
+ )
223
+
224
+ # Apply filter if provided to narrow down the annotations
225
+ if filter:
226
+ annotations = [
227
+ annotation for annotation in annotations if satisfies_filter(annotation, filter)
228
+ ]
229
+
230
+ @dataclass
231
+ class Metrics:
232
+ record_count: int = 0
233
+ label_count: int = 0
234
+ score_sum: float = 0
235
+ score_count: int = 0
236
+
237
+ summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
238
+ lambda: defaultdict(Metrics)
239
+ )
240
+ for annotation in annotations:
241
+ metrics = summaries[annotation.name][annotation.label]
242
+ metrics.record_count += 1
243
+ metrics.label_count += int(annotation.label is not None)
244
+ metrics.score_sum += annotation.score or 0
245
+ metrics.score_count += int(annotation.score is not None)
246
+
247
+ result: list[AnnotationSummary] = []
248
+ for name, label_metrics in summaries.items():
249
+ rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
250
+ result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
251
+ return result
252
+
168
253
 
169
254
  def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
170
255
  return ProjectSession(
@@ -0,0 +1,68 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry import Private
5
+ from strawberry.relay import GlobalID, Node, NodeID
6
+ from strawberry.scalars import JSON
7
+ from strawberry.types import Info
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.api.context import Context
11
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
12
+
13
+ from .AnnotationSource import AnnotationSource
14
+ from .User import User, to_gql_user
15
+
16
+
17
+ @strawberry.type
18
+ class ProjectSessionAnnotation(Node):
19
+ id_attr: NodeID[int]
20
+ user_id: Private[Optional[int]]
21
+ name: str
22
+ annotator_kind: AnnotatorKind
23
+ label: Optional[str]
24
+ score: Optional[float]
25
+ explanation: Optional[str]
26
+ metadata: JSON
27
+ _project_session_id: Private[Optional[int]]
28
+ identifier: str
29
+ source: AnnotationSource
30
+
31
+ @strawberry.field
32
+ async def project_session_id(self) -> GlobalID:
33
+ from phoenix.server.api.types.ProjectSession import ProjectSession
34
+
35
+ return GlobalID(type_name=ProjectSession.__name__, node_id=str(self._project_session_id))
36
+
37
+ @strawberry.field
38
+ async def user(
39
+ self,
40
+ info: Info[Context, None],
41
+ ) -> Optional[User]:
42
+ if self.user_id is None:
43
+ return None
44
+ user = await info.context.data_loaders.users.load(self.user_id)
45
+ if user is None:
46
+ return None
47
+ return to_gql_user(user)
48
+
49
+
50
+ def to_gql_project_session_annotation(
51
+ annotation: models.ProjectSessionAnnotation,
52
+ ) -> ProjectSessionAnnotation:
53
+ """
54
+ Converts an ORM projectSession annotation to a GraphQL ProjectSessionAnnotation.
55
+ """
56
+ return ProjectSessionAnnotation(
57
+ id_attr=annotation.id,
58
+ user_id=annotation.user_id,
59
+ _project_session_id=annotation.project_session_id,
60
+ name=annotation.name,
61
+ annotator_kind=AnnotatorKind(annotation.annotator_kind),
62
+ label=annotation.label,
63
+ score=annotation.score,
64
+ explanation=annotation.explanation,
65
+ metadata=JSON(annotation.metadata_),
66
+ identifier=annotation.identifier,
67
+ source=AnnotationSource(annotation.source),
68
+ )
@@ -9,6 +9,7 @@ from strawberry.relay import Connection, GlobalID, Node, NodeID
9
9
  from strawberry.types import Info
10
10
 
11
11
  from phoenix.db import models
12
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
12
13
  from phoenix.server.api.context import Context
13
14
  from phoenix.server.api.exceptions import NotFound
14
15
  from phoenix.server.api.types.Identifier import Identifier
@@ -37,7 +38,10 @@ class Prompt(Node):
37
38
 
38
39
  @strawberry.field
39
40
  async def version(
40
- self, info: Info[Context, None], version_id: Optional[GlobalID] = None
41
+ self,
42
+ info: Info[Context, None],
43
+ version_id: Optional[GlobalID] = None,
44
+ tag_name: Optional[Identifier] = None,
41
45
  ) -> PromptVersion:
42
46
  async with info.context.db() as session:
43
47
  if version_id:
@@ -50,6 +54,19 @@ class Prompt(Node):
50
54
  )
51
55
  if not version:
52
56
  raise NotFound(f"Prompt version not found: {version_id}")
57
+ elif tag_name:
58
+ try:
59
+ name = IdentifierModel(tag_name)
60
+ except ValueError:
61
+ raise NotFound(f"Prompt version tag not found: {tag_name}")
62
+ version = await session.scalar(
63
+ select(models.PromptVersion)
64
+ .where(models.PromptVersion.prompt_id == self.id_attr)
65
+ .join_from(models.PromptVersion, models.PromptVersionTag)
66
+ .where(models.PromptVersionTag.name == name)
67
+ )
68
+ if not version:
69
+ raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
53
70
  else:
54
71
  stmt = (
55
72
  select(models.PromptVersion)
@@ -23,11 +23,11 @@ from phoenix.server.api.helpers.dataset_helpers import (
23
23
  get_dataset_example_input,
24
24
  get_dataset_example_output,
25
25
  )
26
- from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
27
- from phoenix.server.api.input_types.SpanAnnotationFilter import (
28
- SpanAnnotationFilter,
26
+ from phoenix.server.api.input_types.AnnotationFilter import (
27
+ AnnotationFilter,
29
28
  satisfies_filter,
30
29
  )
30
+ from phoenix.server.api.input_types.InvocationParameters import InvocationParameter
31
31
  from phoenix.server.api.input_types.SpanAnnotationSort import (
32
32
  SpanAnnotationColumn,
33
33
  SpanAnnotationSort,
@@ -547,7 +547,7 @@ class Span(Node):
547
547
  self,
548
548
  info: Info[Context, None],
549
549
  sort: Optional[SpanAnnotationSort] = UNSET,
550
- filter: Optional[SpanAnnotationFilter] = None,
550
+ filter: Optional[AnnotationFilter] = None,
551
551
  ) -> list[SpanAnnotation]:
552
552
  span_id = self.span_rowid
553
553
  annotations = await info.context.data_loaders.span_annotations.load(span_id)
@@ -580,7 +580,7 @@ class Span(Node):
580
580
  async def span_annotation_summaries(
581
581
  self,
582
582
  info: Info[Context, None],
583
- filter: Optional[SpanAnnotationFilter] = None,
583
+ filter: Optional[AnnotationFilter] = None,
584
584
  ) -> list[AnnotationSummary]:
585
585
  """
586
586
  Retrieves and summarizes annotations associated with this span.
@@ -1,8 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import defaultdict
4
+ from dataclasses import asdict, dataclass
3
5
  from datetime import datetime
4
6
  from typing import TYPE_CHECKING, Annotated, Optional, Union
5
7
 
8
+ import pandas as pd
6
9
  import strawberry
7
10
  from openinference.semconv.trace import SpanAttributes
8
11
  from sqlalchemy import desc, select
@@ -13,7 +16,9 @@ from typing_extensions import TypeAlias
13
16
 
14
17
  from phoenix.db import models
15
18
  from phoenix.server.api.context import Context
19
+ from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
16
20
  from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
21
+ from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
17
22
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
18
23
  from phoenix.server.api.types.pagination import (
19
24
  ConnectionArgs,
@@ -229,6 +234,62 @@ class Trace(Node):
229
234
  annotations = await session.scalars(stmt)
230
235
  return [to_gql_trace_annotation(annotation) for annotation in annotations]
231
236
 
237
+ @strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
238
+ async def trace_annotation_summaries(
239
+ self,
240
+ info: Info[Context, None],
241
+ filter: Optional[AnnotationFilter] = None,
242
+ ) -> list[AnnotationSummary]:
243
+ """
244
+ Retrieves and summarizes annotations associated with this span.
245
+
246
+ This method aggregates annotation data by name and label, calculating metrics
247
+ such as count of occurrences and sum of scores. The results are organized
248
+ into a structured format that can be easily converted to a DataFrame.
249
+
250
+ Args:
251
+ info: GraphQL context information
252
+ filter: Optional filter to apply to annotations before processing
253
+
254
+ Returns:
255
+ A list of AnnotationSummary objects, each containing:
256
+ - name: The name of the annotation
257
+ - data: A list of dictionaries with label statistics
258
+ """
259
+ # Load all annotations for this span from the data loader
260
+ annotations = await info.context.data_loaders.trace_annotations_by_trace.load(
261
+ self.trace_rowid
262
+ )
263
+
264
+ # Apply filter if provided to narrow down the annotations
265
+ if filter:
266
+ annotations = [
267
+ annotation for annotation in annotations if satisfies_filter(annotation, filter)
268
+ ]
269
+
270
+ @dataclass
271
+ class Metrics:
272
+ record_count: int = 0
273
+ label_count: int = 0
274
+ score_sum: float = 0
275
+ score_count: int = 0
276
+
277
+ summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
278
+ lambda: defaultdict(Metrics)
279
+ )
280
+ for annotation in annotations:
281
+ metrics = summaries[annotation.name][annotation.label]
282
+ metrics.record_count += 1
283
+ metrics.label_count += int(annotation.label is not None)
284
+ metrics.score_sum += annotation.score or 0
285
+ metrics.score_count += int(annotation.score is not None)
286
+
287
+ result: list[AnnotationSummary] = []
288
+ for name, label_metrics in summaries.items():
289
+ rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
290
+ result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
291
+ return result
292
+
232
293
  @strawberry.field
233
294
  async def cost_summary(
234
295
  self,
phoenix/server/app.py CHANGED
@@ -67,7 +67,6 @@ from phoenix.config import (
67
67
  get_env_gql_extension_paths,
68
68
  get_env_grpc_interceptor_paths,
69
69
  get_env_host,
70
- get_env_host_root_path,
71
70
  get_env_max_spans_queue_size,
72
71
  get_env_port,
73
72
  get_env_support_email,
@@ -92,6 +91,7 @@ from phoenix.server.api.dataloaders import (
92
91
  DatasetExampleRevisionsDataLoader,
93
92
  DatasetExamplesAndVersionsByExperimentRunDataLoader,
94
93
  DatasetExampleSpansDataLoader,
94
+ DatasetExampleSplitsDataLoader,
95
95
  DocumentEvaluationsDataLoader,
96
96
  DocumentEvaluationSummaryDataLoader,
97
97
  DocumentRetrievalMetricsDataLoader,
@@ -99,7 +99,6 @@ from phoenix.server.api.dataloaders import (
99
99
  ExperimentErrorRatesDataLoader,
100
100
  ExperimentRepeatedRunGroupAnnotationSummariesDataLoader,
101
101
  ExperimentRepeatedRunGroupsDataLoader,
102
- ExperimentRepetitionCountsDataLoader,
103
102
  ExperimentRunAnnotations,
104
103
  ExperimentRunCountsDataLoader,
105
104
  ExperimentSequenceNumberDataLoader,
@@ -112,6 +111,7 @@ from phoenix.server.api.dataloaders import (
112
111
  ProjectIdsByTraceRetentionPolicyIdDataLoader,
113
112
  PromptVersionSequenceNumberDataLoader,
114
113
  RecordCountDataLoader,
114
+ SessionAnnotationsBySessionDataLoader,
115
115
  SessionIODataLoader,
116
116
  SessionNumTracesDataLoader,
117
117
  SessionNumTracesWithErrorDataLoader,
@@ -137,12 +137,14 @@ from phoenix.server.api.dataloaders import (
137
137
  SpanProjectsDataLoader,
138
138
  TableFieldsDataLoader,
139
139
  TokenCountDataLoader,
140
+ TraceAnnotationsByTraceDataLoader,
140
141
  TraceByTraceIdsDataLoader,
141
142
  TraceRetentionPolicyIdByProjectIdDataLoader,
142
143
  TraceRootSpansDataLoader,
143
144
  UserRolesDataLoader,
144
145
  UsersDataLoader,
145
146
  )
147
+ from phoenix.server.api.dataloaders.dataset_labels import DatasetLabelsDataLoader
146
148
  from phoenix.server.api.routers import (
147
149
  auth_router,
148
150
  create_embeddings_router,
@@ -173,6 +175,7 @@ from phoenix.server.types import (
173
175
  LastUpdatedAt,
174
176
  TokenStore,
175
177
  )
178
+ from phoenix.server.utils import get_root_path, prepend_root_path
176
179
  from phoenix.settings import Settings
177
180
  from phoenix.trace.fixtures import (
178
181
  TracesFixture,
@@ -281,9 +284,6 @@ class Static(StaticFiles):
281
284
  return {}
282
285
  raise e
283
286
 
284
- def _sanitize_basename(self, basename: str) -> str:
285
- return basename[:-1] if basename.endswith("/") else basename
286
-
287
287
  async def get_response(self, path: str, scope: Scope) -> Response:
288
288
  # Redirect to the oauth2 login page if basic auth is disabled and auto_login is enabled
289
289
  # TODO: this needs to be refactored to be cleaner
@@ -292,14 +292,10 @@ class Static(StaticFiles):
292
292
  and self._app_config.basic_auth_disabled
293
293
  and self._app_config.auto_login_idp_name
294
294
  ):
295
- request = Request(scope)
296
- url = URL(
297
- str(
298
- Path(get_env_host_root_path())
299
- / f"oauth2/{self._app_config.auto_login_idp_name}/login"
300
- )
295
+ redirect_path = prepend_root_path(
296
+ scope, f"oauth2/{self._app_config.auto_login_idp_name}/login"
301
297
  )
302
- url = url.include_query_params(**request.query_params)
298
+ url = URL(redirect_path).include_query_params(**Request(scope).query_params)
303
299
  return RedirectResponse(url=url)
304
300
  try:
305
301
  response = await super().get_response(path, scope)
@@ -316,7 +312,7 @@ class Static(StaticFiles):
316
312
  "min_dist": self._app_config.min_dist,
317
313
  "n_neighbors": self._app_config.n_neighbors,
318
314
  "n_samples": self._app_config.n_samples,
319
- "basename": self._sanitize_basename(request.scope.get("root_path", "")),
315
+ "basename": get_root_path(scope),
320
316
  "platform_version": phoenix_version,
321
317
  "request": request,
322
318
  "is_development": self._app_config.is_development,
@@ -715,6 +711,8 @@ def create_graphql_router(
715
711
  dataset_examples_and_versions_by_experiment_run=DatasetExamplesAndVersionsByExperimentRunDataLoader(
716
712
  db
717
713
  ),
714
+ dataset_example_splits=DatasetExampleSplitsDataLoader(db),
715
+ dataset_labels=DatasetLabelsDataLoader(db),
718
716
  document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(
719
717
  db,
720
718
  cache_map=(
@@ -737,7 +735,6 @@ def create_graphql_router(
737
735
  db
738
736
  ),
739
737
  experiment_repeated_run_groups=ExperimentRepeatedRunGroupsDataLoader(db),
740
- experiment_repetition_counts=ExperimentRepetitionCountsDataLoader(db),
741
738
  experiment_run_annotations=ExperimentRunAnnotations(db),
742
739
  experiment_run_counts=ExperimentRunCountsDataLoader(db),
743
740
  experiment_sequence_number=ExperimentSequenceNumberDataLoader(db),
@@ -769,6 +766,7 @@ def create_graphql_router(
769
766
  db,
770
767
  cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
771
768
  ),
769
+ session_annotations_by_session=SessionAnnotationsBySessionDataLoader(db),
772
770
  session_first_inputs=SessionIODataLoader(db, "first_input"),
773
771
  session_last_outputs=SessionIODataLoader(db, "last_output"),
774
772
  session_num_traces=SessionNumTracesDataLoader(db),
@@ -815,6 +813,7 @@ def create_graphql_router(
815
813
  db,
816
814
  cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None,
817
815
  ),
816
+ trace_annotations_by_trace=TraceAnnotationsByTraceDataLoader(db),
818
817
  trace_by_trace_ids=TraceByTraceIdsDataLoader(db),
819
818
  trace_fields=TableFieldsDataLoader(db, models.Trace),
820
819
  trace_retention_policy_id_by_project_id=TraceRetentionPolicyIdByProjectIdDataLoader(