arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +27 -0
- phoenix/config.py +7 -21
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +62 -64
- phoenix/core/model_schema_adapter.py +25 -27
- phoenix/db/bulk_inserter.py +14 -54
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +2 -13
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
- phoenix/db/models.py +4 -236
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +0 -18
- phoenix/server/api/dataloaders/__init__.py +0 -18
- phoenix/server/api/dataloaders/span_descendants.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -77
- phoenix/server/api/routers/v1/evaluations.py +2 -4
- phoenix/server/api/routers/v1/spans.py +1 -3
- phoenix/server/api/routers/v1/traces.py +4 -1
- phoenix/server/api/schema.py +303 -2
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/Dataset.py +63 -282
- phoenix/server/api/types/DatasetRole.py +23 -0
- phoenix/server/api/types/Dimension.py +29 -30
- phoenix/server/api/types/EmbeddingDimension.py +34 -40
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
- phoenix/server/api/types/Model.py +42 -43
- phoenix/server/api/types/Project.py +12 -26
- phoenix/server/api/types/Span.py +2 -79
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +4 -15
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +111 -5
- phoenix/server/api/types/pagination.py +52 -10
- phoenix/server/app.py +49 -101
- phoenix/server/main.py +27 -49
- phoenix/server/openapi/docs.py +0 -3
- phoenix/server/static/index.js +2595 -3523
- phoenix/server/templates/index.html +0 -1
- phoenix/services.py +15 -15
- phoenix/session/client.py +21 -438
- phoenix/session/session.py +37 -47
- phoenix/trace/exporter.py +9 -14
- phoenix/trace/fixtures.py +7 -133
- phoenix/trace/schemas.py +2 -1
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/version.py +1 -1
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators/__init__.py +0 -18
- phoenix/datasets/evaluators/code_evaluators.py +0 -99
- phoenix/datasets/evaluators/llm_evaluators.py +0 -244
- phoenix/datasets/evaluators/utils.py +0 -292
- phoenix/datasets/experiments.py +0 -550
- phoenix/datasets/tracing.py +0 -85
- phoenix/datasets/types.py +0 -178
- phoenix/db/insertion/dataset.py +0 -237
- phoenix/db/migrations/types.py +0 -29
- phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
- phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
- phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
- phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
- phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
- phoenix/server/api/dataloaders/project_by_name.py +0 -31
- phoenix/server/api/dataloaders/span_projects.py +0 -33
- phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
- phoenix/server/api/helpers/dataset_helpers.py +0 -179
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
- phoenix/server/api/input_types/ClearProjectInput.py +0 -15
- phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
- phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
- phoenix/server/api/input_types/DatasetSort.py +0 -17
- phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
- phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
- phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
- phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
- phoenix/server/api/mutations/__init__.py +0 -13
- phoenix/server/api/mutations/auth.py +0 -11
- phoenix/server/api/mutations/dataset_mutations.py +0 -520
- phoenix/server/api/mutations/experiment_mutations.py +0 -65
- phoenix/server/api/mutations/project_mutations.py +0 -47
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +0 -6
- phoenix/server/api/openapi/schema.py +0 -16
- phoenix/server/api/queries.py +0 -503
- phoenix/server/api/routers/v1/dataset_examples.py +0 -178
- phoenix/server/api/routers/v1/datasets.py +0 -965
- phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
- phoenix/server/api/routers/v1/experiment_runs.py +0 -108
- phoenix/server/api/routers/v1/experiments.py +0 -174
- phoenix/server/api/types/AnnotatorKind.py +0 -10
- phoenix/server/api/types/CreateDatasetPayload.py +0 -8
- phoenix/server/api/types/DatasetExample.py +0 -85
- phoenix/server/api/types/DatasetExampleRevision.py +0 -34
- phoenix/server/api/types/DatasetVersion.py +0 -14
- phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
- phoenix/server/api/types/Experiment.py +0 -140
- phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
- phoenix/server/api/types/ExperimentComparison.py +0 -19
- phoenix/server/api/types/ExperimentRun.py +0 -91
- phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
- phoenix/server/api/types/Inferences.py +0 -80
- phoenix/server/api/types/InferencesRole.py +0 -23
- phoenix/utilities/json.py +0 -61
- phoenix/utilities/re.py +0 -50
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
|
@@ -1,965 +0,0 @@
|
|
|
1
|
-
import csv
|
|
2
|
-
import gzip
|
|
3
|
-
import io
|
|
4
|
-
import json
|
|
5
|
-
import logging
|
|
6
|
-
import zlib
|
|
7
|
-
from asyncio import QueueFull
|
|
8
|
-
from collections import Counter
|
|
9
|
-
from enum import Enum
|
|
10
|
-
from functools import partial
|
|
11
|
-
from typing import (
|
|
12
|
-
Any,
|
|
13
|
-
Awaitable,
|
|
14
|
-
Callable,
|
|
15
|
-
Coroutine,
|
|
16
|
-
FrozenSet,
|
|
17
|
-
Iterator,
|
|
18
|
-
List,
|
|
19
|
-
Mapping,
|
|
20
|
-
Optional,
|
|
21
|
-
Sequence,
|
|
22
|
-
Tuple,
|
|
23
|
-
Union,
|
|
24
|
-
cast,
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
import pandas as pd
|
|
28
|
-
import pyarrow as pa
|
|
29
|
-
from sqlalchemy import and_, func, select
|
|
30
|
-
from sqlalchemy.ext.asyncio import AsyncSession
|
|
31
|
-
from starlette.concurrency import run_in_threadpool
|
|
32
|
-
from starlette.datastructures import FormData, UploadFile
|
|
33
|
-
from starlette.requests import Request
|
|
34
|
-
from starlette.responses import JSONResponse, Response
|
|
35
|
-
from starlette.status import (
|
|
36
|
-
HTTP_404_NOT_FOUND,
|
|
37
|
-
HTTP_409_CONFLICT,
|
|
38
|
-
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
39
|
-
HTTP_429_TOO_MANY_REQUESTS,
|
|
40
|
-
)
|
|
41
|
-
from strawberry.relay import GlobalID
|
|
42
|
-
from typing_extensions import TypeAlias, assert_never
|
|
43
|
-
|
|
44
|
-
from phoenix.db import models
|
|
45
|
-
from phoenix.db.insertion.dataset import (
|
|
46
|
-
DatasetAction,
|
|
47
|
-
DatasetExampleAdditionEvent,
|
|
48
|
-
ExampleContent,
|
|
49
|
-
add_dataset_examples,
|
|
50
|
-
)
|
|
51
|
-
from phoenix.server.api.types.Dataset import Dataset
|
|
52
|
-
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
53
|
-
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
54
|
-
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
55
|
-
|
|
56
|
-
logger = logging.getLogger(__name__)
|
|
57
|
-
|
|
58
|
-
NODE_NAME = "Dataset"
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
async def list_datasets(request: Request) -> Response:
|
|
62
|
-
"""
|
|
63
|
-
summary: List datasets with cursor-based pagination
|
|
64
|
-
operationId: listDatasets
|
|
65
|
-
tags:
|
|
66
|
-
- datasets
|
|
67
|
-
parameters:
|
|
68
|
-
- in: query
|
|
69
|
-
name: cursor
|
|
70
|
-
required: false
|
|
71
|
-
schema:
|
|
72
|
-
type: string
|
|
73
|
-
description: Cursor for pagination
|
|
74
|
-
- in: query
|
|
75
|
-
name: limit
|
|
76
|
-
required: false
|
|
77
|
-
schema:
|
|
78
|
-
type: integer
|
|
79
|
-
default: 10
|
|
80
|
-
- in: query
|
|
81
|
-
name: name
|
|
82
|
-
required: false
|
|
83
|
-
schema:
|
|
84
|
-
type: string
|
|
85
|
-
description: match by dataset name
|
|
86
|
-
responses:
|
|
87
|
-
200:
|
|
88
|
-
description: A paginated list of datasets
|
|
89
|
-
content:
|
|
90
|
-
application/json:
|
|
91
|
-
schema:
|
|
92
|
-
type: object
|
|
93
|
-
properties:
|
|
94
|
-
next_cursor:
|
|
95
|
-
type: string
|
|
96
|
-
data:
|
|
97
|
-
type: array
|
|
98
|
-
items:
|
|
99
|
-
type: object
|
|
100
|
-
properties:
|
|
101
|
-
id:
|
|
102
|
-
type: string
|
|
103
|
-
name:
|
|
104
|
-
type: string
|
|
105
|
-
description:
|
|
106
|
-
type: string
|
|
107
|
-
metadata:
|
|
108
|
-
type: object
|
|
109
|
-
created_at:
|
|
110
|
-
type: string
|
|
111
|
-
format: date-time
|
|
112
|
-
updated_at:
|
|
113
|
-
type: string
|
|
114
|
-
format: date-time
|
|
115
|
-
403:
|
|
116
|
-
description: Forbidden
|
|
117
|
-
404:
|
|
118
|
-
description: No datasets found
|
|
119
|
-
"""
|
|
120
|
-
name = request.query_params.get("name")
|
|
121
|
-
cursor = request.query_params.get("cursor")
|
|
122
|
-
limit = int(request.query_params.get("limit", 10))
|
|
123
|
-
async with request.app.state.db() as session:
|
|
124
|
-
query = select(models.Dataset).order_by(models.Dataset.id.desc())
|
|
125
|
-
|
|
126
|
-
if cursor:
|
|
127
|
-
try:
|
|
128
|
-
cursor_id = GlobalID.from_id(cursor).node_id
|
|
129
|
-
query = query.filter(models.Dataset.id <= int(cursor_id))
|
|
130
|
-
except ValueError:
|
|
131
|
-
return Response(
|
|
132
|
-
content=f"Invalid cursor format: {cursor}",
|
|
133
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
134
|
-
)
|
|
135
|
-
if name:
|
|
136
|
-
query = query.filter(models.Dataset.name.is_(name))
|
|
137
|
-
|
|
138
|
-
query = query.limit(limit + 1)
|
|
139
|
-
result = await session.execute(query)
|
|
140
|
-
datasets = result.scalars().all()
|
|
141
|
-
|
|
142
|
-
if not datasets:
|
|
143
|
-
return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
|
|
144
|
-
|
|
145
|
-
next_cursor = None
|
|
146
|
-
if len(datasets) == limit + 1:
|
|
147
|
-
next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
|
|
148
|
-
datasets = datasets[:-1]
|
|
149
|
-
|
|
150
|
-
data = []
|
|
151
|
-
for dataset in datasets:
|
|
152
|
-
data.append(
|
|
153
|
-
{
|
|
154
|
-
"id": str(GlobalID(NODE_NAME, str(dataset.id))),
|
|
155
|
-
"name": dataset.name,
|
|
156
|
-
"description": dataset.description,
|
|
157
|
-
"metadata": dataset.metadata_,
|
|
158
|
-
"created_at": dataset.created_at.isoformat(),
|
|
159
|
-
"updated_at": dataset.updated_at.isoformat(),
|
|
160
|
-
}
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
return JSONResponse(content={"next_cursor": next_cursor, "data": data})
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
async def get_dataset_by_id(request: Request) -> Response:
|
|
167
|
-
"""
|
|
168
|
-
summary: Get dataset by ID
|
|
169
|
-
operationId: getDatasetById
|
|
170
|
-
tags:
|
|
171
|
-
- datasets
|
|
172
|
-
parameters:
|
|
173
|
-
- in: path
|
|
174
|
-
name: id
|
|
175
|
-
required: true
|
|
176
|
-
schema:
|
|
177
|
-
type: string
|
|
178
|
-
responses:
|
|
179
|
-
200:
|
|
180
|
-
description: Success
|
|
181
|
-
content:
|
|
182
|
-
application/json:
|
|
183
|
-
schema:
|
|
184
|
-
type: object
|
|
185
|
-
properties:
|
|
186
|
-
id:
|
|
187
|
-
type: string
|
|
188
|
-
name:
|
|
189
|
-
type: string
|
|
190
|
-
description:
|
|
191
|
-
type: string
|
|
192
|
-
metadata:
|
|
193
|
-
type: object
|
|
194
|
-
created_at:
|
|
195
|
-
type: string
|
|
196
|
-
format: date-time
|
|
197
|
-
updated_at:
|
|
198
|
-
type: string
|
|
199
|
-
format: date-time
|
|
200
|
-
example_count:
|
|
201
|
-
type: integer
|
|
202
|
-
403:
|
|
203
|
-
description: Forbidden
|
|
204
|
-
404:
|
|
205
|
-
description: Dataset not found
|
|
206
|
-
"""
|
|
207
|
-
dataset_id = GlobalID.from_id(request.path_params["id"])
|
|
208
|
-
|
|
209
|
-
if (type_name := dataset_id.type_name) != NODE_NAME:
|
|
210
|
-
return Response(
|
|
211
|
-
content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
|
|
212
|
-
)
|
|
213
|
-
async with request.app.state.db() as session:
|
|
214
|
-
result = await session.execute(
|
|
215
|
-
select(models.Dataset, models.Dataset.example_count).filter(
|
|
216
|
-
models.Dataset.id == int(dataset_id.node_id)
|
|
217
|
-
)
|
|
218
|
-
)
|
|
219
|
-
dataset_query = result.first()
|
|
220
|
-
dataset = dataset_query[0] if dataset_query else None
|
|
221
|
-
example_count = dataset_query[1] if dataset_query else 0
|
|
222
|
-
if dataset is None:
|
|
223
|
-
return Response(
|
|
224
|
-
content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
output_dict = {
|
|
228
|
-
"id": str(dataset_id),
|
|
229
|
-
"name": dataset.name,
|
|
230
|
-
"description": dataset.description,
|
|
231
|
-
"metadata": dataset.metadata_,
|
|
232
|
-
"created_at": dataset.created_at.isoformat(),
|
|
233
|
-
"updated_at": dataset.updated_at.isoformat(),
|
|
234
|
-
"example_count": example_count,
|
|
235
|
-
}
|
|
236
|
-
return JSONResponse(content=output_dict)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
async def get_dataset_versions(request: Request) -> Response:
|
|
240
|
-
"""
|
|
241
|
-
summary: Get dataset versions (sorted from latest to oldest)
|
|
242
|
-
operationId: getDatasetVersionsByDatasetId
|
|
243
|
-
tags:
|
|
244
|
-
- datasets
|
|
245
|
-
parameters:
|
|
246
|
-
- in: path
|
|
247
|
-
name: id
|
|
248
|
-
required: true
|
|
249
|
-
description: Dataset ID
|
|
250
|
-
schema:
|
|
251
|
-
type: string
|
|
252
|
-
- in: query
|
|
253
|
-
name: cursor
|
|
254
|
-
description: Cursor for pagination.
|
|
255
|
-
schema:
|
|
256
|
-
type: string
|
|
257
|
-
- in: query
|
|
258
|
-
name: limit
|
|
259
|
-
description: Maximum number versions to return.
|
|
260
|
-
schema:
|
|
261
|
-
type: integer
|
|
262
|
-
default: 10
|
|
263
|
-
responses:
|
|
264
|
-
200:
|
|
265
|
-
description: Success
|
|
266
|
-
content:
|
|
267
|
-
application/json:
|
|
268
|
-
schema:
|
|
269
|
-
type: object
|
|
270
|
-
properties:
|
|
271
|
-
next_cursor:
|
|
272
|
-
type: string
|
|
273
|
-
data:
|
|
274
|
-
type: array
|
|
275
|
-
items:
|
|
276
|
-
type: object
|
|
277
|
-
properties:
|
|
278
|
-
version_id:
|
|
279
|
-
type: string
|
|
280
|
-
description:
|
|
281
|
-
type: string
|
|
282
|
-
metadata:
|
|
283
|
-
type: object
|
|
284
|
-
created_at:
|
|
285
|
-
type: string
|
|
286
|
-
format: date-time
|
|
287
|
-
403:
|
|
288
|
-
description: Forbidden
|
|
289
|
-
422:
|
|
290
|
-
description: Dataset ID, cursor or limit is invalid.
|
|
291
|
-
"""
|
|
292
|
-
if id_ := request.path_params.get("id"):
|
|
293
|
-
try:
|
|
294
|
-
dataset_id = from_global_id_with_expected_type(
|
|
295
|
-
GlobalID.from_id(id_),
|
|
296
|
-
Dataset.__name__,
|
|
297
|
-
)
|
|
298
|
-
except ValueError:
|
|
299
|
-
return Response(
|
|
300
|
-
content=f"Invalid Dataset ID: {id_}",
|
|
301
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
302
|
-
)
|
|
303
|
-
else:
|
|
304
|
-
return Response(
|
|
305
|
-
content="Missing Dataset ID",
|
|
306
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
307
|
-
)
|
|
308
|
-
try:
|
|
309
|
-
limit = int(request.query_params.get("limit", 10))
|
|
310
|
-
assert limit > 0
|
|
311
|
-
except (ValueError, AssertionError):
|
|
312
|
-
return Response(
|
|
313
|
-
content="Invalid limit parameter",
|
|
314
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
315
|
-
)
|
|
316
|
-
stmt = (
|
|
317
|
-
select(models.DatasetVersion)
|
|
318
|
-
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
319
|
-
.order_by(models.DatasetVersion.id.desc())
|
|
320
|
-
.limit(limit + 1)
|
|
321
|
-
)
|
|
322
|
-
if cursor := request.query_params.get("cursor"):
|
|
323
|
-
try:
|
|
324
|
-
dataset_version_id = from_global_id_with_expected_type(
|
|
325
|
-
GlobalID.from_id(cursor),
|
|
326
|
-
DatasetVersion.__name__,
|
|
327
|
-
)
|
|
328
|
-
except ValueError:
|
|
329
|
-
return Response(
|
|
330
|
-
content=f"Invalid cursor: {cursor}",
|
|
331
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
332
|
-
)
|
|
333
|
-
max_dataset_version_id = (
|
|
334
|
-
select(models.DatasetVersion.id)
|
|
335
|
-
.where(models.DatasetVersion.id == dataset_version_id)
|
|
336
|
-
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
337
|
-
).scalar_subquery()
|
|
338
|
-
stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
|
|
339
|
-
async with request.app.state.db() as session:
|
|
340
|
-
data = [
|
|
341
|
-
{
|
|
342
|
-
"version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
|
|
343
|
-
"description": version.description,
|
|
344
|
-
"metadata": version.metadata_,
|
|
345
|
-
"created_at": version.created_at.isoformat(),
|
|
346
|
-
}
|
|
347
|
-
async for version in await session.stream_scalars(stmt)
|
|
348
|
-
]
|
|
349
|
-
next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
|
|
350
|
-
return JSONResponse(content={"next_cursor": next_cursor, "data": data})
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
async def post_datasets_upload(request: Request) -> Response:
|
|
354
|
-
"""
|
|
355
|
-
summary: Upload dataset as either JSON or file (CSV or PyArrow)
|
|
356
|
-
operationId: uploadDataset
|
|
357
|
-
tags:
|
|
358
|
-
- datasets
|
|
359
|
-
parameters:
|
|
360
|
-
- in: query
|
|
361
|
-
name: sync
|
|
362
|
-
description: If true, fulfill request synchronously and return JSON containing dataset_id
|
|
363
|
-
schema:
|
|
364
|
-
type: boolean
|
|
365
|
-
requestBody:
|
|
366
|
-
content:
|
|
367
|
-
application/json:
|
|
368
|
-
schema:
|
|
369
|
-
type: object
|
|
370
|
-
required:
|
|
371
|
-
- name
|
|
372
|
-
- inputs
|
|
373
|
-
properties:
|
|
374
|
-
action:
|
|
375
|
-
type: string
|
|
376
|
-
enum: [create, append]
|
|
377
|
-
name:
|
|
378
|
-
type: string
|
|
379
|
-
description:
|
|
380
|
-
type: string
|
|
381
|
-
inputs:
|
|
382
|
-
type: array
|
|
383
|
-
items:
|
|
384
|
-
type: object
|
|
385
|
-
outputs:
|
|
386
|
-
type: array
|
|
387
|
-
items:
|
|
388
|
-
type: object
|
|
389
|
-
metadata:
|
|
390
|
-
type: array
|
|
391
|
-
items:
|
|
392
|
-
type: object
|
|
393
|
-
multipart/form-data:
|
|
394
|
-
schema:
|
|
395
|
-
type: object
|
|
396
|
-
required:
|
|
397
|
-
- name
|
|
398
|
-
- input_keys[]
|
|
399
|
-
- output_keys[]
|
|
400
|
-
- file
|
|
401
|
-
properties:
|
|
402
|
-
action:
|
|
403
|
-
type: string
|
|
404
|
-
enum: [create, append]
|
|
405
|
-
name:
|
|
406
|
-
type: string
|
|
407
|
-
description:
|
|
408
|
-
type: string
|
|
409
|
-
input_keys[]:
|
|
410
|
-
type: array
|
|
411
|
-
items:
|
|
412
|
-
type: string
|
|
413
|
-
uniqueItems: true
|
|
414
|
-
output_keys[]:
|
|
415
|
-
type: array
|
|
416
|
-
items:
|
|
417
|
-
type: string
|
|
418
|
-
uniqueItems: true
|
|
419
|
-
metadata_keys[]:
|
|
420
|
-
type: array
|
|
421
|
-
items:
|
|
422
|
-
type: string
|
|
423
|
-
uniqueItems: true
|
|
424
|
-
file:
|
|
425
|
-
type: string
|
|
426
|
-
format: binary
|
|
427
|
-
responses:
|
|
428
|
-
200:
|
|
429
|
-
description: Success
|
|
430
|
-
403:
|
|
431
|
-
description: Forbidden
|
|
432
|
-
409:
|
|
433
|
-
description: Dataset of the same name already exists
|
|
434
|
-
422:
|
|
435
|
-
description: Request body is invalid
|
|
436
|
-
"""
|
|
437
|
-
request_content_type = request.headers["content-type"]
|
|
438
|
-
examples: Union[Examples, Awaitable[Examples]]
|
|
439
|
-
if request_content_type.startswith("application/json"):
|
|
440
|
-
try:
|
|
441
|
-
examples, action, name, description = await run_in_threadpool(
|
|
442
|
-
_process_json, await request.json()
|
|
443
|
-
)
|
|
444
|
-
except ValueError as e:
|
|
445
|
-
return Response(
|
|
446
|
-
content=str(e),
|
|
447
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
448
|
-
)
|
|
449
|
-
if action is DatasetAction.CREATE:
|
|
450
|
-
async with request.app.state.db() as session:
|
|
451
|
-
if await _check_table_exists(session, name):
|
|
452
|
-
return Response(
|
|
453
|
-
content=f"Dataset with the same name already exists: {name=}",
|
|
454
|
-
status_code=HTTP_409_CONFLICT,
|
|
455
|
-
)
|
|
456
|
-
elif request_content_type.startswith("multipart/form-data"):
|
|
457
|
-
async with request.form() as form:
|
|
458
|
-
try:
|
|
459
|
-
(
|
|
460
|
-
action,
|
|
461
|
-
name,
|
|
462
|
-
description,
|
|
463
|
-
input_keys,
|
|
464
|
-
output_keys,
|
|
465
|
-
metadata_keys,
|
|
466
|
-
file,
|
|
467
|
-
) = await _parse_form_data(form)
|
|
468
|
-
except ValueError as e:
|
|
469
|
-
return Response(
|
|
470
|
-
content=str(e),
|
|
471
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
472
|
-
)
|
|
473
|
-
if action is DatasetAction.CREATE:
|
|
474
|
-
async with request.app.state.db() as session:
|
|
475
|
-
if await _check_table_exists(session, name):
|
|
476
|
-
return Response(
|
|
477
|
-
content=f"Dataset with the same name already exists: {name=}",
|
|
478
|
-
status_code=HTTP_409_CONFLICT,
|
|
479
|
-
)
|
|
480
|
-
content = await file.read()
|
|
481
|
-
try:
|
|
482
|
-
file_content_type = FileContentType(file.content_type)
|
|
483
|
-
if file_content_type is FileContentType.CSV:
|
|
484
|
-
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
485
|
-
examples = await _process_csv(
|
|
486
|
-
content, encoding, input_keys, output_keys, metadata_keys
|
|
487
|
-
)
|
|
488
|
-
elif file_content_type is FileContentType.PYARROW:
|
|
489
|
-
examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
|
|
490
|
-
else:
|
|
491
|
-
assert_never(file_content_type)
|
|
492
|
-
except ValueError as e:
|
|
493
|
-
return Response(
|
|
494
|
-
content=str(e),
|
|
495
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
496
|
-
)
|
|
497
|
-
else:
|
|
498
|
-
return Response(
|
|
499
|
-
content=str("Invalid request Content-Type"),
|
|
500
|
-
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
501
|
-
)
|
|
502
|
-
operation = cast(
|
|
503
|
-
Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
|
|
504
|
-
partial(
|
|
505
|
-
add_dataset_examples,
|
|
506
|
-
examples=examples,
|
|
507
|
-
action=action,
|
|
508
|
-
name=name,
|
|
509
|
-
description=description,
|
|
510
|
-
),
|
|
511
|
-
)
|
|
512
|
-
if request.query_params.get("sync") == "true":
|
|
513
|
-
async with request.app.state.db() as session:
|
|
514
|
-
dataset_id = (await operation(session)).dataset_id
|
|
515
|
-
return JSONResponse(
|
|
516
|
-
content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
|
|
517
|
-
)
|
|
518
|
-
try:
|
|
519
|
-
request.state.enqueue_operation(operation)
|
|
520
|
-
except QueueFull:
|
|
521
|
-
if isinstance(examples, Coroutine):
|
|
522
|
-
examples.close()
|
|
523
|
-
return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
|
|
524
|
-
return Response()
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
class FileContentType(Enum):
|
|
528
|
-
CSV = "text/csv"
|
|
529
|
-
PYARROW = "application/x-pandas-pyarrow"
|
|
530
|
-
|
|
531
|
-
@classmethod
|
|
532
|
-
def _missing_(cls, v: Any) -> "FileContentType":
|
|
533
|
-
if isinstance(v, str) and v and v.isascii() and not v.islower():
|
|
534
|
-
return cls(v.lower())
|
|
535
|
-
raise ValueError(f"Invalid file content type: {v}")
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
class FileContentEncoding(Enum):
|
|
539
|
-
NONE = "none"
|
|
540
|
-
GZIP = "gzip"
|
|
541
|
-
DEFLATE = "deflate"
|
|
542
|
-
|
|
543
|
-
@classmethod
|
|
544
|
-
def _missing_(cls, v: Any) -> "FileContentEncoding":
|
|
545
|
-
if v is None:
|
|
546
|
-
return cls("none")
|
|
547
|
-
if isinstance(v, str) and v and v.isascii() and not v.islower():
|
|
548
|
-
return cls(v.lower())
|
|
549
|
-
raise ValueError(f"Invalid file content encoding: {v}")
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
Name: TypeAlias = str
|
|
553
|
-
Description: TypeAlias = Optional[str]
|
|
554
|
-
InputKeys: TypeAlias = FrozenSet[str]
|
|
555
|
-
OutputKeys: TypeAlias = FrozenSet[str]
|
|
556
|
-
MetadataKeys: TypeAlias = FrozenSet[str]
|
|
557
|
-
DatasetId: TypeAlias = int
|
|
558
|
-
Examples: TypeAlias = Iterator[ExampleContent]
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
def _process_json(
|
|
562
|
-
data: Mapping[str, Any],
|
|
563
|
-
) -> Tuple[Examples, DatasetAction, Name, Description]:
|
|
564
|
-
name = data.get("name")
|
|
565
|
-
if not name:
|
|
566
|
-
raise ValueError("Dataset name is required")
|
|
567
|
-
description = data.get("description") or ""
|
|
568
|
-
inputs = data.get("inputs")
|
|
569
|
-
if not inputs:
|
|
570
|
-
raise ValueError("input is required")
|
|
571
|
-
if not isinstance(inputs, list) or not _is_all_dict(inputs):
|
|
572
|
-
raise ValueError("Input should be a list containing only dictionary objects")
|
|
573
|
-
outputs, metadata = data.get("outputs"), data.get("metadata")
|
|
574
|
-
for k, v in {"outputs": outputs, "metadata": metadata}.items():
|
|
575
|
-
if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
|
|
576
|
-
raise ValueError(
|
|
577
|
-
f"{k} should be a list of same length as input containing only dictionary objects"
|
|
578
|
-
)
|
|
579
|
-
examples: List[ExampleContent] = []
|
|
580
|
-
for i, obj in enumerate(inputs):
|
|
581
|
-
example = ExampleContent(
|
|
582
|
-
input=obj,
|
|
583
|
-
output=outputs[i] if outputs else {},
|
|
584
|
-
metadata=metadata[i] if metadata else {},
|
|
585
|
-
)
|
|
586
|
-
examples.append(example)
|
|
587
|
-
action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
|
|
588
|
-
return iter(examples), action, name, description
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
async def _process_csv(
|
|
592
|
-
content: bytes,
|
|
593
|
-
content_encoding: FileContentEncoding,
|
|
594
|
-
input_keys: InputKeys,
|
|
595
|
-
output_keys: OutputKeys,
|
|
596
|
-
metadata_keys: MetadataKeys,
|
|
597
|
-
) -> Examples:
|
|
598
|
-
if content_encoding is FileContentEncoding.GZIP:
|
|
599
|
-
content = await run_in_threadpool(gzip.decompress, content)
|
|
600
|
-
elif content_encoding is FileContentEncoding.DEFLATE:
|
|
601
|
-
content = await run_in_threadpool(zlib.decompress, content)
|
|
602
|
-
elif content_encoding is not FileContentEncoding.NONE:
|
|
603
|
-
assert_never(content_encoding)
|
|
604
|
-
reader = await run_in_threadpool(lambda c: csv.DictReader(io.StringIO(c.decode())), content)
|
|
605
|
-
if reader.fieldnames is None:
|
|
606
|
-
raise ValueError("Missing CSV column header")
|
|
607
|
-
(header, freq), *_ = Counter(reader.fieldnames).most_common(1)
|
|
608
|
-
if freq > 1:
|
|
609
|
-
raise ValueError(f"Duplicated column header in CSV file: {header}")
|
|
610
|
-
column_headers = frozenset(reader.fieldnames)
|
|
611
|
-
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
612
|
-
return (
|
|
613
|
-
ExampleContent(
|
|
614
|
-
input={k: row.get(k) for k in input_keys},
|
|
615
|
-
output={k: row.get(k) for k in output_keys},
|
|
616
|
-
metadata={k: row.get(k) for k in metadata_keys},
|
|
617
|
-
)
|
|
618
|
-
for row in iter(reader)
|
|
619
|
-
)
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
async def _process_pyarrow(
|
|
623
|
-
content: bytes,
|
|
624
|
-
input_keys: InputKeys,
|
|
625
|
-
output_keys: OutputKeys,
|
|
626
|
-
metadata_keys: MetadataKeys,
|
|
627
|
-
) -> Awaitable[Examples]:
|
|
628
|
-
try:
|
|
629
|
-
reader = pa.ipc.open_stream(content)
|
|
630
|
-
except pa.ArrowInvalid as e:
|
|
631
|
-
raise ValueError("File is not valid pyarrow") from e
|
|
632
|
-
column_headers = frozenset(reader.schema.names)
|
|
633
|
-
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
634
|
-
|
|
635
|
-
def get_examples() -> Iterator[ExampleContent]:
|
|
636
|
-
for row in reader.read_pandas().to_dict(orient="records"):
|
|
637
|
-
yield ExampleContent(
|
|
638
|
-
input={k: row.get(k) for k in input_keys},
|
|
639
|
-
output={k: row.get(k) for k in output_keys},
|
|
640
|
-
metadata={k: row.get(k) for k in metadata_keys},
|
|
641
|
-
)
|
|
642
|
-
|
|
643
|
-
return run_in_threadpool(get_examples)
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
async def _check_table_exists(session: AsyncSession, name: str) -> bool:
|
|
647
|
-
return bool(
|
|
648
|
-
await session.scalar(
|
|
649
|
-
select(1).select_from(models.Dataset).where(models.Dataset.name == name)
|
|
650
|
-
)
|
|
651
|
-
)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
def _check_keys_exist(
|
|
655
|
-
column_headers: FrozenSet[str],
|
|
656
|
-
input_keys: InputKeys,
|
|
657
|
-
output_keys: OutputKeys,
|
|
658
|
-
metadata_keys: MetadataKeys,
|
|
659
|
-
) -> None:
|
|
660
|
-
for desc, keys in (
|
|
661
|
-
("input", input_keys),
|
|
662
|
-
("output", output_keys),
|
|
663
|
-
("metadata", metadata_keys),
|
|
664
|
-
):
|
|
665
|
-
if keys and (diff := keys.difference(column_headers)):
|
|
666
|
-
raise ValueError(f"{desc} keys not found in column headers: {diff}")
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
async def _parse_form_data(
|
|
670
|
-
form: FormData,
|
|
671
|
-
) -> Tuple[
|
|
672
|
-
DatasetAction,
|
|
673
|
-
Name,
|
|
674
|
-
Description,
|
|
675
|
-
InputKeys,
|
|
676
|
-
OutputKeys,
|
|
677
|
-
MetadataKeys,
|
|
678
|
-
UploadFile,
|
|
679
|
-
]:
|
|
680
|
-
name = cast(Optional[str], form.get("name"))
|
|
681
|
-
if not name:
|
|
682
|
-
raise ValueError("Dataset name must not be empty")
|
|
683
|
-
action = DatasetAction(cast(Optional[str], form.get("action")) or "create")
|
|
684
|
-
file = form["file"]
|
|
685
|
-
if not isinstance(file, UploadFile):
|
|
686
|
-
raise ValueError("Malformed file in form data.")
|
|
687
|
-
description = cast(Optional[str], form.get("description")) or file.filename
|
|
688
|
-
input_keys = frozenset(filter(bool, cast(List[str], form.getlist("input_keys[]"))))
|
|
689
|
-
output_keys = frozenset(filter(bool, cast(List[str], form.getlist("output_keys[]"))))
|
|
690
|
-
metadata_keys = frozenset(filter(bool, cast(List[str], form.getlist("metadata_keys[]"))))
|
|
691
|
-
return (
|
|
692
|
-
action,
|
|
693
|
-
name,
|
|
694
|
-
description,
|
|
695
|
-
input_keys,
|
|
696
|
-
output_keys,
|
|
697
|
-
metadata_keys,
|
|
698
|
-
file,
|
|
699
|
-
)
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
async def get_dataset_csv(request: Request) -> Response:
|
|
703
|
-
"""
|
|
704
|
-
summary: Download dataset examples as CSV text file
|
|
705
|
-
operationId: getDatasetCsv
|
|
706
|
-
tags:
|
|
707
|
-
- datasets
|
|
708
|
-
parameters:
|
|
709
|
-
- in: path
|
|
710
|
-
name: id
|
|
711
|
-
required: true
|
|
712
|
-
schema:
|
|
713
|
-
type: string
|
|
714
|
-
description: Dataset ID
|
|
715
|
-
- in: query
|
|
716
|
-
name: version
|
|
717
|
-
schema:
|
|
718
|
-
type: string
|
|
719
|
-
description: Dataset version ID. If omitted, returns the latest version.
|
|
720
|
-
responses:
|
|
721
|
-
200:
|
|
722
|
-
description: Success
|
|
723
|
-
content:
|
|
724
|
-
text/csv:
|
|
725
|
-
schema:
|
|
726
|
-
type: string
|
|
727
|
-
contentMediaType: text/csv
|
|
728
|
-
contentEncoding: gzip
|
|
729
|
-
403:
|
|
730
|
-
description: Forbidden
|
|
731
|
-
404:
|
|
732
|
-
description: Dataset does not exist.
|
|
733
|
-
422:
|
|
734
|
-
description: Dataset ID or version ID is invalid.
|
|
735
|
-
"""
|
|
736
|
-
try:
|
|
737
|
-
dataset_name, examples = await _get_db_examples(request)
|
|
738
|
-
except ValueError as e:
|
|
739
|
-
return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
740
|
-
content = await run_in_threadpool(_get_content_csv, examples)
|
|
741
|
-
return Response(
|
|
742
|
-
content=content,
|
|
743
|
-
headers={
|
|
744
|
-
"content-disposition": f'attachment; filename="{dataset_name}.csv"',
|
|
745
|
-
"content-type": "text/csv",
|
|
746
|
-
"content-encoding": "gzip",
|
|
747
|
-
},
|
|
748
|
-
)
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
|
|
752
|
-
"""
|
|
753
|
-
summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
|
|
754
|
-
operationId: getDatasetJSONLOpenAIFineTuning
|
|
755
|
-
tags:
|
|
756
|
-
- datasets
|
|
757
|
-
parameters:
|
|
758
|
-
- in: path
|
|
759
|
-
name: id
|
|
760
|
-
required: true
|
|
761
|
-
schema:
|
|
762
|
-
type: string
|
|
763
|
-
description: Dataset ID
|
|
764
|
-
- in: query
|
|
765
|
-
name: version
|
|
766
|
-
schema:
|
|
767
|
-
type: string
|
|
768
|
-
description: Dataset version ID. If omitted, returns the latest version.
|
|
769
|
-
responses:
|
|
770
|
-
200:
|
|
771
|
-
description: Success
|
|
772
|
-
content:
|
|
773
|
-
text/plain:
|
|
774
|
-
schema:
|
|
775
|
-
type: string
|
|
776
|
-
contentMediaType: text/plain
|
|
777
|
-
contentEncoding: gzip
|
|
778
|
-
403:
|
|
779
|
-
description: Forbidden
|
|
780
|
-
404:
|
|
781
|
-
description: Dataset does not exist.
|
|
782
|
-
422:
|
|
783
|
-
description: Dataset ID or version ID is invalid.
|
|
784
|
-
"""
|
|
785
|
-
try:
|
|
786
|
-
dataset_name, examples = await _get_db_examples(request)
|
|
787
|
-
except ValueError as e:
|
|
788
|
-
return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
789
|
-
content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
|
|
790
|
-
return Response(
|
|
791
|
-
content=content,
|
|
792
|
-
headers={
|
|
793
|
-
"content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
|
|
794
|
-
"content-type": "text/plain",
|
|
795
|
-
"content-encoding": "gzip",
|
|
796
|
-
},
|
|
797
|
-
)
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
|
|
801
|
-
"""
|
|
802
|
-
summary: Download dataset examples as OpenAI Evals JSONL file
|
|
803
|
-
operationId: getDatasetJSONLOpenAIEvals
|
|
804
|
-
tags:
|
|
805
|
-
- datasets
|
|
806
|
-
parameters:
|
|
807
|
-
- in: path
|
|
808
|
-
name: id
|
|
809
|
-
required: true
|
|
810
|
-
schema:
|
|
811
|
-
type: string
|
|
812
|
-
description: Dataset ID
|
|
813
|
-
- in: query
|
|
814
|
-
name: version
|
|
815
|
-
schema:
|
|
816
|
-
type: string
|
|
817
|
-
description: Dataset version ID. If omitted, returns the latest version.
|
|
818
|
-
responses:
|
|
819
|
-
200:
|
|
820
|
-
description: Success
|
|
821
|
-
content:
|
|
822
|
-
text/plain:
|
|
823
|
-
schema:
|
|
824
|
-
type: string
|
|
825
|
-
contentMediaType: text/plain
|
|
826
|
-
contentEncoding: gzip
|
|
827
|
-
403:
|
|
828
|
-
description: Forbidden
|
|
829
|
-
404:
|
|
830
|
-
description: Dataset does not exist.
|
|
831
|
-
422:
|
|
832
|
-
description: Dataset ID or version ID is invalid.
|
|
833
|
-
"""
|
|
834
|
-
try:
|
|
835
|
-
dataset_name, examples = await _get_db_examples(request)
|
|
836
|
-
except ValueError as e:
|
|
837
|
-
return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
|
|
838
|
-
content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
|
|
839
|
-
return Response(
|
|
840
|
-
content=content,
|
|
841
|
-
headers={
|
|
842
|
-
"content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
|
|
843
|
-
"content-type": "text/plain",
|
|
844
|
-
"content-encoding": "gzip",
|
|
845
|
-
},
|
|
846
|
-
)
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
850
|
-
records = [
|
|
851
|
-
{
|
|
852
|
-
"example_id": GlobalID(
|
|
853
|
-
type_name=DatasetExample.__name__,
|
|
854
|
-
node_id=str(ex.dataset_example_id),
|
|
855
|
-
),
|
|
856
|
-
**{f"input_{k}": v for k, v in ex.input.items()},
|
|
857
|
-
**{f"output_{k}": v for k, v in ex.output.items()},
|
|
858
|
-
**{f"metadata_{k}": v for k, v in ex.metadata_.items()},
|
|
859
|
-
}
|
|
860
|
-
for ex in examples
|
|
861
|
-
]
|
|
862
|
-
return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
866
|
-
records = io.BytesIO()
|
|
867
|
-
for ex in examples:
|
|
868
|
-
records.write(
|
|
869
|
-
(
|
|
870
|
-
json.dumps(
|
|
871
|
-
{
|
|
872
|
-
"messages": (
|
|
873
|
-
ims if isinstance(ims := ex.input.get("messages"), list) else []
|
|
874
|
-
)
|
|
875
|
-
+ (oms if isinstance(oms := ex.output.get("messages"), list) else [])
|
|
876
|
-
},
|
|
877
|
-
ensure_ascii=False,
|
|
878
|
-
)
|
|
879
|
-
+ "\n"
|
|
880
|
-
).encode()
|
|
881
|
-
)
|
|
882
|
-
records.seek(0)
|
|
883
|
-
return gzip.compress(records.read())
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
887
|
-
records = io.BytesIO()
|
|
888
|
-
for ex in examples:
|
|
889
|
-
records.write(
|
|
890
|
-
(
|
|
891
|
-
json.dumps(
|
|
892
|
-
{
|
|
893
|
-
"messages": ims
|
|
894
|
-
if isinstance(ims := ex.input.get("messages"), list)
|
|
895
|
-
else [],
|
|
896
|
-
"ideal": (
|
|
897
|
-
ideal if isinstance(ideal := last_message.get("content"), str) else ""
|
|
898
|
-
)
|
|
899
|
-
if isinstance(oms := ex.output.get("messages"), list)
|
|
900
|
-
and oms
|
|
901
|
-
and hasattr(last_message := oms[-1], "get")
|
|
902
|
-
else "",
|
|
903
|
-
},
|
|
904
|
-
ensure_ascii=False,
|
|
905
|
-
)
|
|
906
|
-
+ "\n"
|
|
907
|
-
).encode()
|
|
908
|
-
)
|
|
909
|
-
records.seek(0)
|
|
910
|
-
return gzip.compress(records.read())
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
|
|
914
|
-
if not (id_ := request.path_params.get("id")):
|
|
915
|
-
raise ValueError("Missing Dataset ID")
|
|
916
|
-
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
|
|
917
|
-
dataset_version_id: Optional[int] = None
|
|
918
|
-
if version := request.query_params.get("version"):
|
|
919
|
-
dataset_version_id = from_global_id_with_expected_type(
|
|
920
|
-
GlobalID.from_id(version),
|
|
921
|
-
DatasetVersion.__name__,
|
|
922
|
-
)
|
|
923
|
-
latest_version = (
|
|
924
|
-
select(
|
|
925
|
-
models.DatasetExampleRevision.dataset_example_id,
|
|
926
|
-
func.max(models.DatasetExampleRevision.dataset_version_id).label("dataset_version_id"),
|
|
927
|
-
)
|
|
928
|
-
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
929
|
-
.join(models.DatasetExample)
|
|
930
|
-
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
931
|
-
)
|
|
932
|
-
if dataset_version_id is not None:
|
|
933
|
-
max_dataset_version_id = (
|
|
934
|
-
select(models.DatasetVersion.id)
|
|
935
|
-
.where(models.DatasetVersion.id == dataset_version_id)
|
|
936
|
-
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
937
|
-
).scalar_subquery()
|
|
938
|
-
latest_version = latest_version.where(
|
|
939
|
-
models.DatasetExampleRevision.dataset_version_id <= max_dataset_version_id
|
|
940
|
-
)
|
|
941
|
-
subq = latest_version.subquery("latest_version")
|
|
942
|
-
stmt = (
|
|
943
|
-
select(models.DatasetExampleRevision)
|
|
944
|
-
.join(
|
|
945
|
-
subq,
|
|
946
|
-
onclause=and_(
|
|
947
|
-
models.DatasetExampleRevision.dataset_example_id == subq.c.dataset_example_id,
|
|
948
|
-
models.DatasetExampleRevision.dataset_version_id == subq.c.dataset_version_id,
|
|
949
|
-
),
|
|
950
|
-
)
|
|
951
|
-
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
952
|
-
.order_by(models.DatasetExampleRevision.dataset_example_id)
|
|
953
|
-
)
|
|
954
|
-
async with request.app.state.db() as session:
|
|
955
|
-
dataset_name: Optional[str] = await session.scalar(
|
|
956
|
-
select(models.Dataset.name).where(models.Dataset.id == dataset_id)
|
|
957
|
-
)
|
|
958
|
-
if not dataset_name:
|
|
959
|
-
raise ValueError("Dataset does not exist.")
|
|
960
|
-
examples = [r async for r in await session.stream_scalars(stmt)]
|
|
961
|
-
return dataset_name, examples
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
def _is_all_dict(seq: Sequence[Any]) -> bool:
|
|
965
|
-
return all(map(lambda obj: isinstance(obj, dict), seq))
|