arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,9 @@
1
1
  from collections.abc import AsyncIterable
2
2
  from datetime import datetime
3
- from typing import ClassVar, Optional, cast
3
+ from typing import Optional, cast
4
4
 
5
5
  import strawberry
6
- from sqlalchemy import and_, func, or_, select
6
+ from sqlalchemy import Text, and_, func, or_, select
7
7
  from sqlalchemy.sql.functions import count
8
8
  from strawberry import UNSET
9
9
  from strawberry.relay import Connection, GlobalID, Node, NodeID
@@ -15,9 +15,13 @@ from phoenix.server.api.context import Context
15
15
  from phoenix.server.api.exceptions import BadRequest
16
16
  from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
17
17
  from phoenix.server.api.types.DatasetExample import DatasetExample
18
+ from phoenix.server.api.types.DatasetExperimentAnnotationSummary import (
19
+ DatasetExperimentAnnotationSummary,
20
+ )
21
+ from phoenix.server.api.types.DatasetLabel import DatasetLabel
22
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
18
23
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
19
24
  from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
20
- from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
21
25
  from phoenix.server.api.types.node import from_global_id_with_expected_type
22
26
  from phoenix.server.api.types.pagination import (
23
27
  ConnectionArgs,
@@ -29,13 +33,77 @@ from phoenix.server.api.types.SortDir import SortDir
29
33
 
30
34
  @strawberry.type
31
35
  class Dataset(Node):
32
- _table: ClassVar[type[models.Base]] = models.Experiment
33
- id_attr: NodeID[int]
34
- name: str
35
- description: Optional[str]
36
- metadata: JSON
37
- created_at: datetime
38
- updated_at: datetime
36
+ id: NodeID[int]
37
+ db_record: strawberry.Private[Optional[models.Dataset]] = None
38
+
39
+ def __post_init__(self) -> None:
40
+ if self.db_record and self.id != self.db_record.id:
41
+ raise ValueError("Dataset ID mismatch")
42
+
43
+ @strawberry.field
44
+ async def name(
45
+ self,
46
+ info: Info[Context, None],
47
+ ) -> str:
48
+ if self.db_record:
49
+ val = self.db_record.name
50
+ else:
51
+ val = await info.context.data_loaders.dataset_fields.load(
52
+ (self.id, models.Dataset.name),
53
+ )
54
+ return val
55
+
56
+ @strawberry.field
57
+ async def description(
58
+ self,
59
+ info: Info[Context, None],
60
+ ) -> Optional[str]:
61
+ if self.db_record:
62
+ val = self.db_record.description
63
+ else:
64
+ val = await info.context.data_loaders.dataset_fields.load(
65
+ (self.id, models.Dataset.description),
66
+ )
67
+ return val
68
+
69
+ @strawberry.field
70
+ async def metadata(
71
+ self,
72
+ info: Info[Context, None],
73
+ ) -> JSON:
74
+ if self.db_record:
75
+ val = self.db_record.metadata_
76
+ else:
77
+ val = await info.context.data_loaders.dataset_fields.load(
78
+ (self.id, models.Dataset.metadata_),
79
+ )
80
+ return val
81
+
82
+ @strawberry.field
83
+ async def created_at(
84
+ self,
85
+ info: Info[Context, None],
86
+ ) -> datetime:
87
+ if self.db_record:
88
+ val = self.db_record.created_at
89
+ else:
90
+ val = await info.context.data_loaders.dataset_fields.load(
91
+ (self.id, models.Dataset.created_at),
92
+ )
93
+ return val
94
+
95
+ @strawberry.field
96
+ async def updated_at(
97
+ self,
98
+ info: Info[Context, None],
99
+ ) -> datetime:
100
+ if self.db_record:
101
+ val = self.db_record.updated_at
102
+ else:
103
+ val = await info.context.data_loaders.dataset_fields.load(
104
+ (self.id, models.Dataset.updated_at),
105
+ )
106
+ return val
39
107
 
40
108
  @strawberry.field
41
109
  async def versions(
@@ -54,7 +122,7 @@ class Dataset(Node):
54
122
  before=before if isinstance(before, CursorString) else None,
55
123
  )
56
124
  async with info.context.db() as session:
57
- stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
125
+ stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id)
58
126
  if sort:
59
127
  # For now assume the the column names match 1:1 with the enum values
60
128
  sort_col = getattr(models.DatasetVersion, sort.col.value)
@@ -65,15 +133,7 @@ class Dataset(Node):
65
133
  else:
66
134
  stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
67
135
  versions = await session.scalars(stmt)
68
- data = [
69
- DatasetVersion(
70
- id_attr=version.id,
71
- description=version.description,
72
- metadata=version.metadata_,
73
- created_at=version.created_at,
74
- )
75
- for version in versions
76
- ]
136
+ data = [DatasetVersion(id=version.id, db_record=version) for version in versions]
77
137
  return connection_from_list(data=data, args=args)
78
138
 
79
139
  @strawberry.field(
@@ -84,8 +144,9 @@ class Dataset(Node):
84
144
  self,
85
145
  info: Info[Context, None],
86
146
  dataset_version_id: Optional[GlobalID] = UNSET,
147
+ split_ids: Optional[list[GlobalID]] = UNSET,
87
148
  ) -> int:
88
- dataset_id = self.id_attr
149
+ dataset_id = self.id
89
150
  version_id = (
90
151
  from_global_id_with_expected_type(
91
152
  global_id=dataset_version_id,
@@ -94,6 +155,20 @@ class Dataset(Node):
94
155
  if dataset_version_id
95
156
  else None
96
157
  )
158
+
159
+ # Parse split IDs if provided
160
+ split_rowids: Optional[list[int]] = None
161
+ if split_ids:
162
+ split_rowids = []
163
+ for split_id in split_ids:
164
+ try:
165
+ split_rowid = from_global_id_with_expected_type(
166
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
167
+ )
168
+ split_rowids.append(split_rowid)
169
+ except Exception:
170
+ raise BadRequest(f"Invalid split ID: {split_id}")
171
+
97
172
  revision_ids = (
98
173
  select(func.max(models.DatasetExampleRevision.id))
99
174
  .join(models.DatasetExample)
@@ -110,11 +185,36 @@ class Dataset(Node):
110
185
  revision_ids = revision_ids.where(
111
186
  models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
112
187
  )
113
- stmt = (
114
- select(count(models.DatasetExampleRevision.id))
115
- .where(models.DatasetExampleRevision.id.in_(revision_ids))
116
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
117
- )
188
+
189
+ # Build the count query
190
+ if split_rowids:
191
+ # When filtering by splits, count distinct examples that belong to those splits
192
+ stmt = (
193
+ select(count(models.DatasetExample.id.distinct()))
194
+ .join(
195
+ models.DatasetExampleRevision,
196
+ onclause=(
197
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id
198
+ ),
199
+ )
200
+ .join(
201
+ models.DatasetSplitDatasetExample,
202
+ onclause=(
203
+ models.DatasetExample.id
204
+ == models.DatasetSplitDatasetExample.dataset_example_id
205
+ ),
206
+ )
207
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
208
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
209
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
210
+ )
211
+ else:
212
+ stmt = (
213
+ select(count(models.DatasetExampleRevision.id))
214
+ .where(models.DatasetExampleRevision.id.in_(revision_ids))
215
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
216
+ )
217
+
118
218
  async with info.context.db() as session:
119
219
  return (await session.scalar(stmt)) or 0
120
220
 
@@ -123,10 +223,12 @@ class Dataset(Node):
123
223
  self,
124
224
  info: Info[Context, None],
125
225
  dataset_version_id: Optional[GlobalID] = UNSET,
226
+ split_ids: Optional[list[GlobalID]] = UNSET,
126
227
  first: Optional[int] = 50,
127
228
  last: Optional[int] = UNSET,
128
229
  after: Optional[CursorString] = UNSET,
129
230
  before: Optional[CursorString] = UNSET,
231
+ filter: Optional[str] = UNSET,
130
232
  ) -> Connection[DatasetExample]:
131
233
  args = ConnectionArgs(
132
234
  first=first,
@@ -134,7 +236,7 @@ class Dataset(Node):
134
236
  last=last,
135
237
  before=before if isinstance(before, CursorString) else None,
136
238
  )
137
- dataset_id = self.id_attr
239
+ dataset_id = self.id
138
240
  version_id = (
139
241
  from_global_id_with_expected_type(
140
242
  global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
@@ -142,6 +244,20 @@ class Dataset(Node):
142
244
  if dataset_version_id
143
245
  else None
144
246
  )
247
+
248
+ # Parse split IDs if provided
249
+ split_rowids: Optional[list[int]] = None
250
+ if split_ids:
251
+ split_rowids = []
252
+ for split_id in split_ids:
253
+ try:
254
+ split_rowid = from_global_id_with_expected_type(
255
+ global_id=split_id, expected_type_name=models.DatasetSplit.__name__
256
+ )
257
+ split_rowids.append(split_rowid)
258
+ except Exception:
259
+ raise BadRequest(f"Invalid split ID: {split_id}")
260
+
145
261
  revision_ids = (
146
262
  select(func.max(models.DatasetExampleRevision.id))
147
263
  .join(models.DatasetExample)
@@ -171,19 +287,51 @@ class Dataset(Node):
171
287
  models.DatasetExampleRevision.revision_kind != "DELETE",
172
288
  )
173
289
  )
174
- .order_by(models.DatasetExampleRevision.dataset_example_id.desc())
290
+ .order_by(models.DatasetExample.id.desc())
175
291
  )
292
+
293
+ # Filter by split IDs if provided
294
+ if split_rowids:
295
+ query = (
296
+ query.join(
297
+ models.DatasetSplitDatasetExample,
298
+ onclause=(
299
+ models.DatasetExample.id
300
+ == models.DatasetSplitDatasetExample.dataset_example_id
301
+ ),
302
+ )
303
+ .where(models.DatasetSplitDatasetExample.dataset_split_id.in_(split_rowids))
304
+ .distinct()
305
+ )
306
+ # Apply filter if provided - search through JSON fields (input, output, metadata)
307
+ if filter is not UNSET and filter:
308
+ # Create a filter that searches for the filter string in JSON fields
309
+ # Using PostgreSQL's JSON operators for case-insensitive text search
310
+ filter_condition = or_(
311
+ func.cast(models.DatasetExampleRevision.input, Text).ilike(f"%{filter}%"),
312
+ func.cast(models.DatasetExampleRevision.output, Text).ilike(f"%{filter}%"),
313
+ func.cast(models.DatasetExampleRevision.metadata_, Text).ilike(f"%{filter}%"),
314
+ )
315
+ query = query.where(filter_condition)
316
+
176
317
  async with info.context.db() as session:
177
318
  dataset_examples = [
178
319
  DatasetExample(
179
- id_attr=example.id,
320
+ id=example.id,
321
+ db_record=example,
180
322
  version_id=version_id,
181
- created_at=example.created_at,
182
323
  )
183
324
  async for example in await session.stream_scalars(query)
184
325
  ]
185
326
  return connection_from_list(data=dataset_examples, args=args)
186
327
 
328
+ @strawberry.field
329
+ async def splits(self, info: Info[Context, None]) -> list[DatasetSplit]:
330
+ return [
331
+ DatasetSplit(id=split.id, db_record=split)
332
+ for split in await info.context.data_loaders.dataset_dataset_splits.load(self.id)
333
+ ]
334
+
187
335
  @strawberry.field(
188
336
  description="Number of experiments for a specific version if version is specified, "
189
337
  "or for all versions if version is not specified."
@@ -193,9 +341,7 @@ class Dataset(Node):
193
341
  info: Info[Context, None],
194
342
  dataset_version_id: Optional[GlobalID] = UNSET,
195
343
  ) -> int:
196
- stmt = select(count(models.Experiment.id)).where(
197
- models.Experiment.dataset_id == self.id_attr
198
- )
344
+ stmt = select(count(models.Experiment.id)).where(models.Experiment.dataset_id == self.id)
199
345
  version_id = (
200
346
  from_global_id_with_expected_type(
201
347
  global_id=dataset_version_id,
@@ -228,7 +374,7 @@ class Dataset(Node):
228
374
  last=last,
229
375
  before=before if isinstance(before, CursorString) else None,
230
376
  )
231
- dataset_id = self.id_attr
377
+ dataset_id = self.id
232
378
  row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
233
379
  query = (
234
380
  select(models.Experiment, row_number)
@@ -270,17 +416,15 @@ class Dataset(Node):
270
416
  @strawberry.field
271
417
  async def experiment_annotation_summaries(
272
418
  self, info: Info[Context, None]
273
- ) -> list[ExperimentAnnotationSummary]:
274
- dataset_id = self.id_attr
419
+ ) -> list[DatasetExperimentAnnotationSummary]:
420
+ dataset_id = self.id
275
421
  query = (
276
422
  select(
277
- models.ExperimentRunAnnotation.name,
278
- func.min(models.ExperimentRunAnnotation.score),
279
- func.max(models.ExperimentRunAnnotation.score),
280
- func.avg(models.ExperimentRunAnnotation.score),
281
- func.count(),
282
- func.count(models.ExperimentRunAnnotation.error),
423
+ models.ExperimentRunAnnotation.name.label("annotation_name"),
424
+ func.min(models.ExperimentRunAnnotation.score).label("min_score"),
425
+ func.max(models.ExperimentRunAnnotation.score).label("max_score"),
283
426
  )
427
+ .select_from(models.ExperimentRunAnnotation)
284
428
  .join(
285
429
  models.ExperimentRun,
286
430
  models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
@@ -295,38 +439,21 @@ class Dataset(Node):
295
439
  )
296
440
  async with info.context.db() as session:
297
441
  return [
298
- ExperimentAnnotationSummary(
299
- annotation_name=annotation_name,
300
- min_score=min_score,
301
- max_score=max_score,
302
- mean_score=mean_score,
303
- count=count_,
304
- error_count=error_count,
442
+ DatasetExperimentAnnotationSummary(
443
+ annotation_name=scores_tuple.annotation_name,
444
+ min_score=scores_tuple.min_score,
445
+ max_score=scores_tuple.max_score,
305
446
  )
306
- async for (
307
- annotation_name,
308
- min_score,
309
- max_score,
310
- mean_score,
311
- count_,
312
- error_count,
313
- ) in await session.stream(query)
447
+ async for scores_tuple in await session.stream(query)
314
448
  ]
315
449
 
316
450
  @strawberry.field
317
- def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
318
- return info.context.last_updated_at.get(self._table, self.id_attr)
319
-
451
+ async def labels(self, info: Info[Context, None]) -> list[DatasetLabel]:
452
+ return [
453
+ DatasetLabel(id=label.id, db_record=label)
454
+ for label in await info.context.data_loaders.dataset_labels.load(self.id)
455
+ ]
320
456
 
321
- def to_gql_dataset(dataset: models.Dataset) -> Dataset:
322
- """
323
- Converts an ORM dataset to a GraphQL dataset.
324
- """
325
- return Dataset(
326
- id_attr=dataset.id,
327
- name=dataset.name,
328
- description=dataset.description,
329
- metadata=dataset.metadata_,
330
- created_at=dataset.created_at,
331
- updated_at=dataset.updated_at,
332
- )
457
+ @strawberry.field
458
+ def last_updated_at(self, info: Info[Context, None]) -> Optional[datetime]:
459
+ return info.context.last_updated_at.get(models.Dataset, self.id)
@@ -1,40 +1,59 @@
1
1
  from datetime import datetime
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Annotated, Optional
3
3
 
4
4
  import strawberry
5
5
  from sqlalchemy import select
6
- from sqlalchemy.orm import joinedload
7
6
  from strawberry import UNSET
8
7
  from strawberry.relay.types import Connection, GlobalID, Node, NodeID
9
8
  from strawberry.types import Info
10
9
 
11
10
  from phoenix.db import models
12
11
  from phoenix.server.api.context import Context
12
+ from phoenix.server.api.exceptions import BadRequest
13
13
  from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
14
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit
14
15
  from phoenix.server.api.types.DatasetVersion import DatasetVersion
15
- from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
16
+ from phoenix.server.api.types.ExperimentRepeatedRunGroup import (
17
+ ExperimentRepeatedRunGroup,
18
+ )
19
+ from phoenix.server.api.types.ExperimentRun import ExperimentRun
16
20
  from phoenix.server.api.types.node import from_global_id_with_expected_type
17
21
  from phoenix.server.api.types.pagination import (
18
22
  ConnectionArgs,
19
23
  CursorString,
20
24
  connection_from_list,
21
25
  )
22
- from phoenix.server.api.types.Span import Span
26
+
27
+ if TYPE_CHECKING:
28
+ from .Span import Span
23
29
 
24
30
 
25
31
  @strawberry.type
26
32
  class DatasetExample(Node):
27
- id_attr: NodeID[int]
28
- created_at: datetime
33
+ id: NodeID[int]
34
+ db_record: strawberry.Private[Optional[models.DatasetExample]] = None
29
35
  version_id: strawberry.Private[Optional[int]] = None
30
36
 
37
+ def __post_init__(self) -> None:
38
+ if self.db_record and self.id != self.db_record.id:
39
+ raise ValueError("DatasetExample ID mismatch")
40
+
41
+ @strawberry.field
42
+ async def created_at(self, info: Info[Context, None]) -> datetime:
43
+ if self.db_record:
44
+ val = self.db_record.created_at
45
+ else:
46
+ val = await info.context.data_loaders.dataset_example_fields.load(
47
+ (self.id, models.DatasetExample.created_at),
48
+ )
49
+ return val
50
+
31
51
  @strawberry.field
32
52
  async def revision(
33
53
  self,
34
54
  info: Info[Context, None],
35
55
  dataset_version_id: Optional[GlobalID] = UNSET,
36
56
  ) -> DatasetExampleRevision:
37
- example_id = self.id_attr
38
57
  version_id: Optional[int] = None
39
58
  if dataset_version_id:
40
59
  version_id = from_global_id_with_expected_type(
@@ -42,18 +61,18 @@ class DatasetExample(Node):
42
61
  )
43
62
  elif self.version_id is not None:
44
63
  version_id = self.version_id
45
- return await info.context.data_loaders.dataset_example_revisions.load(
46
- (example_id, version_id)
47
- )
64
+ return await info.context.data_loaders.dataset_example_revisions.load((self.id, version_id))
48
65
 
49
66
  @strawberry.field
50
67
  async def span(
51
68
  self,
52
69
  info: Info[Context, None],
53
- ) -> Optional[Span]:
70
+ ) -> Optional[Annotated["Span", strawberry.lazy(".Span")]]:
71
+ from .Span import Span
72
+
54
73
  return (
55
- Span(span_rowid=span.id, db_span=span)
56
- if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
74
+ Span(id=span.id, db_record=span)
75
+ if (span := await info.context.data_loaders.dataset_example_spans.load(self.id))
57
76
  else None
58
77
  )
59
78
 
@@ -65,6 +84,7 @@ class DatasetExample(Node):
65
84
  last: Optional[int] = UNSET,
66
85
  after: Optional[CursorString] = UNSET,
67
86
  before: Optional[CursorString] = UNSET,
87
+ experiment_ids: Optional[list[GlobalID]] = UNSET,
68
88
  ) -> Connection[ExperimentRun]:
69
89
  args = ConnectionArgs(
70
90
  first=first,
@@ -72,14 +92,64 @@ class DatasetExample(Node):
72
92
  last=last,
73
93
  before=before if isinstance(before, CursorString) else None,
74
94
  )
75
- example_id = self.id_attr
76
95
  query = (
77
96
  select(models.ExperimentRun)
78
- .options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
79
97
  .join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
80
- .where(models.ExperimentRun.dataset_example_id == example_id)
81
- .order_by(models.Experiment.id.desc())
98
+ .where(models.ExperimentRun.dataset_example_id == self.id)
99
+ .order_by(
100
+ models.ExperimentRun.experiment_id.asc(),
101
+ models.ExperimentRun.repetition_number.asc(),
102
+ )
82
103
  )
104
+ if experiment_ids:
105
+ experiment_db_ids = [
106
+ from_global_id_with_expected_type(
107
+ global_id=experiment_id,
108
+ expected_type_name=models.Experiment.__name__,
109
+ )
110
+ for experiment_id in experiment_ids or []
111
+ ]
112
+ query = query.where(models.ExperimentRun.experiment_id.in_(experiment_db_ids))
83
113
  async with info.context.db() as session:
84
114
  runs = (await session.scalars(query)).all()
85
- return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
115
+ return connection_from_list([ExperimentRun(id=run.id, db_record=run) for run in runs], args)
116
+
117
+ @strawberry.field
118
+ async def experiment_repeated_run_groups(
119
+ self,
120
+ info: Info[Context, None],
121
+ experiment_ids: list[GlobalID],
122
+ ) -> list[ExperimentRepeatedRunGroup]:
123
+ experiment_rowids = []
124
+ for experiment_id in experiment_ids:
125
+ try:
126
+ experiment_rowid = from_global_id_with_expected_type(
127
+ global_id=experiment_id,
128
+ expected_type_name=models.Experiment.__name__,
129
+ )
130
+ except Exception:
131
+ raise BadRequest(f"Invalid experiment ID: {experiment_id}")
132
+ experiment_rowids.append(experiment_rowid)
133
+ repeated_run_groups = (
134
+ await info.context.data_loaders.experiment_repeated_run_groups.load_many(
135
+ [(experiment_rowid, self.id) for experiment_rowid in experiment_rowids]
136
+ )
137
+ )
138
+ return [
139
+ ExperimentRepeatedRunGroup(
140
+ experiment_rowid=group.experiment_rowid,
141
+ dataset_example_rowid=group.dataset_example_rowid,
142
+ cached_runs=[ExperimentRun(id=run.id, db_record=run) for run in group.runs],
143
+ )
144
+ for group in repeated_run_groups
145
+ ]
146
+
147
+ @strawberry.field
148
+ async def dataset_splits(
149
+ self,
150
+ info: Info[Context, None],
151
+ ) -> list[DatasetSplit]:
152
+ return [
153
+ DatasetSplit(id=split.id, db_record=split)
154
+ for split in await info.context.data_loaders.dataset_example_splits.load(self.id)
155
+ ]
@@ -0,0 +1,10 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+
5
+
6
+ @strawberry.type
7
+ class DatasetExperimentAnnotationSummary:
8
+ annotation_name: str
9
+ min_score: Optional[float]
10
+ max_score: Optional[float]
@@ -0,0 +1,57 @@
1
+ from typing import Optional
2
+
3
+ import strawberry
4
+ from strawberry.relay import Node, NodeID
5
+ from strawberry.types import Info
6
+
7
+ from phoenix.db import models
8
+ from phoenix.server.api.context import Context
9
+
10
+
11
+ @strawberry.type
12
+ class DatasetLabel(Node):
13
+ id: NodeID[int]
14
+ db_record: strawberry.Private[Optional[models.DatasetLabel]] = None
15
+
16
+ def __post_init__(self) -> None:
17
+ if self.db_record and self.id != self.db_record.id:
18
+ raise ValueError("DatasetLabel ID mismatch")
19
+
20
+ @strawberry.field
21
+ async def name(
22
+ self,
23
+ info: Info[Context, None],
24
+ ) -> str:
25
+ if self.db_record:
26
+ val = self.db_record.name
27
+ else:
28
+ val = await info.context.data_loaders.dataset_label_fields.load(
29
+ (self.id, models.DatasetLabel.name),
30
+ )
31
+ return val
32
+
33
+ @strawberry.field
34
+ async def description(
35
+ self,
36
+ info: Info[Context, None],
37
+ ) -> Optional[str]:
38
+ if self.db_record:
39
+ val = self.db_record.description
40
+ else:
41
+ val = await info.context.data_loaders.dataset_label_fields.load(
42
+ (self.id, models.DatasetLabel.description),
43
+ )
44
+ return val
45
+
46
+ @strawberry.field
47
+ async def color(
48
+ self,
49
+ info: Info[Context, None],
50
+ ) -> str:
51
+ if self.db_record:
52
+ val = self.db_record.color
53
+ else:
54
+ val = await info.context.data_loaders.dataset_label_fields.load(
55
+ (self.id, models.DatasetLabel.color),
56
+ )
57
+ return val