arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (276) hide show
  1. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
  2. arize_phoenix-12.28.1.dist-info/RECORD +499 -0
  3. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
  4. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
  5. phoenix/__generated__/__init__.py +0 -0
  6. phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
  7. phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
  8. phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
  9. phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
  10. phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
  11. phoenix/__init__.py +5 -4
  12. phoenix/auth.py +39 -2
  13. phoenix/config.py +1763 -91
  14. phoenix/datetime_utils.py +120 -2
  15. phoenix/db/README.md +595 -25
  16. phoenix/db/bulk_inserter.py +145 -103
  17. phoenix/db/engines.py +140 -33
  18. phoenix/db/enums.py +3 -12
  19. phoenix/db/facilitator.py +302 -35
  20. phoenix/db/helpers.py +1000 -65
  21. phoenix/db/iam_auth.py +64 -0
  22. phoenix/db/insertion/dataset.py +135 -2
  23. phoenix/db/insertion/document_annotation.py +9 -6
  24. phoenix/db/insertion/evaluation.py +2 -3
  25. phoenix/db/insertion/helpers.py +17 -2
  26. phoenix/db/insertion/session_annotation.py +176 -0
  27. phoenix/db/insertion/span.py +15 -11
  28. phoenix/db/insertion/span_annotation.py +3 -4
  29. phoenix/db/insertion/trace_annotation.py +3 -4
  30. phoenix/db/insertion/types.py +50 -20
  31. phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
  32. phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
  33. phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
  34. phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
  35. phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
  36. phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
  37. phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
  38. phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
  39. phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
  40. phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
  41. phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
  42. phoenix/db/models.py +669 -56
  43. phoenix/db/pg_config.py +10 -0
  44. phoenix/db/types/model_provider.py +4 -0
  45. phoenix/db/types/token_price_customization.py +29 -0
  46. phoenix/db/types/trace_retention.py +23 -15
  47. phoenix/experiments/evaluators/utils.py +3 -3
  48. phoenix/experiments/functions.py +160 -52
  49. phoenix/experiments/tracing.py +2 -2
  50. phoenix/experiments/types.py +1 -1
  51. phoenix/inferences/inferences.py +1 -2
  52. phoenix/server/api/auth.py +38 -7
  53. phoenix/server/api/auth_messages.py +46 -0
  54. phoenix/server/api/context.py +100 -4
  55. phoenix/server/api/dataloaders/__init__.py +79 -5
  56. phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
  57. phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
  58. phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
  59. phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
  60. phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
  61. phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
  62. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
  63. phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
  64. phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
  65. phoenix/server/api/dataloaders/dataset_labels.py +36 -0
  66. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
  67. phoenix/server/api/dataloaders/document_evaluations.py +6 -9
  68. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
  69. phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
  70. phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
  71. phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
  72. phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
  73. phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
  74. phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
  75. phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
  76. phoenix/server/api/dataloaders/record_counts.py +37 -10
  77. phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
  78. phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
  79. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
  80. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
  81. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
  82. phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
  83. phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
  84. phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
  85. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
  86. phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
  87. phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
  88. phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
  89. phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
  90. phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
  91. phoenix/server/api/dataloaders/span_costs.py +29 -0
  92. phoenix/server/api/dataloaders/table_fields.py +2 -2
  93. phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
  94. phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
  95. phoenix/server/api/dataloaders/types.py +29 -0
  96. phoenix/server/api/exceptions.py +11 -1
  97. phoenix/server/api/helpers/dataset_helpers.py +5 -1
  98. phoenix/server/api/helpers/playground_clients.py +1243 -292
  99. phoenix/server/api/helpers/playground_registry.py +2 -2
  100. phoenix/server/api/helpers/playground_spans.py +8 -4
  101. phoenix/server/api/helpers/playground_users.py +26 -0
  102. phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
  103. phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
  104. phoenix/server/api/helpers/prompts/models.py +205 -22
  105. phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
  106. phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
  107. phoenix/server/api/input_types/CreateProjectInput.py +27 -0
  108. phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
  109. phoenix/server/api/input_types/DatasetFilter.py +17 -0
  110. phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
  111. phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
  112. phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
  113. phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
  114. phoenix/server/api/input_types/PromptFilter.py +14 -0
  115. phoenix/server/api/input_types/PromptVersionInput.py +52 -1
  116. phoenix/server/api/input_types/SpanSort.py +44 -7
  117. phoenix/server/api/input_types/TimeBinConfig.py +23 -0
  118. phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
  119. phoenix/server/api/input_types/UserRoleInput.py +1 -0
  120. phoenix/server/api/mutations/__init__.py +10 -0
  121. phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
  122. phoenix/server/api/mutations/api_key_mutations.py +19 -23
  123. phoenix/server/api/mutations/chat_mutations.py +154 -47
  124. phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
  125. phoenix/server/api/mutations/dataset_mutations.py +21 -16
  126. phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
  127. phoenix/server/api/mutations/experiment_mutations.py +2 -2
  128. phoenix/server/api/mutations/export_events_mutations.py +3 -3
  129. phoenix/server/api/mutations/model_mutations.py +210 -0
  130. phoenix/server/api/mutations/project_mutations.py +49 -10
  131. phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
  132. phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
  133. phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
  134. phoenix/server/api/mutations/prompt_mutations.py +65 -129
  135. phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
  136. phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
  137. phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
  138. phoenix/server/api/mutations/trace_mutations.py +47 -3
  139. phoenix/server/api/mutations/user_mutations.py +66 -41
  140. phoenix/server/api/queries.py +768 -293
  141. phoenix/server/api/routers/__init__.py +2 -2
  142. phoenix/server/api/routers/auth.py +154 -88
  143. phoenix/server/api/routers/ldap.py +229 -0
  144. phoenix/server/api/routers/oauth2.py +369 -106
  145. phoenix/server/api/routers/v1/__init__.py +24 -4
  146. phoenix/server/api/routers/v1/annotation_configs.py +23 -31
  147. phoenix/server/api/routers/v1/annotations.py +481 -17
  148. phoenix/server/api/routers/v1/datasets.py +395 -81
  149. phoenix/server/api/routers/v1/documents.py +142 -0
  150. phoenix/server/api/routers/v1/evaluations.py +24 -31
  151. phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
  152. phoenix/server/api/routers/v1/experiment_runs.py +337 -59
  153. phoenix/server/api/routers/v1/experiments.py +479 -48
  154. phoenix/server/api/routers/v1/models.py +7 -0
  155. phoenix/server/api/routers/v1/projects.py +18 -49
  156. phoenix/server/api/routers/v1/prompts.py +54 -40
  157. phoenix/server/api/routers/v1/sessions.py +108 -0
  158. phoenix/server/api/routers/v1/spans.py +1091 -81
  159. phoenix/server/api/routers/v1/traces.py +132 -78
  160. phoenix/server/api/routers/v1/users.py +389 -0
  161. phoenix/server/api/routers/v1/utils.py +3 -7
  162. phoenix/server/api/subscriptions.py +305 -88
  163. phoenix/server/api/types/Annotation.py +90 -23
  164. phoenix/server/api/types/ApiKey.py +13 -17
  165. phoenix/server/api/types/AuthMethod.py +1 -0
  166. phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
  167. phoenix/server/api/types/CostBreakdown.py +12 -0
  168. phoenix/server/api/types/Dataset.py +226 -72
  169. phoenix/server/api/types/DatasetExample.py +88 -18
  170. phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
  171. phoenix/server/api/types/DatasetLabel.py +57 -0
  172. phoenix/server/api/types/DatasetSplit.py +98 -0
  173. phoenix/server/api/types/DatasetVersion.py +49 -4
  174. phoenix/server/api/types/DocumentAnnotation.py +212 -0
  175. phoenix/server/api/types/Experiment.py +264 -59
  176. phoenix/server/api/types/ExperimentComparison.py +5 -10
  177. phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
  178. phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
  179. phoenix/server/api/types/ExperimentRun.py +169 -65
  180. phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
  181. phoenix/server/api/types/GenerativeModel.py +245 -3
  182. phoenix/server/api/types/GenerativeProvider.py +70 -11
  183. phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
  184. phoenix/server/api/types/ModelInterface.py +16 -0
  185. phoenix/server/api/types/PlaygroundModel.py +20 -0
  186. phoenix/server/api/types/Project.py +1278 -216
  187. phoenix/server/api/types/ProjectSession.py +188 -28
  188. phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
  189. phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
  190. phoenix/server/api/types/Prompt.py +119 -39
  191. phoenix/server/api/types/PromptLabel.py +42 -25
  192. phoenix/server/api/types/PromptVersion.py +11 -8
  193. phoenix/server/api/types/PromptVersionTag.py +65 -25
  194. phoenix/server/api/types/ServerStatus.py +6 -0
  195. phoenix/server/api/types/Span.py +167 -123
  196. phoenix/server/api/types/SpanAnnotation.py +189 -42
  197. phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
  198. phoenix/server/api/types/SpanCostSummary.py +10 -0
  199. phoenix/server/api/types/SystemApiKey.py +65 -1
  200. phoenix/server/api/types/TokenPrice.py +16 -0
  201. phoenix/server/api/types/TokenUsage.py +3 -3
  202. phoenix/server/api/types/Trace.py +223 -51
  203. phoenix/server/api/types/TraceAnnotation.py +149 -50
  204. phoenix/server/api/types/User.py +137 -32
  205. phoenix/server/api/types/UserApiKey.py +73 -26
  206. phoenix/server/api/types/node.py +10 -0
  207. phoenix/server/api/types/pagination.py +11 -2
  208. phoenix/server/app.py +290 -45
  209. phoenix/server/authorization.py +38 -3
  210. phoenix/server/bearer_auth.py +34 -24
  211. phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
  212. phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
  213. phoenix/server/cost_tracking/helpers.py +68 -0
  214. phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
  215. phoenix/server/cost_tracking/regex_specificity.py +397 -0
  216. phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
  217. phoenix/server/daemons/__init__.py +0 -0
  218. phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
  219. phoenix/server/daemons/generative_model_store.py +103 -0
  220. phoenix/server/daemons/span_cost_calculator.py +99 -0
  221. phoenix/server/dml_event.py +17 -0
  222. phoenix/server/dml_event_handler.py +5 -0
  223. phoenix/server/email/sender.py +56 -3
  224. phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
  225. phoenix/server/email/types.py +11 -0
  226. phoenix/server/experiments/__init__.py +0 -0
  227. phoenix/server/experiments/utils.py +14 -0
  228. phoenix/server/grpc_server.py +11 -11
  229. phoenix/server/jwt_store.py +17 -15
  230. phoenix/server/ldap.py +1449 -0
  231. phoenix/server/main.py +26 -10
  232. phoenix/server/oauth2.py +330 -12
  233. phoenix/server/prometheus.py +66 -6
  234. phoenix/server/rate_limiters.py +4 -9
  235. phoenix/server/retention.py +33 -20
  236. phoenix/server/session_filters.py +49 -0
  237. phoenix/server/static/.vite/manifest.json +55 -51
  238. phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
  239. phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
  240. phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
  241. phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
  242. phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
  243. phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
  244. phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
  245. phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
  246. phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
  247. phoenix/server/templates/index.html +40 -6
  248. phoenix/server/thread_server.py +1 -2
  249. phoenix/server/types.py +14 -4
  250. phoenix/server/utils.py +74 -0
  251. phoenix/session/client.py +56 -3
  252. phoenix/session/data_extractor.py +5 -0
  253. phoenix/session/evaluation.py +14 -5
  254. phoenix/session/session.py +45 -9
  255. phoenix/settings.py +5 -0
  256. phoenix/trace/attributes.py +80 -13
  257. phoenix/trace/dsl/helpers.py +90 -1
  258. phoenix/trace/dsl/query.py +8 -6
  259. phoenix/trace/projects.py +5 -0
  260. phoenix/utilities/template_formatters.py +1 -1
  261. phoenix/version.py +1 -1
  262. arize_phoenix-10.0.4.dist-info/RECORD +0 -405
  263. phoenix/server/api/types/Evaluation.py +0 -39
  264. phoenix/server/cost_tracking/cost_lookup.py +0 -255
  265. phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
  266. phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
  267. phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
  268. phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
  269. phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
  270. phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
  271. phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
  272. phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
  273. phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
  274. phoenix/utilities/deprecation.py +0 -31
  275. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
  276. {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
@@ -15,21 +15,16 @@ from typing import Any, Optional, Union, cast
15
15
 
16
16
  import pandas as pd
17
17
  import pyarrow as pa
18
- from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
18
+ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query
19
19
  from fastapi.responses import PlainTextResponse, StreamingResponse
20
- from sqlalchemy import and_, delete, func, select
20
+ from sqlalchemy import and_, case, delete, func, select
21
21
  from sqlalchemy.ext.asyncio import AsyncSession
22
22
  from starlette.concurrency import run_in_threadpool
23
23
  from starlette.datastructures import FormData, UploadFile
24
24
  from starlette.requests import Request
25
25
  from starlette.responses import Response
26
26
  from starlette.status import (
27
- HTTP_200_OK,
28
- HTTP_204_NO_CONTENT,
29
27
  HTTP_404_NOT_FOUND,
30
- HTTP_409_CONFLICT,
31
- HTTP_422_UNPROCESSABLE_ENTITY,
32
- HTTP_429_TOO_MANY_REQUESTS,
33
28
  )
34
29
  from strawberry.relay import GlobalID
35
30
  from typing_extensions import TypeAlias, assert_never
@@ -42,11 +37,15 @@ from phoenix.db.insertion.dataset import (
42
37
  ExampleContent,
43
38
  add_dataset_examples,
44
39
  )
40
+ from phoenix.db.types.db_models import UNDEFINED
45
41
  from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
46
42
  from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
43
+ from phoenix.server.api.types.DatasetSplit import DatasetSplit as DatasetSplitNodeType
47
44
  from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
48
45
  from phoenix.server.api.types.node import from_global_id_with_expected_type
49
46
  from phoenix.server.api.utils import delete_projects, delete_traces
47
+ from phoenix.server.authorization import is_not_locked
48
+ from phoenix.server.bearer_auth import PhoenixUser
50
49
  from phoenix.server.dml_event import DatasetInsertEvent
51
50
 
52
51
  from .models import V1RoutesBaseModel
@@ -57,6 +56,11 @@ from .utils import (
57
56
  add_text_csv_content_to_responses,
58
57
  )
59
58
 
59
+ csv.field_size_limit(
60
+ 1_000_000_000 # allows large field sizes for CSV upload (1GB)
61
+ )
62
+
63
+
60
64
  logger = logging.getLogger(__name__)
61
65
 
62
66
  DATASET_NODE_NAME = DatasetNodeType.__name__
@@ -73,6 +77,7 @@ class Dataset(V1RoutesBaseModel):
73
77
  metadata: dict[str, Any]
74
78
  created_at: datetime
75
79
  updated_at: datetime
80
+ example_count: int
76
81
 
77
82
 
78
83
  class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
@@ -83,7 +88,7 @@ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
83
88
  "/datasets",
84
89
  operation_id="listDatasets",
85
90
  summary="List datasets",
86
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
91
+ responses=add_errors_to_responses([422]),
87
92
  )
88
93
  async def list_datasets(
89
94
  request: Request,
@@ -97,7 +102,18 @@ async def list_datasets(
97
102
  ),
98
103
  ) -> ListDatasetsResponseBody:
99
104
  async with request.app.state.db() as session:
100
- query = select(models.Dataset).order_by(models.Dataset.id.desc())
105
+ value = case(
106
+ (models.DatasetExampleRevision.revision_kind == "CREATE", 1),
107
+ (models.DatasetExampleRevision.revision_kind == "DELETE", -1),
108
+ )
109
+ query = (
110
+ select(models.Dataset)
111
+ .add_columns(func.coalesce(func.sum(value), 0).label("example_count"))
112
+ .outerjoin_from(models.Dataset, models.DatasetExample)
113
+ .outerjoin_from(models.DatasetExample, models.DatasetExampleRevision)
114
+ .group_by(models.Dataset.id)
115
+ .order_by(models.Dataset.id.desc())
116
+ )
101
117
 
102
118
  if cursor:
103
119
  try:
@@ -106,25 +122,26 @@ async def list_datasets(
106
122
  except ValueError:
107
123
  raise HTTPException(
108
124
  detail=f"Invalid cursor format: {cursor}",
109
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
125
+ status_code=422,
110
126
  )
111
127
  if name:
112
128
  query = query.filter(models.Dataset.name == name)
113
129
 
114
130
  query = query.limit(limit + 1)
115
131
  result = await session.execute(query)
116
- datasets = result.scalars().all()
117
-
132
+ datasets = result.all()
118
133
  if not datasets:
119
134
  return ListDatasetsResponseBody(next_cursor=None, data=[])
120
135
 
121
136
  next_cursor = None
122
137
  if len(datasets) == limit + 1:
123
- next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
138
+ dataset = datasets[-1][0]
139
+ next_cursor = str(GlobalID(DATASET_NODE_NAME, str(dataset.id)))
124
140
  datasets = datasets[:-1]
125
141
 
126
142
  data = []
127
- for dataset in datasets:
143
+ for row in datasets:
144
+ dataset = row[0]
128
145
  data.append(
129
146
  Dataset(
130
147
  id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
@@ -133,6 +150,7 @@ async def list_datasets(
133
150
  metadata=dataset.metadata_,
134
151
  created_at=dataset.created_at,
135
152
  updated_at=dataset.updated_at,
153
+ example_count=row[1],
136
154
  )
137
155
  )
138
156
 
@@ -143,11 +161,11 @@ async def list_datasets(
143
161
  "/datasets/{id}",
144
162
  operation_id="deleteDatasetById",
145
163
  summary="Delete dataset by ID",
146
- status_code=HTTP_204_NO_CONTENT,
164
+ status_code=204,
147
165
  responses=add_errors_to_responses(
148
166
  [
149
- {"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
150
- {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
167
+ {"status_code": 404, "description": "Dataset not found"},
168
+ {"status_code": 422, "description": "Invalid dataset ID"},
151
169
  ]
152
170
  ),
153
171
  )
@@ -161,11 +179,9 @@ async def delete_dataset(
161
179
  DATASET_NODE_NAME,
162
180
  )
163
181
  except ValueError:
164
- raise HTTPException(
165
- detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
166
- )
182
+ raise HTTPException(detail=f"Invalid Dataset ID: {id}", status_code=422)
167
183
  else:
168
- raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
184
+ raise HTTPException(detail="Missing Dataset ID", status_code=422)
169
185
  project_names_stmt = get_project_names_for_datasets(dataset_id)
170
186
  eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
171
187
  stmt = (
@@ -175,7 +191,7 @@ async def delete_dataset(
175
191
  project_names = await session.scalars(project_names_stmt)
176
192
  eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
177
193
  if (await session.scalar(stmt)) is None:
178
- raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
194
+ raise HTTPException(detail="Dataset does not exist", status_code=404)
179
195
  tasks = BackgroundTasks()
180
196
  tasks.add_task(delete_projects, request.app.state.db, *project_names)
181
197
  tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
@@ -193,17 +209,21 @@ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
193
209
  "/datasets/{id}",
194
210
  operation_id="getDataset",
195
211
  summary="Get dataset by ID",
196
- responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
212
+ responses=add_errors_to_responses([404]),
197
213
  )
198
214
  async def get_dataset(
199
215
  request: Request, id: str = Path(description="The ID of the dataset")
200
216
  ) -> GetDatasetResponseBody:
201
- dataset_id = GlobalID.from_id(id)
217
+ try:
218
+ dataset_id = GlobalID.from_id(id)
219
+ except Exception as e:
220
+ raise HTTPException(
221
+ detail=f"Invalid dataset ID format: {id}",
222
+ status_code=422,
223
+ ) from e
202
224
 
203
225
  if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
204
- raise HTTPException(
205
- detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
206
- )
226
+ raise HTTPException(detail=f"ID {dataset_id} refers to a f{type_name}", status_code=404)
207
227
  async with request.app.state.db() as session:
208
228
  result = await session.execute(
209
229
  select(models.Dataset, models.Dataset.example_count).filter(
@@ -214,9 +234,7 @@ async def get_dataset(
214
234
  dataset = dataset_query[0] if dataset_query else None
215
235
  example_count = dataset_query[1] if dataset_query else 0
216
236
  if dataset is None:
217
- raise HTTPException(
218
- detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
219
- )
237
+ raise HTTPException(detail=f"Dataset with ID {dataset_id} not found", status_code=404)
220
238
 
221
239
  dataset = DatasetWithExampleCount(
222
240
  id=str(dataset_id),
@@ -245,7 +263,7 @@ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
245
263
  "/datasets/{id}/versions",
246
264
  operation_id="listDatasetVersionsByDatasetId",
247
265
  summary="List dataset versions",
248
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
266
+ responses=add_errors_to_responses([422]),
249
267
  )
250
268
  async def list_dataset_versions(
251
269
  request: Request,
@@ -267,12 +285,12 @@ async def list_dataset_versions(
267
285
  except ValueError:
268
286
  raise HTTPException(
269
287
  detail=f"Invalid Dataset ID: {id}",
270
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
288
+ status_code=422,
271
289
  )
272
290
  else:
273
291
  raise HTTPException(
274
292
  detail="Missing Dataset ID",
275
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
293
+ status_code=422,
276
294
  )
277
295
  stmt = (
278
296
  select(models.DatasetVersion)
@@ -288,7 +306,7 @@ async def list_dataset_versions(
288
306
  except ValueError:
289
307
  raise HTTPException(
290
308
  detail=f"Invalid cursor: {cursor}",
291
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
309
+ status_code=422,
292
310
  )
293
311
  max_dataset_version_id = (
294
312
  select(models.DatasetVersion.id)
@@ -312,6 +330,7 @@ async def list_dataset_versions(
312
330
 
313
331
  class UploadDatasetData(V1RoutesBaseModel):
314
332
  dataset_id: str
333
+ version_id: str
315
334
 
316
335
 
317
336
  class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
@@ -320,15 +339,16 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
320
339
 
321
340
  @router.post(
322
341
  "/datasets/upload",
342
+ dependencies=[Depends(is_not_locked)],
323
343
  operation_id="uploadDataset",
324
- summary="Upload dataset from JSON, CSV, or PyArrow",
344
+ summary="Upload dataset from JSON, JSONL, CSV, or PyArrow",
325
345
  responses=add_errors_to_responses(
326
346
  [
327
347
  {
328
- "status_code": HTTP_409_CONFLICT,
348
+ "status_code": 409,
329
349
  "description": "Dataset of the same name already exists",
330
350
  },
331
- {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
351
+ {"status_code": 422, "description": "Invalid request body"},
332
352
  ]
333
353
  ),
334
354
  # FastAPI cannot generate the request body portion of the OpenAPI schema for
@@ -350,6 +370,17 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
350
370
  "inputs": {"type": "array", "items": {"type": "object"}},
351
371
  "outputs": {"type": "array", "items": {"type": "object"}},
352
372
  "metadata": {"type": "array", "items": {"type": "object"}},
373
+ "splits": {
374
+ "type": "array",
375
+ "items": {
376
+ "oneOf": [
377
+ {"type": "string"},
378
+ {"type": "array", "items": {"type": "string"}},
379
+ {"type": "null"},
380
+ ]
381
+ },
382
+ "description": "Split per example: string, string array, or null",
383
+ },
353
384
  },
354
385
  }
355
386
  },
@@ -376,6 +407,12 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
376
407
  "items": {"type": "string"},
377
408
  "uniqueItems": True,
378
409
  },
410
+ "split_keys[]": {
411
+ "type": "array",
412
+ "items": {"type": "string"},
413
+ "uniqueItems": True,
414
+ "description": "Column names for auto-assigning examples to splits",
415
+ },
379
416
  "file": {"type": "string", "format": "binary"},
380
417
  },
381
418
  }
@@ -391,7 +428,12 @@ async def upload_dataset(
391
428
  description="If true, fulfill request synchronously and return JSON containing dataset_id.",
392
429
  ),
393
430
  ) -> Optional[UploadDatasetResponseBody]:
394
- request_content_type = request.headers["content-type"]
431
+ request_content_type = request.headers.get("content-type")
432
+ if not request_content_type:
433
+ raise HTTPException(
434
+ detail="Missing content-type header",
435
+ status_code=400,
436
+ )
395
437
  examples: Union[Examples, Awaitable[Examples]]
396
438
  if request_content_type.startswith("application/json"):
397
439
  try:
@@ -401,14 +443,14 @@ async def upload_dataset(
401
443
  except ValueError as e:
402
444
  raise HTTPException(
403
445
  detail=str(e),
404
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
446
+ status_code=422,
405
447
  )
406
448
  if action is DatasetAction.CREATE:
407
449
  async with request.app.state.db() as session:
408
450
  if await _check_table_exists(session, name):
409
451
  raise HTTPException(
410
452
  detail=f"Dataset with the same name already exists: {name=}",
411
- status_code=HTTP_409_CONFLICT,
453
+ status_code=409,
412
454
  )
413
455
  elif request_content_type.startswith("multipart/form-data"):
414
456
  async with request.form() as form:
@@ -420,19 +462,20 @@ async def upload_dataset(
420
462
  input_keys,
421
463
  output_keys,
422
464
  metadata_keys,
465
+ split_keys,
423
466
  file,
424
467
  ) = await _parse_form_data(form)
425
468
  except ValueError as e:
426
469
  raise HTTPException(
427
470
  detail=str(e),
428
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
471
+ status_code=422,
429
472
  )
430
473
  if action is DatasetAction.CREATE:
431
474
  async with request.app.state.db() as session:
432
475
  if await _check_table_exists(session, name):
433
476
  raise HTTPException(
434
477
  detail=f"Dataset with the same name already exists: {name=}",
435
- status_code=HTTP_409_CONFLICT,
478
+ status_code=409,
436
479
  )
437
480
  content = await file.read()
438
481
  try:
@@ -440,22 +483,32 @@ async def upload_dataset(
440
483
  if file_content_type is FileContentType.CSV:
441
484
  encoding = FileContentEncoding(file.headers.get("content-encoding"))
442
485
  examples = await _process_csv(
443
- content, encoding, input_keys, output_keys, metadata_keys
486
+ content, encoding, input_keys, output_keys, metadata_keys, split_keys
444
487
  )
445
488
  elif file_content_type is FileContentType.PYARROW:
446
- examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
489
+ examples = await _process_pyarrow(
490
+ content, input_keys, output_keys, metadata_keys, split_keys
491
+ )
492
+ elif file_content_type is FileContentType.JSONL:
493
+ encoding = FileContentEncoding(file.headers.get("content-encoding"))
494
+ examples = await _process_jsonl(
495
+ content, encoding, input_keys, output_keys, metadata_keys, split_keys
496
+ )
447
497
  else:
448
498
  assert_never(file_content_type)
449
499
  except ValueError as e:
450
500
  raise HTTPException(
451
501
  detail=str(e),
452
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
502
+ status_code=422,
453
503
  )
454
504
  else:
455
505
  raise HTTPException(
456
506
  detail="Invalid request Content-Type",
457
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
507
+ status_code=422,
458
508
  )
509
+ user_id: Optional[int] = None
510
+ if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
511
+ user_id = int(request.user.identity)
459
512
  operation = cast(
460
513
  Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
461
514
  partial(
@@ -464,27 +517,34 @@ async def upload_dataset(
464
517
  action=action,
465
518
  name=name,
466
519
  description=description,
520
+ user_id=user_id,
467
521
  ),
468
522
  )
469
523
  if sync:
470
524
  async with request.app.state.db() as session:
471
- dataset_id = (await operation(session)).dataset_id
525
+ event = await operation(session)
526
+ dataset_id = event.dataset_id
527
+ version_id = event.dataset_version_id
472
528
  request.state.event_queue.put(DatasetInsertEvent((dataset_id,)))
473
529
  return UploadDatasetResponseBody(
474
- data=UploadDatasetData(dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))))
530
+ data=UploadDatasetData(
531
+ dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))),
532
+ version_id=str(GlobalID(DatasetVersion.__name__, str(version_id))),
533
+ )
475
534
  )
476
535
  try:
477
536
  request.state.enqueue_operation(operation)
478
537
  except QueueFull:
479
538
  if isinstance(examples, Coroutine):
480
539
  examples.close()
481
- raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
540
+ raise HTTPException(detail="Too many requests.", status_code=429)
482
541
  return None
483
542
 
484
543
 
485
544
  class FileContentType(Enum):
486
545
  CSV = "text/csv"
487
546
  PYARROW = "application/x-pandas-pyarrow"
547
+ JSONL = "application/jsonl"
488
548
 
489
549
  @classmethod
490
550
  def _missing_(cls, v: Any) -> "FileContentType":
@@ -512,6 +572,7 @@ Description: TypeAlias = Optional[str]
512
572
  InputKeys: TypeAlias = frozenset[str]
513
573
  OutputKeys: TypeAlias = frozenset[str]
514
574
  MetadataKeys: TypeAlias = frozenset[str]
575
+ SplitKeys: TypeAlias = frozenset[str]
515
576
  DatasetId: TypeAlias = int
516
577
  Examples: TypeAlias = Iterator[ExampleContent]
517
578
 
@@ -528,18 +589,55 @@ def _process_json(
528
589
  raise ValueError("input is required")
529
590
  if not isinstance(inputs, list) or not _is_all_dict(inputs):
530
591
  raise ValueError("Input should be a list containing only dictionary objects")
531
- outputs, metadata = data.get("outputs"), data.get("metadata")
592
+ outputs, metadata, splits = data.get("outputs"), data.get("metadata"), data.get("splits")
532
593
  for k, v in {"outputs": outputs, "metadata": metadata}.items():
533
594
  if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
534
595
  raise ValueError(
535
596
  f"{k} should be a list of same length as input containing only dictionary objects"
536
597
  )
598
+
599
+ # Validate splits format if provided
600
+ if splits is not None:
601
+ if not isinstance(splits, list):
602
+ raise ValueError("splits must be a list")
603
+ if len(splits) != len(inputs):
604
+ raise ValueError(
605
+ f"splits must have same length as inputs ({len(splits)} != {len(inputs)})"
606
+ )
537
607
  examples: list[ExampleContent] = []
538
608
  for i, obj in enumerate(inputs):
609
+ # Extract split values, validating they're non-empty strings
610
+ split_set: set[str] = set()
611
+ if splits:
612
+ split_value = splits[i]
613
+ if split_value is None:
614
+ # Sparse assignment: None means no splits for this example
615
+ pass
616
+ elif isinstance(split_value, str):
617
+ # Format 1: Single string value
618
+ if split_value.strip():
619
+ split_set.add(split_value.strip())
620
+ elif isinstance(split_value, list):
621
+ # Format 2: List of strings (multiple splits)
622
+ for v in split_value:
623
+ if v is None:
624
+ continue # Skip None values in the list
625
+ if not isinstance(v, str):
626
+ raise ValueError(
627
+ f"Split value must be a string or None, got {type(v).__name__}"
628
+ )
629
+ if v.strip():
630
+ split_set.add(v.strip())
631
+ else:
632
+ raise ValueError(
633
+ f"Split value must be a string, list of strings, or None, "
634
+ f"got {type(split_value).__name__}"
635
+ )
539
636
  example = ExampleContent(
540
637
  input=obj,
541
638
  output=outputs[i] if outputs else {},
542
639
  metadata=metadata[i] if metadata else {},
640
+ splits=frozenset(split_set),
543
641
  )
544
642
  examples.append(example)
545
643
  action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
@@ -552,6 +650,7 @@ async def _process_csv(
552
650
  input_keys: InputKeys,
553
651
  output_keys: OutputKeys,
554
652
  metadata_keys: MetadataKeys,
653
+ split_keys: SplitKeys,
555
654
  ) -> Examples:
556
655
  if content_encoding is FileContentEncoding.GZIP:
557
656
  content = await run_in_threadpool(gzip.decompress, content)
@@ -566,12 +665,15 @@ async def _process_csv(
566
665
  if freq > 1:
567
666
  raise ValueError(f"Duplicated column header in CSV file: {header}")
568
667
  column_headers = frozenset(reader.fieldnames)
569
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
668
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
570
669
  return (
571
670
  ExampleContent(
572
671
  input={k: row.get(k) for k in input_keys},
573
672
  output={k: row.get(k) for k in output_keys},
574
673
  metadata={k: row.get(k) for k in metadata_keys},
674
+ splits=frozenset(
675
+ str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
676
+ ), # Only include non-empty, non-whitespace split values
575
677
  )
576
678
  for row in iter(reader)
577
679
  )
@@ -582,13 +684,14 @@ async def _process_pyarrow(
582
684
  input_keys: InputKeys,
583
685
  output_keys: OutputKeys,
584
686
  metadata_keys: MetadataKeys,
687
+ split_keys: SplitKeys,
585
688
  ) -> Awaitable[Examples]:
586
689
  try:
587
690
  reader = pa.ipc.open_stream(content)
588
691
  except pa.ArrowInvalid as e:
589
692
  raise ValueError("File is not valid pyarrow") from e
590
693
  column_headers = frozenset(reader.schema.names)
591
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
694
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
592
695
 
593
696
  def get_examples() -> Iterator[ExampleContent]:
594
697
  for row in reader.read_pandas().to_dict(orient="records"):
@@ -596,11 +699,48 @@ async def _process_pyarrow(
596
699
  input={k: row.get(k) for k in input_keys},
597
700
  output={k: row.get(k) for k in output_keys},
598
701
  metadata={k: row.get(k) for k in metadata_keys},
702
+ splits=frozenset(
703
+ str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
704
+ ), # Only include non-empty, non-whitespace split values
599
705
  )
600
706
 
601
707
  return run_in_threadpool(get_examples)
602
708
 
603
709
 
710
+ async def _process_jsonl(
711
+ content: bytes,
712
+ encoding: FileContentEncoding,
713
+ input_keys: InputKeys,
714
+ output_keys: OutputKeys,
715
+ metadata_keys: MetadataKeys,
716
+ split_keys: SplitKeys,
717
+ ) -> Examples:
718
+ if encoding is FileContentEncoding.GZIP:
719
+ content = await run_in_threadpool(gzip.decompress, content)
720
+ elif encoding is FileContentEncoding.DEFLATE:
721
+ content = await run_in_threadpool(zlib.decompress, content)
722
+ elif encoding is not FileContentEncoding.NONE:
723
+ assert_never(encoding)
724
+ # content is a newline delimited list of JSON objects
725
+ # parse within a threadpool
726
+ reader = await run_in_threadpool(
727
+ lambda c: [json.loads(line) for line in c.decode().splitlines()], content
728
+ )
729
+
730
+ examples: list[ExampleContent] = []
731
+ for obj in reader:
732
+ example = ExampleContent(
733
+ input={k: obj.get(k) for k in input_keys},
734
+ output={k: obj.get(k) for k in output_keys},
735
+ metadata={k: obj.get(k) for k in metadata_keys},
736
+ splits=frozenset(
737
+ str(v).strip() for k in split_keys if (v := obj.get(k)) and str(v).strip()
738
+ ), # Only include non-empty, non-whitespace split values
739
+ )
740
+ examples.append(example)
741
+ return iter(examples)
742
+
743
+
604
744
  async def _check_table_exists(session: AsyncSession, name: str) -> bool:
605
745
  return bool(
606
746
  await session.scalar(
@@ -614,11 +754,13 @@ def _check_keys_exist(
614
754
  input_keys: InputKeys,
615
755
  output_keys: OutputKeys,
616
756
  metadata_keys: MetadataKeys,
757
+ split_keys: SplitKeys,
617
758
  ) -> None:
618
759
  for desc, keys in (
619
760
  ("input", input_keys),
620
761
  ("output", output_keys),
621
762
  ("metadata", metadata_keys),
763
+ ("split", split_keys),
622
764
  ):
623
765
  if keys and (diff := keys.difference(column_headers)):
624
766
  raise ValueError(f"{desc} keys not found in column headers: {diff}")
@@ -633,6 +775,7 @@ async def _parse_form_data(
633
775
  InputKeys,
634
776
  OutputKeys,
635
777
  MetadataKeys,
778
+ SplitKeys,
636
779
  UploadFile,
637
780
  ]:
638
781
  name = cast(Optional[str], form.get("name"))
@@ -646,6 +789,7 @@ async def _parse_form_data(
646
789
  input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
647
790
  output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
648
791
  metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
792
+ split_keys = frozenset(filter(bool, cast(list[str], form.getlist("split_keys[]"))))
649
793
  return (
650
794
  action,
651
795
  name,
@@ -653,6 +797,7 @@ async def _parse_form_data(
653
797
  input_keys,
654
798
  output_keys,
655
799
  metadata_keys,
800
+ split_keys,
656
801
  file,
657
802
  )
658
803
 
@@ -668,6 +813,7 @@ class DatasetExample(V1RoutesBaseModel):
668
813
  class ListDatasetExamplesData(V1RoutesBaseModel):
669
814
  dataset_id: str
670
815
  version_id: str
816
+ filtered_splits: list[str] = UNDEFINED
671
817
  examples: list[DatasetExample]
672
818
 
673
819
 
@@ -679,7 +825,7 @@ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
679
825
  "/datasets/{id}/examples",
680
826
  operation_id="getDatasetExamples",
681
827
  summary="Get examples from a dataset",
682
- responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
828
+ responses=add_errors_to_responses([404]),
683
829
  )
684
830
  async def get_dataset_examples(
685
831
  request: Request,
@@ -687,22 +833,38 @@ async def get_dataset_examples(
687
833
  version_id: Optional[str] = Query(
688
834
  default=None,
689
835
  description=(
690
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
836
+ "The ID of the dataset version (if omitted, returns data from the latest version)"
691
837
  ),
692
838
  ),
839
+ split: Optional[list[str]] = Query(
840
+ default=None,
841
+ description="List of dataset split identifiers (GlobalIDs or names) to filter by",
842
+ ),
693
843
  ) -> ListDatasetExamplesResponseBody:
694
- dataset_gid = GlobalID.from_id(id)
695
- version_gid = GlobalID.from_id(version_id) if version_id else None
844
+ try:
845
+ dataset_gid = GlobalID.from_id(id)
846
+ except Exception as e:
847
+ raise HTTPException(
848
+ detail=f"Invalid dataset ID format: {id}",
849
+ status_code=422,
850
+ ) from e
851
+
852
+ if version_id:
853
+ try:
854
+ version_gid = GlobalID.from_id(version_id)
855
+ except Exception as e:
856
+ raise HTTPException(
857
+ detail=f"Invalid dataset version ID format: {version_id}",
858
+ status_code=422,
859
+ ) from e
860
+ else:
861
+ version_gid = None
696
862
 
697
863
  if (dataset_type := dataset_gid.type_name) != "Dataset":
698
- raise HTTPException(
699
- detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
700
- )
864
+ raise HTTPException(detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=404)
701
865
 
702
866
  if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
703
- raise HTTPException(
704
- detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
705
- )
867
+ raise HTTPException(detail=f"ID {version_gid} refers to a {version_type}", status_code=404)
706
868
 
707
869
  async with request.app.state.db() as session:
708
870
  if (
@@ -712,7 +874,7 @@ async def get_dataset_examples(
712
874
  ) is None:
713
875
  raise HTTPException(
714
876
  detail=f"No dataset with id {dataset_gid} can be found.",
715
- status_code=HTTP_404_NOT_FOUND,
877
+ status_code=404,
716
878
  )
717
879
 
718
880
  # Subquery to find the maximum created_at for each dataset_example_id
@@ -734,7 +896,7 @@ async def get_dataset_examples(
734
896
  ) is None:
735
897
  raise HTTPException(
736
898
  detail=f"No dataset version with id {version_id} can be found.",
737
- status_code=HTTP_404_NOT_FOUND,
899
+ status_code=404,
738
900
  )
739
901
  # if a version_id is provided, filter the subquery to only include revisions from that
740
902
  partial_subquery = partial_subquery.filter(
@@ -750,13 +912,17 @@ async def get_dataset_examples(
750
912
  ) is None:
751
913
  raise HTTPException(
752
914
  detail="Dataset has no versions.",
753
- status_code=HTTP_404_NOT_FOUND,
915
+ status_code=404,
754
916
  )
755
917
 
756
918
  subquery = partial_subquery.subquery()
919
+
757
920
  # Query for the most recent example revisions that are not deleted
758
921
  query = (
759
- select(models.DatasetExample, models.DatasetExampleRevision)
922
+ select(
923
+ models.DatasetExample,
924
+ models.DatasetExampleRevision,
925
+ )
760
926
  .join(
761
927
  models.DatasetExampleRevision,
762
928
  models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
@@ -769,6 +935,28 @@ async def get_dataset_examples(
769
935
  .filter(models.DatasetExampleRevision.revision_kind != "DELETE")
770
936
  .order_by(models.DatasetExample.id.asc())
771
937
  )
938
+
939
+ # If splits are provided, filter by dataset splits
940
+ resolved_split_names: list[str] = []
941
+ if split:
942
+ # Resolve split identifiers (IDs or names) to IDs and names
943
+ resolved_split_ids, resolved_split_names = await _resolve_split_identifiers(
944
+ session, split
945
+ )
946
+
947
+ # Add filter for splits (join with the association table)
948
+ # Use distinct() to prevent duplicates when an example belongs to
949
+ # multiple splits
950
+ query = (
951
+ query.join(
952
+ models.DatasetSplitDatasetExample,
953
+ models.DatasetExample.id
954
+ == models.DatasetSplitDatasetExample.dataset_example_id,
955
+ )
956
+ .filter(models.DatasetSplitDatasetExample.dataset_split_id.in_(resolved_split_ids))
957
+ .distinct()
958
+ )
959
+
772
960
  examples = [
773
961
  DatasetExample(
774
962
  id=str(GlobalID("DatasetExample", str(example.id))),
@@ -783,6 +971,7 @@ async def get_dataset_examples(
783
971
  data=ListDatasetExamplesData(
784
972
  dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
785
973
  version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
974
+ filtered_splits=resolved_split_names,
786
975
  examples=examples,
787
976
  )
788
977
  )
@@ -793,10 +982,10 @@ async def get_dataset_examples(
793
982
  operation_id="getDatasetCsv",
794
983
  summary="Download dataset examples as CSV file",
795
984
  response_class=StreamingResponse,
796
- status_code=HTTP_200_OK,
985
+ status_code=200,
797
986
  responses={
798
- **add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
799
- **add_text_csv_content_to_responses(HTTP_200_OK),
987
+ **add_errors_to_responses([422]),
988
+ **add_text_csv_content_to_responses(200),
800
989
  },
801
990
  )
802
991
  async def get_dataset_csv(
@@ -806,7 +995,7 @@ async def get_dataset_csv(
806
995
  version_id: Optional[str] = Query(
807
996
  default=None,
808
997
  description=(
809
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
998
+ "The ID of the dataset version (if omitted, returns data from the latest version)"
810
999
  ),
811
1000
  ),
812
1001
  ) -> Response:
@@ -816,7 +1005,7 @@ async def get_dataset_csv(
816
1005
  session=session, id=id, version_id=version_id
817
1006
  )
818
1007
  except ValueError as e:
819
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
1008
+ raise HTTPException(detail=str(e), status_code=422)
820
1009
  content = await run_in_threadpool(_get_content_csv, examples)
821
1010
  encoded_dataset_name = urllib.parse.quote(dataset_name)
822
1011
  return Response(
@@ -836,7 +1025,7 @@ async def get_dataset_csv(
836
1025
  responses=add_errors_to_responses(
837
1026
  [
838
1027
  {
839
- "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
1028
+ "status_code": 422,
840
1029
  "description": "Invalid dataset or version ID",
841
1030
  }
842
1031
  ]
@@ -849,7 +1038,7 @@ async def get_dataset_jsonl_openai_ft(
849
1038
  version_id: Optional[str] = Query(
850
1039
  default=None,
851
1040
  description=(
852
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
1041
+ "The ID of the dataset version (if omitted, returns data from the latest version)"
853
1042
  ),
854
1043
  ),
855
1044
  ) -> bytes:
@@ -859,7 +1048,7 @@ async def get_dataset_jsonl_openai_ft(
859
1048
  session=session, id=id, version_id=version_id
860
1049
  )
861
1050
  except ValueError as e:
862
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
1051
+ raise HTTPException(detail=str(e), status_code=422)
863
1052
  content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
864
1053
  encoded_dataset_name = urllib.parse.quote(dataset_name)
865
1054
  response.headers["content-disposition"] = (
@@ -876,7 +1065,7 @@ async def get_dataset_jsonl_openai_ft(
876
1065
  responses=add_errors_to_responses(
877
1066
  [
878
1067
  {
879
- "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
1068
+ "status_code": 422,
880
1069
  "description": "Invalid dataset or version ID",
881
1070
  }
882
1071
  ]
@@ -889,7 +1078,7 @@ async def get_dataset_jsonl_openai_evals(
889
1078
  version_id: Optional[str] = Query(
890
1079
  default=None,
891
1080
  description=(
892
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
1081
+ "The ID of the dataset version (if omitted, returns data from the latest version)"
893
1082
  ),
894
1083
  ),
895
1084
  ) -> bytes:
@@ -899,7 +1088,7 @@ async def get_dataset_jsonl_openai_evals(
899
1088
  session=session, id=id, version_id=version_id
900
1089
  )
901
1090
  except ValueError as e:
902
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
1091
+ raise HTTPException(detail=str(e), status_code=422)
903
1092
  content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
904
1093
  encoded_dataset_name = urllib.parse.quote(dataset_name)
905
1094
  response.headers["content-disposition"] = (
@@ -978,12 +1167,25 @@ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision
978
1167
  async def _get_db_examples(
979
1168
  *, session: Any, id: str, version_id: Optional[str]
980
1169
  ) -> tuple[str, list[models.DatasetExampleRevision]]:
981
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
1170
+ try:
1171
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
1172
+ except Exception as e:
1173
+ raise HTTPException(
1174
+ detail=f"Invalid dataset ID format: {id}",
1175
+ status_code=422,
1176
+ ) from e
1177
+
982
1178
  dataset_version_id: Optional[int] = None
983
1179
  if version_id:
984
- dataset_version_id = from_global_id_with_expected_type(
985
- GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
986
- )
1180
+ try:
1181
+ dataset_version_id = from_global_id_with_expected_type(
1182
+ GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
1183
+ )
1184
+ except Exception as e:
1185
+ raise HTTPException(
1186
+ detail=f"Invalid dataset version ID format: {version_id}",
1187
+ status_code=422,
1188
+ ) from e
987
1189
  latest_version = (
988
1190
  select(
989
1191
  models.DatasetExampleRevision.dataset_example_id,
@@ -1026,3 +1228,115 @@ async def _get_db_examples(
1026
1228
 
1027
1229
  def _is_all_dict(seq: Sequence[Any]) -> bool:
1028
1230
  return all(map(lambda obj: isinstance(obj, dict), seq))
1231
+
1232
+
1233
+ # Split identifier helper types and functions
1234
+ class _SplitId(int): ...
1235
+
1236
+
1237
+ _SplitIdentifier: TypeAlias = Union[_SplitId, str]
1238
+
1239
+
1240
+ def _parse_split_identifier(split_identifier: str) -> _SplitIdentifier:
1241
+ """
1242
+ Parse a split identifier as either a GlobalID or a name.
1243
+
1244
+ Args:
1245
+ split_identifier: The identifier string (GlobalID or name)
1246
+
1247
+ Returns:
1248
+ Either a _SplitId or an Identifier
1249
+
1250
+ Raises:
1251
+ HTTPException: If the identifier format is invalid
1252
+ """
1253
+ if not split_identifier:
1254
+ raise HTTPException(422, "Invalid split identifier")
1255
+ try:
1256
+ split_id = from_global_id_with_expected_type(
1257
+ GlobalID.from_id(split_identifier),
1258
+ DatasetSplitNodeType.__name__,
1259
+ )
1260
+ except ValueError:
1261
+ return split_identifier
1262
+ return _SplitId(split_id)
1263
+
1264
+
1265
+ async def _resolve_split_identifiers(
1266
+ session: AsyncSession,
1267
+ split_identifiers: list[str],
1268
+ ) -> tuple[list[int], list[str]]:
1269
+ """
1270
+ Resolve a list of split identifiers (IDs or names) to split IDs and names.
1271
+
1272
+ Args:
1273
+ session: The database session
1274
+ split_identifiers: List of split identifiers (GlobalIDs or names)
1275
+
1276
+ Returns:
1277
+ Tuple of (list of split IDs, list of split names)
1278
+
1279
+ Raises:
1280
+ HTTPException: If any split identifier is invalid or not found
1281
+ """
1282
+ split_ids: list[int] = []
1283
+ split_names: list[str] = []
1284
+
1285
+ # Parse all identifiers first
1286
+ parsed_identifiers: list[_SplitIdentifier] = []
1287
+ for identifier_str in split_identifiers:
1288
+ parsed_identifiers.append(_parse_split_identifier(identifier_str.strip()))
1289
+
1290
+ # Separate IDs and names
1291
+ requested_ids: list[int] = []
1292
+ requested_names: list[str] = []
1293
+ for identifier in parsed_identifiers:
1294
+ if isinstance(identifier, _SplitId):
1295
+ requested_ids.append(int(identifier))
1296
+ elif isinstance(identifier, str):
1297
+ requested_names.append(identifier)
1298
+ else:
1299
+ assert_never(identifier)
1300
+
1301
+ # Query for splits by ID
1302
+ if requested_ids:
1303
+ id_results = await session.stream(
1304
+ select(models.DatasetSplit.id, models.DatasetSplit.name).where(
1305
+ models.DatasetSplit.id.in_(requested_ids)
1306
+ )
1307
+ )
1308
+ async for split_id, split_name in id_results:
1309
+ split_ids.append(split_id)
1310
+ split_names.append(split_name)
1311
+
1312
+ # Check if all requested IDs were found
1313
+ found_ids = set(split_ids[-len(requested_ids) :] if requested_ids else [])
1314
+ missing_ids = [sid for sid in requested_ids if sid not in found_ids]
1315
+ if missing_ids:
1316
+ raise HTTPException(
1317
+ status_code=HTTP_404_NOT_FOUND,
1318
+ detail=f"Dataset splits not found for IDs: {', '.join(map(str, missing_ids))}",
1319
+ )
1320
+
1321
+ # Query for splits by name
1322
+ if requested_names:
1323
+ name_results = await session.stream(
1324
+ select(models.DatasetSplit.id, models.DatasetSplit.name).where(
1325
+ models.DatasetSplit.name.in_(requested_names)
1326
+ )
1327
+ )
1328
+ name_to_id: dict[str, int] = {}
1329
+ async for split_id, split_name in name_results:
1330
+ split_ids.append(split_id)
1331
+ split_names.append(split_name)
1332
+ name_to_id[split_name] = split_id
1333
+
1334
+ # Check if all requested names were found
1335
+ missing_names = [name for name in requested_names if name not in name_to_id]
1336
+ if missing_names:
1337
+ raise HTTPException(
1338
+ status_code=HTTP_404_NOT_FOUND,
1339
+ detail=f"Dataset splits not found: {', '.join(missing_names)}",
1340
+ )
1341
+
1342
+ return split_ids, split_names