arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,30 +1,36 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import defaultdict
4
+ from dataclasses import asdict, dataclass
3
5
  from datetime import datetime
4
6
  from typing import TYPE_CHECKING, Annotated, Optional, Union
5
7
 
8
+ import pandas as pd
6
9
  import strawberry
10
+ from aioitertools.itertools import islice
7
11
  from openinference.semconv.trace import SpanAttributes
8
- from sqlalchemy import desc, select
9
- from strawberry import ID, UNSET, Private, lazy
12
+ from sqlalchemy import desc, or_, select
13
+ from strawberry import ID, UNSET, lazy
10
14
  from strawberry.relay import Connection, GlobalID, Node, NodeID
11
15
  from strawberry.types import Info
12
16
  from typing_extensions import TypeAlias
13
17
 
14
18
  from phoenix.db import models
15
19
  from phoenix.server.api.context import Context
20
+ from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
16
21
  from phoenix.server.api.input_types.TraceAnnotationSort import TraceAnnotationSort
22
+ from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
17
23
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
18
24
  from phoenix.server.api.types.pagination import (
19
- ConnectionArgs,
25
+ Cursor,
20
26
  CursorString,
21
- connection_from_list,
27
+ connection_from_cursors_and_nodes,
22
28
  )
23
29
  from phoenix.server.api.types.SortDir import SortDir
24
30
  from phoenix.server.api.types.Span import Span
25
31
  from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
26
32
  from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
27
- from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
33
+ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation
28
34
 
29
35
  if TYPE_CHECKING:
30
36
  from phoenix.server.api.types.Project import Project
@@ -36,11 +42,11 @@ TraceRowId: TypeAlias = int
36
42
 
37
43
  @strawberry.type
38
44
  class Trace(Node):
39
- trace_rowid: NodeID[TraceRowId]
40
- db_trace: Private[models.Trace] = UNSET
45
+ id: NodeID[TraceRowId]
46
+ db_record: strawberry.Private[Optional[models.Trace]] = None
41
47
 
42
48
  def __post_init__(self) -> None:
43
- if self.db_trace and self.trace_rowid != self.db_trace.id:
49
+ if self.db_record and self.id != self.db_record.id:
44
50
  raise ValueError("Trace ID mismatch")
45
51
 
46
52
  @strawberry.field
@@ -48,11 +54,11 @@ class Trace(Node):
48
54
  self,
49
55
  info: Info[Context, None],
50
56
  ) -> ID:
51
- if self.db_trace:
52
- trace_id = self.db_trace.trace_id
57
+ if self.db_record:
58
+ trace_id = self.db_record.trace_id
53
59
  else:
54
60
  trace_id = await info.context.data_loaders.trace_fields.load(
55
- (self.trace_rowid, models.Trace.trace_id),
61
+ (self.id, models.Trace.trace_id),
56
62
  )
57
63
  return ID(trace_id)
58
64
 
@@ -61,11 +67,11 @@ class Trace(Node):
61
67
  self,
62
68
  info: Info[Context, None],
63
69
  ) -> datetime:
64
- if self.db_trace:
65
- start_time = self.db_trace.start_time
70
+ if self.db_record:
71
+ start_time = self.db_record.start_time
66
72
  else:
67
73
  start_time = await info.context.data_loaders.trace_fields.load(
68
- (self.trace_rowid, models.Trace.start_time),
74
+ (self.id, models.Trace.start_time),
69
75
  )
70
76
  return start_time
71
77
 
@@ -74,11 +80,11 @@ class Trace(Node):
74
80
  self,
75
81
  info: Info[Context, None],
76
82
  ) -> datetime:
77
- if self.db_trace:
78
- end_time = self.db_trace.end_time
83
+ if self.db_record:
84
+ end_time = self.db_record.end_time
79
85
  else:
80
86
  end_time = await info.context.data_loaders.trace_fields.load(
81
- (self.trace_rowid, models.Trace.end_time),
87
+ (self.id, models.Trace.end_time),
82
88
  )
83
89
  return end_time
84
90
 
@@ -87,11 +93,11 @@ class Trace(Node):
87
93
  self,
88
94
  info: Info[Context, None],
89
95
  ) -> Optional[float]:
90
- if self.db_trace:
91
- latency_ms = self.db_trace.latency_ms
96
+ if self.db_record:
97
+ latency_ms = self.db_record.latency_ms
92
98
  else:
93
99
  latency_ms = await info.context.data_loaders.trace_fields.load(
94
- (self.trace_rowid, models.Trace.latency_ms),
100
+ (self.id, models.Trace.latency_ms),
95
101
  )
96
102
  return latency_ms
97
103
 
@@ -100,26 +106,26 @@ class Trace(Node):
100
106
  self,
101
107
  info: Info[Context, None],
102
108
  ) -> Annotated["Project", strawberry.lazy(".Project")]:
103
- if self.db_trace:
104
- project_rowid = self.db_trace.project_rowid
109
+ if self.db_record:
110
+ project_rowid = self.db_record.project_rowid
105
111
  else:
106
112
  project_rowid = await info.context.data_loaders.trace_fields.load(
107
- (self.trace_rowid, models.Trace.project_rowid),
113
+ (self.id, models.Trace.project_rowid),
108
114
  )
109
115
  from phoenix.server.api.types.Project import Project
110
116
 
111
- return Project(project_rowid=project_rowid)
117
+ return Project(id=project_rowid)
112
118
 
113
119
  @strawberry.field
114
120
  async def project_id(
115
121
  self,
116
122
  info: Info[Context, None],
117
123
  ) -> GlobalID:
118
- if self.db_trace:
119
- project_rowid = self.db_trace.project_rowid
124
+ if self.db_record:
125
+ project_rowid = self.db_record.project_rowid
120
126
  else:
121
127
  project_rowid = await info.context.data_loaders.trace_fields.load(
122
- (self.trace_rowid, models.Trace.project_rowid),
128
+ (self.id, models.Trace.project_rowid),
123
129
  )
124
130
  from phoenix.server.api.types.Project import Project
125
131
 
@@ -130,11 +136,11 @@ class Trace(Node):
130
136
  self,
131
137
  info: Info[Context, None],
132
138
  ) -> Optional[GlobalID]:
133
- if self.db_trace:
134
- project_session_rowid = self.db_trace.project_session_rowid
139
+ if self.db_record:
140
+ project_session_rowid = self.db_record.project_session_rowid
135
141
  else:
136
142
  project_session_rowid = await info.context.data_loaders.trace_fields.load(
137
- (self.trace_rowid, models.Trace.project_session_rowid),
143
+ (self.id, models.Trace.project_session_rowid),
138
144
  )
139
145
  if project_session_rowid is None:
140
146
  return None
@@ -147,39 +153,40 @@ class Trace(Node):
147
153
  self,
148
154
  info: Info[Context, None],
149
155
  ) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]:
150
- if self.db_trace:
151
- project_session_rowid = self.db_trace.project_session_rowid
156
+ if self.db_record:
157
+ project_session_rowid = self.db_record.project_session_rowid
152
158
  else:
153
159
  project_session_rowid = await info.context.data_loaders.trace_fields.load(
154
- (self.trace_rowid, models.Trace.project_session_rowid),
160
+ (self.id, models.Trace.project_session_rowid),
155
161
  )
156
162
  if project_session_rowid is None:
157
163
  return None
158
- from phoenix.server.api.types.ProjectSession import to_gql_project_session
159
164
 
160
165
  stmt = select(models.ProjectSession).filter_by(id=project_session_rowid)
161
166
  async with info.context.db() as session:
162
167
  project_session = await session.scalar(stmt)
163
168
  if project_session is None:
164
169
  return None
165
- return to_gql_project_session(project_session)
170
+ from .ProjectSession import ProjectSession
171
+
172
+ return ProjectSession(id=project_session.id, db_record=project_session)
166
173
 
167
174
  @strawberry.field
168
175
  async def root_span(
169
176
  self,
170
177
  info: Info[Context, None],
171
178
  ) -> Optional[Span]:
172
- span_rowid = await info.context.data_loaders.trace_root_spans.load(self.trace_rowid)
179
+ span_rowid = await info.context.data_loaders.trace_root_spans.load(self.id)
173
180
  if span_rowid is None:
174
181
  return None
175
- return Span(span_rowid=span_rowid)
182
+ return Span(id=span_rowid)
176
183
 
177
184
  @strawberry.field
178
185
  async def num_spans(
179
186
  self,
180
187
  info: Info[Context, None],
181
188
  ) -> int:
182
- return await info.context.data_loaders.num_spans_per_trace.load(self.trace_rowid)
189
+ return await info.context.data_loaders.num_spans_per_trace.load(self.id)
183
190
 
184
191
  @strawberry.field
185
192
  async def spans(
@@ -189,26 +196,94 @@ class Trace(Node):
189
196
  last: Optional[int] = UNSET,
190
197
  after: Optional[CursorString] = UNSET,
191
198
  before: Optional[CursorString] = UNSET,
199
+ root_spans_only: Optional[bool] = UNSET,
200
+ orphan_span_as_root_span: Optional[bool] = True,
192
201
  ) -> Connection[Span]:
193
- args = ConnectionArgs(
194
- first=first,
195
- after=after if isinstance(after, CursorString) else None,
196
- last=last,
197
- before=before if isinstance(before, CursorString) else None,
198
- )
199
- stmt = (
202
+ # Validate pagination arguments
203
+ if isinstance(first, int) and first <= 0:
204
+ raise ValueError('Argument "first" must be a positive int')
205
+
206
+ # Build base query for spans in this trace
207
+ base_query = (
200
208
  select(models.Span.id)
201
209
  .join(models.Trace)
202
- .where(models.Trace.id == self.trace_rowid)
210
+ .where(models.Trace.id == self.id)
203
211
  # Sort descending because the root span tends to show up later
204
212
  # in the ingestion process.
205
213
  .order_by(desc(models.Span.id))
206
- .limit(first)
207
214
  )
215
+ # Handle cursor pagination (forward pagination only)
216
+ if after is not UNSET and after is not None:
217
+ # Type narrowing: after is guaranteed to be str at this point
218
+ assert after is not None # Type narrowing for mypy
219
+ try:
220
+ cursor = Cursor.from_string(after)
221
+ except Exception as e:
222
+ raise ValueError(f"Invalid cursor format: {after}") from e
223
+ # For descending order, "after" means we want spans with smaller IDs
224
+ # (going forward in descending order)
225
+ base_query = base_query.where(models.Span.id < cursor.rowid)
226
+ # Note: backward pagination (last/before) is not yet implemented
227
+ # as it requires more complex handling with reversed ordering
228
+ if before is not UNSET or (last is not UNSET and last is not None):
229
+ raise ValueError("Backward pagination (last/before) is not yet supported")
230
+
231
+ # Build final query based on filtering requirements
232
+ if root_spans_only:
233
+ if orphan_span_as_root_span:
234
+ # A root span is either a span with no parent_id or an orphan span
235
+ # (a span whose parent_id references a span that doesn't exist in the current trace)
236
+ # We need parent_id to check for orphan spans, so add it to the query
237
+ # and create a CTE
238
+ candidate_spans = base_query.add_columns(models.Span.parent_id).cte(
239
+ "candidate_spans"
240
+ )
241
+ # Subquery to get all span_ids that exist in this trace
242
+ parent_spans_in_trace = (
243
+ select(models.Span.span_id)
244
+ .where(models.Span.trace_rowid == self.id)
245
+ .alias("parent_spans")
246
+ )
247
+ # Filter candidates to only root spans (NULL parent_id or orphan spans)
248
+ stmt = (
249
+ select(candidate_spans.c.id)
250
+ .where(
251
+ or_(
252
+ candidate_spans.c.parent_id.is_(None),
253
+ ~select(1)
254
+ .where(candidate_spans.c.parent_id == parent_spans_in_trace.c.span_id)
255
+ .exists(),
256
+ )
257
+ )
258
+ .order_by(desc(candidate_spans.c.id))
259
+ )
260
+ else:
261
+ # Only include explicit root spans (spans with parent_id = NULL)
262
+ stmt = base_query.where(models.Span.parent_id.is_(None))
263
+ else:
264
+ # Return all spans (no root span filtering)
265
+ stmt = base_query
266
+
267
+ # Over-fetch by one to determine whether there's a next page
268
+ limit = first if isinstance(first, int) else 50
269
+ stmt = stmt.limit(limit + 1)
270
+
271
+ cursors_and_nodes = []
208
272
  async with info.context.db() as session:
209
273
  span_rowids = await session.stream_scalars(stmt)
210
- data = [Span(span_rowid=span_rowid) async for span_rowid in span_rowids]
211
- return connection_from_list(data=data, args=args)
274
+ async for span_rowid in islice(span_rowids, limit):
275
+ cursor = Cursor(rowid=span_rowid)
276
+ cursors_and_nodes.append((cursor, Span(id=span_rowid)))
277
+ has_next_page = True
278
+ try:
279
+ await span_rowids.__anext__()
280
+ except StopAsyncIteration:
281
+ has_next_page = False
282
+ return connection_from_cursors_and_nodes(
283
+ cursors_and_nodes,
284
+ has_previous_page=False,
285
+ has_next_page=has_next_page,
286
+ )
212
287
 
213
288
  @strawberry.field(description="Annotations associated with the trace.") # type: ignore
214
289
  async def trace_annotations(
@@ -217,7 +292,7 @@ class Trace(Node):
217
292
  sort: Optional[TraceAnnotationSort] = None,
218
293
  ) -> list[TraceAnnotation]:
219
294
  async with info.context.db() as session:
220
- stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.trace_rowid)
295
+ stmt = select(models.TraceAnnotation).filter_by(trace_rowid=self.id)
221
296
  if sort:
222
297
  sort_col = getattr(models.TraceAnnotation, sort.col.value)
223
298
  if sort.dir is SortDir.desc:
@@ -227,7 +302,63 @@ class Trace(Node):
227
302
  else:
228
303
  stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
229
304
  annotations = await session.scalars(stmt)
230
- return [to_gql_trace_annotation(annotation) for annotation in annotations]
305
+ return [
306
+ TraceAnnotation(id=annotation.id, db_record=annotation) for annotation in annotations
307
+ ]
308
+
309
+ @strawberry.field(description="Summarizes each annotation (by name) associated with the trace") # type: ignore
310
+ async def trace_annotation_summaries(
311
+ self,
312
+ info: Info[Context, None],
313
+ filter: Optional[AnnotationFilter] = None,
314
+ ) -> list[AnnotationSummary]:
315
+ """
316
+ Retrieves and summarizes annotations associated with this span.
317
+
318
+ This method aggregates annotation data by name and label, calculating metrics
319
+ such as count of occurrences and sum of scores. The results are organized
320
+ into a structured format that can be easily converted to a DataFrame.
321
+
322
+ Args:
323
+ info: GraphQL context information
324
+ filter: Optional filter to apply to annotations before processing
325
+
326
+ Returns:
327
+ A list of AnnotationSummary objects, each containing:
328
+ - name: The name of the annotation
329
+ - data: A list of dictionaries with label statistics
330
+ """
331
+ # Load all annotations for this span from the data loader
332
+ annotations = await info.context.data_loaders.trace_annotations_by_trace.load(self.id)
333
+
334
+ # Apply filter if provided to narrow down the annotations
335
+ if filter:
336
+ annotations = [
337
+ annotation for annotation in annotations if satisfies_filter(annotation, filter)
338
+ ]
339
+
340
+ @dataclass
341
+ class Metrics:
342
+ record_count: int = 0
343
+ label_count: int = 0
344
+ score_sum: float = 0
345
+ score_count: int = 0
346
+
347
+ summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
348
+ lambda: defaultdict(Metrics)
349
+ )
350
+ for annotation in annotations:
351
+ metrics = summaries[annotation.name][annotation.label]
352
+ metrics.record_count += 1
353
+ metrics.label_count += int(annotation.label is not None)
354
+ metrics.score_sum += annotation.score or 0
355
+ metrics.score_count += int(annotation.score is not None)
356
+
357
+ result: list[AnnotationSummary] = []
358
+ for name, label_metrics in summaries.items():
359
+ rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
360
+ result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
361
+ return result
231
362
 
232
363
  @strawberry.field
233
364
  async def cost_summary(
@@ -235,7 +366,7 @@ class Trace(Node):
235
366
  info: Info[Context, None],
236
367
  ) -> SpanCostSummary:
237
368
  loader = info.context.data_loaders.span_cost_summary_by_trace
238
- summary = await loader.load(self.trace_rowid)
369
+ summary = await loader.load(self.id)
239
370
  return SpanCostSummary(
240
371
  prompt=CostBreakdown(
241
372
  tokens=summary.prompt.tokens,
@@ -257,7 +388,7 @@ class Trace(Node):
257
388
  info: Info[Context, None],
258
389
  ) -> list[SpanCostDetailSummaryEntry]:
259
390
  loader = info.context.data_loaders.span_cost_detail_summary_entries_by_trace
260
- entries = await loader.load(self.trace_rowid)
391
+ entries = await loader.load(self.id)
261
392
  return [
262
393
  SpanCostDetailSummaryEntry(
263
394
  token_type=entry.token_type,
@@ -1,8 +1,8 @@
1
- from typing import Optional
1
+ from math import isfinite
2
+ from typing import TYPE_CHECKING, Annotated, Optional
2
3
 
3
4
  import strawberry
4
- from strawberry import Private
5
- from strawberry.relay import GlobalID, Node, NodeID
5
+ from strawberry.relay import Node, NodeID
6
6
  from strawberry.scalars import JSON
7
7
  from strawberry.types import Info
8
8
 
@@ -11,58 +11,157 @@ from phoenix.server.api.context import Context
11
11
  from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
12
12
 
13
13
  from .AnnotationSource import AnnotationSource
14
- from .User import User, to_gql_user
14
+
15
+ if TYPE_CHECKING:
16
+ from .Trace import Trace
17
+ from .User import User
15
18
 
16
19
 
17
20
  @strawberry.type
18
21
  class TraceAnnotation(Node):
19
- id_attr: NodeID[int]
20
- user_id: Private[Optional[int]]
21
- name: str
22
- annotator_kind: AnnotatorKind
23
- label: Optional[str]
24
- score: Optional[float]
25
- explanation: Optional[str]
26
- metadata: JSON
27
- trace_rowid: Private[Optional[int]]
28
- identifier: str
29
- source: AnnotationSource
30
-
31
- @strawberry.field
32
- async def trace_id(self) -> GlobalID:
33
- from phoenix.server.api.types.Trace import Trace
34
-
35
- return GlobalID(type_name=Trace.__name__, node_id=str(self.trace_rowid))
36
-
37
- @strawberry.field
22
+ id: NodeID[int]
23
+ db_record: strawberry.Private[Optional[models.TraceAnnotation]] = None
24
+
25
+ def __post_init__(self) -> None:
26
+ if self.db_record and self.id != self.db_record.id:
27
+ raise ValueError("TraceAnnotation ID mismatch")
28
+
29
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
30
+ async def name(
31
+ self,
32
+ info: Info[Context, None],
33
+ ) -> str:
34
+ if self.db_record:
35
+ val = self.db_record.name
36
+ else:
37
+ val = await info.context.data_loaders.trace_annotation_fields.load(
38
+ (self.id, models.TraceAnnotation.name),
39
+ )
40
+ return val
41
+
42
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
43
+ async def annotator_kind(
44
+ self,
45
+ info: Info[Context, None],
46
+ ) -> AnnotatorKind:
47
+ if self.db_record:
48
+ val = self.db_record.annotator_kind
49
+ else:
50
+ val = await info.context.data_loaders.trace_annotation_fields.load(
51
+ (self.id, models.TraceAnnotation.annotator_kind),
52
+ )
53
+ return AnnotatorKind(val)
54
+
55
+ @strawberry.field(
56
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
57
+ ) # type: ignore
58
+ async def label(
59
+ self,
60
+ info: Info[Context, None],
61
+ ) -> Optional[str]:
62
+ if self.db_record:
63
+ val = self.db_record.label
64
+ else:
65
+ val = await info.context.data_loaders.trace_annotation_fields.load(
66
+ (self.id, models.TraceAnnotation.label),
67
+ )
68
+ return val
69
+
70
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
71
+ async def score(
72
+ self,
73
+ info: Info[Context, None],
74
+ ) -> Optional[float]:
75
+ if self.db_record:
76
+ val = self.db_record.score
77
+ else:
78
+ val = await info.context.data_loaders.trace_annotation_fields.load(
79
+ (self.id, models.TraceAnnotation.score),
80
+ )
81
+ return val if val is not None and isfinite(val) else None
82
+
83
+ @strawberry.field(
84
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
85
+ ) # type: ignore
86
+ async def explanation(
87
+ self,
88
+ info: Info[Context, None],
89
+ ) -> Optional[str]:
90
+ if self.db_record:
91
+ val = self.db_record.explanation
92
+ else:
93
+ val = await info.context.data_loaders.trace_annotation_fields.load(
94
+ (self.id, models.TraceAnnotation.explanation),
95
+ )
96
+ return val
97
+
98
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
99
+ async def metadata(
100
+ self,
101
+ info: Info[Context, None],
102
+ ) -> JSON:
103
+ if self.db_record:
104
+ val = self.db_record.metadata_
105
+ else:
106
+ val = await info.context.data_loaders.trace_annotation_fields.load(
107
+ (self.id, models.TraceAnnotation.metadata_),
108
+ )
109
+ return val
110
+
111
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
112
+ async def identifier(
113
+ self,
114
+ info: Info[Context, None],
115
+ ) -> str:
116
+ if self.db_record:
117
+ val = self.db_record.identifier
118
+ else:
119
+ val = await info.context.data_loaders.trace_annotation_fields.load(
120
+ (self.id, models.TraceAnnotation.identifier),
121
+ )
122
+ return val
123
+
124
+ @strawberry.field(description="The source of the annotation.") # type: ignore
125
+ async def source(
126
+ self,
127
+ info: Info[Context, None],
128
+ ) -> AnnotationSource:
129
+ if self.db_record:
130
+ val = self.db_record.source
131
+ else:
132
+ val = await info.context.data_loaders.trace_annotation_fields.load(
133
+ (self.id, models.TraceAnnotation.source),
134
+ )
135
+ return AnnotationSource(val)
136
+
137
+ @strawberry.field(description="The trace associated with the annotation.") # type: ignore
138
+ async def trace(
139
+ self,
140
+ info: Info[Context, None],
141
+ ) -> Annotated["Trace", strawberry.lazy(".Trace")]:
142
+ if self.db_record:
143
+ trace_rowid = self.db_record.trace_rowid
144
+ else:
145
+ trace_rowid = await info.context.data_loaders.trace_annotation_fields.load(
146
+ (self.id, models.TraceAnnotation.trace_rowid),
147
+ )
148
+ from .Trace import Trace
149
+
150
+ return Trace(id=trace_rowid)
151
+
152
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
38
153
  async def user(
39
154
  self,
40
155
  info: Info[Context, None],
41
- ) -> Optional[User]:
42
- if self.user_id is None:
156
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
157
+ if self.db_record:
158
+ user_id = self.db_record.user_id
159
+ else:
160
+ user_id = await info.context.data_loaders.trace_annotation_fields.load(
161
+ (self.id, models.TraceAnnotation.user_id),
162
+ )
163
+ if user_id is None:
43
164
  return None
44
- user = await info.context.data_loaders.users.load(self.user_id)
45
- if user is None:
46
- return None
47
- return to_gql_user(user)
48
-
49
-
50
- def to_gql_trace_annotation(
51
- annotation: models.TraceAnnotation,
52
- ) -> TraceAnnotation:
53
- """
54
- Converts an ORM trace annotation to a GraphQL TraceAnnotation.
55
- """
56
- return TraceAnnotation(
57
- id_attr=annotation.id,
58
- user_id=annotation.user_id,
59
- trace_rowid=annotation.trace_rowid,
60
- name=annotation.name,
61
- annotator_kind=AnnotatorKind(annotation.annotator_kind),
62
- label=annotation.label,
63
- score=annotation.score,
64
- explanation=annotation.explanation,
65
- metadata=annotation.metadata_,
66
- identifier=annotation.identifier,
67
- source=AnnotationSource(annotation.source),
68
- )
165
+ from .User import User
166
+
167
+ return User(id=user_id)