arize-phoenix 4.10.2rc2__py3-none-any.whl → 4.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (30) hide show
  1. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/METADATA +3 -4
  2. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/RECORD +29 -29
  3. phoenix/server/api/context.py +7 -3
  4. phoenix/server/api/openapi/main.py +2 -18
  5. phoenix/server/api/openapi/schema.py +12 -12
  6. phoenix/server/api/routers/v1/__init__.py +83 -36
  7. phoenix/server/api/routers/v1/dataset_examples.py +123 -102
  8. phoenix/server/api/routers/v1/datasets.py +507 -389
  9. phoenix/server/api/routers/v1/evaluations.py +66 -73
  10. phoenix/server/api/routers/v1/experiment_evaluations.py +91 -67
  11. phoenix/server/api/routers/v1/experiment_runs.py +155 -97
  12. phoenix/server/api/routers/v1/experiments.py +181 -131
  13. phoenix/server/api/routers/v1/spans.py +173 -143
  14. phoenix/server/api/routers/v1/traces.py +128 -114
  15. phoenix/server/api/types/Span.py +1 -0
  16. phoenix/server/app.py +176 -148
  17. phoenix/server/openapi/docs.py +221 -0
  18. phoenix/server/static/index.js +574 -573
  19. phoenix/server/thread_server.py +2 -2
  20. phoenix/session/client.py +5 -0
  21. phoenix/session/data_extractor.py +20 -1
  22. phoenix/session/session.py +4 -0
  23. phoenix/trace/attributes.py +2 -1
  24. phoenix/trace/schemas.py +1 -0
  25. phoenix/trace/span_json_decoder.py +1 -1
  26. phoenix/version.py +1 -1
  27. phoenix/server/api/routers/v1/utils.py +0 -94
  28. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/WHEEL +0 -0
  29. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/licenses/IP_NOTICE +0 -0
  30. {arize_phoenix-4.10.2rc2.dist-info → arize_phoenix-4.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,6 @@ import logging
6
6
  import zlib
7
7
  from asyncio import QueueFull
8
8
  from collections import Counter
9
- from datetime import datetime
10
9
  from enum import Enum
11
10
  from functools import partial
12
11
  from typing import (
@@ -14,7 +13,6 @@ from typing import (
14
13
  Awaitable,
15
14
  Callable,
16
15
  Coroutine,
17
- Dict,
18
16
  FrozenSet,
19
17
  Iterator,
20
18
  List,
@@ -28,17 +26,14 @@ from typing import (
28
26
 
29
27
  import pandas as pd
30
28
  import pyarrow as pa
31
- from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
32
- from fastapi.responses import PlainTextResponse, StreamingResponse
33
- from pydantic import BaseModel
34
29
  from sqlalchemy import and_, delete, func, select
35
30
  from sqlalchemy.ext.asyncio import AsyncSession
31
+ from starlette.background import BackgroundTasks
36
32
  from starlette.concurrency import run_in_threadpool
37
33
  from starlette.datastructures import FormData, UploadFile
38
34
  from starlette.requests import Request
39
- from starlette.responses import Response
35
+ from starlette.responses import JSONResponse, Response
40
36
  from starlette.status import (
41
- HTTP_200_OK,
42
37
  HTTP_204_NO_CONTENT,
43
38
  HTTP_404_NOT_FOUND,
44
39
  HTTP_409_CONFLICT,
@@ -56,59 +51,79 @@ from phoenix.db.insertion.dataset import (
56
51
  ExampleContent,
57
52
  add_dataset_examples,
58
53
  )
59
- from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
54
+ from phoenix.server.api.types.Dataset import Dataset
60
55
  from phoenix.server.api.types.DatasetExample import DatasetExample
61
- from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
56
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
62
57
  from phoenix.server.api.types.node import from_global_id_with_expected_type
63
58
  from phoenix.server.api.utils import delete_projects, delete_traces
64
59
 
65
- from .dataset_examples import router as dataset_examples_router
66
- from .utils import (
67
- PaginatedResponseBody,
68
- ResponseBody,
69
- add_errors_to_responses,
70
- add_text_csv_content_to_responses,
71
- )
72
-
73
60
  logger = logging.getLogger(__name__)
74
61
 
75
- DATASET_NODE_NAME = DatasetNodeType.__name__
76
- DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
77
-
78
-
79
- router = APIRouter(tags=["datasets"])
80
-
81
-
82
- class Dataset(BaseModel):
83
- id: str
84
- name: str
85
- description: Optional[str]
86
- metadata: Dict[str, Any]
87
- created_at: datetime
88
- updated_at: datetime
89
-
90
-
91
- class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
92
- pass
93
-
94
-
95
- @router.get(
96
- "/datasets",
97
- operation_id="listDatasets",
98
- summary="List datasets",
99
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
100
- )
101
- async def list_datasets(
102
- request: Request,
103
- cursor: Optional[str] = Query(
104
- default=None,
105
- description="Cursor for pagination",
106
- ),
107
- name: Optional[str] = Query(default=None, description="An optional dataset name to filter by"),
108
- limit: int = Query(
109
- default=10, description="The max number of datasets to return at a time.", gt=0
110
- ),
111
- ) -> ListDatasetsResponseBody:
62
+ NODE_NAME = "Dataset"
63
+
64
+
65
+ async def list_datasets(request: Request) -> Response:
66
+ """
67
+ summary: List datasets with cursor-based pagination
68
+ operationId: listDatasets
69
+ tags:
70
+ - datasets
71
+ parameters:
72
+ - in: query
73
+ name: cursor
74
+ required: false
75
+ schema:
76
+ type: string
77
+ description: Cursor for pagination
78
+ - in: query
79
+ name: limit
80
+ required: false
81
+ schema:
82
+ type: integer
83
+ default: 10
84
+ - in: query
85
+ name: name
86
+ required: false
87
+ schema:
88
+ type: string
89
+ description: match by dataset name
90
+ responses:
91
+ 200:
92
+ description: A paginated list of datasets
93
+ content:
94
+ application/json:
95
+ schema:
96
+ type: object
97
+ properties:
98
+ next_cursor:
99
+ type: string
100
+ data:
101
+ type: array
102
+ items:
103
+ type: object
104
+ properties:
105
+ id:
106
+ type: string
107
+ name:
108
+ type: string
109
+ description:
110
+ type: string
111
+ metadata:
112
+ type: object
113
+ created_at:
114
+ type: string
115
+ format: date-time
116
+ updated_at:
117
+ type: string
118
+ format: date-time
119
+ 403:
120
+ description: Forbidden
121
+ 404:
122
+ description: No datasets found
123
+ """
124
+ name = request.query_params.get("name")
125
+ cursor = request.query_params.get("cursor")
126
+ limit = int(request.query_params.get("limit", 10))
112
127
  async with request.app.state.db() as session:
113
128
  query = select(models.Dataset).order_by(models.Dataset.id.desc())
114
129
 
@@ -117,8 +132,8 @@ async def list_datasets(
117
132
  cursor_id = GlobalID.from_id(cursor).node_id
118
133
  query = query.filter(models.Dataset.id <= int(cursor_id))
119
134
  except ValueError:
120
- raise HTTPException(
121
- detail=f"Invalid cursor format: {cursor}",
135
+ return Response(
136
+ content=f"Invalid cursor format: {cursor}",
122
137
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
123
138
  )
124
139
  if name:
@@ -129,56 +144,67 @@ async def list_datasets(
129
144
  datasets = result.scalars().all()
130
145
 
131
146
  if not datasets:
132
- return ListDatasetsResponseBody(next_cursor=None, data=[])
147
+ return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
133
148
 
134
149
  next_cursor = None
135
150
  if len(datasets) == limit + 1:
136
- next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
151
+ next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
137
152
  datasets = datasets[:-1]
138
153
 
139
154
  data = []
140
155
  for dataset in datasets:
141
156
  data.append(
142
- Dataset(
143
- id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
144
- name=dataset.name,
145
- description=dataset.description,
146
- metadata=dataset.metadata_,
147
- created_at=dataset.created_at,
148
- updated_at=dataset.updated_at,
149
- )
157
+ {
158
+ "id": str(GlobalID(NODE_NAME, str(dataset.id))),
159
+ "name": dataset.name,
160
+ "description": dataset.description,
161
+ "metadata": dataset.metadata_,
162
+ "created_at": dataset.created_at.isoformat(),
163
+ "updated_at": dataset.updated_at.isoformat(),
164
+ }
150
165
  )
151
166
 
152
- return ListDatasetsResponseBody(next_cursor=next_cursor, data=data)
153
-
154
-
155
- @router.delete(
156
- "/datasets/{id}",
157
- operation_id="deleteDatasetById",
158
- summary="Delete dataset by ID",
159
- status_code=HTTP_204_NO_CONTENT,
160
- responses=add_errors_to_responses(
161
- [
162
- {"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
163
- {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
164
- ]
165
- ),
166
- )
167
- async def delete_dataset(
168
- request: Request, id: str = Path(description="The ID of the dataset to delete.")
169
- ) -> None:
170
- if id:
167
+ return JSONResponse(content={"next_cursor": next_cursor, "data": data})
168
+
169
+
170
+ async def delete_dataset_by_id(request: Request) -> Response:
171
+ """
172
+ summary: Delete dataset by ID
173
+ operationId: deleteDatasetById
174
+ tags:
175
+ - datasets
176
+ parameters:
177
+ - in: path
178
+ name: id
179
+ required: true
180
+ schema:
181
+ type: string
182
+ responses:
183
+ 204:
184
+ description: Success
185
+ 403:
186
+ description: Forbidden
187
+ 404:
188
+ description: Dataset not found
189
+ 422:
190
+ description: Dataset ID is invalid
191
+ """
192
+ if id_ := request.path_params.get("id"):
171
193
  try:
172
194
  dataset_id = from_global_id_with_expected_type(
173
- GlobalID.from_id(id),
174
- DATASET_NODE_NAME,
195
+ GlobalID.from_id(id_),
196
+ Dataset.__name__,
175
197
  )
176
198
  except ValueError:
177
- raise HTTPException(
178
- detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
199
+ return Response(
200
+ content=f"Invalid Dataset ID: {id_}",
201
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
179
202
  )
180
203
  else:
181
- raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
204
+ return Response(
205
+ content="Missing Dataset ID",
206
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
207
+ )
182
208
  project_names_stmt = get_project_names_for_datasets(dataset_id)
183
209
  eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
184
210
  stmt = (
@@ -188,34 +214,59 @@ async def delete_dataset(
188
214
  project_names = await session.scalars(project_names_stmt)
189
215
  eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
190
216
  if (await session.scalar(stmt)) is None:
191
- raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
217
+ return Response(content="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
192
218
  tasks = BackgroundTasks()
193
219
  tasks.add_task(delete_projects, request.app.state.db, *project_names)
194
220
  tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
195
-
196
-
197
- class DatasetWithExampleCount(Dataset):
198
- example_count: int
199
-
200
-
201
- class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
202
- pass
203
-
204
-
205
- @router.get(
206
- "/datasets/{id}",
207
- operation_id="getDataset",
208
- summary="Get dataset by ID",
209
- responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
210
- )
211
- async def get_dataset(
212
- request: Request, id: str = Path(description="The ID of the dataset")
213
- ) -> GetDatasetResponseBody:
214
- dataset_id = GlobalID.from_id(id)
215
-
216
- if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
217
- raise HTTPException(
218
- detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
221
+ return Response(status_code=HTTP_204_NO_CONTENT, background=tasks)
222
+
223
+
224
+ async def get_dataset_by_id(request: Request) -> Response:
225
+ """
226
+ summary: Get dataset by ID
227
+ operationId: getDatasetById
228
+ tags:
229
+ - datasets
230
+ parameters:
231
+ - in: path
232
+ name: id
233
+ required: true
234
+ schema:
235
+ type: string
236
+ responses:
237
+ 200:
238
+ description: Success
239
+ content:
240
+ application/json:
241
+ schema:
242
+ type: object
243
+ properties:
244
+ id:
245
+ type: string
246
+ name:
247
+ type: string
248
+ description:
249
+ type: string
250
+ metadata:
251
+ type: object
252
+ created_at:
253
+ type: string
254
+ format: date-time
255
+ updated_at:
256
+ type: string
257
+ format: date-time
258
+ example_count:
259
+ type: integer
260
+ 403:
261
+ description: Forbidden
262
+ 404:
263
+ description: Dataset not found
264
+ """
265
+ dataset_id = GlobalID.from_id(request.path_params["id"])
266
+
267
+ if (type_name := dataset_id.type_name) != NODE_NAME:
268
+ return Response(
269
+ content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
219
270
  )
220
271
  async with request.app.state.db() as session:
221
272
  result = await session.execute(
@@ -227,64 +278,97 @@ async def get_dataset(
227
278
  dataset = dataset_query[0] if dataset_query else None
228
279
  example_count = dataset_query[1] if dataset_query else 0
229
280
  if dataset is None:
230
- raise HTTPException(
231
- detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
281
+ return Response(
282
+ content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
232
283
  )
233
284
 
234
- dataset = DatasetWithExampleCount(
235
- id=str(dataset_id),
236
- name=dataset.name,
237
- description=dataset.description,
238
- metadata=dataset.metadata_,
239
- created_at=dataset.created_at,
240
- updated_at=dataset.updated_at,
241
- example_count=example_count,
242
- )
243
- return GetDatasetResponseBody(data=dataset)
244
-
245
-
246
- class DatasetVersion(BaseModel):
247
- version_id: str
248
- description: str
249
- metadata: Dict[str, Any]
250
- created_at: datetime
251
-
252
-
253
- class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
254
- pass
255
-
256
-
257
- @router.get(
258
- "/datasets/{id}/versions",
259
- operation_id="listDatasetVersionsByDatasetId",
260
- summary="List dataset versions",
261
- responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
262
- )
263
- async def list_dataset_versions(
264
- request: Request,
265
- id: str = Path(description="The ID of the dataset"),
266
- cursor: Optional[str] = Query(
267
- default=None,
268
- description="Cursor for pagination",
269
- ),
270
- limit: int = Query(
271
- default=10, description="The max number of dataset versions to return at a time", gt=0
272
- ),
273
- ) -> ListDatasetVersionsResponseBody:
274
- if id:
285
+ output_dict = {
286
+ "id": str(dataset_id),
287
+ "name": dataset.name,
288
+ "description": dataset.description,
289
+ "metadata": dataset.metadata_,
290
+ "created_at": dataset.created_at.isoformat(),
291
+ "updated_at": dataset.updated_at.isoformat(),
292
+ "example_count": example_count,
293
+ }
294
+ return JSONResponse(content={"data": output_dict})
295
+
296
+
297
+ async def get_dataset_versions(request: Request) -> Response:
298
+ """
299
+ summary: Get dataset versions (sorted from latest to oldest)
300
+ operationId: getDatasetVersionsByDatasetId
301
+ tags:
302
+ - datasets
303
+ parameters:
304
+ - in: path
305
+ name: id
306
+ required: true
307
+ description: Dataset ID
308
+ schema:
309
+ type: string
310
+ - in: query
311
+ name: cursor
312
+ description: Cursor for pagination.
313
+ schema:
314
+ type: string
315
+ - in: query
316
+ name: limit
317
+ description: Maximum number versions to return.
318
+ schema:
319
+ type: integer
320
+ default: 10
321
+ responses:
322
+ 200:
323
+ description: Success
324
+ content:
325
+ application/json:
326
+ schema:
327
+ type: object
328
+ properties:
329
+ next_cursor:
330
+ type: string
331
+ data:
332
+ type: array
333
+ items:
334
+ type: object
335
+ properties:
336
+ version_id:
337
+ type: string
338
+ description:
339
+ type: string
340
+ metadata:
341
+ type: object
342
+ created_at:
343
+ type: string
344
+ format: date-time
345
+ 403:
346
+ description: Forbidden
347
+ 422:
348
+ description: Dataset ID, cursor or limit is invalid.
349
+ """
350
+ if id_ := request.path_params.get("id"):
275
351
  try:
276
352
  dataset_id = from_global_id_with_expected_type(
277
- GlobalID.from_id(id),
278
- DATASET_NODE_NAME,
353
+ GlobalID.from_id(id_),
354
+ Dataset.__name__,
279
355
  )
280
356
  except ValueError:
281
- raise HTTPException(
282
- detail=f"Invalid Dataset ID: {id}",
357
+ return Response(
358
+ content=f"Invalid Dataset ID: {id_}",
283
359
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
284
360
  )
285
361
  else:
286
- raise HTTPException(
287
- detail="Missing Dataset ID",
362
+ return Response(
363
+ content="Missing Dataset ID",
364
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
365
+ )
366
+ try:
367
+ limit = int(request.query_params.get("limit", 10))
368
+ assert limit > 0
369
+ except (ValueError, AssertionError):
370
+ return Response(
371
+ content="Invalid limit parameter",
288
372
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
289
373
  )
290
374
  stmt = (
@@ -293,14 +377,15 @@ async def list_dataset_versions(
293
377
  .order_by(models.DatasetVersion.id.desc())
294
378
  .limit(limit + 1)
295
379
  )
296
- if cursor:
380
+ if cursor := request.query_params.get("cursor"):
297
381
  try:
298
382
  dataset_version_id = from_global_id_with_expected_type(
299
- GlobalID.from_id(cursor), DATASET_VERSION_NODE_NAME
383
+ GlobalID.from_id(cursor),
384
+ DatasetVersion.__name__,
300
385
  )
301
386
  except ValueError:
302
- raise HTTPException(
303
- detail=f"Invalid cursor: {cursor}",
387
+ return Response(
388
+ content=f"Invalid cursor: {cursor}",
304
389
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
305
390
  )
306
391
  max_dataset_version_id = (
@@ -311,99 +396,102 @@ async def list_dataset_versions(
311
396
  stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
312
397
  async with request.app.state.db() as session:
313
398
  data = [
314
- DatasetVersion(
315
- version_id=str(GlobalID(DATASET_VERSION_NODE_NAME, str(version.id))),
316
- description=version.description,
317
- metadata=version.metadata_,
318
- created_at=version.created_at,
319
- )
320
- async for version in await session.stream_scalars(stmt)
321
- ]
322
- next_cursor = data.pop().version_id if len(data) == limit + 1 else None
323
- return ListDatasetVersionsResponseBody(data=data, next_cursor=next_cursor)
324
-
325
-
326
- class UploadDatasetData(BaseModel):
327
- dataset_id: str
328
-
329
-
330
- class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
331
- pass
332
-
333
-
334
- @router.post(
335
- "/datasets/upload",
336
- operation_id="uploadDataset",
337
- summary="Upload dataset from JSON, CSV, or PyArrow",
338
- responses=add_errors_to_responses(
339
- [
340
399
  {
341
- "status_code": HTTP_409_CONFLICT,
342
- "description": "Dataset of the same name already exists",
343
- },
344
- {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
345
- ]
346
- ),
347
- # FastAPI cannot generate the request body portion of the OpenAPI schema for
348
- # routes that accept multiple request content types, so we have to provide
349
- # this part of the schema manually. For context, see
350
- # https://github.com/tiangolo/fastapi/discussions/7786 and
351
- # https://github.com/tiangolo/fastapi/issues/990
352
- openapi_extra={
353
- "requestBody": {
354
- "content": {
355
- "application/json": {
356
- "schema": {
357
- "type": "object",
358
- "required": ["name", "inputs"],
359
- "properties": {
360
- "action": {"type": "string", "enum": ["create", "append"]},
361
- "name": {"type": "string"},
362
- "description": {"type": "string"},
363
- "inputs": {"type": "array", "items": {"type": "object"}},
364
- "outputs": {"type": "array", "items": {"type": "object"}},
365
- "metadata": {"type": "array", "items": {"type": "object"}},
366
- },
367
- }
368
- },
369
- "multipart/form-data": {
370
- "schema": {
371
- "type": "object",
372
- "required": ["name", "input_keys[]", "output_keys[]", "file"],
373
- "properties": {
374
- "action": {"type": "string", "enum": ["create", "append"]},
375
- "name": {"type": "string"},
376
- "description": {"type": "string"},
377
- "input_keys[]": {
378
- "type": "array",
379
- "items": {"type": "string"},
380
- "uniqueItems": True,
381
- },
382
- "output_keys[]": {
383
- "type": "array",
384
- "items": {"type": "string"},
385
- "uniqueItems": True,
386
- },
387
- "metadata_keys[]": {
388
- "type": "array",
389
- "items": {"type": "string"},
390
- "uniqueItems": True,
391
- },
392
- "file": {"type": "string", "format": "binary"},
393
- },
394
- }
395
- },
400
+ "version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
401
+ "description": version.description,
402
+ "metadata": version.metadata_,
403
+ "created_at": version.created_at.isoformat(),
396
404
  }
397
- },
398
- },
399
- )
400
- async def upload_dataset(
401
- request: Request,
402
- sync: bool = Query(
403
- default=False,
404
- description="If true, fulfill request synchronously and return JSON containing dataset_id.",
405
- ),
406
- ) -> Optional[UploadDatasetResponseBody]:
405
+ async for version in await session.stream_scalars(stmt)
406
+ ]
407
+ next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
408
+ return JSONResponse(content={"next_cursor": next_cursor, "data": data})
409
+
410
+
411
+ async def post_datasets_upload(request: Request) -> Response:
412
+ """
413
+ summary: Upload dataset as either JSON or file (CSV or PyArrow)
414
+ operationId: uploadDataset
415
+ tags:
416
+ - datasets
417
+ parameters:
418
+ - in: query
419
+ name: sync
420
+ description: If true, fulfill request synchronously and return JSON containing dataset_id
421
+ schema:
422
+ type: boolean
423
+ requestBody:
424
+ content:
425
+ application/json:
426
+ schema:
427
+ type: object
428
+ required:
429
+ - name
430
+ - inputs
431
+ properties:
432
+ action:
433
+ type: string
434
+ enum: [create, append]
435
+ name:
436
+ type: string
437
+ description:
438
+ type: string
439
+ inputs:
440
+ type: array
441
+ items:
442
+ type: object
443
+ outputs:
444
+ type: array
445
+ items:
446
+ type: object
447
+ metadata:
448
+ type: array
449
+ items:
450
+ type: object
451
+ multipart/form-data:
452
+ schema:
453
+ type: object
454
+ required:
455
+ - name
456
+ - input_keys[]
457
+ - output_keys[]
458
+ - file
459
+ properties:
460
+ action:
461
+ type: string
462
+ enum: [create, append]
463
+ name:
464
+ type: string
465
+ description:
466
+ type: string
467
+ input_keys[]:
468
+ type: array
469
+ items:
470
+ type: string
471
+ uniqueItems: true
472
+ output_keys[]:
473
+ type: array
474
+ items:
475
+ type: string
476
+ uniqueItems: true
477
+ metadata_keys[]:
478
+ type: array
479
+ items:
480
+ type: string
481
+ uniqueItems: true
482
+ file:
483
+ type: string
484
+ format: binary
485
+ responses:
486
+ 200:
487
+ description: Success
488
+ 403:
489
+ description: Forbidden
490
+ 409:
491
+ description: Dataset of the same name already exists
492
+ 422:
493
+ description: Request body is invalid
494
+ """
407
495
  request_content_type = request.headers["content-type"]
408
496
  examples: Union[Examples, Awaitable[Examples]]
409
497
  if request_content_type.startswith("application/json"):
@@ -412,15 +500,15 @@ async def upload_dataset(
412
500
  _process_json, await request.json()
413
501
  )
414
502
  except ValueError as e:
415
- raise HTTPException(
416
- detail=str(e),
503
+ return Response(
504
+ content=str(e),
417
505
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
418
506
  )
419
507
  if action is DatasetAction.CREATE:
420
508
  async with request.app.state.db() as session:
421
509
  if await _check_table_exists(session, name):
422
- raise HTTPException(
423
- detail=f"Dataset with the same name already exists: {name=}",
510
+ return Response(
511
+ content=f"Dataset with the same name already exists: {name=}",
424
512
  status_code=HTTP_409_CONFLICT,
425
513
  )
426
514
  elif request_content_type.startswith("multipart/form-data"):
@@ -436,15 +524,15 @@ async def upload_dataset(
436
524
  file,
437
525
  ) = await _parse_form_data(form)
438
526
  except ValueError as e:
439
- raise HTTPException(
440
- detail=str(e),
527
+ return Response(
528
+ content=str(e),
441
529
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
442
530
  )
443
531
  if action is DatasetAction.CREATE:
444
532
  async with request.app.state.db() as session:
445
533
  if await _check_table_exists(session, name):
446
- raise HTTPException(
447
- detail=f"Dataset with the same name already exists: {name=}",
534
+ return Response(
535
+ content=f"Dataset with the same name already exists: {name=}",
448
536
  status_code=HTTP_409_CONFLICT,
449
537
  )
450
538
  content = await file.read()
@@ -460,13 +548,13 @@ async def upload_dataset(
460
548
  else:
461
549
  assert_never(file_content_type)
462
550
  except ValueError as e:
463
- raise HTTPException(
464
- detail=str(e),
551
+ return Response(
552
+ content=str(e),
465
553
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
466
554
  )
467
555
  else:
468
- raise HTTPException(
469
- detail="Invalid request Content-Type",
556
+ return Response(
557
+ content=str("Invalid request Content-Type"),
470
558
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
471
559
  )
472
560
  operation = cast(
@@ -479,17 +567,19 @@ async def upload_dataset(
479
567
  description=description,
480
568
  ),
481
569
  )
482
- if sync:
570
+ if request.query_params.get("sync") == "true":
483
571
  async with request.app.state.db() as session:
484
572
  dataset_id = (await operation(session)).dataset_id
485
- return UploadDatasetResponseBody(data=UploadDatasetData(dataset_id=str(dataset_id)))
573
+ return JSONResponse(
574
+ content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
575
+ )
486
576
  try:
487
577
  request.state.enqueue_operation(operation)
488
578
  except QueueFull:
489
579
  if isinstance(examples, Coroutine):
490
580
  examples.close()
491
- raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
492
- return None
581
+ return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
582
+ return Response()
493
583
 
494
584
 
495
585
  class FileContentType(Enum):
@@ -667,125 +757,151 @@ async def _parse_form_data(
667
757
  )
668
758
 
669
759
 
670
- # including the dataset examples router here ensures the dataset example routes
671
- # are included in a natural order in the openapi schema and the swagger ui
672
- #
673
- # todo: move the dataset examples routes here and remove the dataset_examples
674
- # sub-module
675
- router.include_router(dataset_examples_router)
676
-
677
-
678
- @router.get(
679
- "/datasets/{id}/csv",
680
- operation_id="getDatasetCsv",
681
- summary="Download dataset examples as CSV file",
682
- response_class=StreamingResponse,
683
- status_code=HTTP_200_OK,
684
- responses={
685
- **add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
686
- **add_text_csv_content_to_responses(HTTP_200_OK),
687
- },
688
- )
689
- async def get_dataset_csv(
690
- request: Request,
691
- response: Response,
692
- id: str = Path(description="The ID of the dataset"),
693
- version_id: Optional[str] = Query(
694
- default=None,
695
- description=(
696
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
697
- ),
698
- ),
699
- ) -> Response:
760
+ async def get_dataset_csv(request: Request) -> Response:
761
+ """
762
+ summary: Download dataset examples as CSV text file
763
+ operationId: getDatasetCsv
764
+ tags:
765
+ - datasets
766
+ parameters:
767
+ - in: path
768
+ name: id
769
+ required: true
770
+ schema:
771
+ type: string
772
+ description: Dataset ID
773
+ - in: query
774
+ name: version_id
775
+ schema:
776
+ type: string
777
+ description: Dataset version ID. If omitted, returns the latest version.
778
+ responses:
779
+ 200:
780
+ description: Success
781
+ content:
782
+ text/csv:
783
+ schema:
784
+ type: string
785
+ contentMediaType: text/csv
786
+ contentEncoding: gzip
787
+ 403:
788
+ description: Forbidden
789
+ 404:
790
+ description: Dataset does not exist.
791
+ 422:
792
+ description: Dataset ID or version ID is invalid.
793
+ """
700
794
  try:
701
- async with request.app.state.db() as session:
702
- dataset_name, examples = await _get_db_examples(
703
- session=session, id=id, version_id=version_id
704
- )
795
+ dataset_name, examples = await _get_db_examples(request)
705
796
  except ValueError as e:
706
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
797
+ return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
707
798
  content = await run_in_threadpool(_get_content_csv, examples)
708
799
  return Response(
709
800
  content=content,
710
801
  headers={
711
802
  "content-disposition": f'attachment; filename="{dataset_name}.csv"',
712
803
  "content-type": "text/csv",
804
+ "content-encoding": "gzip",
713
805
  },
714
806
  )
715
807
 
716
808
 
717
- @router.get(
718
- "/datasets/{id}/jsonl/openai_ft",
719
- operation_id="getDatasetJSONLOpenAIFineTuning",
720
- summary="Download dataset examples as OpenAI fine-tuning JSONL file",
721
- response_class=PlainTextResponse,
722
- responses=add_errors_to_responses(
723
- [
724
- {
725
- "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
726
- "description": "Invalid dataset or version ID",
727
- }
728
- ]
729
- ),
730
- )
731
- async def get_dataset_jsonl_openai_ft(
732
- request: Request,
733
- response: Response,
734
- id: str = Path(description="The ID of the dataset"),
735
- version_id: Optional[str] = Query(
736
- default=None,
737
- description=(
738
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
739
- ),
740
- ),
741
- ) -> bytes:
809
+ async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
810
+ """
811
+ summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
812
+ operationId: getDatasetJSONLOpenAIFineTuning
813
+ tags:
814
+ - datasets
815
+ parameters:
816
+ - in: path
817
+ name: id
818
+ required: true
819
+ schema:
820
+ type: string
821
+ description: Dataset ID
822
+ - in: query
823
+ name: version_id
824
+ schema:
825
+ type: string
826
+ description: Dataset version ID. If omitted, returns the latest version.
827
+ responses:
828
+ 200:
829
+ description: Success
830
+ content:
831
+ text/plain:
832
+ schema:
833
+ type: string
834
+ contentMediaType: text/plain
835
+ contentEncoding: gzip
836
+ 403:
837
+ description: Forbidden
838
+ 404:
839
+ description: Dataset does not exist.
840
+ 422:
841
+ description: Dataset ID or version ID is invalid.
842
+ """
742
843
  try:
743
- async with request.app.state.db() as session:
744
- dataset_name, examples = await _get_db_examples(
745
- session=session, id=id, version_id=version_id
746
- )
844
+ dataset_name, examples = await _get_db_examples(request)
747
845
  except ValueError as e:
748
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
846
+ return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
749
847
  content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
750
- response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
751
- return content
848
+ return Response(
849
+ content=content,
850
+ headers={
851
+ "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
852
+ "content-type": "text/plain",
853
+ "content-encoding": "gzip",
854
+ },
855
+ )
752
856
 
753
857
 
754
- @router.get(
755
- "/datasets/{id}/jsonl/openai_evals",
756
- operation_id="getDatasetJSONLOpenAIEvals",
757
- summary="Download dataset examples as OpenAI evals JSONL file",
758
- response_class=PlainTextResponse,
759
- responses=add_errors_to_responses(
760
- [
761
- {
762
- "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
763
- "description": "Invalid dataset or version ID",
764
- }
765
- ]
766
- ),
767
- )
768
- async def get_dataset_jsonl_openai_evals(
769
- request: Request,
770
- response: Response,
771
- id: str = Path(description="The ID of the dataset"),
772
- version_id: Optional[str] = Query(
773
- default=None,
774
- description=(
775
- "The ID of the dataset version " "(if omitted, returns data from the latest version)"
776
- ),
777
- ),
778
- ) -> bytes:
858
+ async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
859
+ """
860
+ summary: Download dataset examples as OpenAI Evals JSONL file
861
+ operationId: getDatasetJSONLOpenAIEvals
862
+ tags:
863
+ - datasets
864
+ parameters:
865
+ - in: path
866
+ name: id
867
+ required: true
868
+ schema:
869
+ type: string
870
+ description: Dataset ID
871
+ - in: query
872
+ name: version_id
873
+ schema:
874
+ type: string
875
+ description: Dataset version ID. If omitted, returns the latest version.
876
+ responses:
877
+ 200:
878
+ description: Success
879
+ content:
880
+ text/plain:
881
+ schema:
882
+ type: string
883
+ contentMediaType: text/plain
884
+ contentEncoding: gzip
885
+ 403:
886
+ description: Forbidden
887
+ 404:
888
+ description: Dataset does not exist.
889
+ 422:
890
+ description: Dataset ID or version ID is invalid.
891
+ """
779
892
  try:
780
- async with request.app.state.db() as session:
781
- dataset_name, examples = await _get_db_examples(
782
- session=session, id=id, version_id=version_id
783
- )
893
+ dataset_name, examples = await _get_db_examples(request)
784
894
  except ValueError as e:
785
- raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
895
+ return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
786
896
  content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
787
- response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
788
- return content
897
+ return Response(
898
+ content=content,
899
+ headers={
900
+ "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
901
+ "content-type": "text/plain",
902
+ "content-encoding": "gzip",
903
+ },
904
+ )
789
905
 
790
906
 
791
907
  def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -801,7 +917,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
801
917
  }
802
918
  for ex in examples
803
919
  ]
804
- return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
920
+ return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
805
921
 
806
922
 
807
923
  def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -822,7 +938,7 @@ def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision])
822
938
  ).encode()
823
939
  )
824
940
  records.seek(0)
825
- return records.read()
941
+ return gzip.compress(records.read())
826
942
 
827
943
 
828
944
  def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -849,17 +965,18 @@ def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision
849
965
  ).encode()
850
966
  )
851
967
  records.seek(0)
852
- return records.read()
968
+ return gzip.compress(records.read())
853
969
 
854
970
 
855
- async def _get_db_examples(
856
- *, session: Any, id: str, version_id: Optional[str]
857
- ) -> Tuple[str, List[models.DatasetExampleRevision]]:
858
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
971
+ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
972
+ if not (id_ := request.path_params.get("id")):
973
+ raise ValueError("Missing Dataset ID")
974
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
859
975
  dataset_version_id: Optional[int] = None
860
- if version_id:
976
+ if version_id := request.query_params.get("version_id"):
861
977
  dataset_version_id = from_global_id_with_expected_type(
862
- GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
978
+ GlobalID.from_id(version_id),
979
+ DatasetVersion.__name__,
863
980
  )
864
981
  latest_version = (
865
982
  select(
@@ -892,12 +1009,13 @@ async def _get_db_examples(
892
1009
  .where(models.DatasetExampleRevision.revision_kind != "DELETE")
893
1010
  .order_by(models.DatasetExampleRevision.dataset_example_id)
894
1011
  )
895
- dataset_name: Optional[str] = await session.scalar(
896
- select(models.Dataset.name).where(models.Dataset.id == dataset_id)
897
- )
898
- if not dataset_name:
899
- raise ValueError("Dataset does not exist.")
900
- examples = [r async for r in await session.stream_scalars(stmt)]
1012
+ async with request.app.state.db() as session:
1013
+ dataset_name: Optional[str] = await session.scalar(
1014
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
1015
+ )
1016
+ if not dataset_name:
1017
+ raise ValueError("Dataset does not exist.")
1018
+ examples = [r async for r in await session.stream_scalars(stmt)]
901
1019
  return dataset_name, examples
902
1020
 
903
1021