arize-phoenix 4.4.2__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +12 -11
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +110 -57
- phoenix/__init__.py +0 -27
- phoenix/config.py +21 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +66 -64
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators.py +275 -0
- phoenix/datasets/experiments.py +469 -0
- phoenix/datasets/tracing.py +66 -0
- phoenix/datasets/types.py +212 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +234 -0
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +13 -2
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +230 -3
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/metrics/binning.py +2 -2
- phoenix/server/api/context.py +16 -0
- phoenix/server/api/dataloaders/__init__.py +16 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +178 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +42 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +861 -0
- phoenix/server/api/routers/v1/evaluations.py +4 -2
- phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
- phoenix/server/api/routers/v1/experiment_runs.py +108 -0
- phoenix/server/api/routers/v1/experiments.py +174 -0
- phoenix/server/api/routers/v1/spans.py +3 -1
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +135 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Segments.py +1 -1
- phoenix/server/api/types/Span.py +78 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +99 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2246 -1368
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +316 -21
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.2.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
|
@@ -0,0 +1,861 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import gzip
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import zlib
|
|
7
|
+
from asyncio import QueueFull
|
|
8
|
+
from collections import Counter
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import (
|
|
12
|
+
Any,
|
|
13
|
+
Awaitable,
|
|
14
|
+
Callable,
|
|
15
|
+
Coroutine,
|
|
16
|
+
Dict,
|
|
17
|
+
FrozenSet,
|
|
18
|
+
Iterator,
|
|
19
|
+
List,
|
|
20
|
+
Optional,
|
|
21
|
+
Tuple,
|
|
22
|
+
Union,
|
|
23
|
+
cast,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
import pandas as pd
|
|
27
|
+
import pyarrow as pa
|
|
28
|
+
from sqlalchemy import and_, func, select
|
|
29
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
30
|
+
from starlette.concurrency import run_in_threadpool
|
|
31
|
+
from starlette.datastructures import FormData, UploadFile
|
|
32
|
+
from starlette.requests import Request
|
|
33
|
+
from starlette.responses import JSONResponse, Response
|
|
34
|
+
from starlette.status import (
|
|
35
|
+
HTTP_403_FORBIDDEN,
|
|
36
|
+
HTTP_404_NOT_FOUND,
|
|
37
|
+
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
38
|
+
HTTP_429_TOO_MANY_REQUESTS,
|
|
39
|
+
)
|
|
40
|
+
from strawberry.relay import GlobalID
|
|
41
|
+
from typing_extensions import TypeAlias, assert_never
|
|
42
|
+
|
|
43
|
+
from phoenix.db import models
|
|
44
|
+
from phoenix.db.insertion.dataset import (
|
|
45
|
+
DatasetAction,
|
|
46
|
+
DatasetExampleAdditionEvent,
|
|
47
|
+
add_dataset_examples,
|
|
48
|
+
)
|
|
49
|
+
from phoenix.server.api.types.Dataset import Dataset
|
|
50
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
51
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
52
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
53
|
+
|
|
54
|
+
logger = logging.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
NODE_NAME = "Dataset"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def list_datasets(request: Request) -> Response:
|
|
60
|
+
"""
|
|
61
|
+
summary: List datasets with cursor-based pagination
|
|
62
|
+
operationId: listDatasets
|
|
63
|
+
tags:
|
|
64
|
+
- datasets
|
|
65
|
+
parameters:
|
|
66
|
+
- in: query
|
|
67
|
+
name: cursor
|
|
68
|
+
required: false
|
|
69
|
+
schema:
|
|
70
|
+
type: string
|
|
71
|
+
description: Cursor for pagination
|
|
72
|
+
- in: query
|
|
73
|
+
name: limit
|
|
74
|
+
required: false
|
|
75
|
+
schema:
|
|
76
|
+
type: integer
|
|
77
|
+
default: 10
|
|
78
|
+
- in: query
|
|
79
|
+
name: name
|
|
80
|
+
required: false
|
|
81
|
+
schema:
|
|
82
|
+
type: string
|
|
83
|
+
description: match by dataset name
|
|
84
|
+
responses:
|
|
85
|
+
200:
|
|
86
|
+
description: A paginated list of datasets
|
|
87
|
+
content:
|
|
88
|
+
application/json:
|
|
89
|
+
schema:
|
|
90
|
+
type: object
|
|
91
|
+
properties:
|
|
92
|
+
next_cursor:
|
|
93
|
+
type: string
|
|
94
|
+
data:
|
|
95
|
+
type: array
|
|
96
|
+
items:
|
|
97
|
+
type: object
|
|
98
|
+
properties:
|
|
99
|
+
id:
|
|
100
|
+
type: string
|
|
101
|
+
name:
|
|
102
|
+
type: string
|
|
103
|
+
description:
|
|
104
|
+
type: string
|
|
105
|
+
metadata:
|
|
106
|
+
type: object
|
|
107
|
+
created_at:
|
|
108
|
+
type: string
|
|
109
|
+
format: date-time
|
|
110
|
+
updated_at:
|
|
111
|
+
type: string
|
|
112
|
+
format: date-time
|
|
113
|
+
403:
|
|
114
|
+
description: Forbidden
|
|
115
|
+
404:
|
|
116
|
+
description: No datasets found
|
|
117
|
+
"""
|
|
118
|
+
name = request.query_params.get("name")
|
|
119
|
+
cursor = request.query_params.get("cursor")
|
|
120
|
+
limit = int(request.query_params.get("limit", 10))
|
|
121
|
+
async with request.app.state.db() as session:
|
|
122
|
+
query = select(models.Dataset).order_by(models.Dataset.id.desc())
|
|
123
|
+
|
|
124
|
+
if cursor:
|
|
125
|
+
try:
|
|
126
|
+
cursor_id = GlobalID.from_id(cursor).node_id
|
|
127
|
+
query = query.filter(models.Dataset.id <= int(cursor_id))
|
|
128
|
+
except ValueError:
|
|
129
|
+
return Response(
|
|
130
|
+
content=f"Invalid cursor format: {cursor}",
|
|
131
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
132
|
+
)
|
|
133
|
+
if name:
|
|
134
|
+
query = query.filter(models.Dataset.name.is_(name))
|
|
135
|
+
|
|
136
|
+
query = query.limit(limit + 1)
|
|
137
|
+
result = await session.execute(query)
|
|
138
|
+
datasets = result.scalars().all()
|
|
139
|
+
|
|
140
|
+
if not datasets:
|
|
141
|
+
return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
|
|
142
|
+
|
|
143
|
+
next_cursor = None
|
|
144
|
+
if len(datasets) == limit + 1:
|
|
145
|
+
next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
|
|
146
|
+
datasets = datasets[:-1]
|
|
147
|
+
|
|
148
|
+
data = []
|
|
149
|
+
for dataset in datasets:
|
|
150
|
+
data.append(
|
|
151
|
+
{
|
|
152
|
+
"id": str(GlobalID(NODE_NAME, str(dataset.id))),
|
|
153
|
+
"name": dataset.name,
|
|
154
|
+
"description": dataset.description,
|
|
155
|
+
"metadata": dataset.metadata_,
|
|
156
|
+
"created_at": dataset.created_at.isoformat(),
|
|
157
|
+
"updated_at": dataset.updated_at.isoformat(),
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return JSONResponse(content={"next_cursor": next_cursor, "data": data})
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
async def get_dataset_by_id(request: Request) -> Response:
|
|
165
|
+
"""
|
|
166
|
+
summary: Get dataset by ID
|
|
167
|
+
operationId: getDatasetById
|
|
168
|
+
tags:
|
|
169
|
+
- datasets
|
|
170
|
+
parameters:
|
|
171
|
+
- in: path
|
|
172
|
+
name: id
|
|
173
|
+
required: true
|
|
174
|
+
schema:
|
|
175
|
+
type: string
|
|
176
|
+
responses:
|
|
177
|
+
200:
|
|
178
|
+
description: Success
|
|
179
|
+
content:
|
|
180
|
+
application/json:
|
|
181
|
+
schema:
|
|
182
|
+
type: object
|
|
183
|
+
properties:
|
|
184
|
+
id:
|
|
185
|
+
type: string
|
|
186
|
+
name:
|
|
187
|
+
type: string
|
|
188
|
+
description:
|
|
189
|
+
type: string
|
|
190
|
+
metadata:
|
|
191
|
+
type: object
|
|
192
|
+
created_at:
|
|
193
|
+
type: string
|
|
194
|
+
format: date-time
|
|
195
|
+
updated_at:
|
|
196
|
+
type: string
|
|
197
|
+
format: date-time
|
|
198
|
+
example_count:
|
|
199
|
+
type: integer
|
|
200
|
+
403:
|
|
201
|
+
description: Forbidden
|
|
202
|
+
404:
|
|
203
|
+
description: Dataset not found
|
|
204
|
+
"""
|
|
205
|
+
dataset_id = GlobalID.from_id(request.path_params["id"])
|
|
206
|
+
|
|
207
|
+
if (type_name := dataset_id.type_name) != NODE_NAME:
|
|
208
|
+
return Response(
|
|
209
|
+
content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
210
|
+
)
|
|
211
|
+
async with request.app.state.db() as session:
|
|
212
|
+
result = await session.execute(
|
|
213
|
+
select(models.Dataset, models.Dataset.example_count).filter(
|
|
214
|
+
models.Dataset.id == int(dataset_id.node_id)
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
dataset_query = result.first()
|
|
218
|
+
dataset = dataset_query[0] if dataset_query else None
|
|
219
|
+
example_count = dataset_query[1] if dataset_query else 0
|
|
220
|
+
if dataset is None:
|
|
221
|
+
return Response(
|
|
222
|
+
content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
output_dict = {
|
|
226
|
+
"id": str(dataset_id),
|
|
227
|
+
"name": dataset.name,
|
|
228
|
+
"description": dataset.description,
|
|
229
|
+
"metadata": dataset.metadata_,
|
|
230
|
+
"created_at": dataset.created_at.isoformat(),
|
|
231
|
+
"updated_at": dataset.updated_at.isoformat(),
|
|
232
|
+
"example_count": example_count,
|
|
233
|
+
}
|
|
234
|
+
return JSONResponse(content=output_dict)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
async def get_dataset_versions(request: Request) -> Response:
|
|
238
|
+
"""
|
|
239
|
+
summary: Get dataset versions (sorted from latest to oldest)
|
|
240
|
+
operationId: getDatasetVersionsByDatasetId
|
|
241
|
+
tags:
|
|
242
|
+
- datasets
|
|
243
|
+
parameters:
|
|
244
|
+
- in: path
|
|
245
|
+
name: id
|
|
246
|
+
required: true
|
|
247
|
+
description: Dataset ID
|
|
248
|
+
schema:
|
|
249
|
+
type: string
|
|
250
|
+
- in: query
|
|
251
|
+
name: cursor
|
|
252
|
+
description: Cursor for pagination.
|
|
253
|
+
schema:
|
|
254
|
+
type: string
|
|
255
|
+
- in: query
|
|
256
|
+
name: limit
|
|
257
|
+
description: Maximum number versions to return.
|
|
258
|
+
schema:
|
|
259
|
+
type: integer
|
|
260
|
+
default: 10
|
|
261
|
+
responses:
|
|
262
|
+
200:
|
|
263
|
+
description: Success
|
|
264
|
+
content:
|
|
265
|
+
application/json:
|
|
266
|
+
schema:
|
|
267
|
+
type: object
|
|
268
|
+
properties:
|
|
269
|
+
next_cursor:
|
|
270
|
+
type: string
|
|
271
|
+
data:
|
|
272
|
+
type: array
|
|
273
|
+
items:
|
|
274
|
+
type: object
|
|
275
|
+
properties:
|
|
276
|
+
version_id:
|
|
277
|
+
type: string
|
|
278
|
+
description:
|
|
279
|
+
type: string
|
|
280
|
+
metadata:
|
|
281
|
+
type: object
|
|
282
|
+
created_at:
|
|
283
|
+
type: string
|
|
284
|
+
format: date-time
|
|
285
|
+
403:
|
|
286
|
+
description: Forbidden
|
|
287
|
+
422:
|
|
288
|
+
description: Dataset ID, cursor or limit is invalid.
|
|
289
|
+
"""
|
|
290
|
+
if id_ := request.path_params.get("id"):
|
|
291
|
+
try:
|
|
292
|
+
dataset_id = from_global_id_with_expected_type(
|
|
293
|
+
GlobalID.from_id(id_),
|
|
294
|
+
Dataset.__name__,
|
|
295
|
+
)
|
|
296
|
+
except ValueError:
|
|
297
|
+
return Response(
|
|
298
|
+
content=f"Invalid Dataset ID: {id_}",
|
|
299
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
return Response(
|
|
303
|
+
content="Missing Dataset ID",
|
|
304
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
305
|
+
)
|
|
306
|
+
try:
|
|
307
|
+
limit = int(request.query_params.get("limit", 10))
|
|
308
|
+
assert limit > 0
|
|
309
|
+
except (ValueError, AssertionError):
|
|
310
|
+
return Response(
|
|
311
|
+
content="Invalid limit parameter",
|
|
312
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
313
|
+
)
|
|
314
|
+
stmt = (
|
|
315
|
+
select(models.DatasetVersion)
|
|
316
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
317
|
+
.order_by(models.DatasetVersion.id.desc())
|
|
318
|
+
.limit(limit + 1)
|
|
319
|
+
)
|
|
320
|
+
if cursor := request.query_params.get("cursor"):
|
|
321
|
+
try:
|
|
322
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
323
|
+
GlobalID.from_id(cursor),
|
|
324
|
+
DatasetVersion.__name__,
|
|
325
|
+
)
|
|
326
|
+
except ValueError:
|
|
327
|
+
return Response(
|
|
328
|
+
content=f"Invalid cursor: {cursor}",
|
|
329
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
330
|
+
)
|
|
331
|
+
max_dataset_version_id = (
|
|
332
|
+
select(models.DatasetVersion.id)
|
|
333
|
+
.where(models.DatasetVersion.id == dataset_version_id)
|
|
334
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
335
|
+
).scalar_subquery()
|
|
336
|
+
stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
|
|
337
|
+
async with request.app.state.db() as session:
|
|
338
|
+
data = [
|
|
339
|
+
{
|
|
340
|
+
"version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
|
|
341
|
+
"description": version.description,
|
|
342
|
+
"metadata": version.metadata_,
|
|
343
|
+
"created_at": version.created_at.isoformat(),
|
|
344
|
+
}
|
|
345
|
+
async for version in await session.stream_scalars(stmt)
|
|
346
|
+
]
|
|
347
|
+
next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
|
|
348
|
+
return JSONResponse(content={"next_cursor": next_cursor, "data": data})
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
async def post_datasets_upload(request: Request) -> Response:
|
|
352
|
+
"""
|
|
353
|
+
summary: Upload CSV or PyArrow file as dataset
|
|
354
|
+
operationId: uploadDataset
|
|
355
|
+
tags:
|
|
356
|
+
- datasets
|
|
357
|
+
parameters:
|
|
358
|
+
- in: query
|
|
359
|
+
name: sync
|
|
360
|
+
description: If true, fulfill request synchronously and return JSON containing dataset_id
|
|
361
|
+
schema:
|
|
362
|
+
type: boolean
|
|
363
|
+
requestBody:
|
|
364
|
+
content:
|
|
365
|
+
multipart/form-data:
|
|
366
|
+
schema:
|
|
367
|
+
type: object
|
|
368
|
+
required:
|
|
369
|
+
- name
|
|
370
|
+
- input_keys[]
|
|
371
|
+
- output_keys[]
|
|
372
|
+
- file
|
|
373
|
+
properties:
|
|
374
|
+
action:
|
|
375
|
+
type: string
|
|
376
|
+
enum: [create, append]
|
|
377
|
+
name:
|
|
378
|
+
type: string
|
|
379
|
+
description:
|
|
380
|
+
type: string
|
|
381
|
+
input_keys[]:
|
|
382
|
+
type: array
|
|
383
|
+
items:
|
|
384
|
+
type: string
|
|
385
|
+
uniqueItems: true
|
|
386
|
+
output_keys[]:
|
|
387
|
+
type: array
|
|
388
|
+
items:
|
|
389
|
+
type: string
|
|
390
|
+
uniqueItems: true
|
|
391
|
+
metadata_keys[]:
|
|
392
|
+
type: array
|
|
393
|
+
items:
|
|
394
|
+
type: string
|
|
395
|
+
uniqueItems: true
|
|
396
|
+
file:
|
|
397
|
+
type: string
|
|
398
|
+
format: binary
|
|
399
|
+
responses:
|
|
400
|
+
200:
|
|
401
|
+
description: Success
|
|
402
|
+
403:
|
|
403
|
+
description: Forbidden
|
|
404
|
+
422:
|
|
405
|
+
description: Request body is invalid
|
|
406
|
+
"""
|
|
407
|
+
if request.app.state.read_only:
|
|
408
|
+
return Response(status_code=HTTP_403_FORBIDDEN)
|
|
409
|
+
async with request.form() as form:
|
|
410
|
+
try:
|
|
411
|
+
(
|
|
412
|
+
action,
|
|
413
|
+
name,
|
|
414
|
+
description,
|
|
415
|
+
input_keys,
|
|
416
|
+
output_keys,
|
|
417
|
+
metadata_keys,
|
|
418
|
+
file,
|
|
419
|
+
) = await _parse_form_data(form)
|
|
420
|
+
except ValueError as e:
|
|
421
|
+
return Response(
|
|
422
|
+
content=str(e),
|
|
423
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
424
|
+
)
|
|
425
|
+
if action is DatasetAction.CREATE:
|
|
426
|
+
async with request.app.state.db() as session:
|
|
427
|
+
if await _check_table_exists(session, name):
|
|
428
|
+
return Response(
|
|
429
|
+
content=f"Dataset already exists: {name=}",
|
|
430
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
431
|
+
)
|
|
432
|
+
content = await file.read()
|
|
433
|
+
try:
|
|
434
|
+
examples: Union[Examples, Awaitable[Examples]]
|
|
435
|
+
content_type = FileContentType(file.content_type)
|
|
436
|
+
if content_type is FileContentType.CSV:
|
|
437
|
+
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
438
|
+
examples, column_headers = await _process_csv(content, encoding)
|
|
439
|
+
elif content_type is FileContentType.PYARROW:
|
|
440
|
+
examples, column_headers = await _process_pyarrow(content)
|
|
441
|
+
else:
|
|
442
|
+
assert_never(content_type)
|
|
443
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
444
|
+
except ValueError as e:
|
|
445
|
+
return Response(
|
|
446
|
+
content=str(e),
|
|
447
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
448
|
+
)
|
|
449
|
+
operation = cast(
|
|
450
|
+
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
451
|
+
partial(
|
|
452
|
+
add_dataset_examples,
|
|
453
|
+
examples=examples,
|
|
454
|
+
action=action,
|
|
455
|
+
name=name,
|
|
456
|
+
description=description,
|
|
457
|
+
input_keys=input_keys,
|
|
458
|
+
output_keys=output_keys,
|
|
459
|
+
metadata_keys=metadata_keys,
|
|
460
|
+
),
|
|
461
|
+
)
|
|
462
|
+
if request.query_params.get("sync") == "true":
|
|
463
|
+
async with request.app.state.db() as session:
|
|
464
|
+
dataset_id = (await operation(session)).dataset_id
|
|
465
|
+
return JSONResponse(
|
|
466
|
+
content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
|
|
467
|
+
)
|
|
468
|
+
try:
|
|
469
|
+
request.state.enqueue_operation(operation)
|
|
470
|
+
except QueueFull:
|
|
471
|
+
if isinstance(examples, Coroutine):
|
|
472
|
+
examples.close()
|
|
473
|
+
return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
|
|
474
|
+
return Response()
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class FileContentType(Enum):
|
|
478
|
+
CSV = "text/csv"
|
|
479
|
+
PYARROW = "application/x-pandas-pyarrow"
|
|
480
|
+
|
|
481
|
+
@classmethod
|
|
482
|
+
def _missing_(cls, v: Any) -> "FileContentType":
|
|
483
|
+
if isinstance(v, str) and v and v.isascii() and not v.islower():
|
|
484
|
+
return cls(v.lower())
|
|
485
|
+
raise ValueError(f"Invalid file content type: {v}")
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
class FileContentEncoding(Enum):
|
|
489
|
+
NONE = "none"
|
|
490
|
+
GZIP = "gzip"
|
|
491
|
+
DEFLATE = "deflate"
|
|
492
|
+
|
|
493
|
+
@classmethod
|
|
494
|
+
def _missing_(cls, v: Any) -> "FileContentEncoding":
|
|
495
|
+
if v is None:
|
|
496
|
+
return cls("none")
|
|
497
|
+
if isinstance(v, str) and v and v.isascii() and not v.islower():
|
|
498
|
+
return cls(v.lower())
|
|
499
|
+
raise ValueError(f"Invalid file content encoding: {v}")
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
Name: TypeAlias = str
|
|
503
|
+
Description: TypeAlias = Optional[str]
|
|
504
|
+
InputKeys: TypeAlias = FrozenSet[str]
|
|
505
|
+
OutputKeys: TypeAlias = FrozenSet[str]
|
|
506
|
+
MetadataKeys: TypeAlias = FrozenSet[str]
|
|
507
|
+
DatasetId: TypeAlias = int
|
|
508
|
+
Examples: TypeAlias = Iterator[Dict[str, Any]]
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
async def _process_csv(
|
|
512
|
+
content: bytes,
|
|
513
|
+
content_encoding: FileContentEncoding,
|
|
514
|
+
) -> Tuple[Examples, FrozenSet[str]]:
|
|
515
|
+
if content_encoding is FileContentEncoding.GZIP:
|
|
516
|
+
content = await run_in_threadpool(gzip.decompress, content)
|
|
517
|
+
elif content_encoding is FileContentEncoding.DEFLATE:
|
|
518
|
+
content = await run_in_threadpool(zlib.decompress, content)
|
|
519
|
+
elif content_encoding is not FileContentEncoding.NONE:
|
|
520
|
+
assert_never(content_encoding)
|
|
521
|
+
reader = await run_in_threadpool(lambda c: csv.DictReader(io.StringIO(c.decode())), content)
|
|
522
|
+
if reader.fieldnames is None:
|
|
523
|
+
raise ValueError("Missing CSV column header")
|
|
524
|
+
(header, freq), *_ = Counter(reader.fieldnames).most_common(1)
|
|
525
|
+
if freq > 1:
|
|
526
|
+
raise ValueError(f"Duplicated column header in CSV file: {header}")
|
|
527
|
+
column_headers = frozenset(reader.fieldnames)
|
|
528
|
+
return reader, column_headers
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
async def _process_pyarrow(
|
|
532
|
+
content: bytes,
|
|
533
|
+
) -> Tuple[Awaitable[Examples], FrozenSet[str]]:
|
|
534
|
+
try:
|
|
535
|
+
reader = pa.ipc.open_stream(content)
|
|
536
|
+
except pa.ArrowInvalid as e:
|
|
537
|
+
raise ValueError("File is not valid pyarrow") from e
|
|
538
|
+
column_headers = frozenset(reader.schema.names)
|
|
539
|
+
|
|
540
|
+
def get_examples() -> Iterator[Dict[str, Any]]:
|
|
541
|
+
yield from reader.read_pandas().to_dict(orient="records")
|
|
542
|
+
|
|
543
|
+
return run_in_threadpool(get_examples), column_headers
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
async def _check_table_exists(session: AsyncSession, name: str) -> bool:
|
|
547
|
+
return bool(
|
|
548
|
+
await session.scalar(
|
|
549
|
+
select(1).select_from(models.Dataset).where(models.Dataset.name == name)
|
|
550
|
+
)
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def _check_keys_exist(
|
|
555
|
+
column_headers: FrozenSet[str],
|
|
556
|
+
input_keys: InputKeys,
|
|
557
|
+
output_keys: OutputKeys,
|
|
558
|
+
metadata_keys: MetadataKeys,
|
|
559
|
+
) -> None:
|
|
560
|
+
for desc, keys in (
|
|
561
|
+
("input", input_keys),
|
|
562
|
+
("output", output_keys),
|
|
563
|
+
("metadata", metadata_keys),
|
|
564
|
+
):
|
|
565
|
+
if keys and (diff := keys.difference(column_headers)):
|
|
566
|
+
raise ValueError(f"{desc} keys not found in column headers: {diff}")
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
async def _parse_form_data(
|
|
570
|
+
form: FormData,
|
|
571
|
+
) -> Tuple[
|
|
572
|
+
DatasetAction,
|
|
573
|
+
Name,
|
|
574
|
+
Description,
|
|
575
|
+
InputKeys,
|
|
576
|
+
OutputKeys,
|
|
577
|
+
MetadataKeys,
|
|
578
|
+
UploadFile,
|
|
579
|
+
]:
|
|
580
|
+
name = cast(Optional[str], form.get("name"))
|
|
581
|
+
if not name:
|
|
582
|
+
raise ValueError("Dataset name must not be empty")
|
|
583
|
+
action = DatasetAction(cast(Optional[str], form.get("action")) or "create")
|
|
584
|
+
file = form["file"]
|
|
585
|
+
if not isinstance(file, UploadFile):
|
|
586
|
+
raise ValueError("Malformed file in form data.")
|
|
587
|
+
description = cast(Optional[str], form.get("description")) or file.filename
|
|
588
|
+
input_keys = frozenset(filter(bool, cast(List[str], form.getlist("input_keys[]"))))
|
|
589
|
+
output_keys = frozenset(filter(bool, cast(List[str], form.getlist("output_keys[]"))))
|
|
590
|
+
metadata_keys = frozenset(filter(bool, cast(List[str], form.getlist("metadata_keys[]"))))
|
|
591
|
+
return (
|
|
592
|
+
action,
|
|
593
|
+
name,
|
|
594
|
+
description,
|
|
595
|
+
input_keys,
|
|
596
|
+
output_keys,
|
|
597
|
+
metadata_keys,
|
|
598
|
+
file,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
async def get_dataset_csv(request: Request) -> Response:
|
|
603
|
+
"""
|
|
604
|
+
summary: Download dataset examples as CSV text file
|
|
605
|
+
operationId: getDatasetCsv
|
|
606
|
+
tags:
|
|
607
|
+
- datasets
|
|
608
|
+
parameters:
|
|
609
|
+
- in: path
|
|
610
|
+
name: id
|
|
611
|
+
required: true
|
|
612
|
+
schema:
|
|
613
|
+
type: string
|
|
614
|
+
description: Dataset ID
|
|
615
|
+
- in: query
|
|
616
|
+
name: version
|
|
617
|
+
schema:
|
|
618
|
+
type: string
|
|
619
|
+
description: Dataset version ID. If omitted, returns the latest version.
|
|
620
|
+
responses:
|
|
621
|
+
200:
|
|
622
|
+
description: Success
|
|
623
|
+
content:
|
|
624
|
+
text/csv:
|
|
625
|
+
schema:
|
|
626
|
+
type: string
|
|
627
|
+
contentMediaType: text/csv
|
|
628
|
+
contentEncoding: gzip
|
|
629
|
+
403:
|
|
630
|
+
description: Forbidden
|
|
631
|
+
404:
|
|
632
|
+
description: Dataset does not exist.
|
|
633
|
+
422:
|
|
634
|
+
description: Dataset ID or version ID is invalid.
|
|
635
|
+
"""
|
|
636
|
+
try:
|
|
637
|
+
dataset_name, examples = await _get_db_examples(request)
|
|
638
|
+
except ValueError as e:
|
|
639
|
+
return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
640
|
+
content = await run_in_threadpool(_get_content_csv, examples)
|
|
641
|
+
return Response(
|
|
642
|
+
content=content,
|
|
643
|
+
headers={
|
|
644
|
+
"content-disposition": f'attachment; filename="{dataset_name}.csv"',
|
|
645
|
+
"content-type": "text/csv",
|
|
646
|
+
"content-encoding": "gzip",
|
|
647
|
+
},
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
|
|
652
|
+
"""
|
|
653
|
+
summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
|
|
654
|
+
operationId: getDatasetJSONLOpenAIFineTuning
|
|
655
|
+
tags:
|
|
656
|
+
- datasets
|
|
657
|
+
parameters:
|
|
658
|
+
- in: path
|
|
659
|
+
name: id
|
|
660
|
+
required: true
|
|
661
|
+
schema:
|
|
662
|
+
type: string
|
|
663
|
+
description: Dataset ID
|
|
664
|
+
- in: query
|
|
665
|
+
name: version
|
|
666
|
+
schema:
|
|
667
|
+
type: string
|
|
668
|
+
description: Dataset version ID. If omitted, returns the latest version.
|
|
669
|
+
responses:
|
|
670
|
+
200:
|
|
671
|
+
description: Success
|
|
672
|
+
content:
|
|
673
|
+
text/plain:
|
|
674
|
+
schema:
|
|
675
|
+
type: string
|
|
676
|
+
contentMediaType: text/plain
|
|
677
|
+
contentEncoding: gzip
|
|
678
|
+
403:
|
|
679
|
+
description: Forbidden
|
|
680
|
+
404:
|
|
681
|
+
description: Dataset does not exist.
|
|
682
|
+
422:
|
|
683
|
+
description: Dataset ID or version ID is invalid.
|
|
684
|
+
"""
|
|
685
|
+
try:
|
|
686
|
+
dataset_name, examples = await _get_db_examples(request)
|
|
687
|
+
except ValueError as e:
|
|
688
|
+
return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
689
|
+
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
690
|
+
return Response(
|
|
691
|
+
content=content,
|
|
692
|
+
headers={
|
|
693
|
+
"content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
|
|
694
|
+
"content-type": "text/plain",
|
|
695
|
+
"content-encoding": "gzip",
|
|
696
|
+
},
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
|
|
701
|
+
"""
|
|
702
|
+
summary: Download dataset examples as OpenAI Evals JSONL file
|
|
703
|
+
operationId: getDatasetJSONLOpenAIEvals
|
|
704
|
+
tags:
|
|
705
|
+
- datasets
|
|
706
|
+
parameters:
|
|
707
|
+
- in: path
|
|
708
|
+
name: id
|
|
709
|
+
required: true
|
|
710
|
+
schema:
|
|
711
|
+
type: string
|
|
712
|
+
description: Dataset ID
|
|
713
|
+
- in: query
|
|
714
|
+
name: version
|
|
715
|
+
schema:
|
|
716
|
+
type: string
|
|
717
|
+
description: Dataset version ID. If omitted, returns the latest version.
|
|
718
|
+
responses:
|
|
719
|
+
200:
|
|
720
|
+
description: Success
|
|
721
|
+
content:
|
|
722
|
+
text/plain:
|
|
723
|
+
schema:
|
|
724
|
+
type: string
|
|
725
|
+
contentMediaType: text/plain
|
|
726
|
+
contentEncoding: gzip
|
|
727
|
+
403:
|
|
728
|
+
description: Forbidden
|
|
729
|
+
404:
|
|
730
|
+
description: Dataset does not exist.
|
|
731
|
+
422:
|
|
732
|
+
description: Dataset ID or version ID is invalid.
|
|
733
|
+
"""
|
|
734
|
+
try:
|
|
735
|
+
dataset_name, examples = await _get_db_examples(request)
|
|
736
|
+
except ValueError as e:
|
|
737
|
+
return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
738
|
+
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
739
|
+
return Response(
|
|
740
|
+
content=content,
|
|
741
|
+
headers={
|
|
742
|
+
"content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
|
|
743
|
+
"content-type": "text/plain",
|
|
744
|
+
"content-encoding": "gzip",
|
|
745
|
+
},
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
750
|
+
records = [
|
|
751
|
+
{
|
|
752
|
+
"example_id": GlobalID(
|
|
753
|
+
type_name=DatasetExample.__name__,
|
|
754
|
+
node_id=str(ex.dataset_example_id),
|
|
755
|
+
),
|
|
756
|
+
**{f"input_{k}": v for k, v in ex.input.items()},
|
|
757
|
+
**{f"output_{k}": v for k, v in ex.output.items()},
|
|
758
|
+
**{f"metadata_{k}": v for k, v in ex.metadata_.items()},
|
|
759
|
+
}
|
|
760
|
+
for ex in examples
|
|
761
|
+
]
|
|
762
|
+
return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
766
|
+
records = io.BytesIO()
|
|
767
|
+
for ex in examples:
|
|
768
|
+
records.write(
|
|
769
|
+
(
|
|
770
|
+
json.dumps(
|
|
771
|
+
{
|
|
772
|
+
"messages": (
|
|
773
|
+
ims if isinstance(ims := ex.input.get("messages"), list) else []
|
|
774
|
+
)
|
|
775
|
+
+ (oms if isinstance(oms := ex.output.get("messages"), list) else [])
|
|
776
|
+
},
|
|
777
|
+
ensure_ascii=False,
|
|
778
|
+
)
|
|
779
|
+
+ "\n"
|
|
780
|
+
).encode()
|
|
781
|
+
)
|
|
782
|
+
records.seek(0)
|
|
783
|
+
return gzip.compress(records.read())
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
787
|
+
records = io.BytesIO()
|
|
788
|
+
for ex in examples:
|
|
789
|
+
records.write(
|
|
790
|
+
(
|
|
791
|
+
json.dumps(
|
|
792
|
+
{
|
|
793
|
+
"messages": ims
|
|
794
|
+
if isinstance(ims := ex.input.get("messages"), list)
|
|
795
|
+
else [],
|
|
796
|
+
"ideal": (
|
|
797
|
+
ideal if isinstance(ideal := last_message.get("content"), str) else ""
|
|
798
|
+
)
|
|
799
|
+
if isinstance(oms := ex.output.get("messages"), list)
|
|
800
|
+
and oms
|
|
801
|
+
and hasattr(last_message := oms[-1], "get")
|
|
802
|
+
else "",
|
|
803
|
+
},
|
|
804
|
+
ensure_ascii=False,
|
|
805
|
+
)
|
|
806
|
+
+ "\n"
|
|
807
|
+
).encode()
|
|
808
|
+
)
|
|
809
|
+
records.seek(0)
|
|
810
|
+
return gzip.compress(records.read())
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
|
|
814
|
+
if not (id_ := request.path_params.get("id")):
|
|
815
|
+
raise ValueError("Missing Dataset ID")
|
|
816
|
+
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
|
|
817
|
+
dataset_version_id: Optional[int] = None
|
|
818
|
+
if version := request.query_params.get("version"):
|
|
819
|
+
dataset_version_id = from_global_id_with_expected_type(
|
|
820
|
+
GlobalID.from_id(version),
|
|
821
|
+
DatasetVersion.__name__,
|
|
822
|
+
)
|
|
823
|
+
latest_version = (
|
|
824
|
+
select(
|
|
825
|
+
models.DatasetExampleRevision.dataset_example_id,
|
|
826
|
+
func.max(models.DatasetExampleRevision.dataset_version_id).label("dataset_version_id"),
|
|
827
|
+
)
|
|
828
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
829
|
+
.join(models.DatasetExample)
|
|
830
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
831
|
+
)
|
|
832
|
+
if dataset_version_id is not None:
|
|
833
|
+
max_dataset_version_id = (
|
|
834
|
+
select(models.DatasetVersion.id)
|
|
835
|
+
.where(models.DatasetVersion.id == dataset_version_id)
|
|
836
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
837
|
+
).scalar_subquery()
|
|
838
|
+
latest_version = latest_version.where(
|
|
839
|
+
models.DatasetExampleRevision.dataset_version_id <= max_dataset_version_id
|
|
840
|
+
)
|
|
841
|
+
subq = latest_version.subquery("latest_version")
|
|
842
|
+
stmt = (
|
|
843
|
+
select(models.DatasetExampleRevision)
|
|
844
|
+
.join(
|
|
845
|
+
subq,
|
|
846
|
+
onclause=and_(
|
|
847
|
+
models.DatasetExampleRevision.dataset_example_id == subq.c.dataset_example_id,
|
|
848
|
+
models.DatasetExampleRevision.dataset_version_id == subq.c.dataset_version_id,
|
|
849
|
+
),
|
|
850
|
+
)
|
|
851
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
852
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id)
|
|
853
|
+
)
|
|
854
|
+
async with request.app.state.db() as session:
|
|
855
|
+
dataset_name: Optional[str] = await session.scalar(
|
|
856
|
+
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
857
|
+
)
|
|
858
|
+
if not dataset_name:
|
|
859
|
+
raise ValueError("Dataset does not exist.")
|
|
860
|
+
examples = [r async for r in await session.stream_scalars(stmt)]
|
|
861
|
+
return dataset_name, examples
|