arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,41 +1,155 @@
1
1
  from datetime import datetime
2
- from typing import ClassVar, Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import func, select
6
- from sqlalchemy.orm import joinedload
7
- from sqlalchemy.sql.functions import coalesce
8
6
  from strawberry import UNSET, Private
9
- from strawberry.relay import Connection, Node, NodeID
7
+ from strawberry.relay import Connection, GlobalID, Node, NodeID
10
8
  from strawberry.scalars import JSON
11
9
  from strawberry.types import Info
12
10
 
13
11
  from phoenix.db import models
14
12
  from phoenix.server.api.context import Context
13
+ from phoenix.server.api.exceptions import BadRequest
14
+ from phoenix.server.api.input_types.ExperimentRunSort import (
15
+ ExperimentRunSort,
16
+ add_order_by_and_page_start_to_query,
17
+ get_experiment_run_cursor,
18
+ )
15
19
  from phoenix.server.api.types.CostBreakdown import CostBreakdown
20
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
21
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
16
22
  from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
17
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
23
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
18
24
  from phoenix.server.api.types.pagination import (
19
25
  ConnectionArgs,
26
+ Cursor,
20
27
  CursorString,
28
+ connection_from_cursors_and_nodes,
21
29
  connection_from_list,
22
30
  )
23
- from phoenix.server.api.types.Project import Project
24
31
  from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
25
32
  from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
26
33
 
34
+ _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE = 50
35
+
36
+ if TYPE_CHECKING:
37
+ from .Project import Project
38
+
27
39
 
28
40
  @strawberry.type
29
41
  class Experiment(Node):
30
- _table: ClassVar[type[models.Base]] = models.Experiment
42
+ id: NodeID[int]
43
+ db_record: strawberry.Private[Optional[models.Experiment]] = None
31
44
  cached_sequence_number: Private[Optional[int]] = None
32
- id_attr: NodeID[int]
33
- name: str
34
- project_name: Optional[str]
35
- description: Optional[str]
36
- metadata: JSON
37
- created_at: datetime
38
- updated_at: datetime
45
+
46
+ def __post_init__(self) -> None:
47
+ if self.db_record and self.id != self.db_record.id:
48
+ raise ValueError("Experiment ID mismatch")
49
+
50
+ @strawberry.field
51
+ async def name(
52
+ self,
53
+ info: Info[Context, None],
54
+ ) -> str:
55
+ if self.db_record:
56
+ val = self.db_record.name
57
+ else:
58
+ val = await info.context.data_loaders.experiment_fields.load(
59
+ (self.id, models.Experiment.name),
60
+ )
61
+ return val
62
+
63
+ @strawberry.field
64
+ async def project_name(
65
+ self,
66
+ info: Info[Context, None],
67
+ ) -> Optional[str]:
68
+ if self.db_record:
69
+ val = self.db_record.project_name
70
+ else:
71
+ val = await info.context.data_loaders.experiment_fields.load(
72
+ (self.id, models.Experiment.project_name),
73
+ )
74
+ return val
75
+
76
+ @strawberry.field
77
+ async def description(
78
+ self,
79
+ info: Info[Context, None],
80
+ ) -> Optional[str]:
81
+ if self.db_record:
82
+ val = self.db_record.description
83
+ else:
84
+ val = await info.context.data_loaders.experiment_fields.load(
85
+ (self.id, models.Experiment.description),
86
+ )
87
+ return val
88
+
89
+ @strawberry.field
90
+ async def repetitions(
91
+ self,
92
+ info: Info[Context, None],
93
+ ) -> int:
94
+ if self.db_record:
95
+ val = self.db_record.repetitions
96
+ else:
97
+ val = await info.context.data_loaders.experiment_fields.load(
98
+ (self.id, models.Experiment.repetitions),
99
+ )
100
+ return val
101
+
102
+ @strawberry.field
103
+ async def dataset_version_id(
104
+ self,
105
+ info: Info[Context, None],
106
+ ) -> GlobalID:
107
+ if self.db_record:
108
+ version_id = self.db_record.dataset_version_id
109
+ else:
110
+ version_id = await info.context.data_loaders.experiment_fields.load(
111
+ (self.id, models.Experiment.dataset_version_id),
112
+ )
113
+ return GlobalID(DatasetVersion.__name__, str(version_id))
114
+
115
+ @strawberry.field
116
+ async def metadata(
117
+ self,
118
+ info: Info[Context, None],
119
+ ) -> JSON:
120
+ if self.db_record:
121
+ val = self.db_record.metadata_
122
+ else:
123
+ val = await info.context.data_loaders.experiment_fields.load(
124
+ (self.id, models.Experiment.metadata_),
125
+ )
126
+ return val
127
+
128
+ @strawberry.field
129
+ async def created_at(
130
+ self,
131
+ info: Info[Context, None],
132
+ ) -> datetime:
133
+ if self.db_record:
134
+ val = self.db_record.created_at
135
+ else:
136
+ val = await info.context.data_loaders.experiment_fields.load(
137
+ (self.id, models.Experiment.created_at),
138
+ )
139
+ return val
140
+
141
+ @strawberry.field
142
+ async def updated_at(
143
+ self,
144
+ info: Info[Context, None],
145
+ ) -> datetime:
146
+ if self.db_record:
147
+ val = self.db_record.updated_at
148
+ else:
149
+ val = await info.context.data_loaders.experiment_fields.load(
150
+ (self.id, models.Experiment.updated_at),
151
+ )
152
+ return val
39
153
 
40
154
  @strawberry.field(
41
155
  description="Sequence number (1-based) of experiments belonging to the same dataset"
@@ -45,9 +159,9 @@ class Experiment(Node):
45
159
  info: Info[Context, None],
46
160
  ) -> int:
47
161
  if self.cached_sequence_number is None:
48
- seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id_attr)
162
+ seq_num = await info.context.data_loaders.experiment_sequence_number.load(self.id)
49
163
  if seq_num is None:
50
- raise ValueError(f"invalid experiment: id={self.id_attr}")
164
+ raise ValueError(f"invalid experiment: id={self.id}")
51
165
  self.cached_sequence_number = seq_num
52
166
  return self.cached_sequence_number
53
167
 
@@ -55,41 +169,68 @@ class Experiment(Node):
55
169
  async def runs(
56
170
  self,
57
171
  info: Info[Context, None],
58
- first: Optional[int] = 50,
59
- last: Optional[int] = UNSET,
172
+ first: Optional[int] = _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE,
60
173
  after: Optional[CursorString] = UNSET,
61
- before: Optional[CursorString] = UNSET,
174
+ sort: Optional[ExperimentRunSort] = UNSET,
62
175
  ) -> Connection[ExperimentRun]:
63
- args = ConnectionArgs(
64
- first=first,
65
- after=after if isinstance(after, CursorString) else None,
66
- last=last,
67
- before=before if isinstance(before, CursorString) else None,
176
+ if first is not None and first <= 0:
177
+ raise BadRequest("first must be a positive integer if set")
178
+ page_size = first or _DEFAULT_EXPERIMENT_RUNS_PAGE_SIZE
179
+ experiment_runs_query = (
180
+ select(models.ExperimentRun)
181
+ .where(models.ExperimentRun.experiment_id == self.id)
182
+ .limit(page_size + 1)
183
+ )
184
+
185
+ after_experiment_run_rowid = None
186
+ after_sort_column_value = None
187
+ if after:
188
+ cursor = Cursor.from_string(after)
189
+ after_experiment_run_rowid = cursor.rowid
190
+ if cursor.sort_column is not None:
191
+ after_sort_column_value = cursor.sort_column.value
192
+
193
+ experiment_runs_query = add_order_by_and_page_start_to_query(
194
+ query=experiment_runs_query,
195
+ sort=sort,
196
+ experiment_rowid=self.id,
197
+ after_experiment_run_rowid=after_experiment_run_rowid,
198
+ after_sort_column_value=after_sort_column_value,
68
199
  )
69
- experiment_id = self.id_attr
200
+
70
201
  async with info.context.db() as session:
71
- runs = (
72
- await session.scalars(
73
- select(models.ExperimentRun)
74
- .where(models.ExperimentRun.experiment_id == experiment_id)
75
- .order_by(models.ExperimentRun.id.desc())
76
- .options(
77
- joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
78
- )
79
- )
80
- ).all()
81
- return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
202
+ results = (await session.execute(experiment_runs_query)).all()
203
+
204
+ has_next_page = False
205
+ if len(results) > page_size:
206
+ results = results[:page_size]
207
+ has_next_page = True
208
+
209
+ cursors_and_nodes = []
210
+ for result in results:
211
+ run = result[0]
212
+ annotation_score = result[1] if len(result) > 1 else None
213
+ gql_run = ExperimentRun(id=run.id, db_record=run)
214
+ cursor = get_experiment_run_cursor(
215
+ run=run, annotation_score=annotation_score, sort=sort
216
+ )
217
+ cursors_and_nodes.append((cursor, gql_run))
218
+
219
+ return connection_from_cursors_and_nodes(
220
+ cursors_and_nodes=cursors_and_nodes,
221
+ has_previous_page=False, # set to false since we are only doing forward pagination (https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo.Fields) # noqa: E501
222
+ has_next_page=has_next_page,
223
+ )
82
224
 
83
225
  @strawberry.field
84
226
  async def run_count(self, info: Info[Context, None]) -> int:
85
- experiment_id = self.id_attr
86
- return await info.context.data_loaders.experiment_run_counts.load(experiment_id)
227
+ return await info.context.data_loaders.experiment_run_counts.load(self.id)
87
228
 
88
229
  @strawberry.field
89
230
  async def annotation_summaries(
90
231
  self, info: Info[Context, None]
91
232
  ) -> list[ExperimentAnnotationSummary]:
92
- experiment_id = self.id_attr
233
+ experiment_id = self.id
93
234
  return [
94
235
  ExperimentAnnotationSummary(
95
236
  annotation_name=summary.annotation_name,
@@ -106,40 +247,42 @@ class Experiment(Node):
106
247
 
107
248
  @strawberry.field
108
249
  async def error_rate(self, info: Info[Context, None]) -> Optional[float]:
109
- return await info.context.data_loaders.experiment_error_rates.load(self.id_attr)
250
+ return await info.context.data_loaders.experiment_error_rates.load(self.id)
110
251
 
111
252
  @strawberry.field
112
253
  async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
113
- latency_seconds = await info.context.data_loaders.average_experiment_run_latency.load(
114
- self.id_attr
115
- )
116
- return latency_seconds * 1000 if latency_seconds is not None else None
254
+ latency_ms = await info.context.data_loaders.average_experiment_run_latency.load(self.id)
255
+ return latency_ms
117
256
 
118
257
  @strawberry.field
119
- async def project(self, info: Info[Context, None]) -> Optional[Project]:
120
- if self.project_name is None:
258
+ async def project(
259
+ self, info: Info[Context, None]
260
+ ) -> Optional[Annotated["Project", strawberry.lazy(".Project")]]:
261
+ if self.db_record:
262
+ project_name = self.db_record.project_name
263
+ else:
264
+ project_name = await info.context.data_loaders.experiment_fields.load(
265
+ (self.id, models.Experiment.project_name),
266
+ )
267
+
268
+ if project_name is None:
121
269
  return None
122
270
 
123
- db_project = await info.context.data_loaders.project_by_name.load(self.project_name)
271
+ db_project = await info.context.data_loaders.project_by_name.load(project_name)
124
272
 
125
273
  if db_project is None:
126
274
  return None
275
+ from .Project import Project
127
276
 
128
- return Project(
129
- project_rowid=db_project.id,
130
- db_project=db_project,
131
- )
277
+ return Project(id=db_project.id, db_record=db_project)
132
278
 
133
279
  @strawberry.field
134
280
  def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
135
- return info.context.last_updated_at.get(self._table, self.id_attr)
281
+ return info.context.last_updated_at.get(models.Experiment, self.id)
136
282
 
137
283
  @strawberry.field
138
284
  async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
139
- experiment_id = self.id_attr
140
- summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(
141
- experiment_id
142
- )
285
+ summary = await info.context.data_loaders.span_cost_summary_by_experiment.load(self.id)
143
286
  return SpanCostSummary(
144
287
  prompt=CostBreakdown(
145
288
  tokens=summary.prompt.tokens,
@@ -159,21 +302,19 @@ class Experiment(Node):
159
302
  async def cost_detail_summary_entries(
160
303
  self, info: Info[Context, None]
161
304
  ) -> list[SpanCostDetailSummaryEntry]:
162
- experiment_id = self.id_attr
163
-
164
305
  stmt = (
165
306
  select(
166
307
  models.SpanCostDetail.token_type,
167
308
  models.SpanCostDetail.is_prompt,
168
- coalesce(func.sum(models.SpanCostDetail.cost), 0).label("cost"),
169
- coalesce(func.sum(models.SpanCostDetail.tokens), 0).label("tokens"),
309
+ func.sum(models.SpanCostDetail.cost).label("cost"),
310
+ func.sum(models.SpanCostDetail.tokens).label("tokens"),
170
311
  )
171
312
  .select_from(models.SpanCostDetail)
172
313
  .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
173
314
  .join(models.Span, models.SpanCost.span_rowid == models.Span.id)
174
315
  .join(models.Trace, models.Span.trace_rowid == models.Trace.id)
175
316
  .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
176
- .where(models.ExperimentRun.experiment_id == experiment_id)
317
+ .where(models.ExperimentRun.experiment_id == self.id)
177
318
  .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
178
319
  )
179
320
 
@@ -188,6 +329,17 @@ class Experiment(Node):
188
329
  async for token_type, is_prompt, cost, tokens in data
189
330
  ]
190
331
 
332
+ @strawberry.field
333
+ async def dataset_splits(
334
+ self,
335
+ info: Info[Context, None],
336
+ ) -> Connection[DatasetSplit]:
337
+ """Returns the dataset splits associated with this experiment."""
338
+ splits = await info.context.data_loaders.experiment_dataset_splits.load(self.id)
339
+ return connection_from_list(
340
+ [DatasetSplit(id=split.id, db_record=split) for split in splits], ConnectionArgs()
341
+ )
342
+
191
343
 
192
344
  def to_gql_experiment(
193
345
  experiment: models.Experiment,
@@ -197,12 +349,7 @@ def to_gql_experiment(
197
349
  Converts an ORM experiment to a GraphQL Experiment.
198
350
  """
199
351
  return Experiment(
352
+ id=experiment.id,
353
+ db_record=experiment,
200
354
  cached_sequence_number=sequence_number,
201
- id_attr=experiment.id,
202
- name=experiment.name,
203
- project_name=experiment.project_name,
204
- description=experiment.description,
205
- metadata=experiment.metadata_,
206
- created_at=experiment.created_at,
207
- updated_at=experiment.updated_at,
208
355
  )
@@ -1,18 +1,12 @@
1
1
  import strawberry
2
- from strawberry.relay import GlobalID, Node, NodeID
2
+ from strawberry.relay import Node, NodeID
3
3
 
4
4
  from phoenix.server.api.types.DatasetExample import DatasetExample
5
- from phoenix.server.api.types.ExperimentRun import ExperimentRun
6
-
7
-
8
- @strawberry.type
9
- class RunComparisonItem:
10
- experiment_id: GlobalID
11
- runs: list[ExperimentRun]
5
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import ExperimentRepeatedRunGroup
12
6
 
13
7
 
14
8
  @strawberry.type
15
9
  class ExperimentComparison(Node):
16
10
  id_attr: NodeID[int]
17
11
  example: DatasetExample
18
- run_comparison_items: list[RunComparisonItem]
12
+ repeated_run_groups: list[ExperimentRepeatedRunGroup]
@@ -0,0 +1,155 @@
1
+ import re
2
+ from base64 import b64decode
3
+ from typing import Optional
4
+
5
+ import strawberry
6
+ from sqlalchemy import func, select
7
+ from strawberry.relay import GlobalID, Node
8
+ from strawberry.types import Info
9
+ from typing_extensions import Self, TypeAlias
10
+
11
+ from phoenix.db import models
12
+ from phoenix.server.api.context import Context
13
+ from phoenix.server.api.types.CostBreakdown import CostBreakdown
14
+ from phoenix.server.api.types.ExperimentRepeatedRunGroupAnnotationSummary import (
15
+ ExperimentRepeatedRunGroupAnnotationSummary,
16
+ )
17
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
18
+ from phoenix.server.api.types.SpanCostDetailSummaryEntry import SpanCostDetailSummaryEntry
19
+ from phoenix.server.api.types.SpanCostSummary import SpanCostSummary
20
+
21
+ ExperimentRowId: TypeAlias = int
22
+ DatasetExampleRowId: TypeAlias = int
23
+
24
+
25
+ @strawberry.type
26
+ class ExperimentRepeatedRunGroup(Node):
27
+ experiment_rowid: strawberry.Private[ExperimentRowId]
28
+ dataset_example_rowid: strawberry.Private[DatasetExampleRowId]
29
+ cached_runs: strawberry.Private[Optional[list[ExperimentRun]]] = None
30
+
31
+ @strawberry.field
32
+ async def runs(self, info: Info[Context, None]) -> list[ExperimentRun]:
33
+ if self.cached_runs is not None:
34
+ return self.cached_runs
35
+ runs = await info.context.data_loaders.experiment_runs_by_experiment_and_example.load(
36
+ (self.experiment_rowid, self.dataset_example_rowid)
37
+ )
38
+ return [ExperimentRun(id=run.id, db_record=run) for run in runs]
39
+
40
+ @classmethod
41
+ def resolve_id(
42
+ cls,
43
+ root: Self,
44
+ *,
45
+ info: Info,
46
+ ) -> str:
47
+ return (
48
+ f"experiment_id={root.experiment_rowid}:dataset_example_id={root.dataset_example_rowid}"
49
+ )
50
+
51
+ @strawberry.field
52
+ def experiment_id(self) -> strawberry.ID:
53
+ from phoenix.server.api.types.Experiment import Experiment
54
+
55
+ return strawberry.ID(str(GlobalID(Experiment.__name__, str(self.experiment_rowid))))
56
+
57
+ @strawberry.field
58
+ async def average_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
59
+ return await info.context.data_loaders.average_experiment_repeated_run_group_latency.load(
60
+ (self.experiment_rowid, self.dataset_example_rowid)
61
+ )
62
+
63
+ @strawberry.field
64
+ async def cost_summary(self, info: Info[Context, None]) -> SpanCostSummary:
65
+ experiment_id = self.experiment_rowid
66
+ example_id = self.dataset_example_rowid
67
+ summary = (
68
+ await info.context.data_loaders.span_cost_summary_by_experiment_repeated_run_group.load(
69
+ (experiment_id, example_id)
70
+ )
71
+ )
72
+ return SpanCostSummary(
73
+ prompt=CostBreakdown(
74
+ tokens=summary.prompt.tokens,
75
+ cost=summary.prompt.cost,
76
+ ),
77
+ completion=CostBreakdown(
78
+ tokens=summary.completion.tokens,
79
+ cost=summary.completion.cost,
80
+ ),
81
+ total=CostBreakdown(
82
+ tokens=summary.total.tokens,
83
+ cost=summary.total.cost,
84
+ ),
85
+ )
86
+
87
+ @strawberry.field
88
+ async def cost_detail_summary_entries(
89
+ self, info: Info[Context, None]
90
+ ) -> list[SpanCostDetailSummaryEntry]:
91
+ experiment_id = self.experiment_rowid
92
+ example_id = self.dataset_example_rowid
93
+ stmt = (
94
+ select(
95
+ models.SpanCostDetail.token_type,
96
+ models.SpanCostDetail.is_prompt,
97
+ func.sum(models.SpanCostDetail.cost).label("cost"),
98
+ func.sum(models.SpanCostDetail.tokens).label("tokens"),
99
+ )
100
+ .select_from(models.SpanCostDetail)
101
+ .join(models.SpanCost, models.SpanCostDetail.span_cost_id == models.SpanCost.id)
102
+ .join(models.Trace, models.SpanCost.trace_rowid == models.Trace.id)
103
+ .join(models.ExperimentRun, models.ExperimentRun.trace_id == models.Trace.trace_id)
104
+ .where(models.ExperimentRun.experiment_id == experiment_id)
105
+ .where(models.ExperimentRun.dataset_example_id == example_id)
106
+ .group_by(models.SpanCostDetail.token_type, models.SpanCostDetail.is_prompt)
107
+ )
108
+
109
+ async with info.context.db() as session:
110
+ data = await session.stream(stmt)
111
+ return [
112
+ SpanCostDetailSummaryEntry(
113
+ token_type=token_type,
114
+ is_prompt=is_prompt,
115
+ value=CostBreakdown(tokens=tokens, cost=cost),
116
+ )
117
+ async for token_type, is_prompt, cost, tokens in data
118
+ ]
119
+
120
+ @strawberry.field
121
+ async def annotation_summaries(
122
+ self,
123
+ info: Info[Context, None],
124
+ ) -> list[ExperimentRepeatedRunGroupAnnotationSummary]:
125
+ loader = info.context.data_loaders.experiment_repeated_run_group_annotation_summaries
126
+ summaries = await loader.load((self.experiment_rowid, self.dataset_example_rowid))
127
+ return [
128
+ ExperimentRepeatedRunGroupAnnotationSummary(
129
+ annotation_name=summary.annotation_name,
130
+ mean_score=summary.mean_score,
131
+ )
132
+ for summary in summaries
133
+ ]
134
+
135
+
136
+ _EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN = re.compile(
137
+ r"ExperimentRepeatedRunGroup:experiment_id=(\d+):dataset_example_id=(\d+)"
138
+ )
139
+
140
+
141
+ def parse_experiment_repeated_run_group_node_id(
142
+ node_id: str,
143
+ ) -> tuple[ExperimentRowId, DatasetExampleRowId]:
144
+ decoded_node_id = _base64_decode(node_id)
145
+ match = re.match(_EXPERIMENT_REPEATED_RUN_GROUP_NODE_ID_PATTERN, decoded_node_id)
146
+ if not match:
147
+ raise ValueError(f"Invalid node ID format: {node_id}")
148
+
149
+ experiment_id = int(match.group(1))
150
+ dataset_example_id = int(match.group(2))
151
+ return experiment_id, dataset_example_id
152
+
153
+
154
+ def _base64_decode(string: str) -> str:
155
+ return b64decode(string.encode()).decode()
@@ -0,0 +1,9 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class ExperimentRepeatedRunGroupAnnotationSummary:
8
+ annotation_name: str
9
+ mean_score: Optional[float]