arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,14 +1,19 @@
1
+ from collections import defaultdict
2
+ from dataclasses import asdict, dataclass
1
3
  from datetime import datetime
2
- from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type
4
+ from typing import TYPE_CHECKING, Annotated, Optional
3
5
 
6
+ import pandas as pd
4
7
  import strawberry
5
8
  from openinference.semconv.trace import SpanAttributes
6
9
  from sqlalchemy import select
7
- from strawberry import UNSET, Info, Private, lazy
8
- from strawberry.relay import Connection, GlobalID, Node, NodeID
10
+ from strawberry import UNSET, Info, lazy
11
+ from strawberry.relay import Connection, Node, NodeID
9
12
 
10
13
  from phoenix.db import models
11
14
  from phoenix.server.api.context import Context
15
+ from phoenix.server.api.input_types.AnnotationFilter import AnnotationFilter, satisfies_filter
16
+ from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
12
17
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
13
18
  from phoenix.server.api.types.MimeType import MimeType
14
19
  from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list
@@ -18,44 +23,94 @@ from phoenix.server.api.types.SpanIOValue import SpanIOValue
18
23
  from phoenix.server.api.types.TokenUsage import TokenUsage
19
24
 
20
25
  if TYPE_CHECKING:
26
+ from phoenix.server.api.types.Project import Project
27
+ from phoenix.server.api.types.ProjectSessionAnnotation import ProjectSessionAnnotation
21
28
  from phoenix.server.api.types.Trace import Trace
22
29
 
23
30
 
24
31
  @strawberry.type
25
32
  class ProjectSession(Node):
26
- _table: ClassVar[Type[models.ProjectSession]] = models.ProjectSession
27
- id_attr: NodeID[int]
28
- project_rowid: Private[int]
29
- session_id: str
30
- start_time: datetime
31
- end_time: datetime
33
+ id: NodeID[int]
34
+ db_record: strawberry.Private[Optional[models.ProjectSession]] = None
35
+
36
+ def __post_init__(self) -> None:
37
+ if self.db_record and self.id != self.db_record.id:
38
+ raise ValueError("ProjectSession ID mismatch")
32
39
 
33
40
  @strawberry.field
34
- async def project_id(self) -> GlobalID:
41
+ async def session_id(
42
+ self,
43
+ info: Info[Context, None],
44
+ ) -> str:
45
+ if self.db_record:
46
+ val = self.db_record.session_id
47
+ else:
48
+ val = await info.context.data_loaders.project_session_fields.load(
49
+ (self.id, models.ProjectSession.session_id),
50
+ )
51
+ return val
52
+
53
+ @strawberry.field
54
+ async def start_time(
55
+ self,
56
+ info: Info[Context, None],
57
+ ) -> datetime:
58
+ if self.db_record:
59
+ val = self.db_record.start_time
60
+ else:
61
+ val = await info.context.data_loaders.project_session_fields.load(
62
+ (self.id, models.ProjectSession.start_time),
63
+ )
64
+ return val
65
+
66
+ @strawberry.field
67
+ async def end_time(
68
+ self,
69
+ info: Info[Context, None],
70
+ ) -> datetime:
71
+ if self.db_record:
72
+ val = self.db_record.end_time
73
+ else:
74
+ val = await info.context.data_loaders.project_session_fields.load(
75
+ (self.id, models.ProjectSession.end_time),
76
+ )
77
+ return val
78
+
79
+ @strawberry.field
80
+ async def project(
81
+ self,
82
+ info: Info[Context, None],
83
+ ) -> Annotated["Project", lazy(".Project")]:
35
84
  from phoenix.server.api.types.Project import Project
36
85
 
37
- return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))
86
+ if self.db_record:
87
+ project_rowid = self.db_record.project_id
88
+ else:
89
+ project_rowid = await info.context.data_loaders.project_session_fields.load(
90
+ (self.id, models.ProjectSession.project_id),
91
+ )
92
+ return Project(id=project_rowid)
38
93
 
39
94
  @strawberry.field
40
95
  async def num_traces(
41
96
  self,
42
97
  info: Info[Context, None],
43
98
  ) -> int:
44
- return await info.context.data_loaders.session_num_traces.load(self.id_attr)
99
+ return await info.context.data_loaders.session_num_traces.load(self.id)
45
100
 
46
101
  @strawberry.field
47
102
  async def num_traces_with_error(
48
103
  self,
49
104
  info: Info[Context, None],
50
105
  ) -> int:
51
- return await info.context.data_loaders.session_num_traces_with_error.load(self.id_attr)
106
+ return await info.context.data_loaders.session_num_traces_with_error.load(self.id)
52
107
 
53
108
  @strawberry.field
54
109
  async def first_input(
55
110
  self,
56
111
  info: Info[Context, None],
57
112
  ) -> Optional[SpanIOValue]:
58
- record = await info.context.data_loaders.session_first_inputs.load(self.id_attr)
113
+ record = await info.context.data_loaders.session_first_inputs.load(self.id)
59
114
  if record is None:
60
115
  return None
61
116
  return SpanIOValue(
@@ -68,7 +123,7 @@ class ProjectSession(Node):
68
123
  self,
69
124
  info: Info[Context, None],
70
125
  ) -> Optional[SpanIOValue]:
71
- record = await info.context.data_loaders.session_last_outputs.load(self.id_attr)
126
+ record = await info.context.data_loaders.session_last_outputs.load(self.id)
72
127
  if record is None:
73
128
  return None
74
129
  return SpanIOValue(
@@ -81,7 +136,7 @@ class ProjectSession(Node):
81
136
  self,
82
137
  info: Info[Context, None],
83
138
  ) -> TokenUsage:
84
- usage = await info.context.data_loaders.session_token_usages.load(self.id_attr)
139
+ usage = await info.context.data_loaders.session_token_usages.load(self.id)
85
140
  return TokenUsage(
86
141
  prompt=usage.prompt,
87
142
  completion=usage.completion,
@@ -106,12 +161,12 @@ class ProjectSession(Node):
106
161
  )
107
162
  stmt = (
108
163
  select(models.Trace)
109
- .filter_by(project_session_rowid=self.id_attr)
164
+ .filter_by(project_session_rowid=self.id)
110
165
  .order_by(models.Trace.start_time)
111
166
  )
112
167
  async with info.context.db() as session:
113
168
  traces = await session.stream_scalars(stmt)
114
- data = [Trace(trace_rowid=trace.id, db_trace=trace) async for trace in traces]
169
+ data = [Trace(id=trace.id, db_record=trace) async for trace in traces]
115
170
  return connection_from_list(data=data, args=args)
116
171
 
117
172
  @strawberry.field
@@ -121,7 +176,7 @@ class ProjectSession(Node):
121
176
  probability: float,
122
177
  ) -> Optional[float]:
123
178
  return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
124
- (self.id_attr, probability)
179
+ (self.id, probability)
125
180
  )
126
181
 
127
182
  @strawberry.field
@@ -130,7 +185,7 @@ class ProjectSession(Node):
130
185
  info: Info[Context, None],
131
186
  ) -> SpanCostSummary:
132
187
  loader = info.context.data_loaders.span_cost_summary_by_project_session
133
- summary = await loader.load(self.id_attr)
188
+ summary = await loader.load(self.id)
134
189
  return SpanCostSummary(
135
190
  prompt=CostBreakdown(
136
191
  tokens=summary.prompt.tokens,
@@ -152,7 +207,7 @@ class ProjectSession(Node):
152
207
  info: Info[Context, None],
153
208
  ) -> list[SpanCostDetailSummaryEntry]:
154
209
  loader = info.context.data_loaders.span_cost_detail_summary_entries_by_project_session
155
- summary = await loader.load(self.id_attr)
210
+ summary = await loader.load(self.id)
156
211
  return [
157
212
  SpanCostDetailSummaryEntry(
158
213
  token_type=entry.token_type,
@@ -165,15 +220,77 @@ class ProjectSession(Node):
165
220
  for entry in summary
166
221
  ]
167
222
 
223
+ @strawberry.field
224
+ async def session_annotations(
225
+ self,
226
+ info: Info[Context, None],
227
+ ) -> list[Annotated["ProjectSessionAnnotation", lazy(".ProjectSessionAnnotation")]]:
228
+ """Get all annotations for this session."""
229
+ from .ProjectSessionAnnotation import ProjectSessionAnnotation
230
+
231
+ stmt = select(models.ProjectSessionAnnotation).filter_by(project_session_id=self.id)
232
+ async with info.context.db() as session:
233
+ annotations = await session.stream_scalars(stmt)
234
+ return [
235
+ ProjectSessionAnnotation(id=annotation.id, db_record=annotation)
236
+ async for annotation in annotations
237
+ ]
238
+
239
+ @strawberry.field(
240
+ description="Summarizes each annotation (by name) associated with the session"
241
+ ) # type: ignore
242
+ async def session_annotation_summaries(
243
+ self,
244
+ info: Info[Context, None],
245
+ filter: Optional[AnnotationFilter] = None,
246
+ ) -> list[AnnotationSummary]:
247
+ """
248
+ Retrieves and summarizes annotations associated with this span.
249
+
250
+ This method aggregates annotation data by name and label, calculating metrics
251
+ such as count of occurrences and sum of scores. The results are organized
252
+ into a structured format that can be easily converted to a DataFrame.
253
+
254
+ Args:
255
+ info: GraphQL context information
256
+ filter: Optional filter to apply to annotations before processing
257
+
258
+ Returns:
259
+ A list of AnnotationSummary objects, each containing:
260
+ - name: The name of the annotation
261
+ - data: A list of dictionaries with label statistics
262
+ """
263
+ # Load all annotations for this span from the data loader
264
+ annotations = await info.context.data_loaders.session_annotations_by_session.load(self.id)
265
+
266
+ # Apply filter if provided to narrow down the annotations
267
+ if filter:
268
+ annotations = [
269
+ annotation for annotation in annotations if satisfies_filter(annotation, filter)
270
+ ]
271
+
272
+ @dataclass
273
+ class Metrics:
274
+ record_count: int = 0
275
+ label_count: int = 0
276
+ score_sum: float = 0
277
+ score_count: int = 0
278
+
279
+ summaries: defaultdict[str, defaultdict[Optional[str], Metrics]] = defaultdict(
280
+ lambda: defaultdict(Metrics)
281
+ )
282
+ for annotation in annotations:
283
+ metrics = summaries[annotation.name][annotation.label]
284
+ metrics.record_count += 1
285
+ metrics.label_count += int(annotation.label is not None)
286
+ metrics.score_sum += annotation.score or 0
287
+ metrics.score_count += int(annotation.score is not None)
168
288
 
169
- def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
170
- return ProjectSession(
171
- id_attr=project_session.id,
172
- session_id=project_session.session_id,
173
- start_time=project_session.start_time,
174
- project_rowid=project_session.project_id,
175
- end_time=project_session.end_time,
176
- )
289
+ result: list[AnnotationSummary] = []
290
+ for name, label_metrics in summaries.items():
291
+ rows = [{"label": label, **asdict(metrics)} for label, metrics in label_metrics.items()]
292
+ result.append(AnnotationSummary(name=name, df=pd.DataFrame(rows), simple_avg=True))
293
+ return result
177
294
 
178
295
 
179
296
  INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
@@ -0,0 +1,187 @@
1
+ from math import isfinite
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
+
4
+ import strawberry
5
+ from strawberry.relay import GlobalID, Node, NodeID
6
+ from strawberry.scalars import JSON
7
+ from strawberry.types import Info
8
+
9
+ from phoenix.db import models
10
+ from phoenix.server.api.context import Context
11
+ from phoenix.server.api.types.AnnotatorKind import AnnotatorKind
12
+
13
+ from .Annotation import Annotation
14
+ from .AnnotationSource import AnnotationSource
15
+
16
+ if TYPE_CHECKING:
17
+ from .ProjectSession import ProjectSession
18
+ from .User import User
19
+
20
+
21
+ @strawberry.type
22
+ class ProjectSessionAnnotation(Node, Annotation):
23
+ id: NodeID[int]
24
+ db_record: strawberry.Private[Optional[models.ProjectSessionAnnotation]] = None
25
+
26
+ def __post_init__(self) -> None:
27
+ if self.db_record and self.id != self.db_record.id:
28
+ raise ValueError("ProjectSessionAnnotation ID mismatch")
29
+
30
+ @strawberry.field(description="Name of the annotation, e.g. 'helpfulness' or 'relevance'.") # type: ignore
31
+ async def name(
32
+ self,
33
+ info: Info[Context, None],
34
+ ) -> str:
35
+ if self.db_record:
36
+ val = self.db_record.name
37
+ else:
38
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
39
+ (self.id, models.ProjectSessionAnnotation.name),
40
+ )
41
+ return val
42
+
43
+ @strawberry.field(description="The kind of annotator that produced the annotation.") # type: ignore
44
+ async def annotator_kind(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> AnnotatorKind:
48
+ if self.db_record:
49
+ val = self.db_record.annotator_kind
50
+ else:
51
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
52
+ (self.id, models.ProjectSessionAnnotation.annotator_kind),
53
+ )
54
+ return AnnotatorKind(val)
55
+
56
+ @strawberry.field(
57
+ description="Value of the annotation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary." # noqa: E501
58
+ ) # type: ignore
59
+ async def label(
60
+ self,
61
+ info: Info[Context, None],
62
+ ) -> Optional[str]:
63
+ if self.db_record:
64
+ val = self.db_record.label
65
+ else:
66
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
67
+ (self.id, models.ProjectSessionAnnotation.label),
68
+ )
69
+ return val
70
+
71
+ @strawberry.field(description="Value of the annotation in the form of a numeric score.") # type: ignore
72
+ async def score(
73
+ self,
74
+ info: Info[Context, None],
75
+ ) -> Optional[float]:
76
+ if self.db_record:
77
+ val = self.db_record.score
78
+ else:
79
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
80
+ (self.id, models.ProjectSessionAnnotation.score),
81
+ )
82
+ return val if val is not None and isfinite(val) else None
83
+
84
+ @strawberry.field(
85
+ description="The annotator's explanation for the annotation result (i.e. score or label, or both) given to the subject." # noqa: E501
86
+ ) # type: ignore
87
+ async def explanation(
88
+ self,
89
+ info: Info[Context, None],
90
+ ) -> Optional[str]:
91
+ if self.db_record:
92
+ val = self.db_record.explanation
93
+ else:
94
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
95
+ (self.id, models.ProjectSessionAnnotation.explanation),
96
+ )
97
+ return val
98
+
99
+ @strawberry.field(description="Metadata about the annotation.") # type: ignore
100
+ async def metadata(
101
+ self,
102
+ info: Info[Context, None],
103
+ ) -> JSON:
104
+ if self.db_record:
105
+ val = self.db_record.metadata_
106
+ else:
107
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
108
+ (self.id, models.ProjectSessionAnnotation.metadata_),
109
+ )
110
+ return val
111
+
112
+ @strawberry.field(description="The identifier of the annotation.") # type: ignore
113
+ async def identifier(
114
+ self,
115
+ info: Info[Context, None],
116
+ ) -> str:
117
+ if self.db_record:
118
+ val = self.db_record.identifier
119
+ else:
120
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
121
+ (self.id, models.ProjectSessionAnnotation.identifier),
122
+ )
123
+ return val
124
+
125
+ @strawberry.field(description="The source of the annotation.") # type: ignore
126
+ async def source(
127
+ self,
128
+ info: Info[Context, None],
129
+ ) -> AnnotationSource:
130
+ if self.db_record:
131
+ val = self.db_record.source
132
+ else:
133
+ val = await info.context.data_loaders.project_session_annotation_fields.load(
134
+ (self.id, models.ProjectSessionAnnotation.source),
135
+ )
136
+ return AnnotationSource(val)
137
+
138
+ @strawberry.field(description="The project session associated with the annotation.") # type: ignore
139
+ async def project_session_id(
140
+ self,
141
+ info: Info[Context, None],
142
+ ) -> GlobalID:
143
+ from phoenix.server.api.types.ProjectSession import ProjectSession
144
+
145
+ if self.db_record:
146
+ project_session_id = self.db_record.project_session_id
147
+ else:
148
+ project_session_id = (
149
+ await info.context.data_loaders.project_session_annotation_fields.load(
150
+ (self.id, models.ProjectSessionAnnotation.project_session_id),
151
+ )
152
+ )
153
+ return GlobalID(type_name=ProjectSession.__name__, node_id=str(project_session_id))
154
+
155
+ @strawberry.field(description="The project session associated with the annotation.") # type: ignore
156
+ async def project_session(
157
+ self,
158
+ info: Info[Context, None],
159
+ ) -> Annotated["ProjectSession", strawberry.lazy(".ProjectSession")]:
160
+ if self.db_record:
161
+ project_session_id = self.db_record.project_session_id
162
+ else:
163
+ project_session_id = (
164
+ await info.context.data_loaders.project_session_annotation_fields.load(
165
+ (self.id, models.ProjectSessionAnnotation.project_session_id),
166
+ )
167
+ )
168
+ from .ProjectSession import ProjectSession
169
+
170
+ return ProjectSession(id=project_session_id)
171
+
172
+ @strawberry.field(description="The user that produced the annotation.") # type: ignore
173
+ async def user(
174
+ self,
175
+ info: Info[Context, None],
176
+ ) -> Optional[Annotated["User", strawberry.lazy(".User")]]:
177
+ if self.db_record:
178
+ user_id = self.db_record.user_id
179
+ else:
180
+ user_id = await info.context.data_loaders.project_session_annotation_fields.load(
181
+ (self.id, models.ProjectSessionAnnotation.user_id),
182
+ )
183
+ if user_id is None:
184
+ return None
185
+ from .User import User
186
+
187
+ return User(id=user_id)
@@ -106,5 +106,5 @@ class ProjectTraceRetentionPolicy(Node):
106
106
  project_rowids = await info.context.data_loaders.projects_by_trace_retention_policy_id.load(
107
107
  self.id
108
108
  )
109
- data = [Project(project_rowid=project_rowid) for project_rowid in project_rowids]
109
+ data = [Project(id=project_rowid) for project_rowid in project_rowids]
110
110
  return connection_from_list(data=data, args=args)
@@ -6,9 +6,11 @@ import strawberry
6
6
  from sqlalchemy import func, select
7
7
  from strawberry import UNSET
8
8
  from strawberry.relay import Connection, GlobalID, Node, NodeID
9
+ from strawberry.scalars import JSON
9
10
  from strawberry.types import Info
10
11
 
11
12
  from phoenix.db import models
13
+ from phoenix.db.types.identifier import Identifier as IdentifierModel
12
14
  from phoenix.server.api.context import Context
13
15
  from phoenix.server.api.exceptions import NotFound
14
16
  from phoenix.server.api.types.Identifier import Identifier
@@ -19,24 +21,96 @@ from phoenix.server.api.types.pagination import (
19
21
  connection_from_list,
20
22
  )
21
23
 
24
+ from .PromptLabel import PromptLabel
22
25
  from .PromptVersion import (
23
26
  PromptVersion,
24
27
  to_gql_prompt_version,
25
28
  )
26
- from .PromptVersionTag import PromptVersionTag, to_gql_prompt_version_tag
29
+ from .PromptVersionTag import PromptVersionTag
27
30
 
28
31
 
29
32
  @strawberry.type
30
33
  class Prompt(Node):
31
- id_attr: NodeID[int]
32
- source_prompt_id: Optional[GlobalID]
33
- name: Identifier
34
- description: Optional[str]
35
- created_at: datetime
34
+ id: NodeID[int]
35
+ db_record: strawberry.Private[Optional[models.Prompt]] = None
36
+
37
+ def __post_init__(self) -> None:
38
+ if self.db_record and self.id != self.db_record.id:
39
+ raise ValueError("Prompt ID mismatch")
40
+
41
+ @strawberry.field
42
+ async def source_prompt_id(
43
+ self,
44
+ info: Info[Context, None],
45
+ ) -> Optional[GlobalID]:
46
+ if self.db_record:
47
+ source_id = self.db_record.source_prompt_id
48
+ else:
49
+ source_id = await info.context.data_loaders.prompt_fields.load(
50
+ (self.id, models.Prompt.source_prompt_id),
51
+ )
52
+ if not source_id:
53
+ return None
54
+ return GlobalID(Prompt.__name__, str(source_id))
55
+
56
+ @strawberry.field
57
+ async def name(
58
+ self,
59
+ info: Info[Context, None],
60
+ ) -> Identifier:
61
+ if self.db_record:
62
+ val = self.db_record.name
63
+ else:
64
+ val = await info.context.data_loaders.prompt_fields.load(
65
+ (self.id, models.Prompt.name),
66
+ )
67
+ return Identifier(val.root)
68
+
69
+ @strawberry.field
70
+ async def description(
71
+ self,
72
+ info: Info[Context, None],
73
+ ) -> Optional[str]:
74
+ if self.db_record:
75
+ val = self.db_record.description
76
+ else:
77
+ val = await info.context.data_loaders.prompt_fields.load(
78
+ (self.id, models.Prompt.description),
79
+ )
80
+ return val
81
+
82
+ @strawberry.field
83
+ async def metadata(
84
+ self,
85
+ info: Info[Context, None],
86
+ ) -> JSON:
87
+ if self.db_record:
88
+ val = self.db_record.metadata_
89
+ else:
90
+ val = await info.context.data_loaders.prompt_fields.load(
91
+ (self.id, models.Prompt.metadata_),
92
+ )
93
+ return val
94
+
95
+ @strawberry.field
96
+ async def created_at(
97
+ self,
98
+ info: Info[Context, None],
99
+ ) -> datetime:
100
+ if self.db_record:
101
+ val = self.db_record.created_at
102
+ else:
103
+ val = await info.context.data_loaders.prompt_fields.load(
104
+ (self.id, models.Prompt.created_at),
105
+ )
106
+ return val
36
107
 
37
108
  @strawberry.field
38
109
  async def version(
39
- self, info: Info[Context, None], version_id: Optional[GlobalID] = None
110
+ self,
111
+ info: Info[Context, None],
112
+ version_id: Optional[GlobalID] = None,
113
+ tag_name: Optional[Identifier] = None,
40
114
  ) -> PromptVersion:
41
115
  async with info.context.db() as session:
42
116
  if version_id:
@@ -44,15 +118,28 @@ class Prompt(Node):
44
118
  version = await session.scalar(
45
119
  select(models.PromptVersion).where(
46
120
  models.PromptVersion.id == v_id,
47
- models.PromptVersion.prompt_id == self.id_attr,
121
+ models.PromptVersion.prompt_id == self.id,
48
122
  )
49
123
  )
50
124
  if not version:
51
125
  raise NotFound(f"Prompt version not found: {version_id}")
126
+ elif tag_name:
127
+ try:
128
+ name = IdentifierModel(tag_name)
129
+ except ValueError:
130
+ raise NotFound(f"Prompt version tag not found: {tag_name}")
131
+ version = await session.scalar(
132
+ select(models.PromptVersion)
133
+ .where(models.PromptVersion.prompt_id == self.id)
134
+ .join_from(models.PromptVersion, models.PromptVersionTag)
135
+ .where(models.PromptVersionTag.name == name)
136
+ )
137
+ if not version:
138
+ raise NotFound(f"This prompt has no associated versions by tag {tag_name}")
52
139
  else:
53
140
  stmt = (
54
141
  select(models.PromptVersion)
55
- .where(models.PromptVersion.prompt_id == self.id_attr)
142
+ .where(models.PromptVersion.prompt_id == self.id)
56
143
  .order_by(models.PromptVersion.id.desc())
57
144
  .limit(1)
58
145
  )
@@ -65,10 +152,11 @@ class Prompt(Node):
65
152
  async def version_tags(self, info: Info[Context, None]) -> list[PromptVersionTag]:
66
153
  async with info.context.db() as session:
67
154
  stmt = select(models.PromptVersionTag).where(
68
- models.PromptVersionTag.prompt_id == self.id_attr
155
+ models.PromptVersionTag.prompt_id == self.id
69
156
  )
70
157
  return [
71
- to_gql_prompt_version_tag(tag) async for tag in await session.stream_scalars(stmt)
158
+ PromptVersionTag(id=tag.id, db_record=tag)
159
+ async for tag in await session.stream_scalars(stmt)
72
160
  ]
73
161
 
74
162
  @strawberry.field
@@ -89,7 +177,7 @@ class Prompt(Node):
89
177
  row_number = func.row_number().over(order_by=models.PromptVersion.id).label("row_number")
90
178
  stmt = (
91
179
  select(models.PromptVersion, row_number)
92
- .where(models.PromptVersion.prompt_id == self.id_attr)
180
+ .where(models.PromptVersion.prompt_id == self.id)
93
181
  .order_by(models.PromptVersion.id.desc())
94
182
  )
95
183
  async with info.context.db() as session:
@@ -101,34 +189,26 @@ class Prompt(Node):
101
189
 
102
190
  @strawberry.field
103
191
  async def source_prompt(self, info: Info[Context, None]) -> Optional["Prompt"]:
104
- if not self.source_prompt_id:
192
+ if self.db_record:
193
+ id_ = self.db_record.source_prompt_id
194
+ else:
195
+ id_ = await info.context.data_loaders.prompt_fields.load(
196
+ (self.id, models.Prompt.source_prompt_id),
197
+ )
198
+ if not id_:
105
199
  return None
200
+ async with info.context.db() as session:
201
+ source_prompt = await session.get(models.Prompt, id_)
202
+ if not source_prompt:
203
+ raise NotFound(f"Source prompt not found: {id_}")
204
+ return Prompt(id=source_prompt.id, db_record=source_prompt)
106
205
 
107
- source_prompt_id = from_global_id_with_expected_type(
108
- global_id=self.source_prompt_id, expected_type_name=Prompt.__name__
109
- )
110
-
206
+ @strawberry.field
207
+ async def labels(self, info: Info[Context, None]) -> list["PromptLabel"]:
111
208
  async with info.context.db() as session:
112
- source_prompt = await session.scalar(
113
- select(models.Prompt).where(models.Prompt.id == source_prompt_id)
209
+ labels = await session.scalars(
210
+ select(models.PromptLabel)
211
+ .join(models.PromptPromptLabel)
212
+ .where(models.PromptPromptLabel.prompt_id == self.id)
114
213
  )
115
- if not source_prompt:
116
- raise NotFound(f"Source prompt not found: {self.source_prompt_id}")
117
- return to_gql_prompt_from_orm(source_prompt)
118
-
119
-
120
- def to_gql_prompt_from_orm(orm_model: "models.Prompt") -> Prompt:
121
- if not orm_model.source_prompt_id:
122
- source_prompt_gid = None
123
- else:
124
- source_prompt_gid = GlobalID(
125
- Prompt.__name__,
126
- str(orm_model.source_prompt_id),
127
- )
128
- return Prompt(
129
- id_attr=orm_model.id,
130
- source_prompt_id=source_prompt_gid,
131
- name=Identifier(orm_model.name.root),
132
- description=orm_model.description,
133
- created_at=orm_model.created_at,
134
- )
214
+ return [PromptLabel(id=label.id, db_record=label) for label in labels]