arize-phoenix 3.16.1__py3-none-any.whl → 7.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- arize_phoenix-7.7.1.dist-info/METADATA +261 -0
- arize_phoenix-7.7.1.dist-info/RECORD +345 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/WHEEL +1 -1
- arize_phoenix-7.7.1.dist-info/entry_points.txt +3 -0
- phoenix/__init__.py +86 -14
- phoenix/auth.py +309 -0
- phoenix/config.py +675 -45
- phoenix/core/model.py +32 -30
- phoenix/core/model_schema.py +102 -109
- phoenix/core/model_schema_adapter.py +48 -45
- phoenix/datetime_utils.py +24 -3
- phoenix/db/README.md +54 -0
- phoenix/db/__init__.py +4 -0
- phoenix/db/alembic.ini +85 -0
- phoenix/db/bulk_inserter.py +294 -0
- phoenix/db/engines.py +208 -0
- phoenix/db/enums.py +20 -0
- phoenix/db/facilitator.py +113 -0
- phoenix/db/helpers.py +159 -0
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/dataset.py +227 -0
- phoenix/db/insertion/document_annotation.py +171 -0
- phoenix/db/insertion/evaluation.py +191 -0
- phoenix/db/insertion/helpers.py +98 -0
- phoenix/db/insertion/span.py +193 -0
- phoenix/db/insertion/span_annotation.py +158 -0
- phoenix/db/insertion/trace_annotation.py +158 -0
- phoenix/db/insertion/types.py +256 -0
- phoenix/db/migrate.py +86 -0
- phoenix/db/migrations/data_migration_scripts/populate_project_sessions.py +199 -0
- phoenix/db/migrations/env.py +114 -0
- phoenix/db/migrations/script.py.mako +26 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +317 -0
- phoenix/db/migrations/versions/3be8647b87d8_add_token_columns_to_spans_table.py +126 -0
- phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py +66 -0
- phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +157 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +280 -0
- phoenix/db/models.py +807 -0
- phoenix/exceptions.py +5 -1
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +158 -0
- phoenix/experiments/evaluators/code_evaluators.py +184 -0
- phoenix/experiments/evaluators/llm_evaluators.py +473 -0
- phoenix/experiments/evaluators/utils.py +236 -0
- phoenix/experiments/functions.py +772 -0
- phoenix/experiments/tracing.py +86 -0
- phoenix/experiments/types.py +726 -0
- phoenix/experiments/utils.py +25 -0
- phoenix/inferences/__init__.py +0 -0
- phoenix/{datasets → inferences}/errors.py +6 -5
- phoenix/{datasets → inferences}/fixtures.py +49 -42
- phoenix/{datasets/dataset.py → inferences/inferences.py} +121 -105
- phoenix/{datasets → inferences}/schema.py +11 -11
- phoenix/{datasets → inferences}/validation.py +13 -14
- phoenix/logging/__init__.py +3 -0
- phoenix/logging/_config.py +90 -0
- phoenix/logging/_filter.py +6 -0
- phoenix/logging/_formatter.py +69 -0
- phoenix/metrics/__init__.py +5 -4
- phoenix/metrics/binning.py +4 -3
- phoenix/metrics/metrics.py +2 -1
- phoenix/metrics/mixins.py +7 -6
- phoenix/metrics/retrieval_metrics.py +2 -1
- phoenix/metrics/timeseries.py +5 -4
- phoenix/metrics/wrappers.py +9 -3
- phoenix/pointcloud/clustering.py +5 -5
- phoenix/pointcloud/pointcloud.py +7 -5
- phoenix/pointcloud/projectors.py +5 -6
- phoenix/pointcloud/umap_parameters.py +53 -52
- phoenix/server/api/README.md +28 -0
- phoenix/server/api/auth.py +44 -0
- phoenix/server/api/context.py +152 -9
- phoenix/server/api/dataloaders/__init__.py +91 -0
- phoenix/server/api/dataloaders/annotation_summaries.py +139 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/cache/__init__.py +3 -0
- phoenix/server/api/dataloaders/cache/two_tier_cache.py +68 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +131 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +38 -0
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +144 -0
- phoenix/server/api/dataloaders/document_evaluations.py +31 -0
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +89 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +79 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +58 -0
- phoenix/server/api/dataloaders/experiment_run_annotations.py +36 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +49 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +44 -0
- phoenix/server/api/dataloaders/latency_ms_quantile.py +188 -0
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +85 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/record_counts.py +116 -0
- phoenix/server/api/dataloaders/session_io.py +79 -0
- phoenix/server/api/dataloaders/session_num_traces.py +30 -0
- phoenix/server/api/dataloaders/session_num_traces_with_error.py +32 -0
- phoenix/server/api/dataloaders/session_token_usages.py +41 -0
- phoenix/server/api/dataloaders/session_trace_latency_ms_quantile.py +55 -0
- phoenix/server/api/dataloaders/span_annotations.py +26 -0
- phoenix/server/api/dataloaders/span_dataset_examples.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +57 -0
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/token_counts.py +124 -0
- phoenix/server/api/dataloaders/trace_by_trace_ids.py +25 -0
- phoenix/server/api/dataloaders/trace_root_spans.py +32 -0
- phoenix/server/api/dataloaders/user_roles.py +30 -0
- phoenix/server/api/dataloaders/users.py +33 -0
- phoenix/server/api/exceptions.py +48 -0
- phoenix/server/api/helpers/__init__.py +12 -0
- phoenix/server/api/helpers/dataset_helpers.py +217 -0
- phoenix/server/api/helpers/experiment_run_filters.py +763 -0
- phoenix/server/api/helpers/playground_clients.py +948 -0
- phoenix/server/api/helpers/playground_registry.py +70 -0
- phoenix/server/api/helpers/playground_spans.py +455 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ChatCompletionInput.py +38 -0
- phoenix/server/api/input_types/ChatCompletionMessageInput.py +24 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/ClusterInput.py +2 -2
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/CreateSpanAnnotationInput.py +18 -0
- phoenix/server/api/input_types/CreateTraceAnnotationInput.py +18 -0
- phoenix/server/api/input_types/DataQualityMetricInput.py +5 -2
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteAnnotationsInput.py +7 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +7 -0
- phoenix/server/api/input_types/DimensionFilter.py +4 -4
- phoenix/server/api/input_types/GenerativeModelInput.py +17 -0
- phoenix/server/api/input_types/Granularity.py +1 -1
- phoenix/server/api/input_types/InvocationParameters.py +162 -0
- phoenix/server/api/input_types/PatchAnnotationInput.py +19 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/input_types/PerformanceMetricInput.py +5 -2
- phoenix/server/api/input_types/ProjectSessionSort.py +29 -0
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/SpanSort.py +134 -69
- phoenix/server/api/input_types/TemplateOptions.py +10 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/input_types/UserRoleInput.py +9 -0
- phoenix/server/api/mutations/__init__.py +28 -0
- phoenix/server/api/mutations/api_key_mutations.py +167 -0
- phoenix/server/api/mutations/chat_mutations.py +593 -0
- phoenix/server/api/mutations/dataset_mutations.py +591 -0
- phoenix/server/api/mutations/experiment_mutations.py +75 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +21 -18
- phoenix/server/api/mutations/project_mutations.py +57 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +128 -0
- phoenix/server/api/mutations/trace_annotations_mutations.py +127 -0
- phoenix/server/api/mutations/user_mutations.py +329 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +17 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +738 -0
- phoenix/server/api/routers/__init__.py +11 -0
- phoenix/server/api/routers/auth.py +284 -0
- phoenix/server/api/routers/embeddings.py +26 -0
- phoenix/server/api/routers/oauth2.py +488 -0
- phoenix/server/api/routers/v1/__init__.py +64 -0
- phoenix/server/api/routers/v1/datasets.py +1017 -0
- phoenix/server/api/routers/v1/evaluations.py +362 -0
- phoenix/server/api/routers/v1/experiment_evaluations.py +115 -0
- phoenix/server/api/routers/v1/experiment_runs.py +167 -0
- phoenix/server/api/routers/v1/experiments.py +308 -0
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +267 -0
- phoenix/server/api/routers/v1/traces.py +208 -0
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/schema.py +44 -241
- phoenix/server/api/subscriptions.py +597 -0
- phoenix/server/api/types/Annotation.py +21 -0
- phoenix/server/api/types/AnnotationSummary.py +55 -0
- phoenix/server/api/types/AnnotatorKind.py +16 -0
- phoenix/server/api/types/ApiKey.py +27 -0
- phoenix/server/api/types/AuthMethod.py +9 -0
- phoenix/server/api/types/ChatCompletionMessageRole.py +11 -0
- phoenix/server/api/types/ChatCompletionSubscriptionPayload.py +46 -0
- phoenix/server/api/types/Cluster.py +25 -24
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/DataQualityMetric.py +31 -13
- phoenix/server/api/types/Dataset.py +288 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +32 -31
- phoenix/server/api/types/DocumentEvaluationSummary.py +9 -8
- phoenix/server/api/types/EmbeddingDimension.py +56 -49
- phoenix/server/api/types/Evaluation.py +25 -31
- phoenix/server/api/types/EvaluationSummary.py +30 -50
- phoenix/server/api/types/Event.py +20 -20
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +152 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +17 -0
- phoenix/server/api/types/ExperimentRun.py +119 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +56 -0
- phoenix/server/api/types/GenerativeModel.py +9 -0
- phoenix/server/api/types/GenerativeProvider.py +85 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/LabelFraction.py +7 -0
- phoenix/server/api/types/MimeType.py +2 -2
- phoenix/server/api/types/Model.py +54 -54
- phoenix/server/api/types/PerformanceMetric.py +8 -5
- phoenix/server/api/types/Project.py +407 -142
- phoenix/server/api/types/ProjectSession.py +139 -0
- phoenix/server/api/types/Segments.py +4 -4
- phoenix/server/api/types/Span.py +221 -176
- phoenix/server/api/types/SpanAnnotation.py +43 -0
- phoenix/server/api/types/SpanIOValue.py +15 -0
- phoenix/server/api/types/SystemApiKey.py +9 -0
- phoenix/server/api/types/TemplateLanguage.py +10 -0
- phoenix/server/api/types/TimeSeries.py +19 -15
- phoenix/server/api/types/TokenUsage.py +11 -0
- phoenix/server/api/types/Trace.py +154 -0
- phoenix/server/api/types/TraceAnnotation.py +45 -0
- phoenix/server/api/types/UMAPPoints.py +7 -7
- phoenix/server/api/types/User.py +60 -0
- phoenix/server/api/types/UserApiKey.py +45 -0
- phoenix/server/api/types/UserRole.py +15 -0
- phoenix/server/api/types/node.py +4 -112
- phoenix/server/api/types/pagination.py +156 -57
- phoenix/server/api/utils.py +34 -0
- phoenix/server/app.py +864 -115
- phoenix/server/bearer_auth.py +163 -0
- phoenix/server/dml_event.py +136 -0
- phoenix/server/dml_event_handler.py +256 -0
- phoenix/server/email/__init__.py +0 -0
- phoenix/server/email/sender.py +97 -0
- phoenix/server/email/templates/__init__.py +0 -0
- phoenix/server/email/templates/password_reset.html +19 -0
- phoenix/server/email/types.py +11 -0
- phoenix/server/grpc_server.py +102 -0
- phoenix/server/jwt_store.py +505 -0
- phoenix/server/main.py +305 -116
- phoenix/server/oauth2.py +52 -0
- phoenix/server/openapi/__init__.py +0 -0
- phoenix/server/prometheus.py +111 -0
- phoenix/server/rate_limiters.py +188 -0
- phoenix/server/static/.vite/manifest.json +87 -0
- phoenix/server/static/assets/components-Cy9nwIvF.js +2125 -0
- phoenix/server/static/assets/index-BKvHIxkk.js +113 -0
- phoenix/server/static/assets/pages-CUi2xCVQ.js +4449 -0
- phoenix/server/static/assets/vendor-DvC8cT4X.js +894 -0
- phoenix/server/static/assets/vendor-DxkFTwjz.css +1 -0
- phoenix/server/static/assets/vendor-arizeai-Do1793cv.js +662 -0
- phoenix/server/static/assets/vendor-codemirror-BzwZPyJM.js +24 -0
- phoenix/server/static/assets/vendor-recharts-_Jb7JjhG.js +59 -0
- phoenix/server/static/assets/vendor-shiki-Cl9QBraO.js +5 -0
- phoenix/server/static/assets/vendor-three-DwGkEfCM.js +2998 -0
- phoenix/server/telemetry.py +68 -0
- phoenix/server/templates/index.html +82 -23
- phoenix/server/thread_server.py +3 -3
- phoenix/server/types.py +275 -0
- phoenix/services.py +27 -18
- phoenix/session/client.py +743 -68
- phoenix/session/data_extractor.py +31 -7
- phoenix/session/evaluation.py +3 -9
- phoenix/session/session.py +263 -219
- phoenix/settings.py +22 -0
- phoenix/trace/__init__.py +2 -22
- phoenix/trace/attributes.py +338 -0
- phoenix/trace/dsl/README.md +116 -0
- phoenix/trace/dsl/filter.py +663 -213
- phoenix/trace/dsl/helpers.py +73 -21
- phoenix/trace/dsl/query.py +574 -201
- phoenix/trace/exporter.py +24 -19
- phoenix/trace/fixtures.py +368 -32
- phoenix/trace/otel.py +71 -219
- phoenix/trace/projects.py +3 -2
- phoenix/trace/schemas.py +33 -11
- phoenix/trace/span_evaluations.py +21 -16
- phoenix/trace/span_json_decoder.py +6 -4
- phoenix/trace/span_json_encoder.py +2 -2
- phoenix/trace/trace_dataset.py +47 -32
- phoenix/trace/utils.py +21 -4
- phoenix/utilities/__init__.py +0 -26
- phoenix/utilities/client.py +132 -0
- phoenix/utilities/deprecation.py +31 -0
- phoenix/utilities/error_handling.py +3 -2
- phoenix/utilities/json.py +109 -0
- phoenix/utilities/logging.py +8 -0
- phoenix/utilities/project.py +2 -2
- phoenix/utilities/re.py +49 -0
- phoenix/utilities/span_store.py +0 -23
- phoenix/utilities/template_formatters.py +99 -0
- phoenix/version.py +1 -1
- arize_phoenix-3.16.1.dist-info/METADATA +0 -495
- arize_phoenix-3.16.1.dist-info/RECORD +0 -178
- phoenix/core/project.py +0 -619
- phoenix/core/traces.py +0 -96
- phoenix/experimental/evals/__init__.py +0 -73
- phoenix/experimental/evals/evaluators.py +0 -413
- phoenix/experimental/evals/functions/__init__.py +0 -4
- phoenix/experimental/evals/functions/classify.py +0 -453
- phoenix/experimental/evals/functions/executor.py +0 -353
- phoenix/experimental/evals/functions/generate.py +0 -138
- phoenix/experimental/evals/functions/processing.py +0 -76
- phoenix/experimental/evals/models/__init__.py +0 -14
- phoenix/experimental/evals/models/anthropic.py +0 -175
- phoenix/experimental/evals/models/base.py +0 -170
- phoenix/experimental/evals/models/bedrock.py +0 -221
- phoenix/experimental/evals/models/litellm.py +0 -134
- phoenix/experimental/evals/models/openai.py +0 -448
- phoenix/experimental/evals/models/rate_limiters.py +0 -246
- phoenix/experimental/evals/models/vertex.py +0 -173
- phoenix/experimental/evals/models/vertexai.py +0 -186
- phoenix/experimental/evals/retrievals.py +0 -96
- phoenix/experimental/evals/templates/__init__.py +0 -50
- phoenix/experimental/evals/templates/default_templates.py +0 -472
- phoenix/experimental/evals/templates/template.py +0 -195
- phoenix/experimental/evals/utils/__init__.py +0 -172
- phoenix/experimental/evals/utils/threads.py +0 -27
- phoenix/server/api/helpers.py +0 -11
- phoenix/server/api/routers/evaluation_handler.py +0 -109
- phoenix/server/api/routers/span_handler.py +0 -70
- phoenix/server/api/routers/trace_handler.py +0 -60
- phoenix/server/api/types/DatasetRole.py +0 -23
- phoenix/server/static/index.css +0 -6
- phoenix/server/static/index.js +0 -7447
- phoenix/storage/span_store/__init__.py +0 -23
- phoenix/storage/span_store/text_file.py +0 -85
- phoenix/trace/dsl/missing.py +0 -60
- phoenix/trace/langchain/__init__.py +0 -3
- phoenix/trace/langchain/instrumentor.py +0 -35
- phoenix/trace/llama_index/__init__.py +0 -3
- phoenix/trace/llama_index/callback.py +0 -102
- phoenix/trace/openai/__init__.py +0 -3
- phoenix/trace/openai/instrumentor.py +0 -30
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-3.16.1.dist-info → arize_phoenix-7.7.1.dist-info}/licenses/LICENSE +0 -0
- /phoenix/{datasets → db/insertion}/__init__.py +0 -0
- /phoenix/{experimental → db/migrations}/__init__.py +0 -0
- /phoenix/{storage → db/migrations/data_migration_scripts}/__init__.py +0 -0
|
@@ -0,0 +1,1017 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import gzip
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import zlib
|
|
7
|
+
from asyncio import QueueFull
|
|
8
|
+
from collections import Counter
|
|
9
|
+
from collections.abc import Awaitable, Callable, Coroutine, Iterator, Mapping, Sequence
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from functools import partial
|
|
13
|
+
from typing import Any, Optional, Union, cast
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import pyarrow as pa
|
|
17
|
+
from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
|
|
18
|
+
from fastapi.responses import PlainTextResponse, StreamingResponse
|
|
19
|
+
from sqlalchemy import and_, delete, func, select
|
|
20
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
21
|
+
from starlette.concurrency import run_in_threadpool
|
|
22
|
+
from starlette.datastructures import FormData, UploadFile
|
|
23
|
+
from starlette.requests import Request
|
|
24
|
+
from starlette.responses import Response
|
|
25
|
+
from starlette.status import (
|
|
26
|
+
HTTP_200_OK,
|
|
27
|
+
HTTP_204_NO_CONTENT,
|
|
28
|
+
HTTP_404_NOT_FOUND,
|
|
29
|
+
HTTP_409_CONFLICT,
|
|
30
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
31
|
+
HTTP_429_TOO_MANY_REQUESTS,
|
|
32
|
+
)
|
|
33
|
+
from strawberry.relay import GlobalID
|
|
34
|
+
from typing_extensions import TypeAlias, assert_never
|
|
35
|
+
|
|
36
|
+
from phoenix.db import models
|
|
37
|
+
from phoenix.db.helpers import get_eval_trace_ids_for_datasets, get_project_names_for_datasets
|
|
38
|
+
from phoenix.db.insertion.dataset import (
|
|
39
|
+
DatasetAction,
|
|
40
|
+
DatasetExampleAdditionEvent,
|
|
41
|
+
ExampleContent,
|
|
42
|
+
add_dataset_examples,
|
|
43
|
+
)
|
|
44
|
+
from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
|
|
45
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
|
|
46
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
|
|
47
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
48
|
+
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
49
|
+
from phoenix.server.dml_event import DatasetInsertEvent
|
|
50
|
+
|
|
51
|
+
from .pydantic_compat import V1RoutesBaseModel
|
|
52
|
+
from .utils import (
|
|
53
|
+
PaginatedResponseBody,
|
|
54
|
+
ResponseBody,
|
|
55
|
+
add_errors_to_responses,
|
|
56
|
+
add_text_csv_content_to_responses,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
61
|
+
DATASET_NODE_NAME = DatasetNodeType.__name__
|
|
62
|
+
DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
router = APIRouter(tags=["datasets"])
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Dataset(V1RoutesBaseModel):
|
|
69
|
+
id: str
|
|
70
|
+
name: str
|
|
71
|
+
description: Optional[str]
|
|
72
|
+
metadata: dict[str, Any]
|
|
73
|
+
created_at: datetime
|
|
74
|
+
updated_at: datetime
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@router.get(
|
|
82
|
+
"/datasets",
|
|
83
|
+
operation_id="listDatasets",
|
|
84
|
+
summary="List datasets",
|
|
85
|
+
responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
86
|
+
)
|
|
87
|
+
async def list_datasets(
|
|
88
|
+
request: Request,
|
|
89
|
+
cursor: Optional[str] = Query(
|
|
90
|
+
default=None,
|
|
91
|
+
description="Cursor for pagination",
|
|
92
|
+
),
|
|
93
|
+
name: Optional[str] = Query(default=None, description="An optional dataset name to filter by"),
|
|
94
|
+
limit: int = Query(
|
|
95
|
+
default=10, description="The max number of datasets to return at a time.", gt=0
|
|
96
|
+
),
|
|
97
|
+
) -> ListDatasetsResponseBody:
|
|
98
|
+
async with request.app.state.db() as session:
|
|
99
|
+
query = select(models.Dataset).order_by(models.Dataset.id.desc())
|
|
100
|
+
|
|
101
|
+
if cursor:
|
|
102
|
+
try:
|
|
103
|
+
cursor_id = GlobalID.from_id(cursor).node_id
|
|
104
|
+
query = query.filter(models.Dataset.id <= int(cursor_id))
|
|
105
|
+
except ValueError:
|
|
106
|
+
raise HTTPException(
|
|
107
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
108
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
109
|
+
)
|
|
110
|
+
if name:
|
|
111
|
+
query = query.filter(models.Dataset.name == name)
|
|
112
|
+
|
|
113
|
+
query = query.limit(limit + 1)
|
|
114
|
+
result = await session.execute(query)
|
|
115
|
+
datasets = result.scalars().all()
|
|
116
|
+
|
|
117
|
+
if not datasets:
|
|
118
|
+
return ListDatasetsResponseBody(next_cursor=None, data=[])
|
|
119
|
+
|
|
120
|
+
next_cursor = None
|
|
121
|
+
if len(datasets) == limit + 1:
|
|
122
|
+
next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
|
|
123
|
+
datasets = datasets[:-1]
|
|
124
|
+
|
|
125
|
+
data = []
|
|
126
|
+
for dataset in datasets:
|
|
127
|
+
data.append(
|
|
128
|
+
Dataset(
|
|
129
|
+
id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
|
|
130
|
+
name=dataset.name,
|
|
131
|
+
description=dataset.description,
|
|
132
|
+
metadata=dataset.metadata_,
|
|
133
|
+
created_at=dataset.created_at,
|
|
134
|
+
updated_at=dataset.updated_at,
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return ListDatasetsResponseBody(next_cursor=next_cursor, data=data)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@router.delete(
|
|
142
|
+
"/datasets/{id}",
|
|
143
|
+
operation_id="deleteDatasetById",
|
|
144
|
+
summary="Delete dataset by ID",
|
|
145
|
+
status_code=HTTP_204_NO_CONTENT,
|
|
146
|
+
responses=add_errors_to_responses(
|
|
147
|
+
[
|
|
148
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
|
|
149
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
|
|
150
|
+
]
|
|
151
|
+
),
|
|
152
|
+
)
|
|
153
|
+
async def delete_dataset(
|
|
154
|
+
request: Request, id: str = Path(description="The ID of the dataset to delete.")
|
|
155
|
+
) -> None:
|
|
156
|
+
if id:
|
|
157
|
+
try:
|
|
158
|
+
dataset_id = from_global_id_with_expected_type(
|
|
159
|
+
GlobalID.from_id(id),
|
|
160
|
+
DATASET_NODE_NAME,
|
|
161
|
+
)
|
|
162
|
+
except ValueError:
|
|
163
|
+
raise HTTPException(
|
|
164
|
+
detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
168
|
+
project_names_stmt = get_project_names_for_datasets(dataset_id)
|
|
169
|
+
eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
|
|
170
|
+
stmt = (
|
|
171
|
+
delete(models.Dataset).where(models.Dataset.id == dataset_id).returning(models.Dataset.id)
|
|
172
|
+
)
|
|
173
|
+
async with request.app.state.db() as session:
|
|
174
|
+
project_names = await session.scalars(project_names_stmt)
|
|
175
|
+
eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
|
|
176
|
+
if (await session.scalar(stmt)) is None:
|
|
177
|
+
raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
|
|
178
|
+
tasks = BackgroundTasks()
|
|
179
|
+
tasks.add_task(delete_projects, request.app.state.db, *project_names)
|
|
180
|
+
tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class DatasetWithExampleCount(Dataset):
|
|
184
|
+
example_count: int
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@router.get(
|
|
192
|
+
"/datasets/{id}",
|
|
193
|
+
operation_id="getDataset",
|
|
194
|
+
summary="Get dataset by ID",
|
|
195
|
+
responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
|
|
196
|
+
)
|
|
197
|
+
async def get_dataset(
|
|
198
|
+
request: Request, id: str = Path(description="The ID of the dataset")
|
|
199
|
+
) -> GetDatasetResponseBody:
|
|
200
|
+
dataset_id = GlobalID.from_id(id)
|
|
201
|
+
|
|
202
|
+
if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
|
|
203
|
+
raise HTTPException(
|
|
204
|
+
detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
205
|
+
)
|
|
206
|
+
async with request.app.state.db() as session:
|
|
207
|
+
result = await session.execute(
|
|
208
|
+
select(models.Dataset, models.Dataset.example_count).filter(
|
|
209
|
+
models.Dataset.id == int(dataset_id.node_id)
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
dataset_query = result.first()
|
|
213
|
+
dataset = dataset_query[0] if dataset_query else None
|
|
214
|
+
example_count = dataset_query[1] if dataset_query else 0
|
|
215
|
+
if dataset is None:
|
|
216
|
+
raise HTTPException(
|
|
217
|
+
detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
dataset = DatasetWithExampleCount(
|
|
221
|
+
id=str(dataset_id),
|
|
222
|
+
name=dataset.name,
|
|
223
|
+
description=dataset.description,
|
|
224
|
+
metadata=dataset.metadata_,
|
|
225
|
+
created_at=dataset.created_at,
|
|
226
|
+
updated_at=dataset.updated_at,
|
|
227
|
+
example_count=example_count,
|
|
228
|
+
)
|
|
229
|
+
return GetDatasetResponseBody(data=dataset)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class DatasetVersion(V1RoutesBaseModel):
|
|
233
|
+
version_id: str
|
|
234
|
+
description: Optional[str]
|
|
235
|
+
metadata: dict[str, Any]
|
|
236
|
+
created_at: datetime
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@router.get(
|
|
244
|
+
"/datasets/{id}/versions",
|
|
245
|
+
operation_id="listDatasetVersionsByDatasetId",
|
|
246
|
+
summary="List dataset versions",
|
|
247
|
+
responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
248
|
+
)
|
|
249
|
+
async def list_dataset_versions(
|
|
250
|
+
request: Request,
|
|
251
|
+
id: str = Path(description="The ID of the dataset"),
|
|
252
|
+
cursor: Optional[str] = Query(
|
|
253
|
+
default=None,
|
|
254
|
+
description="Cursor for pagination",
|
|
255
|
+
),
|
|
256
|
+
limit: int = Query(
|
|
257
|
+
default=10, description="The max number of dataset versions to return at a time", gt=0
|
|
258
|
+
),
|
|
259
|
+
) -> ListDatasetVersionsResponseBody:
|
|
260
|
+
if id:
|
|
261
|
+
try:
|
|
262
|
+
dataset_id = from_global_id_with_expected_type(
|
|
263
|
+
GlobalID.from_id(id),
|
|
264
|
+
DATASET_NODE_NAME,
|
|
265
|
+
)
|
|
266
|
+
except ValueError:
|
|
267
|
+
raise HTTPException(
|
|
268
|
+
detail=f"Invalid Dataset ID: {id}",
|
|
269
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
raise HTTPException(
|
|
273
|
+
detail="Missing Dataset ID",
|
|
274
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
275
|
+
)
|
|
276
|
+
stmt = (
|
|
277
|
+
select(models.DatasetVersion)
|
|
278
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
279
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
280
|
+
.limit(limit + 1)
|
|
281
|
+
)
|
|
282
|
+
if cursor:
|
|
283
|
+
try:
|
|
284
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
285
|
+
GlobalID.from_id(cursor), DATASET_VERSION_NODE_NAME
|
|
286
|
+
)
|
|
287
|
+
except ValueError:
|
|
288
|
+
raise HTTPException(
|
|
289
|
+
detail=f"Invalid cursor: {cursor}",
|
|
290
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
291
|
+
)
|
|
292
|
+
max_dataset_version_id = (
|
|
293
|
+
select(models.DatasetVersion.id)
|
|
294
|
+
.where(models.DatasetVersion.id == dataset_version_id)
|
|
295
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
296
|
+
).scalar_subquery()
|
|
297
|
+
stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
|
|
298
|
+
async with request.app.state.db() as session:
|
|
299
|
+
data = [
|
|
300
|
+
DatasetVersion(
|
|
301
|
+
version_id=str(GlobalID(DATASET_VERSION_NODE_NAME, str(version.id))),
|
|
302
|
+
description=version.description,
|
|
303
|
+
metadata=version.metadata_,
|
|
304
|
+
created_at=version.created_at,
|
|
305
|
+
)
|
|
306
|
+
async for version in await session.stream_scalars(stmt)
|
|
307
|
+
]
|
|
308
|
+
next_cursor = data.pop().version_id if len(data) == limit + 1 else None
|
|
309
|
+
return ListDatasetVersionsResponseBody(data=data, next_cursor=next_cursor)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class UploadDatasetData(V1RoutesBaseModel):
|
|
313
|
+
dataset_id: str
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
317
|
+
pass
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@router.post(
|
|
321
|
+
"/datasets/upload",
|
|
322
|
+
operation_id="uploadDataset",
|
|
323
|
+
summary="Upload dataset from JSON, CSV, or PyArrow",
|
|
324
|
+
responses=add_errors_to_responses(
|
|
325
|
+
[
|
|
326
|
+
{
|
|
327
|
+
"status_code": HTTP_409_CONFLICT,
|
|
328
|
+
"description": "Dataset of the same name already exists",
|
|
329
|
+
},
|
|
330
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
|
|
331
|
+
]
|
|
332
|
+
),
|
|
333
|
+
# FastAPI cannot generate the request body portion of the OpenAPI schema for
|
|
334
|
+
# routes that accept multiple request content types, so we have to provide
|
|
335
|
+
# this part of the schema manually. For context, see
|
|
336
|
+
# https://github.com/tiangolo/fastapi/discussions/7786 and
|
|
337
|
+
# https://github.com/tiangolo/fastapi/issues/990
|
|
338
|
+
openapi_extra={
|
|
339
|
+
"requestBody": {
|
|
340
|
+
"content": {
|
|
341
|
+
"application/json": {
|
|
342
|
+
"schema": {
|
|
343
|
+
"type": "object",
|
|
344
|
+
"required": ["name", "inputs"],
|
|
345
|
+
"properties": {
|
|
346
|
+
"action": {"type": "string", "enum": ["create", "append"]},
|
|
347
|
+
"name": {"type": "string"},
|
|
348
|
+
"description": {"type": "string"},
|
|
349
|
+
"inputs": {"type": "array", "items": {"type": "object"}},
|
|
350
|
+
"outputs": {"type": "array", "items": {"type": "object"}},
|
|
351
|
+
"metadata": {"type": "array", "items": {"type": "object"}},
|
|
352
|
+
},
|
|
353
|
+
}
|
|
354
|
+
},
|
|
355
|
+
"multipart/form-data": {
|
|
356
|
+
"schema": {
|
|
357
|
+
"type": "object",
|
|
358
|
+
"required": ["name", "input_keys[]", "output_keys[]", "file"],
|
|
359
|
+
"properties": {
|
|
360
|
+
"action": {"type": "string", "enum": ["create", "append"]},
|
|
361
|
+
"name": {"type": "string"},
|
|
362
|
+
"description": {"type": "string"},
|
|
363
|
+
"input_keys[]": {
|
|
364
|
+
"type": "array",
|
|
365
|
+
"items": {"type": "string"},
|
|
366
|
+
"uniqueItems": True,
|
|
367
|
+
},
|
|
368
|
+
"output_keys[]": {
|
|
369
|
+
"type": "array",
|
|
370
|
+
"items": {"type": "string"},
|
|
371
|
+
"uniqueItems": True,
|
|
372
|
+
},
|
|
373
|
+
"metadata_keys[]": {
|
|
374
|
+
"type": "array",
|
|
375
|
+
"items": {"type": "string"},
|
|
376
|
+
"uniqueItems": True,
|
|
377
|
+
},
|
|
378
|
+
"file": {"type": "string", "format": "binary"},
|
|
379
|
+
},
|
|
380
|
+
}
|
|
381
|
+
},
|
|
382
|
+
}
|
|
383
|
+
},
|
|
384
|
+
},
|
|
385
|
+
)
|
|
386
|
+
async def upload_dataset(
|
|
387
|
+
request: Request,
|
|
388
|
+
sync: bool = Query(
|
|
389
|
+
default=False,
|
|
390
|
+
description="If true, fulfill request synchronously and return JSON containing dataset_id.",
|
|
391
|
+
),
|
|
392
|
+
) -> Optional[UploadDatasetResponseBody]:
|
|
393
|
+
request_content_type = request.headers["content-type"]
|
|
394
|
+
examples: Union[Examples, Awaitable[Examples]]
|
|
395
|
+
if request_content_type.startswith("application/json"):
|
|
396
|
+
try:
|
|
397
|
+
examples, action, name, description = await run_in_threadpool(
|
|
398
|
+
_process_json, await request.json()
|
|
399
|
+
)
|
|
400
|
+
except ValueError as e:
|
|
401
|
+
raise HTTPException(
|
|
402
|
+
detail=str(e),
|
|
403
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
404
|
+
)
|
|
405
|
+
if action is DatasetAction.CREATE:
|
|
406
|
+
async with request.app.state.db() as session:
|
|
407
|
+
if await _check_table_exists(session, name):
|
|
408
|
+
raise HTTPException(
|
|
409
|
+
detail=f"Dataset with the same name already exists: {name=}",
|
|
410
|
+
status_code=HTTP_409_CONFLICT,
|
|
411
|
+
)
|
|
412
|
+
elif request_content_type.startswith("multipart/form-data"):
|
|
413
|
+
async with request.form() as form:
|
|
414
|
+
try:
|
|
415
|
+
(
|
|
416
|
+
action,
|
|
417
|
+
name,
|
|
418
|
+
description,
|
|
419
|
+
input_keys,
|
|
420
|
+
output_keys,
|
|
421
|
+
metadata_keys,
|
|
422
|
+
file,
|
|
423
|
+
) = await _parse_form_data(form)
|
|
424
|
+
except ValueError as e:
|
|
425
|
+
raise HTTPException(
|
|
426
|
+
detail=str(e),
|
|
427
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
428
|
+
)
|
|
429
|
+
if action is DatasetAction.CREATE:
|
|
430
|
+
async with request.app.state.db() as session:
|
|
431
|
+
if await _check_table_exists(session, name):
|
|
432
|
+
raise HTTPException(
|
|
433
|
+
detail=f"Dataset with the same name already exists: {name=}",
|
|
434
|
+
status_code=HTTP_409_CONFLICT,
|
|
435
|
+
)
|
|
436
|
+
content = await file.read()
|
|
437
|
+
try:
|
|
438
|
+
file_content_type = FileContentType(file.content_type)
|
|
439
|
+
if file_content_type is FileContentType.CSV:
|
|
440
|
+
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
441
|
+
examples = await _process_csv(
|
|
442
|
+
content, encoding, input_keys, output_keys, metadata_keys
|
|
443
|
+
)
|
|
444
|
+
elif file_content_type is FileContentType.PYARROW:
|
|
445
|
+
examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
|
|
446
|
+
else:
|
|
447
|
+
assert_never(file_content_type)
|
|
448
|
+
except ValueError as e:
|
|
449
|
+
raise HTTPException(
|
|
450
|
+
detail=str(e),
|
|
451
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
452
|
+
)
|
|
453
|
+
else:
|
|
454
|
+
raise HTTPException(
|
|
455
|
+
detail="Invalid request Content-Type",
|
|
456
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
457
|
+
)
|
|
458
|
+
operation = cast(
|
|
459
|
+
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
460
|
+
partial(
|
|
461
|
+
add_dataset_examples,
|
|
462
|
+
examples=examples,
|
|
463
|
+
action=action,
|
|
464
|
+
name=name,
|
|
465
|
+
description=description,
|
|
466
|
+
),
|
|
467
|
+
)
|
|
468
|
+
if sync:
|
|
469
|
+
async with request.app.state.db() as session:
|
|
470
|
+
dataset_id = (await operation(session)).dataset_id
|
|
471
|
+
request.state.event_queue.put(DatasetInsertEvent((dataset_id,)))
|
|
472
|
+
return UploadDatasetResponseBody(
|
|
473
|
+
data=UploadDatasetData(dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))))
|
|
474
|
+
)
|
|
475
|
+
try:
|
|
476
|
+
request.state.enqueue_operation(operation)
|
|
477
|
+
except QueueFull:
|
|
478
|
+
if isinstance(examples, Coroutine):
|
|
479
|
+
examples.close()
|
|
480
|
+
raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
|
|
481
|
+
return None
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class FileContentType(Enum):
|
|
485
|
+
CSV = "text/csv"
|
|
486
|
+
PYARROW = "application/x-pandas-pyarrow"
|
|
487
|
+
|
|
488
|
+
@classmethod
|
|
489
|
+
def _missing_(cls, v: Any) -> "FileContentType":
|
|
490
|
+
if isinstance(v, str) and v and v.isascii() and not v.islower():
|
|
491
|
+
return cls(v.lower())
|
|
492
|
+
raise ValueError(f"Invalid file content type: {v}")
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class FileContentEncoding(Enum):
|
|
496
|
+
NONE = "none"
|
|
497
|
+
GZIP = "gzip"
|
|
498
|
+
DEFLATE = "deflate"
|
|
499
|
+
|
|
500
|
+
@classmethod
|
|
501
|
+
def _missing_(cls, v: Any) -> "FileContentEncoding":
|
|
502
|
+
if v is None:
|
|
503
|
+
return cls("none")
|
|
504
|
+
if isinstance(v, str) and v and v.isascii() and not v.islower():
|
|
505
|
+
return cls(v.lower())
|
|
506
|
+
raise ValueError(f"Invalid file content encoding: {v}")
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
Name: TypeAlias = str
|
|
510
|
+
Description: TypeAlias = Optional[str]
|
|
511
|
+
InputKeys: TypeAlias = frozenset[str]
|
|
512
|
+
OutputKeys: TypeAlias = frozenset[str]
|
|
513
|
+
MetadataKeys: TypeAlias = frozenset[str]
|
|
514
|
+
DatasetId: TypeAlias = int
|
|
515
|
+
Examples: TypeAlias = Iterator[ExampleContent]
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def _process_json(
|
|
519
|
+
data: Mapping[str, Any],
|
|
520
|
+
) -> tuple[Examples, DatasetAction, Name, Description]:
|
|
521
|
+
name = data.get("name")
|
|
522
|
+
if not name:
|
|
523
|
+
raise ValueError("Dataset name is required")
|
|
524
|
+
description = data.get("description") or ""
|
|
525
|
+
inputs = data.get("inputs")
|
|
526
|
+
if not inputs:
|
|
527
|
+
raise ValueError("input is required")
|
|
528
|
+
if not isinstance(inputs, list) or not _is_all_dict(inputs):
|
|
529
|
+
raise ValueError("Input should be a list containing only dictionary objects")
|
|
530
|
+
outputs, metadata = data.get("outputs"), data.get("metadata")
|
|
531
|
+
for k, v in {"outputs": outputs, "metadata": metadata}.items():
|
|
532
|
+
if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
|
|
533
|
+
raise ValueError(
|
|
534
|
+
f"{k} should be a list of same length as input containing only dictionary objects"
|
|
535
|
+
)
|
|
536
|
+
examples: list[ExampleContent] = []
|
|
537
|
+
for i, obj in enumerate(inputs):
|
|
538
|
+
example = ExampleContent(
|
|
539
|
+
input=obj,
|
|
540
|
+
output=outputs[i] if outputs else {},
|
|
541
|
+
metadata=metadata[i] if metadata else {},
|
|
542
|
+
)
|
|
543
|
+
examples.append(example)
|
|
544
|
+
action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
|
|
545
|
+
return iter(examples), action, name, description
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
async def _process_csv(
|
|
549
|
+
content: bytes,
|
|
550
|
+
content_encoding: FileContentEncoding,
|
|
551
|
+
input_keys: InputKeys,
|
|
552
|
+
output_keys: OutputKeys,
|
|
553
|
+
metadata_keys: MetadataKeys,
|
|
554
|
+
) -> Examples:
|
|
555
|
+
if content_encoding is FileContentEncoding.GZIP:
|
|
556
|
+
content = await run_in_threadpool(gzip.decompress, content)
|
|
557
|
+
elif content_encoding is FileContentEncoding.DEFLATE:
|
|
558
|
+
content = await run_in_threadpool(zlib.decompress, content)
|
|
559
|
+
elif content_encoding is not FileContentEncoding.NONE:
|
|
560
|
+
assert_never(content_encoding)
|
|
561
|
+
reader = await run_in_threadpool(lambda c: csv.DictReader(io.StringIO(c.decode())), content)
|
|
562
|
+
if reader.fieldnames is None:
|
|
563
|
+
raise ValueError("Missing CSV column header")
|
|
564
|
+
(header, freq), *_ = Counter(reader.fieldnames).most_common(1)
|
|
565
|
+
if freq > 1:
|
|
566
|
+
raise ValueError(f"Duplicated column header in CSV file: {header}")
|
|
567
|
+
column_headers = frozenset(reader.fieldnames)
|
|
568
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
569
|
+
return (
|
|
570
|
+
ExampleContent(
|
|
571
|
+
input={k: row.get(k) for k in input_keys},
|
|
572
|
+
output={k: row.get(k) for k in output_keys},
|
|
573
|
+
metadata={k: row.get(k) for k in metadata_keys},
|
|
574
|
+
)
|
|
575
|
+
for row in iter(reader)
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
async def _process_pyarrow(
|
|
580
|
+
content: bytes,
|
|
581
|
+
input_keys: InputKeys,
|
|
582
|
+
output_keys: OutputKeys,
|
|
583
|
+
metadata_keys: MetadataKeys,
|
|
584
|
+
) -> Awaitable[Examples]:
|
|
585
|
+
try:
|
|
586
|
+
reader = pa.ipc.open_stream(content)
|
|
587
|
+
except pa.ArrowInvalid as e:
|
|
588
|
+
raise ValueError("File is not valid pyarrow") from e
|
|
589
|
+
column_headers = frozenset(reader.schema.names)
|
|
590
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
591
|
+
|
|
592
|
+
def get_examples() -> Iterator[ExampleContent]:
|
|
593
|
+
for row in reader.read_pandas().to_dict(orient="records"):
|
|
594
|
+
yield ExampleContent(
|
|
595
|
+
input={k: row.get(k) for k in input_keys},
|
|
596
|
+
output={k: row.get(k) for k in output_keys},
|
|
597
|
+
metadata={k: row.get(k) for k in metadata_keys},
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return run_in_threadpool(get_examples)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
async def _check_table_exists(session: AsyncSession, name: str) -> bool:
|
|
604
|
+
return bool(
|
|
605
|
+
await session.scalar(
|
|
606
|
+
select(1).select_from(models.Dataset).where(models.Dataset.name == name)
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def _check_keys_exist(
|
|
612
|
+
column_headers: frozenset[str],
|
|
613
|
+
input_keys: InputKeys,
|
|
614
|
+
output_keys: OutputKeys,
|
|
615
|
+
metadata_keys: MetadataKeys,
|
|
616
|
+
) -> None:
|
|
617
|
+
for desc, keys in (
|
|
618
|
+
("input", input_keys),
|
|
619
|
+
("output", output_keys),
|
|
620
|
+
("metadata", metadata_keys),
|
|
621
|
+
):
|
|
622
|
+
if keys and (diff := keys.difference(column_headers)):
|
|
623
|
+
raise ValueError(f"{desc} keys not found in column headers: {diff}")
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
async def _parse_form_data(
|
|
627
|
+
form: FormData,
|
|
628
|
+
) -> tuple[
|
|
629
|
+
DatasetAction,
|
|
630
|
+
Name,
|
|
631
|
+
Description,
|
|
632
|
+
InputKeys,
|
|
633
|
+
OutputKeys,
|
|
634
|
+
MetadataKeys,
|
|
635
|
+
UploadFile,
|
|
636
|
+
]:
|
|
637
|
+
name = cast(Optional[str], form.get("name"))
|
|
638
|
+
if not name:
|
|
639
|
+
raise ValueError("Dataset name must not be empty")
|
|
640
|
+
action = DatasetAction(cast(Optional[str], form.get("action")) or "create")
|
|
641
|
+
file = form["file"]
|
|
642
|
+
if not isinstance(file, UploadFile):
|
|
643
|
+
raise ValueError("Malformed file in form data.")
|
|
644
|
+
description = cast(Optional[str], form.get("description")) or file.filename
|
|
645
|
+
input_keys = frozenset(filter(bool, cast(list[str], form.getlist("input_keys[]"))))
|
|
646
|
+
output_keys = frozenset(filter(bool, cast(list[str], form.getlist("output_keys[]"))))
|
|
647
|
+
metadata_keys = frozenset(filter(bool, cast(list[str], form.getlist("metadata_keys[]"))))
|
|
648
|
+
return (
|
|
649
|
+
action,
|
|
650
|
+
name,
|
|
651
|
+
description,
|
|
652
|
+
input_keys,
|
|
653
|
+
output_keys,
|
|
654
|
+
metadata_keys,
|
|
655
|
+
file,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
class DatasetExample(V1RoutesBaseModel):
|
|
660
|
+
id: str
|
|
661
|
+
input: dict[str, Any]
|
|
662
|
+
output: dict[str, Any]
|
|
663
|
+
metadata: dict[str, Any]
|
|
664
|
+
updated_at: datetime
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class ListDatasetExamplesData(V1RoutesBaseModel):
|
|
668
|
+
dataset_id: str
|
|
669
|
+
version_id: str
|
|
670
|
+
examples: list[DatasetExample]
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
|
|
674
|
+
pass
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
@router.get(
|
|
678
|
+
"/datasets/{id}/examples",
|
|
679
|
+
operation_id="getDatasetExamples",
|
|
680
|
+
summary="Get examples from a dataset",
|
|
681
|
+
responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
|
|
682
|
+
)
|
|
683
|
+
async def get_dataset_examples(
|
|
684
|
+
request: Request,
|
|
685
|
+
id: str = Path(description="The ID of the dataset"),
|
|
686
|
+
version_id: Optional[str] = Query(
|
|
687
|
+
default=None,
|
|
688
|
+
description=(
|
|
689
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
690
|
+
),
|
|
691
|
+
),
|
|
692
|
+
) -> ListDatasetExamplesResponseBody:
|
|
693
|
+
dataset_gid = GlobalID.from_id(id)
|
|
694
|
+
version_gid = GlobalID.from_id(version_id) if version_id else None
|
|
695
|
+
|
|
696
|
+
if (dataset_type := dataset_gid.type_name) != "Dataset":
|
|
697
|
+
raise HTTPException(
|
|
698
|
+
detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
|
|
702
|
+
raise HTTPException(
|
|
703
|
+
detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
async with request.app.state.db() as session:
|
|
707
|
+
if (
|
|
708
|
+
resolved_dataset_id := await session.scalar(
|
|
709
|
+
select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id))
|
|
710
|
+
)
|
|
711
|
+
) is None:
|
|
712
|
+
raise HTTPException(
|
|
713
|
+
detail=f"No dataset with id {dataset_gid} can be found.",
|
|
714
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
# Subquery to find the maximum created_at for each dataset_example_id
|
|
718
|
+
# timestamp tiebreaks are resolved by the largest id
|
|
719
|
+
partial_subquery = select(
|
|
720
|
+
func.max(models.DatasetExampleRevision.id).label("max_id"),
|
|
721
|
+
).group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
722
|
+
|
|
723
|
+
if version_gid:
|
|
724
|
+
if (
|
|
725
|
+
resolved_version_id := await session.scalar(
|
|
726
|
+
select(models.DatasetVersion.id).where(
|
|
727
|
+
and_(
|
|
728
|
+
models.DatasetVersion.dataset_id == resolved_dataset_id,
|
|
729
|
+
models.DatasetVersion.id == int(version_gid.node_id),
|
|
730
|
+
)
|
|
731
|
+
)
|
|
732
|
+
)
|
|
733
|
+
) is None:
|
|
734
|
+
raise HTTPException(
|
|
735
|
+
detail=f"No dataset version with id {version_id} can be found.",
|
|
736
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
737
|
+
)
|
|
738
|
+
# if a version_id is provided, filter the subquery to only include revisions from that
|
|
739
|
+
partial_subquery = partial_subquery.filter(
|
|
740
|
+
models.DatasetExampleRevision.dataset_version_id <= resolved_version_id
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
if (
|
|
744
|
+
resolved_version_id := await session.scalar(
|
|
745
|
+
select(func.max(models.DatasetVersion.id)).where(
|
|
746
|
+
models.DatasetVersion.dataset_id == resolved_dataset_id
|
|
747
|
+
)
|
|
748
|
+
)
|
|
749
|
+
) is None:
|
|
750
|
+
raise HTTPException(
|
|
751
|
+
detail="Dataset has no versions.",
|
|
752
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
subquery = partial_subquery.subquery()
|
|
756
|
+
# Query for the most recent example revisions that are not deleted
|
|
757
|
+
query = (
|
|
758
|
+
select(models.DatasetExample, models.DatasetExampleRevision)
|
|
759
|
+
.join(
|
|
760
|
+
models.DatasetExampleRevision,
|
|
761
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
|
|
762
|
+
)
|
|
763
|
+
.join(
|
|
764
|
+
subquery,
|
|
765
|
+
(subquery.c.max_id == models.DatasetExampleRevision.id),
|
|
766
|
+
)
|
|
767
|
+
.filter(models.DatasetExample.dataset_id == resolved_dataset_id)
|
|
768
|
+
.filter(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
769
|
+
.order_by(models.DatasetExample.id.asc())
|
|
770
|
+
)
|
|
771
|
+
examples = [
|
|
772
|
+
DatasetExample(
|
|
773
|
+
id=str(GlobalID("DatasetExample", str(example.id))),
|
|
774
|
+
input=revision.input,
|
|
775
|
+
output=revision.output,
|
|
776
|
+
metadata=revision.metadata_,
|
|
777
|
+
updated_at=revision.created_at,
|
|
778
|
+
)
|
|
779
|
+
async for example, revision in await session.stream(query)
|
|
780
|
+
]
|
|
781
|
+
return ListDatasetExamplesResponseBody(
|
|
782
|
+
data=ListDatasetExamplesData(
|
|
783
|
+
dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
784
|
+
version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
785
|
+
examples=examples,
|
|
786
|
+
)
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
@router.get(
|
|
791
|
+
"/datasets/{id}/csv",
|
|
792
|
+
operation_id="getDatasetCsv",
|
|
793
|
+
summary="Download dataset examples as CSV file",
|
|
794
|
+
response_class=StreamingResponse,
|
|
795
|
+
status_code=HTTP_200_OK,
|
|
796
|
+
responses={
|
|
797
|
+
**add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
798
|
+
**add_text_csv_content_to_responses(HTTP_200_OK),
|
|
799
|
+
},
|
|
800
|
+
)
|
|
801
|
+
async def get_dataset_csv(
|
|
802
|
+
request: Request,
|
|
803
|
+
response: Response,
|
|
804
|
+
id: str = Path(description="The ID of the dataset"),
|
|
805
|
+
version_id: Optional[str] = Query(
|
|
806
|
+
default=None,
|
|
807
|
+
description=(
|
|
808
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
809
|
+
),
|
|
810
|
+
),
|
|
811
|
+
) -> Response:
|
|
812
|
+
try:
|
|
813
|
+
async with request.app.state.db() as session:
|
|
814
|
+
dataset_name, examples = await _get_db_examples(
|
|
815
|
+
session=session, id=id, version_id=version_id
|
|
816
|
+
)
|
|
817
|
+
except ValueError as e:
|
|
818
|
+
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
819
|
+
content = await run_in_threadpool(_get_content_csv, examples)
|
|
820
|
+
return Response(
|
|
821
|
+
content=content,
|
|
822
|
+
headers={
|
|
823
|
+
"content-disposition": f'attachment; filename="{dataset_name}.csv"',
|
|
824
|
+
"content-type": "text/csv",
|
|
825
|
+
},
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
@router.get(
|
|
830
|
+
"/datasets/{id}/jsonl/openai_ft",
|
|
831
|
+
operation_id="getDatasetJSONLOpenAIFineTuning",
|
|
832
|
+
summary="Download dataset examples as OpenAI fine-tuning JSONL file",
|
|
833
|
+
response_class=PlainTextResponse,
|
|
834
|
+
responses=add_errors_to_responses(
|
|
835
|
+
[
|
|
836
|
+
{
|
|
837
|
+
"status_code": HTTP_422_UNPROCESSABLE_ENTITY,
|
|
838
|
+
"description": "Invalid dataset or version ID",
|
|
839
|
+
}
|
|
840
|
+
]
|
|
841
|
+
),
|
|
842
|
+
)
|
|
843
|
+
async def get_dataset_jsonl_openai_ft(
|
|
844
|
+
request: Request,
|
|
845
|
+
response: Response,
|
|
846
|
+
id: str = Path(description="The ID of the dataset"),
|
|
847
|
+
version_id: Optional[str] = Query(
|
|
848
|
+
default=None,
|
|
849
|
+
description=(
|
|
850
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
851
|
+
),
|
|
852
|
+
),
|
|
853
|
+
) -> bytes:
|
|
854
|
+
try:
|
|
855
|
+
async with request.app.state.db() as session:
|
|
856
|
+
dataset_name, examples = await _get_db_examples(
|
|
857
|
+
session=session, id=id, version_id=version_id
|
|
858
|
+
)
|
|
859
|
+
except ValueError as e:
|
|
860
|
+
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
861
|
+
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
862
|
+
response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
|
|
863
|
+
return content
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
@router.get(
|
|
867
|
+
"/datasets/{id}/jsonl/openai_evals",
|
|
868
|
+
operation_id="getDatasetJSONLOpenAIEvals",
|
|
869
|
+
summary="Download dataset examples as OpenAI evals JSONL file",
|
|
870
|
+
response_class=PlainTextResponse,
|
|
871
|
+
responses=add_errors_to_responses(
|
|
872
|
+
[
|
|
873
|
+
{
|
|
874
|
+
"status_code": HTTP_422_UNPROCESSABLE_ENTITY,
|
|
875
|
+
"description": "Invalid dataset or version ID",
|
|
876
|
+
}
|
|
877
|
+
]
|
|
878
|
+
),
|
|
879
|
+
)
|
|
880
|
+
async def get_dataset_jsonl_openai_evals(
|
|
881
|
+
request: Request,
|
|
882
|
+
response: Response,
|
|
883
|
+
id: str = Path(description="The ID of the dataset"),
|
|
884
|
+
version_id: Optional[str] = Query(
|
|
885
|
+
default=None,
|
|
886
|
+
description=(
|
|
887
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
888
|
+
),
|
|
889
|
+
),
|
|
890
|
+
) -> bytes:
|
|
891
|
+
try:
|
|
892
|
+
async with request.app.state.db() as session:
|
|
893
|
+
dataset_name, examples = await _get_db_examples(
|
|
894
|
+
session=session, id=id, version_id=version_id
|
|
895
|
+
)
|
|
896
|
+
except ValueError as e:
|
|
897
|
+
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
898
|
+
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
899
|
+
response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
|
|
900
|
+
return content
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def _get_content_csv(examples: list[models.DatasetExampleRevision]) -> bytes:
|
|
904
|
+
records = [
|
|
905
|
+
{
|
|
906
|
+
"example_id": GlobalID(
|
|
907
|
+
type_name=DatasetExampleNodeType.__name__,
|
|
908
|
+
node_id=str(ex.dataset_example_id),
|
|
909
|
+
),
|
|
910
|
+
**{f"input_{k}": v for k, v in ex.input.items()},
|
|
911
|
+
**{f"output_{k}": v for k, v in ex.output.items()},
|
|
912
|
+
**{f"metadata_{k}": v for k, v in ex.metadata_.items()},
|
|
913
|
+
}
|
|
914
|
+
for ex in examples
|
|
915
|
+
]
|
|
916
|
+
return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def _get_content_jsonl_openai_ft(examples: list[models.DatasetExampleRevision]) -> bytes:
|
|
920
|
+
records = io.BytesIO()
|
|
921
|
+
for ex in examples:
|
|
922
|
+
records.write(
|
|
923
|
+
(
|
|
924
|
+
json.dumps(
|
|
925
|
+
{
|
|
926
|
+
"messages": (
|
|
927
|
+
ims if isinstance(ims := ex.input.get("messages"), list) else []
|
|
928
|
+
)
|
|
929
|
+
+ (oms if isinstance(oms := ex.output.get("messages"), list) else [])
|
|
930
|
+
},
|
|
931
|
+
ensure_ascii=False,
|
|
932
|
+
)
|
|
933
|
+
+ "\n"
|
|
934
|
+
).encode()
|
|
935
|
+
)
|
|
936
|
+
records.seek(0)
|
|
937
|
+
return records.read()
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def _get_content_jsonl_openai_evals(examples: list[models.DatasetExampleRevision]) -> bytes:
|
|
941
|
+
records = io.BytesIO()
|
|
942
|
+
for ex in examples:
|
|
943
|
+
records.write(
|
|
944
|
+
(
|
|
945
|
+
json.dumps(
|
|
946
|
+
{
|
|
947
|
+
"messages": ims
|
|
948
|
+
if isinstance(ims := ex.input.get("messages"), list)
|
|
949
|
+
else [],
|
|
950
|
+
"ideal": (
|
|
951
|
+
ideal if isinstance(ideal := last_message.get("content"), str) else ""
|
|
952
|
+
)
|
|
953
|
+
if isinstance(oms := ex.output.get("messages"), list)
|
|
954
|
+
and oms
|
|
955
|
+
and hasattr(last_message := oms[-1], "get")
|
|
956
|
+
else "",
|
|
957
|
+
},
|
|
958
|
+
ensure_ascii=False,
|
|
959
|
+
)
|
|
960
|
+
+ "\n"
|
|
961
|
+
).encode()
|
|
962
|
+
)
|
|
963
|
+
records.seek(0)
|
|
964
|
+
return records.read()
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
async def _get_db_examples(
|
|
968
|
+
*, session: Any, id: str, version_id: Optional[str]
|
|
969
|
+
) -> tuple[str, list[models.DatasetExampleRevision]]:
|
|
970
|
+
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
|
|
971
|
+
dataset_version_id: Optional[int] = None
|
|
972
|
+
if version_id:
|
|
973
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
974
|
+
GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
|
|
975
|
+
)
|
|
976
|
+
latest_version = (
|
|
977
|
+
select(
|
|
978
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
979
|
+
func.max(models.DatasetExampleRevision.dataset_version_id).label("dataset_version_id"),
|
|
980
|
+
)
|
|
981
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
982
|
+
.join(models.DatasetExample)
|
|
983
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
984
|
+
)
|
|
985
|
+
if dataset_version_id is not None:
|
|
986
|
+
max_dataset_version_id = (
|
|
987
|
+
select(models.DatasetVersion.id)
|
|
988
|
+
.where(models.DatasetVersion.id == dataset_version_id)
|
|
989
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
990
|
+
).scalar_subquery()
|
|
991
|
+
latest_version = latest_version.where(
|
|
992
|
+
models.DatasetExampleRevision.dataset_version_id <= max_dataset_version_id
|
|
993
|
+
)
|
|
994
|
+
subq = latest_version.subquery("latest_version")
|
|
995
|
+
stmt = (
|
|
996
|
+
select(models.DatasetExampleRevision)
|
|
997
|
+
.join(
|
|
998
|
+
subq,
|
|
999
|
+
onclause=and_(
|
|
1000
|
+
models.DatasetExampleRevision.dataset_example_id == subq.c.dataset_example_id,
|
|
1001
|
+
models.DatasetExampleRevision.dataset_version_id == subq.c.dataset_version_id,
|
|
1002
|
+
),
|
|
1003
|
+
)
|
|
1004
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
1005
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id)
|
|
1006
|
+
)
|
|
1007
|
+
dataset_name: Optional[str] = await session.scalar(
|
|
1008
|
+
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
1009
|
+
)
|
|
1010
|
+
if not dataset_name:
|
|
1011
|
+
raise ValueError("Dataset does not exist.")
|
|
1012
|
+
examples = [r async for r in await session.stream_scalars(stmt)]
|
|
1013
|
+
return dataset_name, examples
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def _is_all_dict(seq: Sequence[Any]) -> bool:
|
|
1017
|
+
return all(map(lambda obj: isinstance(obj, dict), seq))
|