arize-phoenix 4.14.1__py3-none-any.whl → 4.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/METADATA +5 -3
- {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/RECORD +81 -71
- phoenix/db/bulk_inserter.py +131 -5
- phoenix/db/engines.py +2 -1
- phoenix/db/helpers.py +23 -1
- phoenix/db/insertion/constants.py +2 -0
- phoenix/db/insertion/document_annotation.py +157 -0
- phoenix/db/insertion/helpers.py +13 -0
- phoenix/db/insertion/span_annotation.py +144 -0
- phoenix/db/insertion/trace_annotation.py +144 -0
- phoenix/db/insertion/types.py +261 -0
- phoenix/experiments/functions.py +3 -2
- phoenix/experiments/types.py +3 -3
- phoenix/server/api/context.py +7 -9
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +3 -3
- phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
- phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
- phoenix/server/api/dataloaders/document_evaluations.py +2 -4
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
- phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
- phoenix/server/api/dataloaders/experiment_error_rates.py +2 -4
- phoenix/server/api/dataloaders/experiment_run_counts.py +2 -4
- phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
- phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
- phoenix/server/api/dataloaders/project_by_name.py +3 -3
- phoenix/server/api/dataloaders/record_counts.py +2 -4
- phoenix/server/api/dataloaders/span_annotations.py +2 -4
- phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -4
- phoenix/server/api/dataloaders/span_evaluations.py +2 -4
- phoenix/server/api/dataloaders/span_projects.py +3 -3
- phoenix/server/api/dataloaders/token_counts.py +2 -4
- phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
- phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
- phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
- phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
- phoenix/server/api/mutations/span_annotations_mutations.py +8 -3
- phoenix/server/api/mutations/trace_annotations_mutations.py +8 -3
- phoenix/server/api/openapi/main.py +18 -2
- phoenix/server/api/openapi/schema.py +12 -12
- phoenix/server/api/routers/v1/__init__.py +36 -83
- phoenix/server/api/routers/v1/datasets.py +515 -509
- phoenix/server/api/routers/v1/evaluations.py +164 -73
- phoenix/server/api/routers/v1/experiment_evaluations.py +68 -91
- phoenix/server/api/routers/v1/experiment_runs.py +98 -155
- phoenix/server/api/routers/v1/experiments.py +132 -181
- phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
- phoenix/server/api/routers/v1/spans.py +164 -203
- phoenix/server/api/routers/v1/traces.py +134 -159
- phoenix/server/api/routers/v1/utils.py +95 -0
- phoenix/server/api/types/Span.py +27 -3
- phoenix/server/api/types/Trace.py +21 -4
- phoenix/server/api/utils.py +4 -4
- phoenix/server/app.py +172 -192
- phoenix/server/grpc_server.py +2 -2
- phoenix/server/main.py +5 -9
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/components-Ci5kMOk5.js +1175 -0
- phoenix/server/static/assets/{index-CQgXRwU0.js → index-BQG5WVX7.js} +2 -2
- phoenix/server/static/assets/{pages-hdjlFZhO.js → pages-BrevprVW.js} +451 -275
- phoenix/server/static/assets/{vendor-DPvSDRn3.js → vendor-CP0b0YG0.js} +2 -2
- phoenix/server/static/assets/{vendor-arizeai-CkvPT67c.js → vendor-arizeai-DTbiPGp6.js} +27 -27
- phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
- phoenix/server/static/assets/{vendor-recharts-5jlNaZuF.js → vendor-recharts-A0DA1O99.js} +1 -1
- phoenix/server/thread_server.py +2 -2
- phoenix/server/types.py +18 -0
- phoenix/session/client.py +5 -3
- phoenix/session/session.py +2 -2
- phoenix/trace/dsl/filter.py +2 -6
- phoenix/trace/fixtures.py +17 -23
- phoenix/trace/utils.py +23 -0
- phoenix/utilities/client.py +116 -0
- phoenix/utilities/project.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/routers/v1/dataset_examples.py +0 -178
- phoenix/server/openapi/docs.py +0 -221
- phoenix/server/static/assets/components-DeS0YEmv.js +0 -1142
- phoenix/server/static/assets/vendor-codemirror-Cqwpwlua.js +0 -12
- {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,6 +6,7 @@ import logging
|
|
|
6
6
|
import zlib
|
|
7
7
|
from asyncio import QueueFull
|
|
8
8
|
from collections import Counter
|
|
9
|
+
from datetime import datetime
|
|
9
10
|
from enum import Enum
|
|
10
11
|
from functools import partial
|
|
11
12
|
from typing import (
|
|
@@ -13,6 +14,7 @@ from typing import (
|
|
|
13
14
|
Awaitable,
|
|
14
15
|
Callable,
|
|
15
16
|
Coroutine,
|
|
17
|
+
Dict,
|
|
16
18
|
FrozenSet,
|
|
17
19
|
Iterator,
|
|
18
20
|
List,
|
|
@@ -26,14 +28,16 @@ from typing import (
|
|
|
26
28
|
|
|
27
29
|
import pandas as pd
|
|
28
30
|
import pyarrow as pa
|
|
31
|
+
from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
|
|
32
|
+
from fastapi.responses import PlainTextResponse, StreamingResponse
|
|
29
33
|
from sqlalchemy import and_, delete, func, select
|
|
30
34
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
31
|
-
from starlette.background import BackgroundTasks
|
|
32
35
|
from starlette.concurrency import run_in_threadpool
|
|
33
36
|
from starlette.datastructures import FormData, UploadFile
|
|
34
37
|
from starlette.requests import Request
|
|
35
|
-
from starlette.responses import
|
|
38
|
+
from starlette.responses import Response
|
|
36
39
|
from starlette.status import (
|
|
40
|
+
HTTP_200_OK,
|
|
37
41
|
HTTP_204_NO_CONTENT,
|
|
38
42
|
HTTP_404_NOT_FOUND,
|
|
39
43
|
HTTP_409_CONFLICT,
|
|
@@ -51,79 +55,59 @@ from phoenix.db.insertion.dataset import (
|
|
|
51
55
|
ExampleContent,
|
|
52
56
|
add_dataset_examples,
|
|
53
57
|
)
|
|
54
|
-
from phoenix.server.api.types.Dataset import Dataset
|
|
55
|
-
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
56
|
-
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
58
|
+
from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
|
|
59
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
|
|
60
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
|
|
57
61
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
58
62
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
59
63
|
|
|
64
|
+
from .pydantic_compat import V1RoutesBaseModel
|
|
65
|
+
from .utils import (
|
|
66
|
+
PaginatedResponseBody,
|
|
67
|
+
ResponseBody,
|
|
68
|
+
add_errors_to_responses,
|
|
69
|
+
add_text_csv_content_to_responses,
|
|
70
|
+
)
|
|
71
|
+
|
|
60
72
|
logger = logging.getLogger(__name__)
|
|
61
73
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
description
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
type: string
|
|
100
|
-
data:
|
|
101
|
-
type: array
|
|
102
|
-
items:
|
|
103
|
-
type: object
|
|
104
|
-
properties:
|
|
105
|
-
id:
|
|
106
|
-
type: string
|
|
107
|
-
name:
|
|
108
|
-
type: string
|
|
109
|
-
description:
|
|
110
|
-
type: string
|
|
111
|
-
metadata:
|
|
112
|
-
type: object
|
|
113
|
-
created_at:
|
|
114
|
-
type: string
|
|
115
|
-
format: date-time
|
|
116
|
-
updated_at:
|
|
117
|
-
type: string
|
|
118
|
-
format: date-time
|
|
119
|
-
403:
|
|
120
|
-
description: Forbidden
|
|
121
|
-
404:
|
|
122
|
-
description: No datasets found
|
|
123
|
-
"""
|
|
124
|
-
name = request.query_params.get("name")
|
|
125
|
-
cursor = request.query_params.get("cursor")
|
|
126
|
-
limit = int(request.query_params.get("limit", 10))
|
|
74
|
+
DATASET_NODE_NAME = DatasetNodeType.__name__
|
|
75
|
+
DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
router = APIRouter(tags=["datasets"])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Dataset(V1RoutesBaseModel):
|
|
82
|
+
id: str
|
|
83
|
+
name: str
|
|
84
|
+
description: Optional[str]
|
|
85
|
+
metadata: Dict[str, Any]
|
|
86
|
+
created_at: datetime
|
|
87
|
+
updated_at: datetime
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@router.get(
|
|
95
|
+
"/datasets",
|
|
96
|
+
operation_id="listDatasets",
|
|
97
|
+
summary="List datasets",
|
|
98
|
+
responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
99
|
+
)
|
|
100
|
+
async def list_datasets(
|
|
101
|
+
request: Request,
|
|
102
|
+
cursor: Optional[str] = Query(
|
|
103
|
+
default=None,
|
|
104
|
+
description="Cursor for pagination",
|
|
105
|
+
),
|
|
106
|
+
name: Optional[str] = Query(default=None, description="An optional dataset name to filter by"),
|
|
107
|
+
limit: int = Query(
|
|
108
|
+
default=10, description="The max number of datasets to return at a time.", gt=0
|
|
109
|
+
),
|
|
110
|
+
) -> ListDatasetsResponseBody:
|
|
127
111
|
async with request.app.state.db() as session:
|
|
128
112
|
query = select(models.Dataset).order_by(models.Dataset.id.desc())
|
|
129
113
|
|
|
@@ -132,79 +116,68 @@ async def list_datasets(request: Request) -> Response:
|
|
|
132
116
|
cursor_id = GlobalID.from_id(cursor).node_id
|
|
133
117
|
query = query.filter(models.Dataset.id <= int(cursor_id))
|
|
134
118
|
except ValueError:
|
|
135
|
-
|
|
136
|
-
|
|
119
|
+
raise HTTPException(
|
|
120
|
+
detail=f"Invalid cursor format: {cursor}",
|
|
137
121
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
138
122
|
)
|
|
139
123
|
if name:
|
|
140
|
-
query = query.filter(models.Dataset.name
|
|
124
|
+
query = query.filter(models.Dataset.name == name)
|
|
141
125
|
|
|
142
126
|
query = query.limit(limit + 1)
|
|
143
127
|
result = await session.execute(query)
|
|
144
128
|
datasets = result.scalars().all()
|
|
145
129
|
|
|
146
130
|
if not datasets:
|
|
147
|
-
return
|
|
131
|
+
return ListDatasetsResponseBody(next_cursor=None, data=[])
|
|
148
132
|
|
|
149
133
|
next_cursor = None
|
|
150
134
|
if len(datasets) == limit + 1:
|
|
151
|
-
next_cursor = str(GlobalID(
|
|
135
|
+
next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
|
|
152
136
|
datasets = datasets[:-1]
|
|
153
137
|
|
|
154
138
|
data = []
|
|
155
139
|
for dataset in datasets:
|
|
156
140
|
data.append(
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
141
|
+
Dataset(
|
|
142
|
+
id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
|
|
143
|
+
name=dataset.name,
|
|
144
|
+
description=dataset.description,
|
|
145
|
+
metadata=dataset.metadata_,
|
|
146
|
+
created_at=dataset.created_at,
|
|
147
|
+
updated_at=dataset.updated_at,
|
|
148
|
+
)
|
|
165
149
|
)
|
|
166
150
|
|
|
167
|
-
return
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
""
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
description: Forbidden
|
|
187
|
-
404:
|
|
188
|
-
description: Dataset not found
|
|
189
|
-
422:
|
|
190
|
-
description: Dataset ID is invalid
|
|
191
|
-
"""
|
|
192
|
-
if id_ := request.path_params.get("id"):
|
|
151
|
+
return ListDatasetsResponseBody(next_cursor=next_cursor, data=data)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@router.delete(
|
|
155
|
+
"/datasets/{id}",
|
|
156
|
+
operation_id="deleteDatasetById",
|
|
157
|
+
summary="Delete dataset by ID",
|
|
158
|
+
status_code=HTTP_204_NO_CONTENT,
|
|
159
|
+
responses=add_errors_to_responses(
|
|
160
|
+
[
|
|
161
|
+
{"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
|
|
162
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
|
|
163
|
+
]
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
async def delete_dataset(
|
|
167
|
+
request: Request, id: str = Path(description="The ID of the dataset to delete.")
|
|
168
|
+
) -> None:
|
|
169
|
+
if id:
|
|
193
170
|
try:
|
|
194
171
|
dataset_id = from_global_id_with_expected_type(
|
|
195
|
-
GlobalID.from_id(
|
|
196
|
-
|
|
172
|
+
GlobalID.from_id(id),
|
|
173
|
+
DATASET_NODE_NAME,
|
|
197
174
|
)
|
|
198
175
|
except ValueError:
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
176
|
+
raise HTTPException(
|
|
177
|
+
detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
|
|
202
178
|
)
|
|
203
179
|
else:
|
|
204
|
-
|
|
205
|
-
content="Missing Dataset ID",
|
|
206
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
207
|
-
)
|
|
180
|
+
raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
208
181
|
project_names_stmt = get_project_names_for_datasets(dataset_id)
|
|
209
182
|
eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
|
|
210
183
|
stmt = (
|
|
@@ -214,59 +187,34 @@ async def delete_dataset_by_id(request: Request) -> Response:
|
|
|
214
187
|
project_names = await session.scalars(project_names_stmt)
|
|
215
188
|
eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
|
|
216
189
|
if (await session.scalar(stmt)) is None:
|
|
217
|
-
|
|
190
|
+
raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
|
|
218
191
|
tasks = BackgroundTasks()
|
|
219
192
|
tasks.add_task(delete_projects, request.app.state.db, *project_names)
|
|
220
193
|
tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
type: string
|
|
246
|
-
name:
|
|
247
|
-
type: string
|
|
248
|
-
description:
|
|
249
|
-
type: string
|
|
250
|
-
metadata:
|
|
251
|
-
type: object
|
|
252
|
-
created_at:
|
|
253
|
-
type: string
|
|
254
|
-
format: date-time
|
|
255
|
-
updated_at:
|
|
256
|
-
type: string
|
|
257
|
-
format: date-time
|
|
258
|
-
example_count:
|
|
259
|
-
type: integer
|
|
260
|
-
403:
|
|
261
|
-
description: Forbidden
|
|
262
|
-
404:
|
|
263
|
-
description: Dataset not found
|
|
264
|
-
"""
|
|
265
|
-
dataset_id = GlobalID.from_id(request.path_params["id"])
|
|
266
|
-
|
|
267
|
-
if (type_name := dataset_id.type_name) != NODE_NAME:
|
|
268
|
-
return Response(
|
|
269
|
-
content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class DatasetWithExampleCount(Dataset):
|
|
197
|
+
example_count: int
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@router.get(
|
|
205
|
+
"/datasets/{id}",
|
|
206
|
+
operation_id="getDataset",
|
|
207
|
+
summary="Get dataset by ID",
|
|
208
|
+
responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
|
|
209
|
+
)
|
|
210
|
+
async def get_dataset(
|
|
211
|
+
request: Request, id: str = Path(description="The ID of the dataset")
|
|
212
|
+
) -> GetDatasetResponseBody:
|
|
213
|
+
dataset_id = GlobalID.from_id(id)
|
|
214
|
+
|
|
215
|
+
if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
|
|
216
|
+
raise HTTPException(
|
|
217
|
+
detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
270
218
|
)
|
|
271
219
|
async with request.app.state.db() as session:
|
|
272
220
|
result = await session.execute(
|
|
@@ -278,97 +226,64 @@ async def get_dataset_by_id(request: Request) -> Response:
|
|
|
278
226
|
dataset = dataset_query[0] if dataset_query else None
|
|
279
227
|
example_count = dataset_query[1] if dataset_query else 0
|
|
280
228
|
if dataset is None:
|
|
281
|
-
|
|
282
|
-
|
|
229
|
+
raise HTTPException(
|
|
230
|
+
detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
|
|
283
231
|
)
|
|
284
232
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
return
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
schema:
|
|
327
|
-
type: object
|
|
328
|
-
properties:
|
|
329
|
-
next_cursor:
|
|
330
|
-
type: string
|
|
331
|
-
data:
|
|
332
|
-
type: array
|
|
333
|
-
items:
|
|
334
|
-
type: object
|
|
335
|
-
properties:
|
|
336
|
-
version_id:
|
|
337
|
-
type: string
|
|
338
|
-
description:
|
|
339
|
-
type: string
|
|
340
|
-
metadata:
|
|
341
|
-
type: object
|
|
342
|
-
created_at:
|
|
343
|
-
type: string
|
|
344
|
-
format: date-time
|
|
345
|
-
403:
|
|
346
|
-
description: Forbidden
|
|
347
|
-
422:
|
|
348
|
-
description: Dataset ID, cursor or limit is invalid.
|
|
349
|
-
"""
|
|
350
|
-
if id_ := request.path_params.get("id"):
|
|
233
|
+
dataset = DatasetWithExampleCount(
|
|
234
|
+
id=str(dataset_id),
|
|
235
|
+
name=dataset.name,
|
|
236
|
+
description=dataset.description,
|
|
237
|
+
metadata=dataset.metadata_,
|
|
238
|
+
created_at=dataset.created_at,
|
|
239
|
+
updated_at=dataset.updated_at,
|
|
240
|
+
example_count=example_count,
|
|
241
|
+
)
|
|
242
|
+
return GetDatasetResponseBody(data=dataset)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class DatasetVersion(V1RoutesBaseModel):
|
|
246
|
+
version_id: str
|
|
247
|
+
description: Optional[str]
|
|
248
|
+
metadata: Dict[str, Any]
|
|
249
|
+
created_at: datetime
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@router.get(
|
|
257
|
+
"/datasets/{id}/versions",
|
|
258
|
+
operation_id="listDatasetVersionsByDatasetId",
|
|
259
|
+
summary="List dataset versions",
|
|
260
|
+
responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
261
|
+
)
|
|
262
|
+
async def list_dataset_versions(
|
|
263
|
+
request: Request,
|
|
264
|
+
id: str = Path(description="The ID of the dataset"),
|
|
265
|
+
cursor: Optional[str] = Query(
|
|
266
|
+
default=None,
|
|
267
|
+
description="Cursor for pagination",
|
|
268
|
+
),
|
|
269
|
+
limit: int = Query(
|
|
270
|
+
default=10, description="The max number of dataset versions to return at a time", gt=0
|
|
271
|
+
),
|
|
272
|
+
) -> ListDatasetVersionsResponseBody:
|
|
273
|
+
if id:
|
|
351
274
|
try:
|
|
352
275
|
dataset_id = from_global_id_with_expected_type(
|
|
353
|
-
GlobalID.from_id(
|
|
354
|
-
|
|
276
|
+
GlobalID.from_id(id),
|
|
277
|
+
DATASET_NODE_NAME,
|
|
355
278
|
)
|
|
356
279
|
except ValueError:
|
|
357
|
-
|
|
358
|
-
|
|
280
|
+
raise HTTPException(
|
|
281
|
+
detail=f"Invalid Dataset ID: {id}",
|
|
359
282
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
360
283
|
)
|
|
361
284
|
else:
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
365
|
-
)
|
|
366
|
-
try:
|
|
367
|
-
limit = int(request.query_params.get("limit", 10))
|
|
368
|
-
assert limit > 0
|
|
369
|
-
except (ValueError, AssertionError):
|
|
370
|
-
return Response(
|
|
371
|
-
content="Invalid limit parameter",
|
|
285
|
+
raise HTTPException(
|
|
286
|
+
detail="Missing Dataset ID",
|
|
372
287
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
373
288
|
)
|
|
374
289
|
stmt = (
|
|
@@ -377,15 +292,14 @@ async def get_dataset_versions(request: Request) -> Response:
|
|
|
377
292
|
.order_by(models.DatasetVersion.id.desc())
|
|
378
293
|
.limit(limit + 1)
|
|
379
294
|
)
|
|
380
|
-
if cursor
|
|
295
|
+
if cursor:
|
|
381
296
|
try:
|
|
382
297
|
dataset_version_id = from_global_id_with_expected_type(
|
|
383
|
-
GlobalID.from_id(cursor),
|
|
384
|
-
DatasetVersion.__name__,
|
|
298
|
+
GlobalID.from_id(cursor), DATASET_VERSION_NODE_NAME
|
|
385
299
|
)
|
|
386
300
|
except ValueError:
|
|
387
|
-
|
|
388
|
-
|
|
301
|
+
raise HTTPException(
|
|
302
|
+
detail=f"Invalid cursor: {cursor}",
|
|
389
303
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
390
304
|
)
|
|
391
305
|
max_dataset_version_id = (
|
|
@@ -396,102 +310,99 @@ async def get_dataset_versions(request: Request) -> Response:
|
|
|
396
310
|
stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
|
|
397
311
|
async with request.app.state.db() as session:
|
|
398
312
|
data = [
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
313
|
+
DatasetVersion(
|
|
314
|
+
version_id=str(GlobalID(DATASET_VERSION_NODE_NAME, str(version.id))),
|
|
315
|
+
description=version.description,
|
|
316
|
+
metadata=version.metadata_,
|
|
317
|
+
created_at=version.created_at,
|
|
318
|
+
)
|
|
405
319
|
async for version in await session.stream_scalars(stmt)
|
|
406
320
|
]
|
|
407
|
-
next_cursor = data.pop()
|
|
408
|
-
return
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
description
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
422:
|
|
493
|
-
description: Request body is invalid
|
|
494
|
-
"""
|
|
321
|
+
next_cursor = data.pop().version_id if len(data) == limit + 1 else None
|
|
322
|
+
return ListDatasetVersionsResponseBody(data=data, next_cursor=next_cursor)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class UploadDatasetData(V1RoutesBaseModel):
|
|
326
|
+
dataset_id: str
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
|
|
330
|
+
pass
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@router.post(
|
|
334
|
+
"/datasets/upload",
|
|
335
|
+
operation_id="uploadDataset",
|
|
336
|
+
summary="Upload dataset from JSON, CSV, or PyArrow",
|
|
337
|
+
responses=add_errors_to_responses(
|
|
338
|
+
[
|
|
339
|
+
{
|
|
340
|
+
"status_code": HTTP_409_CONFLICT,
|
|
341
|
+
"description": "Dataset of the same name already exists",
|
|
342
|
+
},
|
|
343
|
+
{"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
|
|
344
|
+
]
|
|
345
|
+
),
|
|
346
|
+
# FastAPI cannot generate the request body portion of the OpenAPI schema for
|
|
347
|
+
# routes that accept multiple request content types, so we have to provide
|
|
348
|
+
# this part of the schema manually. For context, see
|
|
349
|
+
# https://github.com/tiangolo/fastapi/discussions/7786 and
|
|
350
|
+
# https://github.com/tiangolo/fastapi/issues/990
|
|
351
|
+
openapi_extra={
|
|
352
|
+
"requestBody": {
|
|
353
|
+
"content": {
|
|
354
|
+
"application/json": {
|
|
355
|
+
"schema": {
|
|
356
|
+
"type": "object",
|
|
357
|
+
"required": ["name", "inputs"],
|
|
358
|
+
"properties": {
|
|
359
|
+
"action": {"type": "string", "enum": ["create", "append"]},
|
|
360
|
+
"name": {"type": "string"},
|
|
361
|
+
"description": {"type": "string"},
|
|
362
|
+
"inputs": {"type": "array", "items": {"type": "object"}},
|
|
363
|
+
"outputs": {"type": "array", "items": {"type": "object"}},
|
|
364
|
+
"metadata": {"type": "array", "items": {"type": "object"}},
|
|
365
|
+
},
|
|
366
|
+
}
|
|
367
|
+
},
|
|
368
|
+
"multipart/form-data": {
|
|
369
|
+
"schema": {
|
|
370
|
+
"type": "object",
|
|
371
|
+
"required": ["name", "input_keys[]", "output_keys[]", "file"],
|
|
372
|
+
"properties": {
|
|
373
|
+
"action": {"type": "string", "enum": ["create", "append"]},
|
|
374
|
+
"name": {"type": "string"},
|
|
375
|
+
"description": {"type": "string"},
|
|
376
|
+
"input_keys[]": {
|
|
377
|
+
"type": "array",
|
|
378
|
+
"items": {"type": "string"},
|
|
379
|
+
"uniqueItems": True,
|
|
380
|
+
},
|
|
381
|
+
"output_keys[]": {
|
|
382
|
+
"type": "array",
|
|
383
|
+
"items": {"type": "string"},
|
|
384
|
+
"uniqueItems": True,
|
|
385
|
+
},
|
|
386
|
+
"metadata_keys[]": {
|
|
387
|
+
"type": "array",
|
|
388
|
+
"items": {"type": "string"},
|
|
389
|
+
"uniqueItems": True,
|
|
390
|
+
},
|
|
391
|
+
"file": {"type": "string", "format": "binary"},
|
|
392
|
+
},
|
|
393
|
+
}
|
|
394
|
+
},
|
|
395
|
+
}
|
|
396
|
+
},
|
|
397
|
+
},
|
|
398
|
+
)
|
|
399
|
+
async def upload_dataset(
|
|
400
|
+
request: Request,
|
|
401
|
+
sync: bool = Query(
|
|
402
|
+
default=False,
|
|
403
|
+
description="If true, fulfill request synchronously and return JSON containing dataset_id.",
|
|
404
|
+
),
|
|
405
|
+
) -> Optional[UploadDatasetResponseBody]:
|
|
495
406
|
request_content_type = request.headers["content-type"]
|
|
496
407
|
examples: Union[Examples, Awaitable[Examples]]
|
|
497
408
|
if request_content_type.startswith("application/json"):
|
|
@@ -500,15 +411,15 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
500
411
|
_process_json, await request.json()
|
|
501
412
|
)
|
|
502
413
|
except ValueError as e:
|
|
503
|
-
|
|
504
|
-
|
|
414
|
+
raise HTTPException(
|
|
415
|
+
detail=str(e),
|
|
505
416
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
506
417
|
)
|
|
507
418
|
if action is DatasetAction.CREATE:
|
|
508
419
|
async with request.app.state.db() as session:
|
|
509
420
|
if await _check_table_exists(session, name):
|
|
510
|
-
|
|
511
|
-
|
|
421
|
+
raise HTTPException(
|
|
422
|
+
detail=f"Dataset with the same name already exists: {name=}",
|
|
512
423
|
status_code=HTTP_409_CONFLICT,
|
|
513
424
|
)
|
|
514
425
|
elif request_content_type.startswith("multipart/form-data"):
|
|
@@ -524,15 +435,15 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
524
435
|
file,
|
|
525
436
|
) = await _parse_form_data(form)
|
|
526
437
|
except ValueError as e:
|
|
527
|
-
|
|
528
|
-
|
|
438
|
+
raise HTTPException(
|
|
439
|
+
detail=str(e),
|
|
529
440
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
530
441
|
)
|
|
531
442
|
if action is DatasetAction.CREATE:
|
|
532
443
|
async with request.app.state.db() as session:
|
|
533
444
|
if await _check_table_exists(session, name):
|
|
534
|
-
|
|
535
|
-
|
|
445
|
+
raise HTTPException(
|
|
446
|
+
detail=f"Dataset with the same name already exists: {name=}",
|
|
536
447
|
status_code=HTTP_409_CONFLICT,
|
|
537
448
|
)
|
|
538
449
|
content = await file.read()
|
|
@@ -548,13 +459,13 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
548
459
|
else:
|
|
549
460
|
assert_never(file_content_type)
|
|
550
461
|
except ValueError as e:
|
|
551
|
-
|
|
552
|
-
|
|
462
|
+
raise HTTPException(
|
|
463
|
+
detail=str(e),
|
|
553
464
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
554
465
|
)
|
|
555
466
|
else:
|
|
556
|
-
|
|
557
|
-
|
|
467
|
+
raise HTTPException(
|
|
468
|
+
detail="Invalid request Content-Type",
|
|
558
469
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
559
470
|
)
|
|
560
471
|
operation = cast(
|
|
@@ -567,19 +478,19 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
567
478
|
description=description,
|
|
568
479
|
),
|
|
569
480
|
)
|
|
570
|
-
if
|
|
481
|
+
if sync:
|
|
571
482
|
async with request.app.state.db() as session:
|
|
572
483
|
dataset_id = (await operation(session)).dataset_id
|
|
573
|
-
return
|
|
574
|
-
|
|
484
|
+
return UploadDatasetResponseBody(
|
|
485
|
+
data=UploadDatasetData(dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))))
|
|
575
486
|
)
|
|
576
487
|
try:
|
|
577
488
|
request.state.enqueue_operation(operation)
|
|
578
489
|
except QueueFull:
|
|
579
490
|
if isinstance(examples, Coroutine):
|
|
580
491
|
examples.close()
|
|
581
|
-
|
|
582
|
-
return
|
|
492
|
+
raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
|
|
493
|
+
return None
|
|
583
494
|
|
|
584
495
|
|
|
585
496
|
class FileContentType(Enum):
|
|
@@ -757,158 +668,255 @@ async def _parse_form_data(
|
|
|
757
668
|
)
|
|
758
669
|
|
|
759
670
|
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
671
|
+
class DatasetExample(V1RoutesBaseModel):
|
|
672
|
+
id: str
|
|
673
|
+
input: Dict[str, Any]
|
|
674
|
+
output: Dict[str, Any]
|
|
675
|
+
metadata: Dict[str, Any]
|
|
676
|
+
updated_at: datetime
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
class ListDatasetExamplesData(V1RoutesBaseModel):
|
|
680
|
+
dataset_id: str
|
|
681
|
+
version_id: str
|
|
682
|
+
examples: List[DatasetExample]
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
|
|
686
|
+
pass
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
@router.get(
|
|
690
|
+
"/datasets/{id}/examples",
|
|
691
|
+
operation_id="getDatasetExamples",
|
|
692
|
+
summary="Get examples from a dataset",
|
|
693
|
+
responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
|
|
694
|
+
)
|
|
695
|
+
async def get_dataset_examples(
|
|
696
|
+
request: Request,
|
|
697
|
+
id: str = Path(description="The ID of the dataset"),
|
|
698
|
+
version_id: Optional[str] = Query(
|
|
699
|
+
default=None,
|
|
700
|
+
description=(
|
|
701
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
702
|
+
),
|
|
703
|
+
),
|
|
704
|
+
) -> ListDatasetExamplesResponseBody:
|
|
705
|
+
dataset_gid = GlobalID.from_id(id)
|
|
706
|
+
version_gid = GlobalID.from_id(version_id) if version_id else None
|
|
707
|
+
|
|
708
|
+
if (dataset_type := dataset_gid.type_name) != "Dataset":
|
|
709
|
+
raise HTTPException(
|
|
710
|
+
detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
|
|
714
|
+
raise HTTPException(
|
|
715
|
+
detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
async with request.app.state.db() as session:
|
|
719
|
+
if (
|
|
720
|
+
resolved_dataset_id := await session.scalar(
|
|
721
|
+
select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id))
|
|
722
|
+
)
|
|
723
|
+
) is None:
|
|
724
|
+
raise HTTPException(
|
|
725
|
+
detail=f"No dataset with id {dataset_gid} can be found.",
|
|
726
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Subquery to find the maximum created_at for each dataset_example_id
|
|
730
|
+
# timestamp tiebreaks are resolved by the largest id
|
|
731
|
+
partial_subquery = select(
|
|
732
|
+
func.max(models.DatasetExampleRevision.id).label("max_id"),
|
|
733
|
+
).group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
734
|
+
|
|
735
|
+
if version_gid:
|
|
736
|
+
if (
|
|
737
|
+
resolved_version_id := await session.scalar(
|
|
738
|
+
select(models.DatasetVersion.id).where(
|
|
739
|
+
and_(
|
|
740
|
+
models.DatasetVersion.dataset_id == resolved_dataset_id,
|
|
741
|
+
models.DatasetVersion.id == int(version_gid.node_id),
|
|
742
|
+
)
|
|
743
|
+
)
|
|
744
|
+
)
|
|
745
|
+
) is None:
|
|
746
|
+
raise HTTPException(
|
|
747
|
+
detail=f"No dataset version with id {version_id} can be found.",
|
|
748
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
749
|
+
)
|
|
750
|
+
# if a version_id is provided, filter the subquery to only include revisions from that
|
|
751
|
+
partial_subquery = partial_subquery.filter(
|
|
752
|
+
models.DatasetExampleRevision.dataset_version_id <= resolved_version_id
|
|
753
|
+
)
|
|
754
|
+
else:
|
|
755
|
+
if (
|
|
756
|
+
resolved_version_id := await session.scalar(
|
|
757
|
+
select(func.max(models.DatasetVersion.id)).where(
|
|
758
|
+
models.DatasetVersion.dataset_id == resolved_dataset_id
|
|
759
|
+
)
|
|
760
|
+
)
|
|
761
|
+
) is None:
|
|
762
|
+
raise HTTPException(
|
|
763
|
+
detail="Dataset has no versions.",
|
|
764
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
subquery = partial_subquery.subquery()
|
|
768
|
+
# Query for the most recent example revisions that are not deleted
|
|
769
|
+
query = (
|
|
770
|
+
select(models.DatasetExample, models.DatasetExampleRevision)
|
|
771
|
+
.join(
|
|
772
|
+
models.DatasetExampleRevision,
|
|
773
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
|
|
774
|
+
)
|
|
775
|
+
.join(
|
|
776
|
+
subquery,
|
|
777
|
+
(subquery.c.max_id == models.DatasetExampleRevision.id),
|
|
778
|
+
)
|
|
779
|
+
.filter(models.DatasetExample.dataset_id == resolved_dataset_id)
|
|
780
|
+
.filter(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
781
|
+
.order_by(models.DatasetExample.id.asc())
|
|
782
|
+
)
|
|
783
|
+
examples = [
|
|
784
|
+
DatasetExample(
|
|
785
|
+
id=str(GlobalID("DatasetExample", str(example.id))),
|
|
786
|
+
input=revision.input,
|
|
787
|
+
output=revision.output,
|
|
788
|
+
metadata=revision.metadata_,
|
|
789
|
+
updated_at=revision.created_at,
|
|
790
|
+
)
|
|
791
|
+
async for example, revision in await session.stream(query)
|
|
792
|
+
]
|
|
793
|
+
return ListDatasetExamplesResponseBody(
|
|
794
|
+
data=ListDatasetExamplesData(
|
|
795
|
+
dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
796
|
+
version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
797
|
+
examples=examples,
|
|
798
|
+
)
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
@router.get(
|
|
803
|
+
"/datasets/{id}/csv",
|
|
804
|
+
operation_id="getDatasetCsv",
|
|
805
|
+
summary="Download dataset examples as CSV file",
|
|
806
|
+
response_class=StreamingResponse,
|
|
807
|
+
status_code=HTTP_200_OK,
|
|
808
|
+
responses={
|
|
809
|
+
**add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
|
|
810
|
+
**add_text_csv_content_to_responses(HTTP_200_OK),
|
|
811
|
+
},
|
|
812
|
+
)
|
|
813
|
+
async def get_dataset_csv(
|
|
814
|
+
request: Request,
|
|
815
|
+
response: Response,
|
|
816
|
+
id: str = Path(description="The ID of the dataset"),
|
|
817
|
+
version_id: Optional[str] = Query(
|
|
818
|
+
default=None,
|
|
819
|
+
description=(
|
|
820
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
821
|
+
),
|
|
822
|
+
),
|
|
823
|
+
) -> Response:
|
|
794
824
|
try:
|
|
795
|
-
|
|
825
|
+
async with request.app.state.db() as session:
|
|
826
|
+
dataset_name, examples = await _get_db_examples(
|
|
827
|
+
session=session, id=id, version_id=version_id
|
|
828
|
+
)
|
|
796
829
|
except ValueError as e:
|
|
797
|
-
|
|
830
|
+
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
798
831
|
content = await run_in_threadpool(_get_content_csv, examples)
|
|
799
832
|
return Response(
|
|
800
833
|
content=content,
|
|
801
834
|
headers={
|
|
802
835
|
"content-disposition": f'attachment; filename="{dataset_name}.csv"',
|
|
803
836
|
"content-type": "text/csv",
|
|
804
|
-
"content-encoding": "gzip",
|
|
805
837
|
},
|
|
806
838
|
)
|
|
807
839
|
|
|
808
840
|
|
|
809
|
-
|
|
810
|
-
""
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
description
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
contentMediaType: text/plain
|
|
835
|
-
contentEncoding: gzip
|
|
836
|
-
403:
|
|
837
|
-
description: Forbidden
|
|
838
|
-
404:
|
|
839
|
-
description: Dataset does not exist.
|
|
840
|
-
422:
|
|
841
|
-
description: Dataset ID or version ID is invalid.
|
|
842
|
-
"""
|
|
841
|
+
@router.get(
|
|
842
|
+
"/datasets/{id}/jsonl/openai_ft",
|
|
843
|
+
operation_id="getDatasetJSONLOpenAIFineTuning",
|
|
844
|
+
summary="Download dataset examples as OpenAI fine-tuning JSONL file",
|
|
845
|
+
response_class=PlainTextResponse,
|
|
846
|
+
responses=add_errors_to_responses(
|
|
847
|
+
[
|
|
848
|
+
{
|
|
849
|
+
"status_code": HTTP_422_UNPROCESSABLE_ENTITY,
|
|
850
|
+
"description": "Invalid dataset or version ID",
|
|
851
|
+
}
|
|
852
|
+
]
|
|
853
|
+
),
|
|
854
|
+
)
|
|
855
|
+
async def get_dataset_jsonl_openai_ft(
|
|
856
|
+
request: Request,
|
|
857
|
+
response: Response,
|
|
858
|
+
id: str = Path(description="The ID of the dataset"),
|
|
859
|
+
version_id: Optional[str] = Query(
|
|
860
|
+
default=None,
|
|
861
|
+
description=(
|
|
862
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
863
|
+
),
|
|
864
|
+
),
|
|
865
|
+
) -> bytes:
|
|
843
866
|
try:
|
|
844
|
-
|
|
867
|
+
async with request.app.state.db() as session:
|
|
868
|
+
dataset_name, examples = await _get_db_examples(
|
|
869
|
+
session=session, id=id, version_id=version_id
|
|
870
|
+
)
|
|
845
871
|
except ValueError as e:
|
|
846
|
-
|
|
872
|
+
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
847
873
|
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
headers={
|
|
851
|
-
"content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
|
|
852
|
-
"content-type": "text/plain",
|
|
853
|
-
"content-encoding": "gzip",
|
|
854
|
-
},
|
|
855
|
-
)
|
|
874
|
+
response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
|
|
875
|
+
return content
|
|
856
876
|
|
|
857
877
|
|
|
858
|
-
|
|
859
|
-
""
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
description
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
contentMediaType: text/plain
|
|
884
|
-
contentEncoding: gzip
|
|
885
|
-
403:
|
|
886
|
-
description: Forbidden
|
|
887
|
-
404:
|
|
888
|
-
description: Dataset does not exist.
|
|
889
|
-
422:
|
|
890
|
-
description: Dataset ID or version ID is invalid.
|
|
891
|
-
"""
|
|
878
|
+
@router.get(
|
|
879
|
+
"/datasets/{id}/jsonl/openai_evals",
|
|
880
|
+
operation_id="getDatasetJSONLOpenAIEvals",
|
|
881
|
+
summary="Download dataset examples as OpenAI evals JSONL file",
|
|
882
|
+
response_class=PlainTextResponse,
|
|
883
|
+
responses=add_errors_to_responses(
|
|
884
|
+
[
|
|
885
|
+
{
|
|
886
|
+
"status_code": HTTP_422_UNPROCESSABLE_ENTITY,
|
|
887
|
+
"description": "Invalid dataset or version ID",
|
|
888
|
+
}
|
|
889
|
+
]
|
|
890
|
+
),
|
|
891
|
+
)
|
|
892
|
+
async def get_dataset_jsonl_openai_evals(
|
|
893
|
+
request: Request,
|
|
894
|
+
response: Response,
|
|
895
|
+
id: str = Path(description="The ID of the dataset"),
|
|
896
|
+
version_id: Optional[str] = Query(
|
|
897
|
+
default=None,
|
|
898
|
+
description=(
|
|
899
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
900
|
+
),
|
|
901
|
+
),
|
|
902
|
+
) -> bytes:
|
|
892
903
|
try:
|
|
893
|
-
|
|
904
|
+
async with request.app.state.db() as session:
|
|
905
|
+
dataset_name, examples = await _get_db_examples(
|
|
906
|
+
session=session, id=id, version_id=version_id
|
|
907
|
+
)
|
|
894
908
|
except ValueError as e:
|
|
895
|
-
|
|
909
|
+
raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
896
910
|
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
headers={
|
|
900
|
-
"content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
|
|
901
|
-
"content-type": "text/plain",
|
|
902
|
-
"content-encoding": "gzip",
|
|
903
|
-
},
|
|
904
|
-
)
|
|
911
|
+
response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
|
|
912
|
+
return content
|
|
905
913
|
|
|
906
914
|
|
|
907
915
|
def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
908
916
|
records = [
|
|
909
917
|
{
|
|
910
918
|
"example_id": GlobalID(
|
|
911
|
-
type_name=
|
|
919
|
+
type_name=DatasetExampleNodeType.__name__,
|
|
912
920
|
node_id=str(ex.dataset_example_id),
|
|
913
921
|
),
|
|
914
922
|
**{f"input_{k}": v for k, v in ex.input.items()},
|
|
@@ -917,7 +925,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
|
917
925
|
}
|
|
918
926
|
for ex in examples
|
|
919
927
|
]
|
|
920
|
-
return
|
|
928
|
+
return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
|
|
921
929
|
|
|
922
930
|
|
|
923
931
|
def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
@@ -938,7 +946,7 @@ def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision])
|
|
|
938
946
|
).encode()
|
|
939
947
|
)
|
|
940
948
|
records.seek(0)
|
|
941
|
-
return
|
|
949
|
+
return records.read()
|
|
942
950
|
|
|
943
951
|
|
|
944
952
|
def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
@@ -965,18 +973,17 @@ def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision
|
|
|
965
973
|
).encode()
|
|
966
974
|
)
|
|
967
975
|
records.seek(0)
|
|
968
|
-
return
|
|
976
|
+
return records.read()
|
|
969
977
|
|
|
970
978
|
|
|
971
|
-
async def _get_db_examples(
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(
|
|
979
|
+
async def _get_db_examples(
|
|
980
|
+
*, session: Any, id: str, version_id: Optional[str]
|
|
981
|
+
) -> Tuple[str, List[models.DatasetExampleRevision]]:
|
|
982
|
+
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
|
|
975
983
|
dataset_version_id: Optional[int] = None
|
|
976
|
-
if version_id
|
|
984
|
+
if version_id:
|
|
977
985
|
dataset_version_id = from_global_id_with_expected_type(
|
|
978
|
-
GlobalID.from_id(version_id),
|
|
979
|
-
DatasetVersion.__name__,
|
|
986
|
+
GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
|
|
980
987
|
)
|
|
981
988
|
latest_version = (
|
|
982
989
|
select(
|
|
@@ -1009,13 +1016,12 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
|
|
|
1009
1016
|
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
1010
1017
|
.order_by(models.DatasetExampleRevision.dataset_example_id)
|
|
1011
1018
|
)
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
examples = [r async for r in await session.stream_scalars(stmt)]
|
|
1019
|
+
dataset_name: Optional[str] = await session.scalar(
|
|
1020
|
+
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
1021
|
+
)
|
|
1022
|
+
if not dataset_name:
|
|
1023
|
+
raise ValueError("Dataset does not exist.")
|
|
1024
|
+
examples = [r async for r in await session.stream_scalars(stmt)]
|
|
1019
1025
|
return dataset_name, examples
|
|
1020
1026
|
|
|
1021
1027
|
|