arize-phoenix 10.0.4__py3-none-any.whl → 12.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/METADATA +124 -72
- arize_phoenix-12.28.1.dist-info/RECORD +499 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/WHEEL +1 -1
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/IP_NOTICE +1 -1
- phoenix/__generated__/__init__.py +0 -0
- phoenix/__generated__/classification_evaluator_configs/__init__.py +20 -0
- phoenix/__generated__/classification_evaluator_configs/_document_relevance_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_hallucination_classification_evaluator_config.py +17 -0
- phoenix/__generated__/classification_evaluator_configs/_models.py +18 -0
- phoenix/__generated__/classification_evaluator_configs/_tool_selection_classification_evaluator_config.py +17 -0
- phoenix/__init__.py +5 -4
- phoenix/auth.py +39 -2
- phoenix/config.py +1763 -91
- phoenix/datetime_utils.py +120 -2
- phoenix/db/README.md +595 -25
- phoenix/db/bulk_inserter.py +145 -103
- phoenix/db/engines.py +140 -33
- phoenix/db/enums.py +3 -12
- phoenix/db/facilitator.py +302 -35
- phoenix/db/helpers.py +1000 -65
- phoenix/db/iam_auth.py +64 -0
- phoenix/db/insertion/dataset.py +135 -2
- phoenix/db/insertion/document_annotation.py +9 -6
- phoenix/db/insertion/evaluation.py +2 -3
- phoenix/db/insertion/helpers.py +17 -2
- phoenix/db/insertion/session_annotation.py +176 -0
- phoenix/db/insertion/span.py +15 -11
- phoenix/db/insertion/span_annotation.py +3 -4
- phoenix/db/insertion/trace_annotation.py +3 -4
- phoenix/db/insertion/types.py +50 -20
- phoenix/db/migrations/versions/01a8342c9cdf_add_user_id_on_datasets.py +40 -0
- phoenix/db/migrations/versions/0df286449799_add_session_annotations_table.py +105 -0
- phoenix/db/migrations/versions/272b66ff50f8_drop_single_indices.py +119 -0
- phoenix/db/migrations/versions/58228d933c91_dataset_labels.py +67 -0
- phoenix/db/migrations/versions/699f655af132_experiment_tags.py +57 -0
- phoenix/db/migrations/versions/735d3d93c33e_add_composite_indices.py +41 -0
- phoenix/db/migrations/versions/a20694b15f82_cost.py +196 -0
- phoenix/db/migrations/versions/ab513d89518b_add_user_id_on_dataset_versions.py +40 -0
- phoenix/db/migrations/versions/d0690a79ea51_users_on_experiments.py +40 -0
- phoenix/db/migrations/versions/deb2c81c0bb2_dataset_splits.py +139 -0
- phoenix/db/migrations/versions/e76cbd66ffc3_add_experiments_dataset_examples.py +87 -0
- phoenix/db/models.py +669 -56
- phoenix/db/pg_config.py +10 -0
- phoenix/db/types/model_provider.py +4 -0
- phoenix/db/types/token_price_customization.py +29 -0
- phoenix/db/types/trace_retention.py +23 -15
- phoenix/experiments/evaluators/utils.py +3 -3
- phoenix/experiments/functions.py +160 -52
- phoenix/experiments/tracing.py +2 -2
- phoenix/experiments/types.py +1 -1
- phoenix/inferences/inferences.py +1 -2
- phoenix/server/api/auth.py +38 -7
- phoenix/server/api/auth_messages.py +46 -0
- phoenix/server/api/context.py +100 -4
- phoenix/server/api/dataloaders/__init__.py +79 -5
- phoenix/server/api/dataloaders/annotation_configs_by_project.py +31 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +60 -8
- phoenix/server/api/dataloaders/average_experiment_repeated_run_group_latency.py +50 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +17 -24
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +1 -2
- phoenix/server/api/dataloaders/dataset_dataset_splits.py +52 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -1
- phoenix/server/api/dataloaders/dataset_example_splits.py +40 -0
- phoenix/server/api/dataloaders/dataset_examples_and_versions_by_experiment_run.py +47 -0
- phoenix/server/api/dataloaders/dataset_labels.py +36 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -2
- phoenix/server/api/dataloaders/document_evaluations.py +6 -9
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +88 -34
- phoenix/server/api/dataloaders/experiment_dataset_splits.py +43 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +21 -28
- phoenix/server/api/dataloaders/experiment_repeated_run_group_annotation_summaries.py +77 -0
- phoenix/server/api/dataloaders/experiment_repeated_run_groups.py +57 -0
- phoenix/server/api/dataloaders/experiment_runs_by_experiment_and_example.py +44 -0
- phoenix/server/api/dataloaders/last_used_times_by_generative_model_id.py +35 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +40 -8
- phoenix/server/api/dataloaders/record_counts.py +37 -10
- phoenix/server/api/dataloaders/session_annotations_by_session.py +29 -0
- phoenix/server/api/dataloaders/span_cost_by_span.py +24 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_generative_model.py +56 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_project_session.py +57 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_span.py +43 -0
- phoenix/server/api/dataloaders/span_cost_detail_summary_entries_by_trace.py +56 -0
- phoenix/server/api/dataloaders/span_cost_details_by_span_cost.py +27 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment.py +57 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_repeated_run_group.py +64 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_experiment_run.py +58 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_generative_model.py +55 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project.py +152 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_project_session.py +56 -0
- phoenix/server/api/dataloaders/span_cost_summary_by_trace.py +55 -0
- phoenix/server/api/dataloaders/span_costs.py +29 -0
- phoenix/server/api/dataloaders/table_fields.py +2 -2
- phoenix/server/api/dataloaders/token_prices_by_model.py +30 -0
- phoenix/server/api/dataloaders/trace_annotations_by_trace.py +27 -0
- phoenix/server/api/dataloaders/types.py +29 -0
- phoenix/server/api/exceptions.py +11 -1
- phoenix/server/api/helpers/dataset_helpers.py +5 -1
- phoenix/server/api/helpers/playground_clients.py +1243 -292
- phoenix/server/api/helpers/playground_registry.py +2 -2
- phoenix/server/api/helpers/playground_spans.py +8 -4
- phoenix/server/api/helpers/playground_users.py +26 -0
- phoenix/server/api/helpers/prompts/conversions/aws.py +83 -0
- phoenix/server/api/helpers/prompts/conversions/google.py +103 -0
- phoenix/server/api/helpers/prompts/models.py +205 -22
- phoenix/server/api/input_types/{SpanAnnotationFilter.py → AnnotationFilter.py} +22 -14
- phoenix/server/api/input_types/ChatCompletionInput.py +6 -2
- phoenix/server/api/input_types/CreateProjectInput.py +27 -0
- phoenix/server/api/input_types/CreateProjectSessionAnnotationInput.py +37 -0
- phoenix/server/api/input_types/DatasetFilter.py +17 -0
- phoenix/server/api/input_types/ExperimentRunSort.py +237 -0
- phoenix/server/api/input_types/GenerativeCredentialInput.py +9 -0
- phoenix/server/api/input_types/GenerativeModelInput.py +5 -0
- phoenix/server/api/input_types/ProjectSessionSort.py +161 -1
- phoenix/server/api/input_types/PromptFilter.py +14 -0
- phoenix/server/api/input_types/PromptVersionInput.py +52 -1
- phoenix/server/api/input_types/SpanSort.py +44 -7
- phoenix/server/api/input_types/TimeBinConfig.py +23 -0
- phoenix/server/api/input_types/UpdateAnnotationInput.py +34 -0
- phoenix/server/api/input_types/UserRoleInput.py +1 -0
- phoenix/server/api/mutations/__init__.py +10 -0
- phoenix/server/api/mutations/annotation_config_mutations.py +8 -8
- phoenix/server/api/mutations/api_key_mutations.py +19 -23
- phoenix/server/api/mutations/chat_mutations.py +154 -47
- phoenix/server/api/mutations/dataset_label_mutations.py +243 -0
- phoenix/server/api/mutations/dataset_mutations.py +21 -16
- phoenix/server/api/mutations/dataset_split_mutations.py +351 -0
- phoenix/server/api/mutations/experiment_mutations.py +2 -2
- phoenix/server/api/mutations/export_events_mutations.py +3 -3
- phoenix/server/api/mutations/model_mutations.py +210 -0
- phoenix/server/api/mutations/project_mutations.py +49 -10
- phoenix/server/api/mutations/project_session_annotations_mutations.py +158 -0
- phoenix/server/api/mutations/project_trace_retention_policy_mutations.py +8 -4
- phoenix/server/api/mutations/prompt_label_mutations.py +74 -65
- phoenix/server/api/mutations/prompt_mutations.py +65 -129
- phoenix/server/api/mutations/prompt_version_tag_mutations.py +11 -8
- phoenix/server/api/mutations/span_annotations_mutations.py +15 -10
- phoenix/server/api/mutations/trace_annotations_mutations.py +14 -10
- phoenix/server/api/mutations/trace_mutations.py +47 -3
- phoenix/server/api/mutations/user_mutations.py +66 -41
- phoenix/server/api/queries.py +768 -293
- phoenix/server/api/routers/__init__.py +2 -2
- phoenix/server/api/routers/auth.py +154 -88
- phoenix/server/api/routers/ldap.py +229 -0
- phoenix/server/api/routers/oauth2.py +369 -106
- phoenix/server/api/routers/v1/__init__.py +24 -4
- phoenix/server/api/routers/v1/annotation_configs.py +23 -31
- phoenix/server/api/routers/v1/annotations.py +481 -17
- phoenix/server/api/routers/v1/datasets.py +395 -81
- phoenix/server/api/routers/v1/documents.py +142 -0
- phoenix/server/api/routers/v1/evaluations.py +24 -31
- phoenix/server/api/routers/v1/experiment_evaluations.py +19 -8
- phoenix/server/api/routers/v1/experiment_runs.py +337 -59
- phoenix/server/api/routers/v1/experiments.py +479 -48
- phoenix/server/api/routers/v1/models.py +7 -0
- phoenix/server/api/routers/v1/projects.py +18 -49
- phoenix/server/api/routers/v1/prompts.py +54 -40
- phoenix/server/api/routers/v1/sessions.py +108 -0
- phoenix/server/api/routers/v1/spans.py +1091 -81
- phoenix/server/api/routers/v1/traces.py +132 -78
- phoenix/server/api/routers/v1/users.py +389 -0
- phoenix/server/api/routers/v1/utils.py +3 -7
- phoenix/server/api/subscriptions.py +305 -88
- phoenix/server/api/types/Annotation.py +90 -23
- phoenix/server/api/types/ApiKey.py +13 -17
- phoenix/server/api/types/AuthMethod.py +1 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +1 -0
- phoenix/server/api/types/CostBreakdown.py +12 -0
- phoenix/server/api/types/Dataset.py +226 -72
- phoenix/server/api/types/DatasetExample.py +88 -18
- phoenix/server/api/types/DatasetExperimentAnnotationSummary.py +10 -0
- phoenix/server/api/types/DatasetLabel.py +57 -0
- phoenix/server/api/types/DatasetSplit.py +98 -0
- phoenix/server/api/types/DatasetVersion.py +49 -4
- phoenix/server/api/types/DocumentAnnotation.py +212 -0
- phoenix/server/api/types/Experiment.py +264 -59
- phoenix/server/api/types/ExperimentComparison.py +5 -10
- phoenix/server/api/types/ExperimentRepeatedRunGroup.py +155 -0
- phoenix/server/api/types/ExperimentRepeatedRunGroupAnnotationSummary.py +9 -0
- phoenix/server/api/types/ExperimentRun.py +169 -65
- phoenix/server/api/types/ExperimentRunAnnotation.py +158 -39
- phoenix/server/api/types/GenerativeModel.py +245 -3
- phoenix/server/api/types/GenerativeProvider.py +70 -11
- phoenix/server/api/types/{Model.py → InferenceModel.py} +1 -1
- phoenix/server/api/types/ModelInterface.py +16 -0
- phoenix/server/api/types/PlaygroundModel.py +20 -0
- phoenix/server/api/types/Project.py +1278 -216
- phoenix/server/api/types/ProjectSession.py +188 -28
- phoenix/server/api/types/ProjectSessionAnnotation.py +187 -0
- phoenix/server/api/types/ProjectTraceRetentionPolicy.py +1 -1
- phoenix/server/api/types/Prompt.py +119 -39
- phoenix/server/api/types/PromptLabel.py +42 -25
- phoenix/server/api/types/PromptVersion.py +11 -8
- phoenix/server/api/types/PromptVersionTag.py +65 -25
- phoenix/server/api/types/ServerStatus.py +6 -0
- phoenix/server/api/types/Span.py +167 -123
- phoenix/server/api/types/SpanAnnotation.py +189 -42
- phoenix/server/api/types/SpanCostDetailSummaryEntry.py +10 -0
- phoenix/server/api/types/SpanCostSummary.py +10 -0
- phoenix/server/api/types/SystemApiKey.py +65 -1
- phoenix/server/api/types/TokenPrice.py +16 -0
- phoenix/server/api/types/TokenUsage.py +3 -3
- phoenix/server/api/types/Trace.py +223 -51
- phoenix/server/api/types/TraceAnnotation.py +149 -50
- phoenix/server/api/types/User.py +137 -32
- phoenix/server/api/types/UserApiKey.py +73 -26
- phoenix/server/api/types/node.py +10 -0
- phoenix/server/api/types/pagination.py +11 -2
- phoenix/server/app.py +290 -45
- phoenix/server/authorization.py +38 -3
- phoenix/server/bearer_auth.py +34 -24
- phoenix/server/cost_tracking/cost_details_calculator.py +196 -0
- phoenix/server/cost_tracking/cost_model_lookup.py +179 -0
- phoenix/server/cost_tracking/helpers.py +68 -0
- phoenix/server/cost_tracking/model_cost_manifest.json +3657 -830
- phoenix/server/cost_tracking/regex_specificity.py +397 -0
- phoenix/server/cost_tracking/token_cost_calculator.py +57 -0
- phoenix/server/daemons/__init__.py +0 -0
- phoenix/server/daemons/db_disk_usage_monitor.py +214 -0
- phoenix/server/daemons/generative_model_store.py +103 -0
- phoenix/server/daemons/span_cost_calculator.py +99 -0
- phoenix/server/dml_event.py +17 -0
- phoenix/server/dml_event_handler.py +5 -0
- phoenix/server/email/sender.py +56 -3
- phoenix/server/email/templates/db_disk_usage_notification.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/experiments/__init__.py +0 -0
- phoenix/server/experiments/utils.py +14 -0
- phoenix/server/grpc_server.py +11 -11
- phoenix/server/jwt_store.py +17 -15
- phoenix/server/ldap.py +1449 -0
- phoenix/server/main.py +26 -10
- phoenix/server/oauth2.py +330 -12
- phoenix/server/prometheus.py +66 -6
- phoenix/server/rate_limiters.py +4 -9
- phoenix/server/retention.py +33 -20
- phoenix/server/session_filters.py +49 -0
- phoenix/server/static/.vite/manifest.json +55 -51
- phoenix/server/static/assets/components-BreFUQQa.js +6702 -0
- phoenix/server/static/assets/{index-E0M82BdE.js → index-CTQoemZv.js} +140 -56
- phoenix/server/static/assets/pages-DBE5iYM3.js +9524 -0
- phoenix/server/static/assets/vendor-BGzfc4EU.css +1 -0
- phoenix/server/static/assets/vendor-DCE4v-Ot.js +920 -0
- phoenix/server/static/assets/vendor-codemirror-D5f205eT.js +25 -0
- phoenix/server/static/assets/vendor-recharts-V9cwpXsm.js +37 -0
- phoenix/server/static/assets/vendor-shiki-Do--csgv.js +5 -0
- phoenix/server/static/assets/vendor-three-CmB8bl_y.js +3840 -0
- phoenix/server/templates/index.html +40 -6
- phoenix/server/thread_server.py +1 -2
- phoenix/server/types.py +14 -4
- phoenix/server/utils.py +74 -0
- phoenix/session/client.py +56 -3
- phoenix/session/data_extractor.py +5 -0
- phoenix/session/evaluation.py +14 -5
- phoenix/session/session.py +45 -9
- phoenix/settings.py +5 -0
- phoenix/trace/attributes.py +80 -13
- phoenix/trace/dsl/helpers.py +90 -1
- phoenix/trace/dsl/query.py +8 -6
- phoenix/trace/projects.py +5 -0
- phoenix/utilities/template_formatters.py +1 -1
- phoenix/version.py +1 -1
- arize_phoenix-10.0.4.dist-info/RECORD +0 -405
- phoenix/server/api/types/Evaluation.py +0 -39
- phoenix/server/cost_tracking/cost_lookup.py +0 -255
- phoenix/server/static/assets/components-DULKeDfL.js +0 -4365
- phoenix/server/static/assets/pages-Cl0A-0U2.js +0 -7430
- phoenix/server/static/assets/vendor-WIZid84E.css +0 -1
- phoenix/server/static/assets/vendor-arizeai-Dy-0mSNw.js +0 -649
- phoenix/server/static/assets/vendor-codemirror-DBtifKNr.js +0 -33
- phoenix/server/static/assets/vendor-oB4u9zuV.js +0 -905
- phoenix/server/static/assets/vendor-recharts-D-T4KPz2.js +0 -59
- phoenix/server/static/assets/vendor-shiki-BMn4O_9F.js +0 -5
- phoenix/server/static/assets/vendor-three-C5WAXd5r.js +0 -2998
- phoenix/utilities/deprecation.py +0 -31
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/entry_points.txt +0 -0
- {arize_phoenix-10.0.4.dist-info → arize_phoenix-12.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,21 +15,16 @@ from typing import Any, Optional, Union, cast
|
|
|
15
15
|
|
|
16
16
|
import pandas as pd
|
|
17
17
|
import pyarrow as pa
|
|
18
|
-
from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
|
|
18
|
+
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query
|
|
19
19
|
from fastapi.responses import PlainTextResponse, StreamingResponse
|
|
20
|
-
from sqlalchemy import and_, delete, func, select
|
|
20
|
+
from sqlalchemy import and_, case, delete, func, select
|
|
21
21
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
22
22
|
from starlette.concurrency import run_in_threadpool
|
|
23
23
|
from starlette.datastructures import FormData, UploadFile
|
|
24
24
|
from starlette.requests import Request
|
|
25
25
|
from starlette.responses import Response
|
|
26
26
|
from starlette.status import (
|
|
27
|
-
HTTP_200_OK,
|
|
28
|
-
HTTP_204_NO_CONTENT,
|
|
29
27
|
HTTP_404_NOT_FOUND,
|
|
30
|
-
HTTP_409_CONFLICT,
|
|
31
|
-
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
32
|
-
HTTP_429_TOO_MANY_REQUESTS,
|
|
33
28
|
)
|
|
34
29
|
from strawberry.relay import GlobalID
|
|
35
30
|
from typing_extensions import TypeAlias, assert_never
|
|
@@ -42,11 +37,15 @@ from phoenix.db.insertion.dataset import (
|
|
|
42
37
|
ExampleContent,
|
|
43
38
|
add_dataset_examples,
|
|
44
39
|
)
|
|
40
|
+
from phoenix.db.types.db_models import UNDEFINED
|
|
45
41
|
from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
|
|
46
42
|
from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
|
|
43
|
+
from phoenix.server.api.types.DatasetSplit import DatasetSplit as DatasetSplitNodeType
|
|
47
44
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
|
|
48
45
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
49
46
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
47
|
+
from phoenix.server.authorization import is_not_locked
|
|
48
|
+
from phoenix.server.bearer_auth import PhoenixUser
|
|
50
49
|
from phoenix.server.dml_event import DatasetInsertEvent
|
|
51
50
|
|
|
52
51
|
from .models import V1RoutesBaseModel
|
|
@@ -57,6 +56,11 @@ from .utils import (
|
|
|
57
56
|
add_text_csv_content_to_responses,
|
|
58
57
|
)
|
|
59
58
|
|
|
59
|
+
csv.field_size_limit(
|
|
60
|
+
1_000_000_000 # allows large field sizes for CSV upload (1GB)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
60
64
|
logger = logging.getLogger(__name__)
|
|
61
65
|
|
|
62
66
|
DATASET_NODE_NAME = DatasetNodeType.__name__
|
|
@@ -73,6 +77,7 @@ class Dataset(V1RoutesBaseModel):
|
|
|
73
77
|
metadata: dict[str, Any]
|
|
74
78
|
created_at: datetime
|
|
75
79
|
updated_at: datetime
|
|
80
|
+
example_count: int
|
|
76
81
|
|
|
77
82
|
|
|
78
83
|
class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
|
|
@@ -83,7 +88,7 @@ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
|
|
|
83
88
|
"/datasets",
|
|
84
89
|
operation_id="listDatasets",
|
|
85
90
|
summary="List datasets",
|
|
86
|
-
responses=add_errors_to_responses([
|
|
91
|
+
responses=add_errors_to_responses([422]),
|
|
87
92
|
)
|
|
88
93
|
async def list_datasets(
|
|
89
94
|
request: Request,
|
|
@@ -97,7 +102,18 @@ async def list_datasets(
|
|
|
97
102
|
),
|
|
98
103
|
) -> ListDatasetsResponseBody:
|
|
99
104
|
async with request.app.state.db() as session:
|
|
100
|
-
|
|
105
|
+
value = case(
|
|
106
|
+
(models.DatasetExampleRevision.revision_kind == "CREATE", 1),
|
|
107
|
+
(models.DatasetExampleRevision.revision_kind == "DELETE", -1),
|
|
108
|
+
)
|
|
109
|
+
query = (
|
|
110
|
+
select(models.Dataset)
|
|
111
|
+
.add_columns(func.coalesce(func.sum(value), 0).label("example_count"))
|
|
112
|
+
.outerjoin_from(models.Dataset, models.DatasetExample)
|
|
113
|
+
.outerjoin_from(models.DatasetExample, models.DatasetExampleRevision)
|
|
114
|
+
.group_by(models.Dataset.id)
|
|
115
|
+
.order_by(models.Dataset.id.desc())
|
|
116
|
+
)
|
|
101
117
|
|
|
102
118
|
if cursor:
|
|
103
119
|
try:
|
|
@@ -106,25 +122,26 @@ async def list_datasets(
|
|
|
106
122
|
except ValueError:
|
|
107
123
|
raise HTTPException(
|
|
108
124
|
detail=f"Invalid cursor format: {cursor}",
|
|
109
|
-
status_code=
|
|
125
|
+
status_code=422,
|
|
110
126
|
)
|
|
111
127
|
if name:
|
|
112
128
|
query = query.filter(models.Dataset.name == name)
|
|
113
129
|
|
|
114
130
|
query = query.limit(limit + 1)
|
|
115
131
|
result = await session.execute(query)
|
|
116
|
-
datasets = result.
|
|
117
|
-
|
|
132
|
+
datasets = result.all()
|
|
118
133
|
if not datasets:
|
|
119
134
|
return ListDatasetsResponseBody(next_cursor=None, data=[])
|
|
120
135
|
|
|
121
136
|
next_cursor = None
|
|
122
137
|
if len(datasets) == limit + 1:
|
|
123
|
-
|
|
138
|
+
dataset = datasets[-1][0]
|
|
139
|
+
next_cursor = str(GlobalID(DATASET_NODE_NAME, str(dataset.id)))
|
|
124
140
|
datasets = datasets[:-1]
|
|
125
141
|
|
|
126
142
|
data = []
|
|
127
|
-
for
|
|
143
|
+
for row in datasets:
|
|
144
|
+
dataset = row[0]
|
|
128
145
|
data.append(
|
|
129
146
|
Dataset(
|
|
130
147
|
id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
|
|
@@ -133,6 +150,7 @@ async def list_datasets(
|
|
|
133
150
|
metadata=dataset.metadata_,
|
|
134
151
|
created_at=dataset.created_at,
|
|
135
152
|
updated_at=dataset.updated_at,
|
|
153
|
+
example_count=row[1],
|
|
136
154
|
)
|
|
137
155
|
)
|
|
138
156
|
|
|
@@ -143,11 +161,11 @@ async def list_datasets(
|
|
|
143
161
|
"/datasets/{id}",
|
|
144
162
|
operation_id="deleteDatasetById",
|
|
145
163
|
summary="Delete dataset by ID",
|
|
146
|
-
status_code=
|
|
164
|
+
status_code=204,
|
|
147
165
|
responses=add_errors_to_responses(
|
|
148
166
|
[
|
|
149
|
-
{"status_code":
|
|
150
|
-
{"status_code":
|
|
167
|
+
{"status_code": 404, "description": "Dataset not found"},
|
|
168
|
+
{"status_code": 422, "description": "Invalid dataset ID"},
|
|
151
169
|
]
|
|
152
170
|
),
|
|
153
171
|
)
|
|
@@ -161,11 +179,9 @@ async def delete_dataset(
|
|
|
161
179
|
DATASET_NODE_NAME,
|
|
162
180
|
)
|
|
163
181
|
except ValueError:
|
|
164
|
-
raise HTTPException(
|
|
165
|
-
detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
|
|
166
|
-
)
|
|
182
|
+
raise HTTPException(detail=f"Invalid Dataset ID: {id}", status_code=422)
|
|
167
183
|
else:
|
|
168
|
-
raise HTTPException(detail="Missing Dataset ID", status_code=
|
|
184
|
+
raise HTTPException(detail="Missing Dataset ID", status_code=422)
|
|
169
185
|
project_names_stmt = get_project_names_for_datasets(dataset_id)
|
|
170
186
|
eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
|
|
171
187
|
stmt = (
|
|
@@ -175,7 +191,7 @@ async def delete_dataset(
|
|
|
175
191
|
project_names = await session.scalars(project_names_stmt)
|
|
176
192
|
eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
|
|
177
193
|
if (await session.scalar(stmt)) is None:
|
|
178
|
-
raise HTTPException(detail="Dataset does not exist", status_code=
|
|
194
|
+
raise HTTPException(detail="Dataset does not exist", status_code=404)
|
|
179
195
|
tasks = BackgroundTasks()
|
|
180
196
|
tasks.add_task(delete_projects, request.app.state.db, *project_names)
|
|
181
197
|
tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
|
|
@@ -193,17 +209,21 @@ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
|
|
|
193
209
|
"/datasets/{id}",
|
|
194
210
|
operation_id="getDataset",
|
|
195
211
|
summary="Get dataset by ID",
|
|
196
|
-
responses=add_errors_to_responses([
|
|
212
|
+
responses=add_errors_to_responses([404]),
|
|
197
213
|
)
|
|
198
214
|
async def get_dataset(
|
|
199
215
|
request: Request, id: str = Path(description="The ID of the dataset")
|
|
200
216
|
) -> GetDatasetResponseBody:
|
|
201
|
-
|
|
217
|
+
try:
|
|
218
|
+
dataset_id = GlobalID.from_id(id)
|
|
219
|
+
except Exception as e:
|
|
220
|
+
raise HTTPException(
|
|
221
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
222
|
+
status_code=422,
|
|
223
|
+
) from e
|
|
202
224
|
|
|
203
225
|
if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
|
|
204
|
-
raise HTTPException(
|
|
205
|
-
detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
206
|
-
)
|
|
226
|
+
raise HTTPException(detail=f"ID {dataset_id} refers to a f{type_name}", status_code=404)
|
|
207
227
|
async with request.app.state.db() as session:
|
|
208
228
|
result = await session.execute(
|
|
209
229
|
select(models.Dataset, models.Dataset.example_count).filter(
|
|
@@ -214,9 +234,7 @@ async def get_dataset(
|
|
|
214
234
|
dataset = dataset_query[0] if dataset_query else None
|
|
215
235
|
example_count = dataset_query[1] if dataset_query else 0
|
|
216
236
|
if dataset is None:
|
|
217
|
-
raise HTTPException(
|
|
218
|
-
detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
|
|
219
|
-
)
|
|
237
|
+
raise HTTPException(detail=f"Dataset with ID {dataset_id} not found", status_code=404)
|
|
220
238
|
|
|
221
239
|
dataset = DatasetWithExampleCount(
|
|
222
240
|
id=str(dataset_id),
|
|
@@ -245,7 +263,7 @@ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
|
|
|
245
263
|
"/datasets/{id}/versions",
|
|
246
264
|
operation_id="listDatasetVersionsByDatasetId",
|
|
247
265
|
summary="List dataset versions",
|
|
248
|
-
responses=add_errors_to_responses([
|
|
266
|
+
responses=add_errors_to_responses([422]),
|
|
249
267
|
)
|
|
250
268
|
async def list_dataset_versions(
|
|
251
269
|
request: Request,
|
|
@@ -267,12 +285,12 @@ async def list_dataset_versions(
|
|
|
267
285
|
except ValueError:
|
|
268
286
|
raise HTTPException(
|
|
269
287
|
detail=f"Invalid Dataset ID: {id}",
|
|
270
|
-
status_code=
|
|
288
|
+
status_code=422,
|
|
271
289
|
)
|
|
272
290
|
else:
|
|
273
291
|
raise HTTPException(
|
|
274
292
|
detail="Missing Dataset ID",
|
|
275
|
-
status_code=
|
|
293
|
+
status_code=422,
|
|
276
294
|
)
|
|
277
295
|
stmt = (
|
|
278
296
|
select(models.DatasetVersion)
|
|
@@ -288,7 +306,7 @@ async def list_dataset_versions(
|
|
|
288
306
|
except ValueError:
|
|
289
307
|
raise HTTPException(
|
|
290
308
|
detail=f"Invalid cursor: {cursor}",
|
|
291
|
-
status_code=
|
|
309
|
+
status_code=422,
|
|
292
310
|
)
|
|
293
311
|
max_dataset_version_id = (
|
|
294
312
|
select(models.DatasetVersion.id)
|
|
@@ -312,6 +330,7 @@ async def list_dataset_versions(
|
|
|
312
330
|
|
|
313
331
|
class UploadDatasetData(V1RoutesBaseModel):
|
|
314
332
|
dataset_id: str
|
|
333
|
+
version_id: str
|
|
315
334
|
|
|
316
335
|
|
|
317
336
|
class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
@@ -320,15 +339,16 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
|
320
339
|
|
|
321
340
|
@router.post(
|
|
322
341
|
"/datasets/upload",
|
|
342
|
+
dependencies=[Depends(is_not_locked)],
|
|
323
343
|
operation_id="uploadDataset",
|
|
324
|
-
summary="Upload dataset from JSON, CSV, or PyArrow",
|
|
344
|
+
summary="Upload dataset from JSON, JSONL, CSV, or PyArrow",
|
|
325
345
|
responses=add_errors_to_responses(
|
|
326
346
|
[
|
|
327
347
|
{
|
|
328
|
-
"status_code":
|
|
348
|
+
"status_code": 409,
|
|
329
349
|
"description": "Dataset of the same name already exists",
|
|
330
350
|
},
|
|
331
|
-
{"status_code":
|
|
351
|
+
{"status_code": 422, "description": "Invalid request body"},
|
|
332
352
|
]
|
|
333
353
|
),
|
|
334
354
|
# FastAPI cannot generate the request body portion of the OpenAPI schema for
|
|
@@ -350,6 +370,17 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
|
350
370
|
"inputs": {"type": "array", "items": {"type": "object"}},
|
|
351
371
|
"outputs": {"type": "array", "items": {"type": "object"}},
|
|
352
372
|
"metadata": {"type": "array", "items": {"type": "object"}},
|
|
373
|
+
"splits": {
|
|
374
|
+
"type": "array",
|
|
375
|
+
"items": {
|
|
376
|
+
"oneOf": [
|
|
377
|
+
{"type": "string"},
|
|
378
|
+
{"type": "array", "items": {"type": "string"}},
|
|
379
|
+
{"type": "null"},
|
|
380
|
+
]
|
|
381
|
+
},
|
|
382
|
+
"description": "Split per example: string, string array, or null",
|
|
383
|
+
},
|
|
353
384
|
},
|
|
354
385
|
}
|
|
355
386
|
},
|
|
@@ -376,6 +407,12 @@ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
|
376
407
|
"items": {"type": "string"},
|
|
377
408
|
"uniqueItems": True,
|
|
378
409
|
},
|
|
410
|
+
"split_keys[]": {
|
|
411
|
+
"type": "array",
|
|
412
|
+
"items": {"type": "string"},
|
|
413
|
+
"uniqueItems": True,
|
|
414
|
+
"description": "Column names for auto-assigning examples to splits",
|
|
415
|
+
},
|
|
379
416
|
"file": {"type": "string", "format": "binary"},
|
|
380
417
|
},
|
|
381
418
|
}
|
|
@@ -391,7 +428,12 @@ async def upload_dataset(
|
|
|
391
428
|
description="If true, fulfill request synchronously and return JSON containing dataset_id.",
|
|
392
429
|
),
|
|
393
430
|
) -> Optional[UploadDatasetResponseBody]:
|
|
394
|
-
request_content_type = request.headers
|
|
431
|
+
request_content_type = request.headers.get("content-type")
|
|
432
|
+
if not request_content_type:
|
|
433
|
+
raise HTTPException(
|
|
434
|
+
detail="Missing content-type header",
|
|
435
|
+
status_code=400,
|
|
436
|
+
)
|
|
395
437
|
examples: Union[Examples, Awaitable[Examples]]
|
|
396
438
|
if request_content_type.startswith("application/json"):
|
|
397
439
|
try:
|
|
@@ -401,14 +443,14 @@ async def upload_dataset(
|
|
|
401
443
|
except ValueError as e:
|
|
402
444
|
raise HTTPException(
|
|
403
445
|
detail=str(e),
|
|
404
|
-
status_code=
|
|
446
|
+
status_code=422,
|
|
405
447
|
)
|
|
406
448
|
if action is DatasetAction.CREATE:
|
|
407
449
|
async with request.app.state.db() as session:
|
|
408
450
|
if await _check_table_exists(session, name):
|
|
409
451
|
raise HTTPException(
|
|
410
452
|
detail=f"Dataset with the same name already exists: {name=}",
|
|
411
|
-
status_code=
|
|
453
|
+
status_code=409,
|
|
412
454
|
)
|
|
413
455
|
elif request_content_type.startswith("multipart/form-data"):
|
|
414
456
|
async with request.form() as form:
|
|
@@ -420,19 +462,20 @@ async def upload_dataset(
|
|
|
420
462
|
input_keys,
|
|
421
463
|
output_keys,
|
|
422
464
|
metadata_keys,
|
|
465
|
+
split_keys,
|
|
423
466
|
file,
|
|
424
467
|
) = await _parse_form_data(form)
|
|
425
468
|
except ValueError as e:
|
|
426
469
|
raise HTTPException(
|
|
427
470
|
detail=str(e),
|
|
428
|
-
status_code=
|
|
471
|
+
status_code=422,
|
|
429
472
|
)
|
|
430
473
|
if action is DatasetAction.CREATE:
|
|
431
474
|
async with request.app.state.db() as session:
|
|
432
475
|
if await _check_table_exists(session, name):
|
|
433
476
|
raise HTTPException(
|
|
434
477
|
detail=f"Dataset with the same name already exists: {name=}",
|
|
435
|
-
status_code=
|
|
478
|
+
status_code=409,
|
|
436
479
|
)
|
|
437
480
|
content = await file.read()
|
|
438
481
|
try:
|
|
@@ -440,22 +483,32 @@ async def upload_dataset(
|
|
|
440
483
|
if file_content_type is FileContentType.CSV:
|
|
441
484
|
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
442
485
|
examples = await _process_csv(
|
|
443
|
-
content, encoding, input_keys, output_keys, metadata_keys
|
|
486
|
+
content, encoding, input_keys, output_keys, metadata_keys, split_keys
|
|
444
487
|
)
|
|
445
488
|
elif file_content_type is FileContentType.PYARROW:
|
|
446
|
-
examples = await _process_pyarrow(
|
|
489
|
+
examples = await _process_pyarrow(
|
|
490
|
+
content, input_keys, output_keys, metadata_keys, split_keys
|
|
491
|
+
)
|
|
492
|
+
elif file_content_type is FileContentType.JSONL:
|
|
493
|
+
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
494
|
+
examples = await _process_jsonl(
|
|
495
|
+
content, encoding, input_keys, output_keys, metadata_keys, split_keys
|
|
496
|
+
)
|
|
447
497
|
else:
|
|
448
498
|
assert_never(file_content_type)
|
|
449
499
|
except ValueError as e:
|
|
450
500
|
raise HTTPException(
|
|
451
501
|
detail=str(e),
|
|
452
|
-
status_code=
|
|
502
|
+
status_code=422,
|
|
453
503
|
)
|
|
454
504
|
else:
|
|
455
505
|
raise HTTPException(
|
|
456
506
|
detail="Invalid request Content-Type",
|
|
457
|
-
status_code=
|
|
507
|
+
status_code=422,
|
|
458
508
|
)
|
|
509
|
+
user_id: Optional[int] = None
|
|
510
|
+
if request.app.state.authentication_enabled and isinstance(request.user, PhoenixUser):
|
|
511
|
+
user_id = int(request.user.identity)
|
|
459
512
|
operation = cast(
|
|
460
513
|
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
461
514
|
partial(
|
|
@@ -464,27 +517,34 @@ async def upload_dataset(
|
|
|
464
517
|
action=action,
|
|
465
518
|
name=name,
|
|
466
519
|
description=description,
|
|
520
|
+
user_id=user_id,
|
|
467
521
|
),
|
|
468
522
|
)
|
|
469
523
|
if sync:
|
|
470
524
|
async with request.app.state.db() as session:
|
|
471
|
-
|
|
525
|
+
event = await operation(session)
|
|
526
|
+
dataset_id = event.dataset_id
|
|
527
|
+
version_id = event.dataset_version_id
|
|
472
528
|
request.state.event_queue.put(DatasetInsertEvent((dataset_id,)))
|
|
473
529
|
return UploadDatasetResponseBody(
|
|
474
|
-
data=UploadDatasetData(
|
|
530
|
+
data=UploadDatasetData(
|
|
531
|
+
dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))),
|
|
532
|
+
version_id=str(GlobalID(DatasetVersion.__name__, str(version_id))),
|
|
533
|
+
)
|
|
475
534
|
)
|
|
476
535
|
try:
|
|
477
536
|
request.state.enqueue_operation(operation)
|
|
478
537
|
except QueueFull:
|
|
479
538
|
if isinstance(examples, Coroutine):
|
|
480
539
|
examples.close()
|
|
481
|
-
raise HTTPException(detail="Too many requests.", status_code=
|
|
540
|
+
raise HTTPException(detail="Too many requests.", status_code=429)
|
|
482
541
|
return None
|
|
483
542
|
|
|
484
543
|
|
|
485
544
|
class FileContentType(Enum):
|
|
486
545
|
CSV = "text/csv"
|
|
487
546
|
PYARROW = "application/x-pandas-pyarrow"
|
|
547
|
+
JSONL = "application/jsonl"
|
|
488
548
|
|
|
489
549
|
@classmethod
|
|
490
550
|
def _missing_(cls, v: Any) -> "FileContentType":
|
|
@@ -512,6 +572,7 @@ Description: TypeAlias = Optional[str]
|
|
|
512
572
|
InputKeys: TypeAlias = frozenset[str]
|
|
513
573
|
OutputKeys: TypeAlias = frozenset[str]
|
|
514
574
|
MetadataKeys: TypeAlias = frozenset[str]
|
|
575
|
+
SplitKeys: TypeAlias = frozenset[str]
|
|
515
576
|
DatasetId: TypeAlias = int
|
|
516
577
|
Examples: TypeAlias = Iterator[ExampleContent]
|
|
517
578
|
|
|
@@ -528,18 +589,55 @@ def _process_json(
|
|
|
528
589
|
raise ValueError("input is required")
|
|
529
590
|
if not isinstance(inputs, list) or not _is_all_dict(inputs):
|
|
530
591
|
raise ValueError("Input should be a list containing only dictionary objects")
|
|
531
|
-
outputs, metadata = data.get("outputs"), data.get("metadata")
|
|
592
|
+
outputs, metadata, splits = data.get("outputs"), data.get("metadata"), data.get("splits")
|
|
532
593
|
for k, v in {"outputs": outputs, "metadata": metadata}.items():
|
|
533
594
|
if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
|
|
534
595
|
raise ValueError(
|
|
535
596
|
f"{k} should be a list of same length as input containing only dictionary objects"
|
|
536
597
|
)
|
|
598
|
+
|
|
599
|
+
# Validate splits format if provided
|
|
600
|
+
if splits is not None:
|
|
601
|
+
if not isinstance(splits, list):
|
|
602
|
+
raise ValueError("splits must be a list")
|
|
603
|
+
if len(splits) != len(inputs):
|
|
604
|
+
raise ValueError(
|
|
605
|
+
f"splits must have same length as inputs ({len(splits)} != {len(inputs)})"
|
|
606
|
+
)
|
|
537
607
|
examples: list[ExampleContent] = []
|
|
538
608
|
for i, obj in enumerate(inputs):
|
|
609
|
+
# Extract split values, validating they're non-empty strings
|
|
610
|
+
split_set: set[str] = set()
|
|
611
|
+
if splits:
|
|
612
|
+
split_value = splits[i]
|
|
613
|
+
if split_value is None:
|
|
614
|
+
# Sparse assignment: None means no splits for this example
|
|
615
|
+
pass
|
|
616
|
+
elif isinstance(split_value, str):
|
|
617
|
+
# Format 1: Single string value
|
|
618
|
+
if split_value.strip():
|
|
619
|
+
split_set.add(split_value.strip())
|
|
620
|
+
elif isinstance(split_value, list):
|
|
621
|
+
# Format 2: List of strings (multiple splits)
|
|
622
|
+
for v in split_value:
|
|
623
|
+
if v is None:
|
|
624
|
+
continue # Skip None values in the list
|
|
625
|
+
if not isinstance(v, str):
|
|
626
|
+
raise ValueError(
|
|
627
|
+
f"Split value must be a string or None, got {type(v).__name__}"
|
|
628
|
+
)
|
|
629
|
+
if v.strip():
|
|
630
|
+
split_set.add(v.strip())
|
|
631
|
+
else:
|
|
632
|
+
raise ValueError(
|
|
633
|
+
f"Split value must be a string, list of strings, or None, "
|
|
634
|
+
f"got {type(split_value).__name__}"
|
|
635
|
+
)
|
|
539
636
|
example = ExampleContent(
|
|
540
637
|
input=obj,
|
|
541
638
|
output=outputs[i] if outputs else {},
|
|
542
639
|
metadata=metadata[i] if metadata else {},
|
|
640
|
+
splits=frozenset(split_set),
|
|
543
641
|
)
|
|
544
642
|
examples.append(example)
|
|
545
643
|
action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
|
|
@@ -552,6 +650,7 @@ async def _process_csv(
|
|
|
552
650
|
input_keys: InputKeys,
|
|
553
651
|
output_keys: OutputKeys,
|
|
554
652
|
metadata_keys: MetadataKeys,
|
|
653
|
+
split_keys: SplitKeys,
|
|
555
654
|
) -> Examples:
|
|
556
655
|
if content_encoding is FileContentEncoding.GZIP:
|
|
557
656
|
content = await run_in_threadpool(gzip.decompress, content)
|
|
@@ -566,12 +665,15 @@ async def _process_csv(
|
|
|
566
665
|
if freq > 1:
|
|
567
666
|
raise ValueError(f"Duplicated column header in CSV file: {header}")
|
|
568
667
|
column_headers = frozenset(reader.fieldnames)
|
|
569
|
-
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
668
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
|
|
570
669
|
return (
|
|
571
670
|
ExampleContent(
|
|
572
671
|
input={k: row.get(k) for k in input_keys},
|
|
573
672
|
output={k: row.get(k) for k in output_keys},
|
|
574
673
|
metadata={k: row.get(k) for k in metadata_keys},
|
|
674
|
+
splits=frozenset(
|
|
675
|
+
str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
|
|
676
|
+
), # Only include non-empty, non-whitespace split values
|
|
575
677
|
)
|
|
576
678
|
for row in iter(reader)
|
|
577
679
|
)
|
|
@@ -582,13 +684,14 @@ async def _process_pyarrow(
|
|
|
582
684
|
input_keys: InputKeys,
|
|
583
685
|
output_keys: OutputKeys,
|
|
584
686
|
metadata_keys: MetadataKeys,
|
|
687
|
+
split_keys: SplitKeys,
|
|
585
688
|
) -> Awaitable[Examples]:
|
|
586
689
|
try:
|
|
587
690
|
reader = pa.ipc.open_stream(content)
|
|
588
691
|
except pa.ArrowInvalid as e:
|
|
589
692
|
raise ValueError("File is not valid pyarrow") from e
|
|
590
693
|
column_headers = frozenset(reader.schema.names)
|
|
591
|
-
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
694
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys, split_keys)
|
|
592
695
|
|
|
593
696
|
def get_examples() -> Iterator[ExampleContent]:
|
|
594
697
|
for row in reader.read_pandas().to_dict(orient="records"):
|
|
@@ -596,11 +699,48 @@ async def _process_pyarrow(
|
|
|
596
699
|
input={k: row.get(k) for k in input_keys},
|
|
597
700
|
output={k: row.get(k) for k in output_keys},
|
|
598
701
|
metadata={k: row.get(k) for k in metadata_keys},
|
|
702
|
+
splits=frozenset(
|
|
703
|
+
str(v).strip() for k in split_keys if (v := row.get(k)) and str(v).strip()
|
|
704
|
+
), # Only include non-empty, non-whitespace split values
|
|
599
705
|
)
|
|
600
706
|
|
|
601
707
|
return run_in_threadpool(get_examples)
|
|
602
708
|
|
|
603
709
|
|
|
710
|
+
async def _process_jsonl(
|
|
711
|
+
content: bytes,
|
|
712
|
+
encoding: FileContentEncoding,
|
|
713
|
+
input_keys: InputKeys,
|
|
714
|
+
output_keys: OutputKeys,
|
|
715
|
+
metadata_keys: MetadataKeys,
|
|
716
|
+
split_keys: SplitKeys,
|
|
717
|
+
) -> Examples:
|
|
718
|
+
if encoding is FileContentEncoding.GZIP:
|
|
719
|
+
content = await run_in_threadpool(gzip.decompress, content)
|
|
720
|
+
elif encoding is FileContentEncoding.DEFLATE:
|
|
721
|
+
content = await run_in_threadpool(zlib.decompress, content)
|
|
722
|
+
elif encoding is not FileContentEncoding.NONE:
|
|
723
|
+
assert_never(encoding)
|
|
724
|
+
# content is a newline delimited list of JSON objects
|
|
725
|
+
# parse within a threadpool
|
|
726
|
+
reader = await run_in_threadpool(
|
|
727
|
+
lambda c: [json.loads(line) for line in c.decode().splitlines()], content
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
examples: list[ExampleContent] = []
|
|
731
|
+
for obj in reader:
|
|
732
|
+
example = ExampleContent(
|
|
733
|
+
input={k: obj.get(k) for k in input_keys},
|
|
734
|
+
output={k: obj.get(k) for k in output_keys},
|
|
735
|
+
metadata={k: obj.get(k) for k in metadata_keys},
|
|
736
|
+
splits=frozenset(
|
|
737
|
+
str(v).strip() for k in split_keys if (v := obj.get(k)) and str(v).strip()
|
|
738
|
+
), # Only include non-empty, non-whitespace split values
|
|
739
|
+
)
|
|
740
|
+
examples.append(example)
|
|
741
|
+
return iter(examples)
|
|
742
|
+
|
|
743
|
+
|
|
604
744
|
async def _check_table_exists(session: AsyncSession, name: str) -> bool:
|
|
605
745
|
return bool(
|
|
606
746
|
await session.scalar(
|
|
@@ -614,11 +754,13 @@ def _check_keys_exist(
|
|
|
614
754
|
input_keys: InputKeys,
|
|
615
755
|
output_keys: OutputKeys,
|
|
616
756
|
metadata_keys: MetadataKeys,
|
|
757
|
+
split_keys: SplitKeys,
|
|
617
758
|
) -> None:
|
|
618
759
|
for desc, keys in (
|
|
619
760
|
("input", input_keys),
|
|
620
761
|
("output", output_keys),
|
|
621
762
|
("metadata", metadata_keys),
|
|
763
|
+
("split", split_keys),
|
|
622
764
|
):
|
|
623
765
|
if keys and (diff := keys.difference(column_headers)):
|
|
624
766
|
raise ValueError(f"{desc} keys not found in column headers: {diff}")
|
|
@@ -633,6 +775,7 @@ async def _parse_form_data(
|
|
|
633
775
|
InputKeys,
|
|
634
776
|
OutputKeys,
|
|
635
777
|
MetadataKeys,
|
|
778
|
+
SplitKeys,
|
|
636
779
|
UploadFile,
|
|
637
780
|
]:
|
|
638
781
|
name = cast(Optional[str], form.get("name"))
|
|
@@ -646,6 +789,7 @@ async def _parse_form_data(
|
|
|
646
789
|
input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
|
|
647
790
|
output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
|
|
648
791
|
metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
|
|
792
|
+
split_keys = frozenset(filter(bool, cast(list[str], form.getlist("split_keys[]"))))
|
|
649
793
|
return (
|
|
650
794
|
action,
|
|
651
795
|
name,
|
|
@@ -653,6 +797,7 @@ async def _parse_form_data(
|
|
|
653
797
|
input_keys,
|
|
654
798
|
output_keys,
|
|
655
799
|
metadata_keys,
|
|
800
|
+
split_keys,
|
|
656
801
|
file,
|
|
657
802
|
)
|
|
658
803
|
|
|
@@ -668,6 +813,7 @@ class DatasetExample(V1RoutesBaseModel):
|
|
|
668
813
|
class ListDatasetExamplesData(V1RoutesBaseModel):
|
|
669
814
|
dataset_id: str
|
|
670
815
|
version_id: str
|
|
816
|
+
filtered_splits: list[str] = UNDEFINED
|
|
671
817
|
examples: list[DatasetExample]
|
|
672
818
|
|
|
673
819
|
|
|
@@ -679,7 +825,7 @@ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
|
|
|
679
825
|
"/datasets/{id}/examples",
|
|
680
826
|
operation_id="getDatasetExamples",
|
|
681
827
|
summary="Get examples from a dataset",
|
|
682
|
-
responses=add_errors_to_responses([
|
|
828
|
+
responses=add_errors_to_responses([404]),
|
|
683
829
|
)
|
|
684
830
|
async def get_dataset_examples(
|
|
685
831
|
request: Request,
|
|
@@ -687,22 +833,38 @@ async def get_dataset_examples(
|
|
|
687
833
|
version_id: Optional[str] = Query(
|
|
688
834
|
default=None,
|
|
689
835
|
description=(
|
|
690
|
-
"The ID of the dataset version
|
|
836
|
+
"The ID of the dataset version (if omitted, returns data from the latest version)"
|
|
691
837
|
),
|
|
692
838
|
),
|
|
839
|
+
split: Optional[list[str]] = Query(
|
|
840
|
+
default=None,
|
|
841
|
+
description="List of dataset split identifiers (GlobalIDs or names) to filter by",
|
|
842
|
+
),
|
|
693
843
|
) -> ListDatasetExamplesResponseBody:
|
|
694
|
-
|
|
695
|
-
|
|
844
|
+
try:
|
|
845
|
+
dataset_gid = GlobalID.from_id(id)
|
|
846
|
+
except Exception as e:
|
|
847
|
+
raise HTTPException(
|
|
848
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
849
|
+
status_code=422,
|
|
850
|
+
) from e
|
|
851
|
+
|
|
852
|
+
if version_id:
|
|
853
|
+
try:
|
|
854
|
+
version_gid = GlobalID.from_id(version_id)
|
|
855
|
+
except Exception as e:
|
|
856
|
+
raise HTTPException(
|
|
857
|
+
detail=f"Invalid dataset version ID format: {version_id}",
|
|
858
|
+
status_code=422,
|
|
859
|
+
) from e
|
|
860
|
+
else:
|
|
861
|
+
version_gid = None
|
|
696
862
|
|
|
697
863
|
if (dataset_type := dataset_gid.type_name) != "Dataset":
|
|
698
|
-
raise HTTPException(
|
|
699
|
-
detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
|
|
700
|
-
)
|
|
864
|
+
raise HTTPException(detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=404)
|
|
701
865
|
|
|
702
866
|
if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
|
|
703
|
-
raise HTTPException(
|
|
704
|
-
detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
|
|
705
|
-
)
|
|
867
|
+
raise HTTPException(detail=f"ID {version_gid} refers to a {version_type}", status_code=404)
|
|
706
868
|
|
|
707
869
|
async with request.app.state.db() as session:
|
|
708
870
|
if (
|
|
@@ -712,7 +874,7 @@ async def get_dataset_examples(
|
|
|
712
874
|
) is None:
|
|
713
875
|
raise HTTPException(
|
|
714
876
|
detail=f"No dataset with id {dataset_gid} can be found.",
|
|
715
|
-
status_code=
|
|
877
|
+
status_code=404,
|
|
716
878
|
)
|
|
717
879
|
|
|
718
880
|
# Subquery to find the maximum created_at for each dataset_example_id
|
|
@@ -734,7 +896,7 @@ async def get_dataset_examples(
|
|
|
734
896
|
) is None:
|
|
735
897
|
raise HTTPException(
|
|
736
898
|
detail=f"No dataset version with id {version_id} can be found.",
|
|
737
|
-
status_code=
|
|
899
|
+
status_code=404,
|
|
738
900
|
)
|
|
739
901
|
# if a version_id is provided, filter the subquery to only include revisions from that
|
|
740
902
|
partial_subquery = partial_subquery.filter(
|
|
@@ -750,13 +912,17 @@ async def get_dataset_examples(
|
|
|
750
912
|
) is None:
|
|
751
913
|
raise HTTPException(
|
|
752
914
|
detail="Dataset has no versions.",
|
|
753
|
-
status_code=
|
|
915
|
+
status_code=404,
|
|
754
916
|
)
|
|
755
917
|
|
|
756
918
|
subquery = partial_subquery.subquery()
|
|
919
|
+
|
|
757
920
|
# Query for the most recent example revisions that are not deleted
|
|
758
921
|
query = (
|
|
759
|
-
select(
|
|
922
|
+
select(
|
|
923
|
+
models.DatasetExample,
|
|
924
|
+
models.DatasetExampleRevision,
|
|
925
|
+
)
|
|
760
926
|
.join(
|
|
761
927
|
models.DatasetExampleRevision,
|
|
762
928
|
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
|
|
@@ -769,6 +935,28 @@ async def get_dataset_examples(
|
|
|
769
935
|
.filter(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
770
936
|
.order_by(models.DatasetExample.id.asc())
|
|
771
937
|
)
|
|
938
|
+
|
|
939
|
+
# If splits are provided, filter by dataset splits
|
|
940
|
+
resolved_split_names: list[str] = []
|
|
941
|
+
if split:
|
|
942
|
+
# Resolve split identifiers (IDs or names) to IDs and names
|
|
943
|
+
resolved_split_ids, resolved_split_names = await _resolve_split_identifiers(
|
|
944
|
+
session, split
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# Add filter for splits (join with the association table)
|
|
948
|
+
# Use distinct() to prevent duplicates when an example belongs to
|
|
949
|
+
# multiple splits
|
|
950
|
+
query = (
|
|
951
|
+
query.join(
|
|
952
|
+
models.DatasetSplitDatasetExample,
|
|
953
|
+
models.DatasetExample.id
|
|
954
|
+
== models.DatasetSplitDatasetExample.dataset_example_id,
|
|
955
|
+
)
|
|
956
|
+
.filter(models.DatasetSplitDatasetExample.dataset_split_id.in_(resolved_split_ids))
|
|
957
|
+
.distinct()
|
|
958
|
+
)
|
|
959
|
+
|
|
772
960
|
examples = [
|
|
773
961
|
DatasetExample(
|
|
774
962
|
id=str(GlobalID("DatasetExample", str(example.id))),
|
|
@@ -783,6 +971,7 @@ async def get_dataset_examples(
|
|
|
783
971
|
data=ListDatasetExamplesData(
|
|
784
972
|
dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
785
973
|
version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
974
|
+
filtered_splits=resolved_split_names,
|
|
786
975
|
examples=examples,
|
|
787
976
|
)
|
|
788
977
|
)
|
|
@@ -793,10 +982,10 @@ async def get_dataset_examples(
|
|
|
793
982
|
operation_id="getDatasetCsv",
|
|
794
983
|
summary="Download dataset examples as CSV file",
|
|
795
984
|
response_class=StreamingResponse,
|
|
796
|
-
status_code=
|
|
985
|
+
status_code=200,
|
|
797
986
|
responses={
|
|
798
|
-
**add_errors_to_responses([
|
|
799
|
-
**add_text_csv_content_to_responses(
|
|
987
|
+
**add_errors_to_responses([422]),
|
|
988
|
+
**add_text_csv_content_to_responses(200),
|
|
800
989
|
},
|
|
801
990
|
)
|
|
802
991
|
async def get_dataset_csv(
|
|
@@ -806,7 +995,7 @@ async def get_dataset_csv(
|
|
|
806
995
|
version_id: Optional[str] = Query(
|
|
807
996
|
default=None,
|
|
808
997
|
description=(
|
|
809
|
-
"The ID of the dataset version
|
|
998
|
+
"The ID of the dataset version (if omitted, returns data from the latest version)"
|
|
810
999
|
),
|
|
811
1000
|
),
|
|
812
1001
|
) -> Response:
|
|
@@ -816,7 +1005,7 @@ async def get_dataset_csv(
|
|
|
816
1005
|
session=session, id=id, version_id=version_id
|
|
817
1006
|
)
|
|
818
1007
|
except ValueError as e:
|
|
819
|
-
raise HTTPException(detail=str(e), status_code=
|
|
1008
|
+
raise HTTPException(detail=str(e), status_code=422)
|
|
820
1009
|
content = await run_in_threadpool(_get_content_csv, examples)
|
|
821
1010
|
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
822
1011
|
return Response(
|
|
@@ -836,7 +1025,7 @@ async def get_dataset_csv(
|
|
|
836
1025
|
responses=add_errors_to_responses(
|
|
837
1026
|
[
|
|
838
1027
|
{
|
|
839
|
-
"status_code":
|
|
1028
|
+
"status_code": 422,
|
|
840
1029
|
"description": "Invalid dataset or version ID",
|
|
841
1030
|
}
|
|
842
1031
|
]
|
|
@@ -849,7 +1038,7 @@ async def get_dataset_jsonl_openai_ft(
|
|
|
849
1038
|
version_id: Optional[str] = Query(
|
|
850
1039
|
default=None,
|
|
851
1040
|
description=(
|
|
852
|
-
"The ID of the dataset version
|
|
1041
|
+
"The ID of the dataset version (if omitted, returns data from the latest version)"
|
|
853
1042
|
),
|
|
854
1043
|
),
|
|
855
1044
|
) -> bytes:
|
|
@@ -859,7 +1048,7 @@ async def get_dataset_jsonl_openai_ft(
|
|
|
859
1048
|
session=session, id=id, version_id=version_id
|
|
860
1049
|
)
|
|
861
1050
|
except ValueError as e:
|
|
862
|
-
raise HTTPException(detail=str(e), status_code=
|
|
1051
|
+
raise HTTPException(detail=str(e), status_code=422)
|
|
863
1052
|
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
864
1053
|
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
865
1054
|
response.headers["content-disposition"] = (
|
|
@@ -876,7 +1065,7 @@ async def get_dataset_jsonl_openai_ft(
|
|
|
876
1065
|
responses=add_errors_to_responses(
|
|
877
1066
|
[
|
|
878
1067
|
{
|
|
879
|
-
"status_code":
|
|
1068
|
+
"status_code": 422,
|
|
880
1069
|
"description": "Invalid dataset or version ID",
|
|
881
1070
|
}
|
|
882
1071
|
]
|
|
@@ -889,7 +1078,7 @@ async def get_dataset_jsonl_openai_evals(
|
|
|
889
1078
|
version_id: Optional[str] = Query(
|
|
890
1079
|
default=None,
|
|
891
1080
|
description=(
|
|
892
|
-
"The ID of the dataset version
|
|
1081
|
+
"The ID of the dataset version (if omitted, returns data from the latest version)"
|
|
893
1082
|
),
|
|
894
1083
|
),
|
|
895
1084
|
) -> bytes:
|
|
@@ -899,7 +1088,7 @@ async def get_dataset_jsonl_openai_evals(
|
|
|
899
1088
|
session=session, id=id, version_id=version_id
|
|
900
1089
|
)
|
|
901
1090
|
except ValueError as e:
|
|
902
|
-
raise HTTPException(detail=str(e), status_code=
|
|
1091
|
+
raise HTTPException(detail=str(e), status_code=422)
|
|
903
1092
|
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
904
1093
|
encoded_dataset_name = urllib.parse.quote(dataset_name)
|
|
905
1094
|
response.headers["content-disposition"] = (
|
|
@@ -978,12 +1167,25 @@ def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision
|
|
|
978
1167
|
async def _get_db_examples(
|
|
979
1168
|
*, session: Any, id: str, version_id: Optional[str]
|
|
980
1169
|
) -> tuple[str, list[models.DatasetExampleRevision]]:
|
|
981
|
-
|
|
1170
|
+
try:
|
|
1171
|
+
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
|
|
1172
|
+
except Exception as e:
|
|
1173
|
+
raise HTTPException(
|
|
1174
|
+
detail=f"Invalid dataset ID format: {id}",
|
|
1175
|
+
status_code=422,
|
|
1176
|
+
) from e
|
|
1177
|
+
|
|
982
1178
|
dataset_version_id: Optional[int] = None
|
|
983
1179
|
if version_id:
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
1180
|
+
try:
|
|
1181
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
1182
|
+
GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
|
|
1183
|
+
)
|
|
1184
|
+
except Exception as e:
|
|
1185
|
+
raise HTTPException(
|
|
1186
|
+
detail=f"Invalid dataset version ID format: {version_id}",
|
|
1187
|
+
status_code=422,
|
|
1188
|
+
) from e
|
|
987
1189
|
latest_version = (
|
|
988
1190
|
select(
|
|
989
1191
|
models.DatasetExampleRevision.dataset_example_id,
|
|
@@ -1026,3 +1228,115 @@ async def _get_db_examples(
|
|
|
1026
1228
|
|
|
1027
1229
|
def _is_all_dict(seq: Sequence[Any]) -> bool:
|
|
1028
1230
|
return all(map(lambda obj: isinstance(obj, dict), seq))
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
# Split identifier helper types and functions
|
|
1234
|
+
class _SplitId(int): ...
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
_SplitIdentifier: TypeAlias = Union[_SplitId, str]
|
|
1238
|
+
|
|
1239
|
+
|
|
1240
|
+
def _parse_split_identifier(split_identifier: str) -> _SplitIdentifier:
|
|
1241
|
+
"""
|
|
1242
|
+
Parse a split identifier as either a GlobalID or a name.
|
|
1243
|
+
|
|
1244
|
+
Args:
|
|
1245
|
+
split_identifier: The identifier string (GlobalID or name)
|
|
1246
|
+
|
|
1247
|
+
Returns:
|
|
1248
|
+
Either a _SplitId or an Identifier
|
|
1249
|
+
|
|
1250
|
+
Raises:
|
|
1251
|
+
HTTPException: If the identifier format is invalid
|
|
1252
|
+
"""
|
|
1253
|
+
if not split_identifier:
|
|
1254
|
+
raise HTTPException(422, "Invalid split identifier")
|
|
1255
|
+
try:
|
|
1256
|
+
split_id = from_global_id_with_expected_type(
|
|
1257
|
+
GlobalID.from_id(split_identifier),
|
|
1258
|
+
DatasetSplitNodeType.__name__,
|
|
1259
|
+
)
|
|
1260
|
+
except ValueError:
|
|
1261
|
+
return split_identifier
|
|
1262
|
+
return _SplitId(split_id)
|
|
1263
|
+
|
|
1264
|
+
|
|
1265
|
+
async def _resolve_split_identifiers(
|
|
1266
|
+
session: AsyncSession,
|
|
1267
|
+
split_identifiers: list[str],
|
|
1268
|
+
) -> tuple[list[int], list[str]]:
|
|
1269
|
+
"""
|
|
1270
|
+
Resolve a list of split identifiers (IDs or names) to split IDs and names.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
session: The database session
|
|
1274
|
+
split_identifiers: List of split identifiers (GlobalIDs or names)
|
|
1275
|
+
|
|
1276
|
+
Returns:
|
|
1277
|
+
Tuple of (list of split IDs, list of split names)
|
|
1278
|
+
|
|
1279
|
+
Raises:
|
|
1280
|
+
HTTPException: If any split identifier is invalid or not found
|
|
1281
|
+
"""
|
|
1282
|
+
split_ids: list[int] = []
|
|
1283
|
+
split_names: list[str] = []
|
|
1284
|
+
|
|
1285
|
+
# Parse all identifiers first
|
|
1286
|
+
parsed_identifiers: list[_SplitIdentifier] = []
|
|
1287
|
+
for identifier_str in split_identifiers:
|
|
1288
|
+
parsed_identifiers.append(_parse_split_identifier(identifier_str.strip()))
|
|
1289
|
+
|
|
1290
|
+
# Separate IDs and names
|
|
1291
|
+
requested_ids: list[int] = []
|
|
1292
|
+
requested_names: list[str] = []
|
|
1293
|
+
for identifier in parsed_identifiers:
|
|
1294
|
+
if isinstance(identifier, _SplitId):
|
|
1295
|
+
requested_ids.append(int(identifier))
|
|
1296
|
+
elif isinstance(identifier, str):
|
|
1297
|
+
requested_names.append(identifier)
|
|
1298
|
+
else:
|
|
1299
|
+
assert_never(identifier)
|
|
1300
|
+
|
|
1301
|
+
# Query for splits by ID
|
|
1302
|
+
if requested_ids:
|
|
1303
|
+
id_results = await session.stream(
|
|
1304
|
+
select(models.DatasetSplit.id, models.DatasetSplit.name).where(
|
|
1305
|
+
models.DatasetSplit.id.in_(requested_ids)
|
|
1306
|
+
)
|
|
1307
|
+
)
|
|
1308
|
+
async for split_id, split_name in id_results:
|
|
1309
|
+
split_ids.append(split_id)
|
|
1310
|
+
split_names.append(split_name)
|
|
1311
|
+
|
|
1312
|
+
# Check if all requested IDs were found
|
|
1313
|
+
found_ids = set(split_ids[-len(requested_ids) :] if requested_ids else [])
|
|
1314
|
+
missing_ids = [sid for sid in requested_ids if sid not in found_ids]
|
|
1315
|
+
if missing_ids:
|
|
1316
|
+
raise HTTPException(
|
|
1317
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
1318
|
+
detail=f"Dataset splits not found for IDs: {', '.join(map(str, missing_ids))}",
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# Query for splits by name
|
|
1322
|
+
if requested_names:
|
|
1323
|
+
name_results = await session.stream(
|
|
1324
|
+
select(models.DatasetSplit.id, models.DatasetSplit.name).where(
|
|
1325
|
+
models.DatasetSplit.name.in_(requested_names)
|
|
1326
|
+
)
|
|
1327
|
+
)
|
|
1328
|
+
name_to_id: dict[str, int] = {}
|
|
1329
|
+
async for split_id, split_name in name_results:
|
|
1330
|
+
split_ids.append(split_id)
|
|
1331
|
+
split_names.append(split_name)
|
|
1332
|
+
name_to_id[split_name] = split_id
|
|
1333
|
+
|
|
1334
|
+
# Check if all requested names were found
|
|
1335
|
+
missing_names = [name for name in requested_names if name not in name_to_id]
|
|
1336
|
+
if missing_names:
|
|
1337
|
+
raise HTTPException(
|
|
1338
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
1339
|
+
detail=f"Dataset splits not found: {', '.join(missing_names)}",
|
|
1340
|
+
)
|
|
1341
|
+
|
|
1342
|
+
return split_ids, split_names
|