arize-phoenix 11.23.1__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +61 -36
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/RECORD +212 -162
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +2 -1
- phoenix/auth.py +27 -2
- phoenix/config.py +1594 -81
- phoenix/db/README.md +546 -28
- phoenix/db/bulk_inserter.py +119 -116
- phoenix/db/engines.py +140 -33
- phoenix/db/facilitator.py +22 -1
- phoenix/db/helpers.py +818 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +133 -1
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +2 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +41 -18
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +364 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/trace_retention.py +7 -6
- phoenix/experiments/functions.py +69 -19
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +9 -0
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +60 -0
- phoenix/server/api/dataloaders/__init__.py +36 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +28 -14
- phoenix/server/api/dataloaders/span_costs.py +3 -9
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/exceptions.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +263 -83
- phoenix/server/api/helpers/playground_spans.py +2 -1
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +61 -19
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +3 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +5 -2
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +3 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +158 -1
- phoenix/server/api/input_types/PromptVersionInput.py +47 -1
- phoenix/server/api/input_types/SpanSort.py +3 -2
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +8 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +15 -20
- phoenix/server/api/mutations/chat_mutations.py +106 -37
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +11 -9
- phoenix/server/api/mutations/project_mutations.py +4 -4
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +13 -8
- phoenix/server/api/mutations/trace_mutations.py +3 -3
- phoenix/server/api/mutations/user_mutations.py +55 -26
- phoenix/server/api/queries.py +501 -617
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +141 -87
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +349 -101
- phoenix/server/api/routers/v1/__init__.py +22 -4
- phoenix/server/api/routers/v1/annotation_configs.py +19 -30
- phoenix/server/api/routers/v1/annotations.py +455 -13
- phoenix/server/api/routers/v1/datasets.py +355 -68
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +20 -28
- phoenix/server/api/routers/v1/experiment_evaluations.py +16 -6
- phoenix/server/api/routers/v1/experiment_runs.py +335 -59
- phoenix/server/api/routers/v1/experiments.py +475 -47
- phoenix/server/api/routers/v1/projects.py +16 -50
- phoenix/server/api/routers/v1/prompts.py +50 -39
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +156 -96
- phoenix/server/api/routers/v1/traces.py +51 -77
- phoenix/server/api/routers/v1/users.py +64 -24
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +257 -93
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/Dataset.py +199 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +215 -68
- phoenix/server/api/types/ExperimentComparison.py +3 -9
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +120 -70
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +95 -42
- phoenix/server/api/types/GenerativeProvider.py +1 -1
- phoenix/server/api/types/ModelInterface.py +7 -2
- phoenix/server/api/types/PlaygroundModel.py +12 -2
- phoenix/server/api/types/Project.py +218 -185
- phoenix/server/api/types/ProjectSession.py +146 -29
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/Span.py +130 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/Trace.py +184 -53
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +128 -33
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +154 -36
- phoenix/server/authorization.py +5 -4
- phoenix/server/bearer_auth.py +13 -5
- phoenix/server/cost_tracking/cost_model_lookup.py +42 -14
- phoenix/server/cost_tracking/model_cost_manifest.json +1085 -194
- phoenix/server/daemons/generative_model_store.py +61 -9
- phoenix/server/daemons/span_cost_calculator.py +10 -8
- phoenix/server/dml_event.py +13 -0
- phoenix/server/email/sender.py +29 -2
- phoenix/server/grpc_server.py +9 -9
- phoenix/server/jwt_store.py +8 -6
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +9 -3
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +43 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +51 -53
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-BPCwGQr8.js → index-CTQoemZv.js} +42 -35
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/{vendor-recharts-Bw30oz1A.js → vendor-recharts-V9cwpXsm.js} +7 -7
- phoenix/server/static/assets/{vendor-shiki-DZajAPeq.js → vendor-shiki-Do--csgv.js} +1 -1
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +7 -1
- phoenix/server/thread_server.py +1 -2
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +55 -1
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +8 -4
- phoenix/session/session.py +44 -8
- phoenix/settings.py +2 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/query.py +2 -0
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/static/assets/components-D0DWAf0l.js +0 -5650
- phoenix/server/static/assets/pages-Creyamao.js +0 -8612
- phoenix/server/static/assets/vendor-CU36oj8y.js +0 -905
- phoenix/server/static/assets/vendor-CqDb5u4o.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Ctgw0e1G.js +0 -168
- phoenix/server/static/assets/vendor-codemirror-Cojjzqb9.js +0 -25
- phoenix/server/static/assets/vendor-three-BLWp5bic.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-11.23.1.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -24,12 +24,7 @@ from starlette.datastructures import FormData, UploadFile
|
|
|
24
24
|
from starlette.requests import Request
|
|
25
25
|
from starlette.responses import Response
|
|
26
26
|
from starlette.status import (
|
|
27
|
-
HTTP_200_OK,
|
|
28
|
-
HTTP_204_NO_CONTENT,
|
|
29
27
|
HTTP_404_NOT_FOUND,
|
|
30
|
-
HTTP_409_CONFLICT,
|
|
31
|
-
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
32
|
-
HTTP_429_TOO_MANY_REQUESTS,
|
|
33
28
|
)
|
|
34
29
|
from strawberry.relay import GlobalID
|
|
35
30
|
from typing_extensions import TypeAlias, assert_never
|
|
@@ -42,12 +37,15 @@ from phoenix.db.insertion.dataset import (
|
|
|
42
37
|
ExampleContent,
|
|
43
38
|
add_dataset_examples,
|
|
44
39
|
)
|
|
40
|
+
from phoenix.db.types.db_models import UNDEFINED
|
|
45
41
|
from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
|
|
46
42
|
from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
|
|
43
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit as DatasetSplitNodeType
|
|
47
44
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
|
|
48
45
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
46
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
50
47
|
from phoenix.server.authorization import is_not_locked
|
|
48
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
51
49
|
from phoenix.server.dml_event import DatasetInsertEvent
|
|
52
50
|
|
|
53
51
|
from .models import V1RoutesBaseModel
|
|
@@ -90,7 +88,7 @@ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
|
|
|
90
88
|
"/datasets",
|
|
91
89
|
operation_id="listDatasets",
|
|
92
90
|
summary="List datasets",
|
|
93
|
-
responses=add_errors_to_responses([
|
|
91
|
+
responses=add_errors_to_responses([422]),
|
|
94
92
|
)
|
|
95
93
|
async def list_datasets(
|
|
96
94
|
request: Request,
|
|
@@ -124,7 +122,7 @@ async def list_datasets(
|
|
|
124
122
|
except ValueError:
|
|
125
123
|
raise HTTPException(
|
|
126
124
|
detail=f"Invalid cursor format: {cursor}",
|
|
127
|
-
status_code=
|
|
125
|
+
status_code=422,
|
|
128
126
|
)
|
|
129
127
|
if name:
|
|
130
128
|
query = query.filter(models.Dataset.name == name)
|
|
@@ -163,11 +161,11 @@ async def list_datasets(
|
|
|
163
161
|
"/datasets/{id}",
|
|
164
162
|
operation_id="deleteDatasetById",
|
|
165
163
|
summary="Delete dataset by ID",
|
|
166
|
-
status_code=
|
|
164
|
+
status_code=204,
|
|
167
165
|
responses=add_errors_to_responses(
|
|
168
166
|
[
|
|
169
|
-
{"status_code":
|
|
170
|
-
{"status_code":
|
|
167
|
+
{"status_code": 404, "description": "Dataset not found"},
|
|
168
|
+
{"status_code": 422, "description": "Invalid dataset ID"},
|
|
171
169
|
]
|
|
172
170
|
),
|
|
173
171
|
)
|
|
@@ -181,11 +179,9 @@ async def delete_dataset(
|
|
|
181
179
|
DATASET_NODE_NAME,
|
|
182
180
|
)
|
|
183
181
|
except ValueError:
|
|
184
|
-
raise HTTPException(
|
|
185
|
-
detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
|
|
186
|
-
)
|
|
182
|
+
raise HTTPException(detail=f"Invalid Dataset ID: {id}", status_code=422)
|
|
187
183
|
else:
|
|
188
|
-
raise HTTPException(detail="Missing Dataset ID", status_code=
|
|
184
|
+
raise HTTPException(detail="Missing Dataset ID", status_code=422)
|
|
189
185
|
project_names_stmt = get_project_names_for_datasets(dataset_id)
|
|
190
186
|
eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
|
|
191
187
|
stmt = (
|
|
@@ -195,7 +191,7 @@ async def delete_dataset(
|
|
|
195
191
|
project_names = await session.scalars(project_names_stmt)
|
|
196
192
|
eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
|
|
197
193
|
if (await session.scalar(stmt)) is None:
|
|
198
|
-
raise HTTPException(detail="Dataset does not exist", status_code=
|
|
194
|
+
raise HTTPException(detail="Dataset does not exist", status_code=404)
|
|
199
195
|
tasks = BackgroundTasks()
|
|
200
196
|
tasks.add_task(delete_projects, request.app.state.db, *project_names)
|
|
201
197
|
tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
|
|
@@ -213,17 +209,21 @@ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
|
|
|
213
209
|
"/datasets/{id}",
|
|
214
210
|
operation_id="getDataset",
|
|
215
211
|
summary="Get dataset by ID",
|
|
216
|
-
responses=add_errors_to_responses([
|
|
212
|
+
responses=add_errors_to_responses([404]),
|
|
217
213
|
)
|
|
218
214
|
async def get_dataset(
|
|
219
215
|
request: Request, id: str = Path(description="The ID of the dataset")
|
|
220
216
|
) -> GetDatasetResponseBody:
|
|
221
|
-
|
|
217
|
+
try:
|
|
218
|
+
dataset_id = GlobalID.from_id(id)
|
|
219
|
+
except Exception as e:
|
|
220
|
+
raise HTTPException(
|
|
221
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
222
|
+
status_code=422,
|
|
223
|
+
) from e
|
|
222
224
|
|
|
223
225
|
if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
|
|
224
|
-
raise HTTPException(
|
|
225
|
-
detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
226
|
-
)
|
|
226
|
+
raise HTTPException(detail=f"ID {dataset_id} refers to a f{type_name}", status_code=404)
|
|
227
227
|
async with request.app.state.db() as session:
|
|
228
228
|
result = await session.execute(
|
|
229
229
|
select(models.Dataset, models.Dataset.example_count).filter(
|
|
@@ -234,9 +234,7 @@ async def get_dataset(
|
|
|
234
234
|
dataset = dataset_query[0] if dataset_query else None
|
|
235
235
|
example_count = dataset_query[1] if dataset_query else 0
|
|
236
236
|
if dataset is None:
|
|
237
|
-
raise HTTPException(
|
|
238
|
-
detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
|
|
239
|
-
)
|
|
237
|
+
raise HTTPException(detail=f"Dataset with ID {dataset_id} not found", status_code=404)
|
|
240
238
|
|
|
241
239
|
dataset = DatasetWithExampleCount(
|
|
242
240
|
id=str(dataset_id),
|
|
@@ -265,7 +263,7 @@ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
|
|
|
265
263
|
"/datasets/{id}/versions",
|
|
266
264
|
operation_id="listDatasetVersionsByDatasetId",
|
|
267
265
|
summary="List dataset versions",
|
|
268
|
-
responses=add_errors_to_responses([
|
|
266
|
+
responses=add_errors_to_responses([422]),
|
|
269
267
|
)
|
|
270
268
|
async def list_dataset_versions(
|
|
271
269
|
request: Request,
|
|
@@ -287,12 +285,12 @@ async def list_dataset_versions(
|
|
|
287
285
|
except ValueError:
|
|
288
286
|
raise HTTPException(
|
|
289
287
|
detail=f"Invalid Dataset ID: {id}",
|
|
290
|
-
status_code=
|
|
288
|
+
status_code=422,
|
|
291
289
|
)
|
|
292
290
|
else:
|
|
293
291
|
raise HTTPException(
|
|
294
292
|
detail="Missing Dataset ID",
|
|
295
|
-
status_code=
|
|
293
|
+
status_code=422,
|
|
296
294
|
)
|
|
297
295
|
stmt = (
|
|
298
296
|
select(models.DatasetVersion)
|
|
@@ -308,7 +306,7 @@ async def list_dataset_versions(
|
|
|
308
306
|
except ValueError:
|
|
309
307
|
raise HTTPException(
|
|
310
308
|
detail=f"Invalid cursor: {cursor}",
|
|
311
|
-
status_code=
|
|
309
|
+
status_code=422,
|
|
312
310
|
)
|
|
313
311
|
max_dataset_version_id = (
|
|
314
312
|
select(models.DatasetVersion.id)
|
|
@@ -343,14 +341,14 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
|
343
341
|
"/datasets/upload",
|
|
344
342
|
dependencies=[Depends(is_not_locked)],
|
|
345
343
|
operation_id="uploadDataset",
|
|
346
|
-
summary="Upload dataset from JSON, CSV, or PyArrow",
|
|
344
|
+
summary="Upload dataset from JSON, JSONL, CSV, or PyArrow",
|
|
347
345
|
responses=add_errors_to_responses(
|
|
348
346
|
[
|
|
349
347
|
{
|
|
350
|
-
"status_code":
|
|
348
|
+
"status_code": 409,
|
|
351
349
|
"description": "Dataset of the same name already exists",
|
|
352
350
|
},
|
|
353
|
-
{"status_code":
|
|
351
|
+
{"status_code": 422, "description": "Invalid request body"},
|
|
354
352
|
]
|
|
355
353
|
),
|
|
356
354
|
# FastAPI cannot generate the request body portion of the OpenAPI schema for
|
|
@@ -372,6 +370,17 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
|
372
370
|
"inputs": {"type": "array", "items": {"type": "object"}},
|
|
373
371
|
"outputs": {"type": "array", "items": {"type": "object"}},
|
|
374
372
|
"metadata": {"type": "array", "items": {"type": "object"}},
|
|
373
|
+
"splits": {
|
|
374
|
+
"type": "array",
|
|
375
|
+
"items": {
|
|
376
|
+
"oneOf": [
|
|
377
|
+
{"type": "string"},
|
|
378
|
+
{"type": "array", "items": {"type": "string"}},
|
|
379
|
+
{"type": "null"},
|
|
380
|
+
]
|
|
381
|
+
},
|
|
382
|
+
"description": "Split per example: string, string array, or null",
|
|
383
|
+
},
|
|
375
384
|
},
|
|
376
385
|
}
|
|
377
386
|
},
|
|
@@ -398,6 +407,12 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
|
398
407
|
"items": {"type": "string"},
|
|
399
408
|
"uniqueItems": True,
|
|
400
409
|
},
|
|
410
|
+
"split_keys[]": {
|
|
411
|
+
"type": "array",
|
|
412
|
+
"items": {"type": "string"},
|
|
413
|
+
"uniqueItems": True,
|
|
414
|
+
"description": "Column names for auto-assigning examples to splits",
|
|
415
|
+
},
|
|
401
416
|
"file": {"type": "string", "format": "binary"},
|
|
402
417
|
},
|
|
403
418
|
}
|
|
@@ -413,7 +428,12 @@ async def upload_dataset(
|
|
|
413
428
|
description="If true, fulfill request synchronously and return JSON containing dataset_id.",
|
|
414
429
|
),
|
|
415
430
|
) -> Optional[UploadDatasetResponseBody]:
|
|
416
|
-
request_content_type = request.headers
|
|
431
|
+
request_content_type = request.headers.get("content-type")
|
|
432
|
+
if not request_content_type:
|
|
433
|
+
raise HTTPException(
|
|
434
|
+
detail="Missing content-type header",
|
|
435
|
+
status_code=400,
|
|
436
|
+
)
|
|
417
437
|
examples: Union[Examples, Awaitable[Examples]]
|
|
418
438
|
if request_content_type.startswith("application/json"):
|
|
419
439
|
try:
|
|
@@ -423,14 +443,14 @@ async def upload_dataset(
|
|
|
423
443
|
except ValueError as e:
|
|
424
444
|
raise HTTPException(
|
|
425
445
|
detail=str(e),
|
|
426
|
-
status_code=
|
|
446
|
+
status_code=422,
|
|
427
447
|
)
|
|
428
448
|
if action is DatasetAction.CREATE:
|
|
429
449
|
async with request.app.state.db() as session:
|
|
430
450
|
if await _check_table_exists(session, name):
|
|
431
451
|
raise HTTPException(
|
|
432
452
|
detail=f"Dataset with the same name already exists: {name=}",
|
|
433
|
-
status_code=
|
|
453
|
+
status_code=409,
|
|
434
454
|
)
|
|
435
455
|
elif request_content_type.startswith("multipart/form-data"):
|
|
436
456
|
async with request.form() as form:
|
|
@@ -442,19 +462,20 @@ async def upload_dataset(
|
|
|
442
462
|
input_keys,
|
|
443
463
|
output_keys,
|
|
444
464
|
metadata_keys,
|
|
465
|
+
split_keys,
|
|
445
466
|
file,
|
|
446
467
|
) = await _parse_form_data(form)
|
|
447
468
|
except ValueError as e:
|
|
448
469
|
raise HTTPException(
|
|
449
470
|
detail=str(e),
|
|
450
|
-
status_code=
|
|
471
|
+
status_code=422,
|
|
451
472
|
)
|
|
452
473
|
if action is DatasetAction.CREATE:
|
|
453
474
|
async with request.app.state.db() as session:
|
|
454
475
|
if await _check_table_exists(session, name):
|
|
455
476
|
raise HTTPException(
|
|
456
477
|
detail=f"Dataset with the same name already exists: {name=}",
|
|
457
|
-
status_code=
|
|
478
|
+
status_code=409,
|
|
458
479
|
)
|
|
459
480
|
content = await file.read()
|
|
460
481
|
try:
|
|
@@ -462,22 +483,32 @@ async def upload_dataset(
|
|
|
462
483
|
if file_content_type is FileContentType.CSV:
|
|
463
484
|
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
464
485
|
examples = await _process_csv(
|
|
465
|
-
content, encoding, input_keys, output_keys, metadata_keys
|
|
486
|
+
content, encoding, input_keys, output_keys, metadata_keys, split_keys
|
|
466
487
|
)
|
|
467
488
|
elif file_content_type is FileContentType.PYARROW:
|
|
468
|
-
examples = await _process_pyarrow(
|
|
489
|
+
examples = await _process_pyarrow(
|
|
490
|
+
content, input_keys, output_keys, metadata_keys, split_keys
|
|
491
|
+
)
|
|
492
|
+
elif file_content_type is FileContentType.JSONL:
|
|
493
|
+
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
494
|
+
examples = await _process_jsonl(
|
|
495
|
+
content, encoding, input_keys, output_keys, metadata_keys, split_keys
|
|
496
|
+
)
|
|
469
497
|
else:
|
|
470
498
|
assert_never(file_content_type)
|
|
471
499
|
except ValueError as e:
|
|
472
500
|
raise HTTPException(
|
|
473
501
|
detail=str(e),
|
|
474
|
-
status_code=
|
|
502
|
+
status_code=422,
|
|
475
503
|
)
|
|
476
504
|
else:
|
|
477
505
|
raise HTTPException(
|
|
478
506
|
detail="Invalid request Content-Type",
|
|
479
|
-
status_code=
|
|
507
|
+
status_code=422,
|
|
480
508
|
)
|
|
509
|
+
user_id: Optional[int] = None
|
|
510
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
511
|
+
user_id = int(request.user.identity)
|
|
481
512
|
operation = cast(
|
|
482
513
|
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
483
514
|
partial(
|
|
@@ -486,6 +517,7 @@ async def upload_dataset(
|
|
|
486
517
|
action=action,
|
|
487
518
|
name=name,
|
|
488
519
|
description=description,
|
|
520
|
+
user_id=user_id,
|
|
489
521
|
),
|
|
490
522
|
)
|
|
491
523
|
if sync:
|
|
@@ -505,13 +537,14 @@ async def upload_dataset(
|
|
|
505
537
|
except QueueFull:
|
|
506
538
|
if isinstance(examples, Coroutine):
|
|
507
539
|
examples.close()
|
|
508
|
-
raise HTTPException(detail="Too many requests.", status_code=
|
|
540
|
+
raise HTTPException(detail="Too many requests.", status_code=429)
|
|
509
541
|
return None
|
|
510
542
|
|
|
511
543
|
|
|
512
544
|
class FileContentType(Enum):
|
|
513
545
|
CSV = "text/csv"
|
|
514
546
|
PYARROW = "application/x-pandas-pyarrow"
|
|
547
|
+
JSONL = "application/jsonl"
|
|
515
548
|
|
|
516
549
|
@classmethod
|
|
517
550
|
def _missing_(cls, v: Any) -> "FileContentType":
|
|
@@ -539,6 +572,7 @@ Description: TypeAlias = Optional[str]
|
|
|
539
572
|
InputKeys: TypeAlias = frozenset[str]
|
|
540
573
|
OutputKeys: TypeAlias = frozenset[str]
|
|
541
574
|
MetadataKeys: TypeAlias = frozenset[str]
|
|
575
|
+
SplitKeys: TypeAlias = frozenset[str]
|
|
542
576
|
DatasetId: TypeAlias = int
|
|
543
577
|
Examples: TypeAlias = Iterator[ExampleContent]
|
|
544
578
|
|
|
@@ -555,18 +589,55 @@ def _process_json(
|
|
|
555
589
|
raise ValueError("input is required")
|
|
556
590
|
if not isinstance(inputs, list) or not _is_all_dict(inputs):
|
|
557
591
|
raise ValueError("Input should be a list containing only dictionary objects")
|
|
558
|
-
outputs, metadata = data.get("outputs"), data.get("metadata")
|
|
592
|
+
outputs, metadata, splits = data.get("outputs"), data.get("metadata"), data.get("splits")
|
|
559
593
|
for k, v in {"outputs": outputs, "metadata": metadata}.items():
|
|
560
594
|
if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
|
|
561
595
|
raise ValueError(
|
|
562
596
|
f"{k} should be a list of same length as input containing only dictionary objects"
|
|
563
597
|
)
|
|
598
|
+
|
|
599
|
+
# Validate splits format if provided
|
|
600
|
+
if splits is not None:
|
|
601
|
+
if not isinstance(splits, list):
|
|
602
|
+
raise ValueError("splits must be a list")
|
|
603
|
+
if len(splits) != len(inputs):
|
|
604
|
+
raise ValueError(
|
|
605
|
+
f"splits must have same length as inputs ({len(splits)} != {len(inputs)})"
|
|
606
|
+
)
|
|
564
607
|
examples: list[ExampleContent] = []
|
|
565
608
|
for i, obj in enumerate(inputs):
|
|
609
|
+
# Extract split values, validating they're non-empty strings
|
|
610
|
+
split_set: set[str] = set()
|
|
611
|
+
if splits:
|
|
612
|
+
split_value = splits[i]
|
|
613
|
+
if split_value is None:
|
|
614
|
+
# Sparse assignment: None means no splits for this example
|
|
615
|
+
pass
|
|
616
|
+
elif isinstance(split_value, str):
|
|
617
|
+
# Format 1: Single string value
|
|
618
|
+
if split_value.strip():
|
|
619
|
+
split_set.add(split_value.strip())
|
|
620
|
+
elif isinstance(split_value, list):
|
|
621
|
+
# Format 2: List of strings (multiple splits)
|
|
622
|
+
for v in split_value:
|
|
623
|
+
if v is None:
|
|
624
|
+
continue # Skip None values in the list
|
|
625
|
+
if not isinstance(v, str):
|
|
626
|
+
raise ValueError(
|
|
627
|
+
f"Split value must be a string or None, got {type(v).__name__}"
|
|
628
|
+
)
|
|
629
|
+
if v.strip():
|
|
630
|
+
split_set.add(v.strip())
|
|
631
|
+
else:
|
|
632
|
+
raise ValueError(
|
|
633
|
+
f"Split value must be a string, list of strings, or None, "
|
|
634
|
+
f"got {type(split_value).__name__}"
|
|
635
|
+
)
|
|
566
636
|
example = ExampleContent(
|
|
567
637
|
input=obj,
|
|
568
638
|
output=outputs[i] if outputs else {},
|
|
569
639
|
metadata=metadata[i] if metadata else {},
|
|
640
|
+
splits=frozenset(split_set),
|
|
570
641
|
)
|
|
571
642
|
examples.append(example)
|
|
572
643
|
action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
|
|
@@ -579,6 +650,7 @@ async def _process_csv(
|
|
|
579
650
|
input_keys: InputKeys,
|
|
580
651
|
output_keys: OutputKeys,
|
|
581
652
|
metadata_keys: MetadataKeys,
|
|
653
|
+
split_keys: SplitKeys,
|
|
582
654
|
) -> Examples:
|
|
583
655
|
if content_encoding is FileContentEncoding.GZIP:
|
|
584
656
|
content = await run_in_threadpool(gzip.decompress, content)
|
|
@@ -593,12 +665,15 @@ async def _process_csv(
|
|
|
593
665
|
if freq > 1:
|
|
594
666
|
raise ValueError(f"Duplicated column header in CSV file: {header}")
|
|
595
667
|
column_headers = frozenset(reader.fieldnames)
|
|
596
|
-
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
668
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
|
|
597
669
|
return (
|
|
598
670
|
ExampleContent(
|
|
599
671
|
input={k: row.get(k) for k in input_keys},
|
|
600
672
|
output={k: row.get(k) for k in output_keys},
|
|
601
673
|
metadata={k: row.get(k) for k in metadata_keys},
|
|
674
|
+
splits=frozenset(
|
|
675
|
+
str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
|
|
676
|
+
), # Only include non-empty, non-whitespace split values
|
|
602
677
|
)
|
|
603
678
|
for row in iter(reader)
|
|
604
679
|
)
|
|
@@ -609,13 +684,14 @@ async def _process_pyarrow(
|
|
|
609
684
|
input_keys: InputKeys,
|
|
610
685
|
output_keys: OutputKeys,
|
|
611
686
|
metadata_keys: MetadataKeys,
|
|
687
|
+
split_keys: SplitKeys,
|
|
612
688
|
) -> Awaitable[Examples]:
|
|
613
689
|
try:
|
|
614
690
|
reader = pa.ipc.open_stream(content)
|
|
615
691
|
except pa.ArrowInvalid as e:
|
|
616
692
|
raise ValueError("File is not valid pyarrow") from e
|
|
617
693
|
column_headers = frozenset(reader.schema.names)
|
|
618
|
-
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
694
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
|
|
619
695
|
|
|
620
696
|
def get_examples() -> Iterator[ExampleContent]:
|
|
621
697
|
for row in reader.read_pandas().to_dict(orient="records"):
|
|
@@ -623,11 +699,48 @@ async def _process_pyarrow(
|
|
|
623
699
|
input={k: row.get(k) for k in input_keys},
|
|
624
700
|
output={k: row.get(k) for k in output_keys},
|
|
625
701
|
metadata={k: row.get(k) for k in metadata_keys},
|
|
702
|
+
splits=frozenset(
|
|
703
|
+
str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
|
|
704
|
+
), # Only include non-empty, non-whitespace split values
|
|
626
705
|
)
|
|
627
706
|
|
|
628
707
|
return run_in_threadpool(get_examples)
|
|
629
708
|
|
|
630
709
|
|
|
710
|
+
async def _process_jsonl(
|
|
711
|
+
content: bytes,
|
|
712
|
+
encoding: FileContentEncoding,
|
|
713
|
+
input_keys: InputKeys,
|
|
714
|
+
output_keys: OutputKeys,
|
|
715
|
+
metadata_keys: MetadataKeys,
|
|
716
|
+
split_keys: SplitKeys,
|
|
717
|
+
) -> Examples:
|
|
718
|
+
if encoding is FileContentEncoding.GZIP:
|
|
719
|
+
content = await run_in_threadpool(gzip.decompress, content)
|
|
720
|
+
elif encoding is FileContentEncoding.DEFLATE:
|
|
721
|
+
content = await run_in_threadpool(zlib.decompress, content)
|
|
722
|
+
elif encoding is not FileContentEncoding.NONE:
|
|
723
|
+
assert_never(encoding)
|
|
724
|
+
# content is a newline delimited list of JSON objects
|
|
725
|
+
# parse within a threadpool
|
|
726
|
+
reader = await run_in_threadpool(
|
|
727
|
+
lambda c: [json.loads(line) for line in c.decode().splitlines()], content
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
examples: list[ExampleContent] = []
|
|
731
|
+
for obj in reader:
|
|
732
|
+
example = ExampleContent(
|
|
733
|
+
input={k: obj.get(k) for k in input_keys},
|
|
734
|
+
output={k: obj.get(k) for k in output_keys},
|
|
735
|
+
metadata={k: obj.get(k) for k in metadata_keys},
|
|
736
|
+
splits=frozenset(
|
|
737
|
+
str(v).strip() for k in split_keys if (v := obj.get(k)) and str(v).strip()
|
|
738
|
+
), # Only include non-empty, non-whitespace split values
|
|
739
|
+
)
|
|
740
|
+
examples.append(example)
|
|
741
|
+
return iter(examples)
|
|
742
|
+
|
|
743
|
+
|
|
631
744
|
async def _check_table_exists(session: AsyncSession, name: str) -> bool:
|
|
632
745
|
return bool(
|
|
633
746
|
await session.scalar(
|
|
@@ -641,11 +754,13 @@ def _check_keys_exist(
|
|
|
641
754
|
input_keys: InputKeys,
|
|
642
755
|
output_keys: OutputKeys,
|
|
643
756
|
metadata_keys: MetadataKeys,
|
|
757
|
+
split_keys: SplitKeys,
|
|
644
758
|
) -> None:
|
|
645
759
|
for desc, keys in (
|
|
646
760
|
("input", input_keys),
|
|
647
761
|
("output", output_keys),
|
|
648
762
|
("metadata", metadata_keys),
|
|
763
|
+
("split", split_keys),
|
|
649
764
|
):
|
|
650
765
|
if keys and (diff := keys.difference(column_headers)):
|
|
651
766
|
raise ValueError(f"{desc} keys not found in column headers: {diff}")
|
|
@@ -660,6 +775,7 @@ async def _parse_form_data(
|
|
|
660
775
|
InputKeys,
|
|
661
776
|
OutputKeys,
|
|
662
777
|
MetadataKeys,
|
|
778
|
+
SplitKeys,
|
|
663
779
|
UploadFile,
|
|
664
780
|
]:
|
|
665
781
|
name = cast(Optional[str], form.get("name"))
|
|
@@ -673,6 +789,7 @@ async def _parse_form_data(
|
|
|
673
789
|
input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
|
|
674
790
|
output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
|
|
675
791
|
metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
|
|
792
|
+
split_keys = frozenset(filter(bool, cast(list[str], form.getlist("split_keys[]"))))
|
|
676
793
|
return (
|
|
677
794
|
action,
|
|
678
795
|
name,
|
|
@@ -680,6 +797,7 @@ async def _parse_form_data(
|
|
|
680
797
|
input_keys,
|
|
681
798
|
output_keys,
|
|
682
799
|
metadata_keys,
|
|
800
|
+
split_keys,
|
|
683
801
|
file,
|
|
684
802
|
)
|
|
685
803
|
|
|
@@ -695,6 +813,7 @@ class DatasetExample(V1RoutesBaseModel):
|
|
|
695
813
|
class ListDatasetExamplesData(V1RoutesBaseModel):
|
|
696
814
|
dataset_id: str
|
|
697
815
|
version_id: str
|
|
816
|
+
filtered_splits: list[str] = UNDEFINED
|
|
698
817
|
examples: list[DatasetExample]
|
|
699
818
|
|
|
700
819
|
|
|
@@ -706,7 +825,7 @@ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
|
|
|
706
825
|
"/datasets/{id}/examples",
|
|
707
826
|
operation_id="getDatasetExamples",
|
|
708
827
|
summary="Get examples from a dataset",
|
|
709
|
-
responses=add_errors_to_responses([
|
|
828
|
+
responses=add_errors_to_responses([404]),
|
|
710
829
|
)
|
|
711
830
|
async def get_dataset_examples(
|
|
712
831
|
request: Request,
|
|
@@ -717,19 +836,35 @@ async def get_dataset_examples(
|
|
|
717
836
|
"The ID of the dataset version (if omitted, returns data from the latest version)"
|
|
718
837
|
),
|
|
719
838
|
),
|
|
839
|
+
split: Optional[list[str]] = Query(
|
|
840
|
+
default=None,
|
|
841
|
+
description="List of dataset split identifiers (GlobalIDs or names) to filter by",
|
|
842
|
+
),
|
|
720
843
|
) -> ListDatasetExamplesResponseBody:
|
|
721
|
-
|
|
722
|
-
|
|
844
|
+
try:
|
|
845
|
+
dataset_gid = GlobalID.from_id(id)
|
|
846
|
+
except Exception as e:
|
|
847
|
+
raise HTTPException(
|
|
848
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
849
|
+
status_code=422,
|
|
850
|
+
) from e
|
|
851
|
+
|
|
852
|
+
if version_id:
|
|
853
|
+
try:
|
|
854
|
+
version_gid = GlobalID.from_id(version_id)
|
|
855
|
+
except Exception as e:
|
|
856
|
+
raise HTTPException(
|
|
857
|
+
detail=f"Invalid dataset version ID format: {version_id}",
|
|
858
|
+
status_code=422,
|
|
859
|
+
) from e
|
|
860
|
+
else:
|
|
861
|
+
version_gid = None
|
|
723
862
|
|
|
724
863
|
if (dataset_type := dataset_gid.type_name) != "Dataset":
|
|
725
|
-
raise HTTPException(
|
|
726
|
-
detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
|
|
727
|
-
)
|
|
864
|
+
raise HTTPException(detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=404)
|
|
728
865
|
|
|
729
866
|
if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
|
|
730
|
-
raise HTTPException(
|
|
731
|
-
detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
|
|
732
|
-
)
|
|
867
|
+
raise HTTPException(detail=f"ID {version_gid} refers to a {version_type}", status_code=404)
|
|
733
868
|
|
|
734
869
|
async with request.app.state.db() as session:
|
|
735
870
|
if (
|
|
@@ -739,7 +874,7 @@ async def get_dataset_examples(
|
|
|
739
874
|
) is None:
|
|
740
875
|
raise HTTPException(
|
|
741
876
|
detail=f"No dataset with id {dataset_gid} can be found.",
|
|
742
|
-
status_code=
|
|
877
|
+
status_code=404,
|
|
743
878
|
)
|
|
744
879
|
|
|
745
880
|
# Subquery to find the maximum created_at for each dataset_example_id
|
|
@@ -761,7 +896,7 @@ async def get_dataset_examples(
|
|
|
761
896
|
) is None:
|
|
762
897
|
raise HTTPException(
|
|
763
898
|
detail=f"No dataset version with id {version_id} can be found.",
|
|
764
|
-
status_code=
|
|
899
|
+
status_code=404,
|
|
765
900
|
)
|
|
766
901
|
# if a version_id is provided, filter the subquery to only include revisions from that
|
|
767
902
|
partial_subquery = partial_subquery.filter(
|
|
@@ -777,13 +912,17 @@ async def get_dataset_examples(
|
|
|
777
912
|
) is None:
|
|
778
913
|
raise HTTPException(
|
|
779
914
|
detail="Dataset has no versions.",
|
|
780
|
-
status_code=
|
|
915
|
+
status_code=404,
|
|
781
916
|
)
|
|
782
917
|
|
|
783
918
|
subquery = partial_subquery.subquery()
|
|
919
|
+
|
|
784
920
|
# Query for the most recent example revisions that are not deleted
|
|
785
921
|
query = (
|
|
786
|
-
select(
|
|
922
|
+
select(
|
|
923
|
+
models.DatasetExample,
|
|
924
|
+
models.DatasetExampleRevision,
|
|
925
|
+
)
|
|
787
926
|
.join(
|
|
788
927
|
models.DatasetExampleRevision,
|
|
789
928
|
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
|
|
@@ -796,6 +935,28 @@ async def get_dataset_examples(
|
|
|
796
935
|
.filter(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
797
936
|
.order_by(models.DatasetExample.id.asc())
|
|
798
937
|
)
|
|
938
|
+
|
|
939
|
+
# If splits are provided, filter by dataset splits
|
|
940
|
+
resolved_split_names: list[str] = []
|
|
941
|
+
if split:
|
|
942
|
+
# Resolve split identifiers (IDs or names) to IDs and names
|
|
943
|
+
resolved_split_ids, resolved_split_names = await _resolve_split_identifiers(
|
|
944
|
+
session, split
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# Add filter for splits (join with the association table)
|
|
948
|
+
# Use distinct() to prevent duplicates when an example belongs to
|
|
949
|
+
# multiple splits
|
|
950
|
+
query = (
|
|
951
|
+
query.join(
|
|
952
|
+
models.DatasetSplitDatasetExample,
|
|
953
|
+
models.DatasetExample.id
|
|
954
|
+
== models.DatasetSplitDatasetExample.dataset_example_id,
|
|
955
|
+
)
|
|
956
|
+
.filter(models.DatasetSplitDatasetExample.dataset_split_id.in_(resolved_split_ids))
|
|
957
|
+
.distinct()
|
|
958
|
+
)
|
|
959
|
+
|
|
799
960
|
examples = [
|
|
800
961
|
DatasetExample(
|
|
801
962
|
id=str(GlobalID("DatasetExample", str(example.id))),
|
|
@@ -810,6 +971,7 @@ async def get_dataset_examples(
|
|
|
810
971
|
data=ListDatasetExamplesData(
|
|
811
972
|
dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
812
973
|
version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
974
|
+
filtered_splits=resolved_split_names,
|
|
813
975
|
examples=examples,
|
|
814
976
|
)
|
|
815
977
|
)
|
|
@@ -820,10 +982,10 @@ async def get_dataset_examples(
|
|
|
820
982
|
operation_id="getDatasetCsv",
|
|
821
983
|
summary="Download dataset examples as CSV file",
|
|
822
984
|
response_class=StreamingResponse,
|
|
823
|
-
status_code=
|
|
985
|
+
status_code=200,
|
|
824
986
|
responses={
|
|
825
|
-
**add_errors_to_responses([
|
|
826
|
-
**add_text_csv_content_to_responses(
|
|
987
|
+
**add_errors_to_responses([422]),
|
|
988
|
+
**add_text_csv_content_to_responses(200),
|
|
827
989
|
},
|
|
828
990
|
)
|
|
829
991
|
async def get_dataset_csv(
|
|
@@ -843,7 +1005,7 @@ async def get_dataset_csv(
|
|
|
843
1005
|
session=session, id=id, version_id=version_id
|
|
844
1006
|
)
|
|
845
1007
|
except ValueError as e:
|
|
846
|
-
raise HTTPException(detail=str(e), status_code=
|
|
1008
|
+
raise HTTPException(detail=str(e), status_code=422)
|
|
847
1009
|
content = await run_in_threadpool(_get_content_csv, examples)
|
|
848
1010
|
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
849
1011
|
return Response(
|
|
@@ -863,7 +1025,7 @@ async def get_dataset_csv(
|
|
|
863
1025
|
responses=add_errors_to_responses(
|
|
864
1026
|
[
|
|
865
1027
|
{
|
|
866
|
-
"status_code":
|
|
1028
|
+
"status_code": 422,
|
|
867
1029
|
"description": "Invalid dataset or version ID",
|
|
868
1030
|
}
|
|
869
1031
|
]
|
|
@@ -886,7 +1048,7 @@ async def get_dataset_jsonl_openai_ft(
|
|
|
886
1048
|
session=session, id=id, version_id=version_id
|
|
887
1049
|
)
|
|
888
1050
|
except ValueError as e:
|
|
889
|
-
raise HTTPException(detail=str(e), status_code=
|
|
1051
|
+
raise HTTPException(detail=str(e), status_code=422)
|
|
890
1052
|
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
891
1053
|
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
892
1054
|
response.headers["content-disposition"] = (
|
|
@@ -903,7 +1065,7 @@ async def get_dataset_jsonl_openai_ft(
|
|
|
903
1065
|
responses=add_errors_to_responses(
|
|
904
1066
|
[
|
|
905
1067
|
{
|
|
906
|
-
"status_code":
|
|
1068
|
+
"status_code": 422,
|
|
907
1069
|
"description": "Invalid dataset or version ID",
|
|
908
1070
|
}
|
|
909
1071
|
]
|
|
@@ -926,7 +1088,7 @@ async def get_dataset_jsonl_openai_evals(
|
|
|
926
1088
|
session=session, id=id, version_id=version_id
|
|
927
1089
|
)
|
|
928
1090
|
except ValueError as e:
|
|
929
|
-
raise HTTPException(detail=str(e), status_code=
|
|
1091
|
+
raise HTTPException(detail=str(e), status_code=422)
|
|
930
1092
|
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
931
1093
|
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
932
1094
|
response.headers["content-disposition"] = (
|
|
@@ -1005,12 +1167,25 @@ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision
|
|
|
1005
1167
|
async def _get_db_examples(
|
|
1006
1168
|
*, session: Any, id: str, version_id: Optional[str]
|
|
1007
1169
|
) -> tuple[str, list[models.DatasetExampleRevision]]:
|
|
1008
|
-
|
|
1170
|
+
try:
|
|
1171
|
+
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
|
|
1172
|
+
except Exception as e:
|
|
1173
|
+
raise HTTPException(
|
|
1174
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
1175
|
+
status_code=422,
|
|
1176
|
+
) from e
|
|
1177
|
+
|
|
1009
1178
|
dataset_version_id: Optional[int] = None
|
|
1010
1179
|
if version_id:
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1180
|
+
try:
|
|
1181
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
1182
|
+
GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
|
|
1183
|
+
)
|
|
1184
|
+
except Exception as e:
|
|
1185
|
+
raise HTTPException(
|
|
1186
|
+
detail=f"Invalid dataset version ID format: {version_id}",
|
|
1187
|
+
status_code=422,
|
|
1188
|
+
) from e
|
|
1014
1189
|
latest_version = (
|
|
1015
1190
|
select(
|
|
1016
1191
|
models.DatasetExampleRevision.dataset_example_id,
|
|
@@ -1053,3 +1228,115 @@ async def _get_db_examples(
|
|
|
1053
1228
|
|
|
1054
1229
|
def _is_all_dict(seq: Sequence[Any]) -> bool:
|
|
1055
1230
|
return all(map(lambda obj: isinstance(obj, dict), seq))
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
# Split identifier helper types and functions
|
|
1234
|
+
class _SplitId(int): ...
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
_SplitIdentifier: TypeAlias = Union[_SplitId, str]
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def _parse_split_identifier(split_identifier: str) -> _SplitIdentifier:
|
|
1241
|
+
"""
|
|
1242
|
+
Parse a split identifier as either a GlobalID or a name.
|
|
1243
|
+
|
|
1244
|
+
Args:
|
|
1245
|
+
split_identifier: The identifier string (GlobalID or name)
|
|
1246
|
+
|
|
1247
|
+
Returns:
|
|
1248
|
+
Either a _SplitId or an Identifier
|
|
1249
|
+
|
|
1250
|
+
Raises:
|
|
1251
|
+
HTTPException: If the identifier format is invalid
|
|
1252
|
+
"""
|
|
1253
|
+
if not split_identifier:
|
|
1254
|
+
raise HTTPException(422, "Invalid split identifier")
|
|
1255
|
+
try:
|
|
1256
|
+
split_id = from_global_id_with_expected_type(
|
|
1257
|
+
GlobalID.from_id(split_identifier),
|
|
1258
|
+
DatasetSplitNodeType.__name__,
|
|
1259
|
+
)
|
|
1260
|
+
except ValueError:
|
|
1261
|
+
return split_identifier
|
|
1262
|
+
return _SplitId(split_id)
|
|
1263
|
+
|
|
1264
|
+
|
|
1265
|
+
async def _resolve_split_identifiers(
|
|
1266
|
+
session: AsyncSession,
|
|
1267
|
+
split_identifiers: list[str],
|
|
1268
|
+
) -> tuple[list[int], list[str]]:
|
|
1269
|
+
"""
|
|
1270
|
+
Resolve a list of split identifiers (IDs or names) to split IDs and names.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
session: The database session
|
|
1274
|
+
split_identifiers: List of split identifiers (GlobalIDs or names)
|
|
1275
|
+
|
|
1276
|
+
Returns:
|
|
1277
|
+
Tuple of (list of split IDs, list of split names)
|
|
1278
|
+
|
|
1279
|
+
Raises:
|
|
1280
|
+
HTTPException: If any split identifier is invalid or not found
|
|
1281
|
+
"""
|
|
1282
|
+
split_ids: list[int] = []
|
|
1283
|
+
split_names: list[str] = []
|
|
1284
|
+
|
|
1285
|
+
# Parse all identifiers first
|
|
1286
|
+
parsed_identifiers: list[_SplitIdentifier] = []
|
|
1287
|
+
for identifier_str in split_identifiers:
|
|
1288
|
+
parsed_identifiers.append(_parse_split_identifier(identifier_str.strip()))
|
|
1289
|
+
|
|
1290
|
+
# Separate IDs and names
|
|
1291
|
+
requested_ids: list[int] = []
|
|
1292
|
+
requested_names: list[str] = []
|
|
1293
|
+
for identifier in parsed_identifiers:
|
|
1294
|
+
if isinstance(identifier, _SplitId):
|
|
1295
|
+
requested_ids.append(int(identifier))
|
|
1296
|
+
elif isinstance(identifier, str):
|
|
1297
|
+
requested_names.append(identifier)
|
|
1298
|
+
else:
|
|
1299
|
+
assert_never(identifier)
|
|
1300
|
+
|
|
1301
|
+
# Query for splits by ID
|
|
1302
|
+
if requested_ids:
|
|
1303
|
+
id_results = await session.stream(
|
|
1304
|
+
select(models.DatasetSplit.id, models.DatasetSplit.name).where(
|
|
1305
|
+
models.DatasetSplit.id.in_(requested_ids)
|
|
1306
|
+
)
|
|
1307
|
+
)
|
|
1308
|
+
async for split_id, split_name in id_results:
|
|
1309
|
+
split_ids.append(split_id)
|
|
1310
|
+
split_names.append(split_name)
|
|
1311
|
+
|
|
1312
|
+
# Check if all requested IDs were found
|
|
1313
|
+
found_ids = set(split_ids[-len(requested_ids) :] if requested_ids else [])
|
|
1314
|
+
missing_ids = [sid for sid in requested_ids if sid not in found_ids]
|
|
1315
|
+
if missing_ids:
|
|
1316
|
+
raise HTTPException(
|
|
1317
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
1318
|
+
detail=f"Dataset splits not found for IDs: {', '.join(map(str, missing_ids))}",
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# Query for splits by name
|
|
1322
|
+
if requested_names:
|
|
1323
|
+
name_results = await session.stream(
|
|
1324
|
+
select(models.DatasetSplit.id, models.DatasetSplit.name).where(
|
|
1325
|
+
models.DatasetSplit.name.in_(requested_names)
|
|
1326
|
+
)
|
|
1327
|
+
)
|
|
1328
|
+
name_to_id: dict[str, int] = {}
|
|
1329
|
+
async for split_id, split_name in name_results:
|
|
1330
|
+
split_ids.append(split_id)
|
|
1331
|
+
split_names.append(split_name)
|
|
1332
|
+
name_to_id[split_name] = split_id
|
|
1333
|
+
|
|
1334
|
+
# Check if all requested names were found
|
|
1335
|
+
missing_names = [name for name in requested_names if name not in name_to_id]
|
|
1336
|
+
if missing_names:
|
|
1337
|
+
raise HTTPException(
|
|
1338
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
1339
|
+
detail=f"Dataset splits not found: {', '.join(missing_names)}",
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
return split_ids, split_names
|