arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (221) hide show
  1. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
  2. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
  3. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +2 -1
  12. phoenix/auth.py +27 -2
  13. phoenix/config.py +1594 -81
  14. phoenix/db/README.md +546 -28
  15. phoenix/db/bulk_inserter.py +119 -116
  16. phoenix/db/engines.py +140 -33
  17. phoenix/db/facilitator.py +22 -1
  18. phoenix/db/helpers.py +818 -65
  19. phoenix/db/iam_auth.py +64 -0
  20. phoenix/db/insertion/dataset.py +133 -1
  21. phoenix/db/insertion/document_annotation.py +9 -6
  22. phoenix/db/insertion/evaluation.py +2 -3
  23. phoenix/db/insertion/helpers.py +2 -2
  24. phoenix/db/insertion/session_annotation.py +176 -0
  25. phoenix/db/insertion/span_annotation.py +3 -4
  26. phoenix/db/insertion/trace_annotation.py +3 -4
  27. phoenix/db/insertion/types.py +41 -18
  28. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  29. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  30. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  31. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  32. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  33. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  34. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  35. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  36. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  37. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  38. phoenix/db/models.py +364 -56
  39. phoenix/db/pg_config.py +10 -0
  40. phoenix/db/types/trace_retention.py +7 -6
  41. phoenix/experiments/functions.py +69 -19
  42. phoenix/inferences/inferences.py +1 -2
  43. phoenix/server/api/auth.py +9 -0
  44. phoenix/server/api/auth_messages.py +46 -0
  45. phoenix/server/api/context.py +60 -0
  46. phoenix/server/api/dataloaders/__init__.py +36 -0
  47. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  48. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  49. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  50. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  51. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  52. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  53. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  54. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  55. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  56. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  57. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  58. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  59. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  60. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  61. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  62. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  63. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  64. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  65. phoenix/server/api/dataloaders/record_counts.py +37 -10
  66. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  67. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  68. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
  69. phoenix/server/api/dataloaders/span_costs.py +3 -9
  70. phoenix/server/api/dataloaders/table_fields.py +2 -2
  71. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  72. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  73. phoenix/server/api/exceptions.py +5 -1
  74. phoenix/server/api/helpers/playground_clients.py +263 -83
  75. phoenix/server/api/helpers/playground_spans.py +2 -1
  76. phoenix/server/api/helpers/playground_users.py +26 -0
  77. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  78. phoenix/server/api/helpers/prompts/models.py +61 -19
  79. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  80. phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
  81. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  82. phoenix/server/api/input_types/DatasetFilter.py +5 -2
  83. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  84. phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
  85. phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
  86. phoenix/server/api/input_types/PromptVersionInput.py +47 -1
  87. phoenix/server/api/input_types/SpanSort.py +3 -2
  88. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  89. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  90. phoenix/server/api/mutations/__init__.py +8 -0
  91. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  92. phoenix/server/api/mutations/api_key_mutations.py +15 -20
  93. phoenix/server/api/mutations/chat_mutations.py +106 -37
  94. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  95. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  96. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  97. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  98. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  99. phoenix/server/api/mutations/model_mutations.py +11 -9
  100. phoenix/server/api/mutations/project_mutations.py +4 -4
  101. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  102. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  103. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  104. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  105. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  106. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  107. phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
  108. phoenix/server/api/mutations/trace_mutations.py +3 -3
  109. phoenix/server/api/mutations/user_mutations.py +55 -26
  110. phoenix/server/api/queries.py +501 -617
  111. phoenix/server/api/routers/__init__.py +2 -2
  112. phoenix/server/api/routers/auth.py +141 -87
  113. phoenix/server/api/routers/ldap.py +229 -0
  114. phoenix/server/api/routers/oauth2.py +349 -101
  115. phoenix/server/api/routers/v1/__init__.py +22 -4
  116. phoenix/server/api/routers/v1/annotation_configs.py +19 -30
  117. phoenix/server/api/routers/v1/annotations.py +455 -13
  118. phoenix/server/api/routers/v1/datasets.py +355 -68
  119. phoenix/server/api/routers/v1/documents.py +142 -0
  120. phoenix/server/api/routers/v1/evaluations.py +20 -28
  121. phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
  122. phoenix/server/api/routers/v1/experiment_runs.py +335 -59
  123. phoenix/server/api/routers/v1/experiments.py +475 -47
  124. phoenix/server/api/routers/v1/projects.py +16 -50
  125. phoenix/server/api/routers/v1/prompts.py +50 -39
  126. phoenix/server/api/routers/v1/sessions.py +108 -0
  127. phoenix/server/api/routers/v1/spans.py +156 -96
  128. phoenix/server/api/routers/v1/traces.py +51 -77
  129. phoenix/server/api/routers/v1/users.py +64 -24
  130. phoenix/server/api/routers/v1/utils.py +3 -7
  131. phoenix/server/api/subscriptions.py +257 -93
  132. phoenix/server/api/types/Annotation.py +90 -23
  133. phoenix/server/api/types/ApiKey.py +13 -17
  134. phoenix/server/api/types/AuthMethod.py +1 -0
  135. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  136. phoenix/server/api/types/Dataset.py +199 -72
  137. phoenix/server/api/types/DatasetExample.py +88 -18
  138. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  139. phoenix/server/api/types/DatasetLabel.py +57 -0
  140. phoenix/server/api/types/DatasetSplit.py +98 -0
  141. phoenix/server/api/types/DatasetVersion.py +49 -4
  142. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  143. phoenix/server/api/types/Experiment.py +215 -68
  144. phoenix/server/api/types/ExperimentComparison.py +3 -9
  145. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  146. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  147. phoenix/server/api/types/ExperimentRun.py +120 -70
  148. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  149. phoenix/server/api/types/GenerativeModel.py +95 -42
  150. phoenix/server/api/types/GenerativeProvider.py +1 -1
  151. phoenix/server/api/types/ModelInterface.py +7 -2
  152. phoenix/server/api/types/PlaygroundModel.py +12 -2
  153. phoenix/server/api/types/Project.py +218 -185
  154. phoenix/server/api/types/ProjectSession.py +146 -29
  155. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  156. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  157. phoenix/server/api/types/Prompt.py +119 -39
  158. phoenix/server/api/types/PromptLabel.py +42 -25
  159. phoenix/server/api/types/PromptVersion.py +11 -8
  160. phoenix/server/api/types/PromptVersionTag.py +65 -25
  161. phoenix/server/api/types/Span.py +130 -123
  162. phoenix/server/api/types/SpanAnnotation.py +189 -42
  163. phoenix/server/api/types/SystemApiKey.py +65 -1
  164. phoenix/server/api/types/Trace.py +184 -53
  165. phoenix/server/api/types/TraceAnnotation.py +149 -50
  166. phoenix/server/api/types/User.py +128 -33
  167. phoenix/server/api/types/UserApiKey.py +73 -26
  168. phoenix/server/api/types/node.py +10 -0
  169. phoenix/server/api/types/pagination.py +11 -2
  170. phoenix/server/app.py +154 -36
  171. phoenix/server/authorization.py +5 -4
  172. phoenix/server/bearer_auth.py +13 -5
  173. phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
  174. phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
  175. phoenix/server/daemons/generative_model_store.py +61 -9
  176. phoenix/server/daemons/span_cost_calculator.py +10 -8
  177. phoenix/server/dml_event.py +13 -0
  178. phoenix/server/email/sender.py +29 -2
  179. phoenix/server/grpc_server.py +9 -9
  180. phoenix/server/jwt_store.py +8 -6
  181. phoenix/server/ldap.py +1449 -0
  182. phoenix/server/main.py +9 -3
  183. phoenix/server/oauth2.py +330 -12
  184. phoenix/server/prometheus.py +43 -6
  185. phoenix/server/rate_limiters.py +4 -9
  186. phoenix/server/retention.py +33 -20
  187. phoenix/server/session_filters.py +49 -0
  188. phoenix/server/static/.vite/manifest.json +51 -53
  189. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  190. phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
  191. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  192. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  193. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  194. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  195. phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
  196. phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
  197. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  198. phoenix/server/templates/index.html +7 -1
  199. phoenix/server/thread_server.py +1 -2
  200. phoenix/server/utils.py +74 -0
  201. phoenix/session/client.py +55 -1
  202. phoenix/session/data_extractor.py +5 -0
  203. phoenix/session/evaluation.py +8 -4
  204. phoenix/session/session.py +44 -8
  205. phoenix/settings.py +2 -0
  206. phoenix/trace/attributes.py +80 -13
  207. phoenix/trace/dsl/query.py +2 -0
  208. phoenix/trace/projects.py +5 -0
  209. phoenix/utilities/template_formatters.py +1 -1
  210. phoenix/version.py +1 -1
  211. phoenix/server/api/types/Evaluation.py +0 -39
  212. phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
  213. phoenix/server/static/assets/pages-Creyamao.js +0 -8612
  214. phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
  215. phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
  216. phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
  217. phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
  218. phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
  219. phoenix/utilities/deprecation.py +0 -31
  220. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  221. {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -24,12 +24,7 @@ from starlette.datastructures import FormData, UploadFile
24
24
  from starlette.requests import Request
25
25
  from starlette.responses import Response
26
26
  from starlette.status import (
27
- HTTP_200_OK,
28
- HTTP_204_NO_CONTENT,
29
27
  HTTP_404_NOT_FOUND,
30
- HTTP_409_CONFLICT,
31
- HTTP_422_UNPROCESSABLE_ENTITY,
32
- HTTP_429_TOO_MANY_REQUESTS,
33
28
  )
34
29
  from strawberry.relay import GlobalID
35
30
  from typing_extensions import TypeAlias, assert_never
@@ -42,12 +37,15 @@ from phoenix.db.insertion.dataset import (
42
37
  ExampleContent,
43
38
  add_dataset_examples,
44
39
  )
40
+ from phoenix.db.types.db_models import UNDEFINED
45
41
  from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
46
42
  from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
43
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit as DatasetSplitNodeType
47
44
  from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
48
45
  from phoenix.server.api.types.node import from_global_id_with_expected_type
49
46
  from phoenix.server.api.utils import delete_projects, delete_traces
50
47
  from phoenix.server.authorization import is_not_locked
48
+ from phoenix.server.bearer_auth import PhoenixUser
51
49
  from phoenix.server.dml_event import DatasetInsertEvent
52
50
 
53
51
  from .models import V1RoutesBaseModel
@@ -90,7 +88,7 @@ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
90
88
  "/datasets",
91
89
  operation_id="listDatasets",
92
90
  summary="List datasets",
93
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
91
+ responses=add_errors_to_responses([422]),
94
92
  )
95
93
  async def list_datasets(
96
94
  request: Request,
@@ -124,7 +122,7 @@ async def list_datasets(
124
122
  except ValueError:
125
123
  raise HTTPException(
126
124
  detail=f"Invalid cursor format: {cursor}",
127
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
125
+ status_code=422,
128
126
  )
129
127
  if name:
130
128
  query = query.filter(models.Dataset.name == name)
@@ -163,11 +161,11 @@ async def list_datasets(
163
161
  "/datasets/{id}",
164
162
  operation_id="deleteDatasetById",
165
163
  summary="Delete dataset by ID",
166
- status_code=HTTP_204_NO_CONTENT,
164
+ status_code=204,
167
165
  responses=add_errors_to_responses(
168
166
  [
169
- {"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
170
- {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
167
+ {"status_code": 404, "description": "Dataset not found"},
168
+ {"status_code": 422, "description": "Invalid dataset ID"},
171
169
  ]
172
170
  ),
173
171
  )
@@ -181,11 +179,9 @@ async def delete_dataset(
181
179
  DATASET_NODE_NAME,
182
180
  )
183
181
  except ValueError:
184
- raise HTTPException(
185
- detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
186
- )
182
+ raise HTTPException(detail=f"Invalid Dataset ID: {id}", status_code=422)
187
183
  else:
188
- raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
184
+ raise HTTPException(detail="Missing Dataset ID", status_code=422)
189
185
  project_names_stmt = get_project_names_for_datasets(dataset_id)
190
186
  eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
191
187
  stmt = (
@@ -195,7 +191,7 @@ async def delete_dataset(
195
191
  project_names = await session.scalars(project_names_stmt)
196
192
  eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
197
193
  if (await session.scalar(stmt)) is None:
198
- raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
194
+ raise HTTPException(detail="Dataset does not exist", status_code=404)
199
195
  tasks = BackgroundTasks()
200
196
  tasks.add_task(delete_projects, request.app.state.db, *project_names)
201
197
  tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
@@ -213,17 +209,21 @@ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
213
209
  "/datasets/{id}",
214
210
  operation_id="getDataset",
215
211
  summary="Get dataset by ID",
216
- responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
212
+ responses=add_errors_to_responses([404]),
217
213
  )
218
214
  async def get_dataset(
219
215
  request: Request, id: str = Path(description="The ID of the dataset")
220
216
  ) -> GetDatasetResponseBody:
221
- dataset_id = GlobalID.from_id(id)
217
+ try:
218
+ dataset_id = GlobalID.from_id(id)
219
+ except Exception as e:
220
+ raise HTTPException(
221
+ detail=f"Invalid dataset ID format: {id}",
222
+ status_code=422,
223
+ ) from e
222
224
 
223
225
  if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
224
- raise HTTPException(
225
- detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
226
- )
226
+ raise HTTPException(detail=f"ID {dataset_id} refers to a f{type_name}", status_code=404)
227
227
  async with request.app.state.db() as session:
228
228
  result = await session.execute(
229
229
  select(models.Dataset, models.Dataset.example_count).filter(
@@ -234,9 +234,7 @@ async def get_dataset(
234
234
  dataset = dataset_query[0] if dataset_query else None
235
235
  example_count = dataset_query[1] if dataset_query else 0
236
236
  if dataset is None:
237
- raise HTTPException(
238
- detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
239
- )
237
+ raise HTTPException(detail=f"Dataset with ID {dataset_id} not found", status_code=404)
240
238
 
241
239
  dataset = DatasetWithExampleCount(
242
240
  id=str(dataset_id),
@@ -265,7 +263,7 @@ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
265
263
  "/datasets/{id}/versions",
266
264
  operation_id="listDatasetVersionsByDatasetId",
267
265
  summary="List dataset versions",
268
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
266
+ responses=add_errors_to_responses([422]),
269
267
  )
270
268
  async def list_dataset_versions(
271
269
  request: Request,
@@ -287,12 +285,12 @@ async def list_dataset_versions(
287
285
  except ValueError:
288
286
  raise HTTPException(
289
287
  detail=f"Invalid Dataset ID: {id}",
290
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
288
+ status_code=422,
291
289
  )
292
290
  else:
293
291
  raise HTTPException(
294
292
  detail="Missing Dataset ID",
295
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
293
+ status_code=422,
296
294
  )
297
295
  stmt = (
298
296
  select(models.DatasetVersion)
@@ -308,7 +306,7 @@ async def list_dataset_versions(
308
306
  except ValueError:
309
307
  raise HTTPException(
310
308
  detail=f"Invalid cursor: {cursor}",
311
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
309
+ status_code=422,
312
310
  )
313
311
  max_dataset_version_id = (
314
312
  select(models.DatasetVersion.id)
@@ -343,14 +341,14 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
343
341
  "/datasets/upload",
344
342
  dependencies=[Depends(is_not_locked)],
345
343
  operation_id="uploadDataset",
346
- summary="Upload dataset from JSON, CSV, or PyArrow",
344
+ summary="Upload dataset from JSON, JSONL, CSV, or PyArrow",
347
345
  responses=add_errors_to_responses(
348
346
  [
349
347
  {
350
- "status_code": HTTP_409_CONFLICT,
348
+ "status_code": 409,
351
349
  "description": "Dataset of the same name already exists",
352
350
  },
353
- {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
351
+ {"status_code": 422, "description": "Invalid request body"},
354
352
  ]
355
353
  ),
356
354
  # FastAPI cannot generate the request body portion of the OpenAPI schema for
@@ -372,6 +370,17 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
372
370
  "inputs": {"type": "array", "items": {"type": "object"}},
373
371
  "outputs": {"type": "array", "items": {"type": "object"}},
374
372
  "metadata": {"type": "array", "items": {"type": "object"}},
373
+ "splits": {
374
+ "type": "array",
375
+ "items": {
376
+ "oneOf": [
377
+ {"type": "string"},
378
+ {"type": "array", "items": {"type": "string"}},
379
+ {"type": "null"},
380
+ ]
381
+ },
382
+ "description": "Split per example: string, string array, or null",
383
+ },
375
384
  },
376
385
  }
377
386
  },
@@ -398,6 +407,12 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
398
407
  "items": {"type": "string"},
399
408
  "uniqueItems": True,
400
409
  },
410
+ "split_keys[]": {
411
+ "type": "array",
412
+ "items": {"type": "string"},
413
+ "uniqueItems": True,
414
+ "description": "Column names for auto-assigning examples to splits",
415
+ },
401
416
  "file": {"type": "string", "format": "binary"},
402
417
  },
403
418
  }
@@ -413,7 +428,12 @@ async def upload_dataset(
413
428
  description="If true, fulfill request synchronously and return JSON containing dataset_id.",
414
429
  ),
415
430
  ) -> Optional[UploadDatasetResponseBody]:
416
- request_content_type = request.headers["content-type"]
431
+ request_content_type = request.headers.get("content-type")
432
+ if not request_content_type:
433
+ raise HTTPException(
434
+ detail="Missing content-type header",
435
+ status_code=400,
436
+ )
417
437
  examples: Union[Examples, Awaitable[Examples]]
418
438
  if request_content_type.startswith("application/json"):
419
439
  try:
@@ -423,14 +443,14 @@ async def upload_dataset(
423
443
  except ValueError as e:
424
444
  raise HTTPException(
425
445
  detail=str(e),
426
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
446
+ status_code=422,
427
447
  )
428
448
  if action is DatasetAction.CREATE:
429
449
  async with request.app.state.db() as session:
430
450
  if await _check_table_exists(session, name):
431
451
  raise HTTPException(
432
452
  detail=f"Dataset with the same name already exists: {name=}",
433
- status_code=HTTP_409_CONFLICT,
453
+ status_code=409,
434
454
  )
435
455
  elif request_content_type.startswith("multipart/form-data"):
436
456
  async with request.form() as form:
@@ -442,19 +462,20 @@ async def upload_dataset(
442
462
  input_keys,
443
463
  output_keys,
444
464
  metadata_keys,
465
+ split_keys,
445
466
  file,
446
467
  ) = await _parse_form_data(form)
447
468
  except ValueError as e:
448
469
  raise HTTPException(
449
470
  detail=str(e),
450
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
471
+ status_code=422,
451
472
  )
452
473
  if action is DatasetAction.CREATE:
453
474
  async with request.app.state.db() as session:
454
475
  if await _check_table_exists(session, name):
455
476
  raise HTTPException(
456
477
  detail=f"Dataset with the same name already exists: {name=}",
457
- status_code=HTTP_409_CONFLICT,
478
+ status_code=409,
458
479
  )
459
480
  content = await file.read()
460
481
  try:
@@ -462,22 +483,32 @@ async def upload_dataset(
462
483
  if file_content_type is FileContentType.CSV:
463
484
  encoding = FileContentEncoding(file.headers.get("content-encoding"))
464
485
  examples = await _process_csv(
465
- content, encoding, input_keys, output_keys, metadata_keys
486
+ content, encoding, input_keys, output_keys, metadata_keys, split_keys
466
487
  )
467
488
  elif file_content_type is FileContentType.PYARROW:
468
- examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
489
+ examples = await _process_pyarrow(
490
+ content, input_keys, output_keys, metadata_keys, split_keys
491
+ )
492
+ elif file_content_type is FileContentType.JSONL:
493
+ encoding = FileContentEncoding(file.headers.get("content-encoding"))
494
+ examples = await _process_jsonl(
495
+ content, encoding, input_keys, output_keys, metadata_keys, split_keys
496
+ )
469
497
  else:
470
498
  assert_never(file_content_type)
471
499
  except ValueError as e:
472
500
  raise HTTPException(
473
501
  detail=str(e),
474
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
502
+ status_code=422,
475
503
  )
476
504
  else:
477
505
  raise HTTPException(
478
506
  detail="Invalid request Content-Type",
479
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
507
+ status_code=422,
480
508
  )
509
+ user_id: Optional[int] = None
510
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
511
+ user_id = int(request.user.identity)
481
512
  operation = cast(
482
513
  Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
483
514
  partial(
@@ -486,6 +517,7 @@ async def upload_dataset(
486
517
  action=action,
487
518
  name=name,
488
519
  description=description,
520
+ user_id=user_id,
489
521
  ),
490
522
  )
491
523
  if sync:
@@ -505,13 +537,14 @@ async def upload_dataset(
505
537
  except QueueFull:
506
538
  if isinstance(examples, Coroutine):
507
539
  examples.close()
508
- raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
540
+ raise HTTPException(detail="Too many requests.", status_code=429)
509
541
  return None
510
542
 
511
543
 
512
544
  class FileContentType(Enum):
513
545
  CSV = "text/csv"
514
546
  PYARROW = "application/x-pandas-pyarrow"
547
+ JSONL = "application/jsonl"
515
548
 
516
549
  @classmethod
517
550
  def _missing_(cls, v: Any) -> "FileContentType":
@@ -539,6 +572,7 @@ Description: TypeAlias = Optional[str]
539
572
  InputKeys: TypeAlias = frozenset[str]
540
573
  OutputKeys: TypeAlias = frozenset[str]
541
574
  MetadataKeys: TypeAlias = frozenset[str]
575
+ SplitKeys: TypeAlias = frozenset[str]
542
576
  DatasetId: TypeAlias = int
543
577
  Examples: TypeAlias = Iterator[ExampleContent]
544
578
 
@@ -555,18 +589,55 @@ def _process_json(
555
589
  raise ValueError("input is required")
556
590
  if not isinstance(inputs, list) or not _is_all_dict(inputs):
557
591
  raise ValueError("Input should be a list containing only dictionary objects")
558
- outputs, metadata = data.get("outputs"), data.get("metadata")
592
+ outputs, metadata, splits = data.get("outputs"), data.get("metadata"), data.get("splits")
559
593
  for k, v in {"outputs": outputs, "metadata": metadata}.items():
560
594
  if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
561
595
  raise ValueError(
562
596
  f"{k} should be a list of same length as input containing only dictionary objects"
563
597
  )
598
+
599
+ # Validate splits format if provided
600
+ if splits is not None:
601
+ if not isinstance(splits, list):
602
+ raise ValueError("splits must be a list")
603
+ if len(splits) != len(inputs):
604
+ raise ValueError(
605
+ f"splits must have same length as inputs ({len(splits)} != {len(inputs)})"
606
+ )
564
607
  examples: list[ExampleContent] = []
565
608
  for i, obj in enumerate(inputs):
609
+ # Extract split values, validating they're non-empty strings
610
+ split_set: set[str] = set()
611
+ if splits:
612
+ split_value = splits[i]
613
+ if split_value is None:
614
+ # Sparse assignment: None means no splits for this example
615
+ pass
616
+ elif isinstance(split_value, str):
617
+ # Format 1: Single string value
618
+ if split_value.strip():
619
+ split_set.add(split_value.strip())
620
+ elif isinstance(split_value, list):
621
+ # Format 2: List of strings (multiple splits)
622
+ for v in split_value:
623
+ if v is None:
624
+ continue # Skip None values in the list
625
+ if not isinstance(v, str):
626
+ raise ValueError(
627
+ f"Split value must be a string or None, got {type(v).__name__}"
628
+ )
629
+ if v.strip():
630
+ split_set.add(v.strip())
631
+ else:
632
+ raise ValueError(
633
+ f"Split value must be a string, list of strings, or None, "
634
+ f"got {type(split_value).__name__}"
635
+ )
566
636
  example = ExampleContent(
567
637
  input=obj,
568
638
  output=outputs[i] if outputs else {},
569
639
  metadata=metadata[i] if metadata else {},
640
+ splits=frozenset(split_set),
570
641
  )
571
642
  examples.append(example)
572
643
  action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
@@ -579,6 +650,7 @@ async def _process_csv(
579
650
  input_keys: InputKeys,
580
651
  output_keys: OutputKeys,
581
652
  metadata_keys: MetadataKeys,
653
+ split_keys: SplitKeys,
582
654
  ) -> Examples:
583
655
  if content_encoding is FileContentEncoding.GZIP:
584
656
  content = await run_in_threadpool(gzip.decompress, content)
@@ -593,12 +665,15 @@ async def _process_csv(
593
665
  if freq > 1:
594
666
  raise ValueError(f"Duplicated column header in CSV file: {header}")
595
667
  column_headers = frozenset(reader.fieldnames)
596
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
668
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
597
669
  return (
598
670
  ExampleContent(
599
671
  input={k: row.get(k) for k in input_keys},
600
672
  output={k: row.get(k) for k in output_keys},
601
673
  metadata={k: row.get(k) for k in metadata_keys},
674
+ splits=frozenset(
675
+ str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
676
+ ), # Only include non-empty, non-whitespace split values
602
677
  )
603
678
  for row in iter(reader)
604
679
  )
@@ -609,13 +684,14 @@ async def _process_pyarrow(
609
684
  input_keys: InputKeys,
610
685
  output_keys: OutputKeys,
611
686
  metadata_keys: MetadataKeys,
687
+ split_keys: SplitKeys,
612
688
  ) -> Awaitable[Examples]:
613
689
  try:
614
690
  reader = pa.ipc.open_stream(content)
615
691
  except pa.ArrowInvalid as e:
616
692
  raise ValueError("File is not valid pyarrow") from e
617
693
  column_headers = frozenset(reader.schema.names)
618
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
694
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
619
695
 
620
696
  def get_examples() -> Iterator[ExampleContent]:
621
697
  for row in reader.read_pandas().to_dict(orient="records"):
@@ -623,11 +699,48 @@ async def _process_pyarrow(
623
699
  input={k: row.get(k) for k in input_keys},
624
700
  output={k: row.get(k) for k in output_keys},
625
701
  metadata={k: row.get(k) for k in metadata_keys},
702
+ splits=frozenset(
703
+ str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
704
+ ), # Only include non-empty, non-whitespace split values
626
705
  )
627
706
 
628
707
  return run_in_threadpool(get_examples)
629
708
 
630
709
 
710
+ async def _process_jsonl(
711
+ content: bytes,
712
+ encoding: FileContentEncoding,
713
+ input_keys: InputKeys,
714
+ output_keys: OutputKeys,
715
+ metadata_keys: MetadataKeys,
716
+ split_keys: SplitKeys,
717
+ ) -> Examples:
718
+ if encoding is FileContentEncoding.GZIP:
719
+ content = await run_in_threadpool(gzip.decompress, content)
720
+ elif encoding is FileContentEncoding.DEFLATE:
721
+ content = await run_in_threadpool(zlib.decompress, content)
722
+ elif encoding is not FileContentEncoding.NONE:
723
+ assert_never(encoding)
724
+ # content is a newline delimited list of JSON objects
725
+ # parse within a threadpool
726
+ reader = await run_in_threadpool(
727
+ lambda c: [json.loads(line) for line in c.decode().splitlines()], content
728
+ )
729
+
730
+ examples: list[ExampleContent] = []
731
+ for obj in reader:
732
+ example = ExampleContent(
733
+ input={k: obj.get(k) for k in input_keys},
734
+ output={k: obj.get(k) for k in output_keys},
735
+ metadata={k: obj.get(k) for k in metadata_keys},
736
+ splits=frozenset(
737
+ str(v).strip() for k in split_keys if (v := obj.get(k)) and str(v).strip()
738
+ ), # Only include non-empty, non-whitespace split values
739
+ )
740
+ examples.append(example)
741
+ return iter(examples)
742
+
743
+
631
744
  async def _check_table_exists(session: AsyncSession, name: str) -> bool:
632
745
  return bool(
633
746
  await session.scalar(
@@ -641,11 +754,13 @@ def _check_keys_exist(
641
754
  input_keys: InputKeys,
642
755
  output_keys: OutputKeys,
643
756
  metadata_keys: MetadataKeys,
757
+ split_keys: SplitKeys,
644
758
  ) -> None:
645
759
  for desc, keys in (
646
760
  ("input", input_keys),
647
761
  ("output", output_keys),
648
762
  ("metadata", metadata_keys),
763
+ ("split", split_keys),
649
764
  ):
650
765
  if keys and (diff := keys.difference(column_headers)):
651
766
  raise ValueError(f"{desc} keys not found in column headers: {diff}")
@@ -660,6 +775,7 @@ async def _parse_form_data(
660
775
  InputKeys,
661
776
  OutputKeys,
662
777
  MetadataKeys,
778
+ SplitKeys,
663
779
  UploadFile,
664
780
  ]:
665
781
  name = cast(Optional[str], form.get("name"))
@@ -673,6 +789,7 @@ async def _parse_form_data(
673
789
  input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
674
790
  output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
675
791
  metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
792
+ split_keys = frozenset(filter(bool, cast(list[str], form.getlist("split_keys[]"))))
676
793
  return (
677
794
  action,
678
795
  name,
@@ -680,6 +797,7 @@ async def _parse_form_data(
680
797
  input_keys,
681
798
  output_keys,
682
799
  metadata_keys,
800
+ split_keys,
683
801
  file,
684
802
  )
685
803
 
@@ -695,6 +813,7 @@ class DatasetExample(V1RoutesBaseModel):
695
813
  class ListDatasetExamplesData(V1RoutesBaseModel):
696
814
  dataset_id: str
697
815
  version_id: str
816
+ filtered_splits: list[str] = UNDEFINED
698
817
  examples: list[DatasetExample]
699
818
 
700
819
 
@@ -706,7 +825,7 @@ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
706
825
  "/datasets/{id}/examples",
707
826
  operation_id="getDatasetExamples",
708
827
  summary="Get examples from a dataset",
709
- responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
828
+ responses=add_errors_to_responses([404]),
710
829
  )
711
830
  async def get_dataset_examples(
712
831
  request: Request,
@@ -717,19 +836,35 @@ async def get_dataset_examples(
717
836
  "The ID of the dataset version (if omitted, returns data from the latest version)"
718
837
  ),
719
838
  ),
839
+ split: Optional[list[str]] = Query(
840
+ default=None,
841
+ description="List of dataset split identifiers (GlobalIDs or names) to filter by",
842
+ ),
720
843
  ) -> ListDatasetExamplesResponseBody:
721
- dataset_gid = GlobalID.from_id(id)
722
- version_gid = GlobalID.from_id(version_id) if version_id else None
844
+ try:
845
+ dataset_gid = GlobalID.from_id(id)
846
+ except Exception as e:
847
+ raise HTTPException(
848
+ detail=f"Invalid dataset ID format: {id}",
849
+ status_code=422,
850
+ ) from e
851
+
852
+ if version_id:
853
+ try:
854
+ version_gid = GlobalID.from_id(version_id)
855
+ except Exception as e:
856
+ raise HTTPException(
857
+ detail=f"Invalid dataset version ID format: {version_id}",
858
+ status_code=422,
859
+ ) from e
860
+ else:
861
+ version_gid = None
723
862
 
724
863
  if (dataset_type := dataset_gid.type_name) != "Dataset":
725
- raise HTTPException(
726
- detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
727
- )
864
+ raise HTTPException(detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=404)
728
865
 
729
866
  if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
730
- raise HTTPException(
731
- detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
732
- )
867
+ raise HTTPException(detail=f"ID {version_gid} refers to a {version_type}", status_code=404)
733
868
 
734
869
  async with request.app.state.db() as session:
735
870
  if (
@@ -739,7 +874,7 @@ async def get_dataset_examples(
739
874
  ) is None:
740
875
  raise HTTPException(
741
876
  detail=f"No dataset with id {dataset_gid} can be found.",
742
- status_code=HTTP_404_NOT_FOUND,
877
+ status_code=404,
743
878
  )
744
879
 
745
880
  # Subquery to find the maximum created_at for each dataset_example_id
@@ -761,7 +896,7 @@ async def get_dataset_examples(
761
896
  ) is None:
762
897
  raise HTTPException(
763
898
  detail=f"No dataset version with id {version_id} can be found.",
764
- status_code=HTTP_404_NOT_FOUND,
899
+ status_code=404,
765
900
  )
766
901
  # if a version_id is provided, filter the subquery to only include revisions from that
767
902
  partial_subquery = partial_subquery.filter(
@@ -777,13 +912,17 @@ async def get_dataset_examples(
777
912
  ) is None:
778
913
  raise HTTPException(
779
914
  detail="Dataset has no versions.",
780
- status_code=HTTP_404_NOT_FOUND,
915
+ status_code=404,
781
916
  )
782
917
 
783
918
  subquery = partial_subquery.subquery()
919
+
784
920
  # Query for the most recent example revisions that are not deleted
785
921
  query = (
786
- select(models.DatasetExample, models.DatasetExampleRevision)
922
+ select(
923
+ models.DatasetExample,
924
+ models.DatasetExampleRevision,
925
+ )
787
926
  .join(
788
927
  models.DatasetExampleRevision,
789
928
  models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
@@ -796,6 +935,28 @@ async def get_dataset_examples(
796
935
  .filter(models.DatasetExampleRevision.revision_kind != "DELETE")
797
936
  .order_by(models.DatasetExample.id.asc())
798
937
  )
938
+
939
+ # If splits are provided, filter by dataset splits
940
+ resolved_split_names: list[str] = []
941
+ if split:
942
+ # Resolve split identifiers (IDs or names) to IDs and names
943
+ resolved_split_ids, resolved_split_names = await _resolve_split_identifiers(
944
+ session, split
945
+ )
946
+
947
+ # Add filter for splits (join with the association table)
948
+ # Use distinct() to prevent duplicates when an example belongs to
949
+ # multiple splits
950
+ query = (
951
+ query.join(
952
+ models.DatasetSplitDatasetExample,
953
+ models.DatasetExample.id
954
+ == models.DatasetSplitDatasetExample.dataset_example_id,
955
+ )
956
+ .filter(models.DatasetSplitDatasetExample.dataset_split_id.in_(resolved_split_ids))
957
+ .distinct()
958
+ )
959
+
799
960
  examples = [
800
961
  DatasetExample(
801
962
  id=str(GlobalID("DatasetExample", str(example.id))),
@@ -810,6 +971,7 @@ async def get_dataset_examples(
810
971
  data=ListDatasetExamplesData(
811
972
  dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
812
973
  version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
974
+ filtered_splits=resolved_split_names,
813
975
  examples=examples,
814
976
  )
815
977
  )
@@ -820,10 +982,10 @@ async def get_dataset_examples(
820
982
  operation_id="getDatasetCsv",
821
983
  summary="Download dataset examples as CSV file",
822
984
  response_class=StreamingResponse,
823
- status_code=HTTP_200_OK,
985
+ status_code=200,
824
986
  responses={
825
- **add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
826
- **add_text_csv_content_to_responses(HTTP_200_OK),
987
+ **add_errors_to_responses([422]),
988
+ **add_text_csv_content_to_responses(200),
827
989
  },
828
990
  )
829
991
  async def get_dataset_csv(
@@ -843,7 +1005,7 @@ async def get_dataset_csv(
843
1005
  session=session, id=id, version_id=version_id
844
1006
  )
845
1007
  except ValueError as e:
846
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
1008
+ raise HTTPException(detail=str(e), status_code=422)
847
1009
  content = await run_in_threadpool(_get_content_csv, examples)
848
1010
  encoded_dataset_name = urllib.parse.quote(dataset_name)
849
1011
  return Response(
@@ -863,7 +1025,7 @@ async def get_dataset_csv(
863
1025
  responses=add_errors_to_responses(
864
1026
  [
865
1027
  {
866
- "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
1028
+ "status_code": 422,
867
1029
  "description": "Invalid dataset or version ID",
868
1030
  }
869
1031
  ]
@@ -886,7 +1048,7 @@ async def get_dataset_jsonl_openai_ft(
886
1048
  session=session, id=id, version_id=version_id
887
1049
  )
888
1050
  except ValueError as e:
889
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
1051
+ raise HTTPException(detail=str(e), status_code=422)
890
1052
  content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
891
1053
  encoded_dataset_name = urllib.parse.quote(dataset_name)
892
1054
  response.headers["content-disposition"] = (
@@ -903,7 +1065,7 @@ async def get_dataset_jsonl_openai_ft(
903
1065
  responses=add_errors_to_responses(
904
1066
  [
905
1067
  {
906
- "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
1068
+ "status_code": 422,
907
1069
  "description": "Invalid dataset or version ID",
908
1070
  }
909
1071
  ]
@@ -926,7 +1088,7 @@ async def get_dataset_jsonl_openai_evals(
926
1088
  session=session, id=id, version_id=version_id
927
1089
  )
928
1090
  except ValueError as e:
929
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
1091
+ raise HTTPException(detail=str(e), status_code=422)
930
1092
  content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
931
1093
  encoded_dataset_name = urllib.parse.quote(dataset_name)
932
1094
  response.headers["content-disposition"] = (
@@ -1005,12 +1167,25 @@ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision
1005
1167
  async def _get_db_examples(
1006
1168
  *, session: Any, id: str, version_id: Optional[str]
1007
1169
  ) -> tuple[str, list[models.DatasetExampleRevision]]:
1008
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
1170
+ try:
1171
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
1172
+ except Exception as e:
1173
+ raise HTTPException(
1174
+ detail=f"Invalid dataset ID format: {id}",
1175
+ status_code=422,
1176
+ ) from e
1177
+
1009
1178
  dataset_version_id: Optional[int] = None
1010
1179
  if version_id:
1011
- dataset_version_id = from_global_id_with_expected_type(
1012
- GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
1013
- )
1180
+ try:
1181
+ dataset_version_id = from_global_id_with_expected_type(
1182
+ GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
1183
+ )
1184
+ except Exception as e:
1185
+ raise HTTPException(
1186
+ detail=f"Invalid dataset version ID format: {version_id}",
1187
+ status_code=422,
1188
+ ) from e
1014
1189
  latest_version = (
1015
1190
  select(
1016
1191
  models.DatasetExampleRevision.dataset_example_id,
@@ -1053,3 +1228,115 @@ async def _get_db_examples(
1053
1228
 
1054
1229
  def _is_all_dict(seq: Sequence[Any]) -> bool:
1055
1230
  return all(map(lambda obj: isinstance(obj, dict), seq))
1231
+
1232
+
1233
+ # Split identifier helper types and functions
1234
+ class _SplitId(int): ...
1235
+
1236
+
1237
+ _SplitIdentifier: TypeAlias = Union[_SplitId, str]
1238
+
1239
+
1240
+ def _parse_split_identifier(split_identifier: str) -> _SplitIdentifier:
1241
+ """
1242
+ Parse a split identifier as either a GlobalID or a name.
1243
+
1244
+ Args:
1245
+ split_identifier: The identifier string (GlobalID or name)
1246
+
1247
+ Returns:
1248
+ Either a _SplitId or an Identifier
1249
+
1250
+ Raises:
1251
+ HTTPException: If the identifier format is invalid
1252
+ """
1253
+ if not split_identifier:
1254
+ raise HTTPException(422, "Invalid split identifier")
1255
+ try:
1256
+ split_id = from_global_id_with_expected_type(
1257
+ GlobalID.from_id(split_identifier),
1258
+ DatasetSplitNodeType.__name__,
1259
+ )
1260
+ except ValueError:
1261
+ return split_identifier
1262
+ return _SplitId(split_id)
1263
+
1264
+
1265
+ async def _resolve_split_identifiers(
1266
+ session: AsyncSession,
1267
+ split_identifiers: list[str],
1268
+ ) -> tuple[list[int], list[str]]:
1269
+ """
1270
+ Resolve a list of split identifiers (IDs or names) to split IDs and names.
1271
+
1272
+ Args:
1273
+ session: The database session
1274
+ split_identifiers: List of split identifiers (GlobalIDs or names)
1275
+
1276
+ Returns:
1277
+ Tuple of (list of split IDs, list of split names)
1278
+
1279
+ Raises:
1280
+ HTTPException: If any split identifier is invalid or not found
1281
+ """
1282
+ split_ids: list[int] = []
1283
+ split_names: list[str] = []
1284
+
1285
+ # Parse all identifiers first
1286
+ parsed_identifiers: list[_SplitIdentifier] = []
1287
+ for identifier_str in split_identifiers:
1288
+ parsed_identifiers.append(_parse_split_identifier(identifier_str.strip()))
1289
+
1290
+ # Separate IDs and names
1291
+ requested_ids: list[int] = []
1292
+ requested_names: list[str] = []
1293
+ for identifier in parsed_identifiers:
1294
+ if isinstance(identifier, _SplitId):
1295
+ requested_ids.append(int(identifier))
1296
+ elif isinstance(identifier, str):
1297
+ requested_names.append(identifier)
1298
+ else:
1299
+ assert_never(identifier)
1300
+
1301
+ # Query for splits by ID
1302
+ if requested_ids:
1303
+ id_results = await session.stream(
1304
+ select(models.DatasetSplit.id, models.DatasetSplit.name).where(
1305
+ models.DatasetSplit.id.in_(requested_ids)
1306
+ )
1307
+ )
1308
+ async for split_id, split_name in id_results:
1309
+ split_ids.append(split_id)
1310
+ split_names.append(split_name)
1311
+
1312
+ # Check if all requested IDs were found
1313
+ found_ids = set(split_ids[-len(requested_ids) :] if requested_ids else [])
1314
+ missing_ids = [sid for sid in requested_ids if sid not in found_ids]
1315
+ if missing_ids:
1316
+ raise HTTPException(
1317
+ status_code=HTTP_404_NOT_FOUND,
1318
+ detail=f"Dataset splits not found for IDs: {', '.join(map(str, missing_ids))}",
1319
+ )
1320
+
1321
+ # Query for splits by name
1322
+ if requested_names:
1323
+ name_results = await session.stream(
1324
+ select(models.DatasetSplit.id, models.DatasetSplit.name).where(
1325
+ models.DatasetSplit.name.in_(requested_names)
1326
+ )
1327
+ )
1328
+ name_to_id: dict[str, int] = {}
1329
+ async for split_id, split_name in name_results:
1330
+ split_ids.append(split_id)
1331
+ split_names.append(split_name)
1332
+ name_to_id[split_name] = split_id
1333
+
1334
+ # Check if all requested names were found
1335
+ missing_names = [name for name in requested_names if name not in name_to_id]
1336
+ if missing_names:
1337
+ raise HTTPException(
1338
+ status_code=HTTP_404_NOT_FOUND,
1339
+ detail=f"Dataset splits not found: {', '.join(missing_names)}",
1340
+ )
1341
+
1342
+ return split_ids, split_names