arize-phoenix 4.10.2rc0__py3-none-any.whl → 4.10.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (38) hide show
  1. {arize_phoenix-4.10.2rc0.dist-info → arize_phoenix-4.10.2rc1.dist-info}/METADATA +4 -3
  2. {arize_phoenix-4.10.2rc0.dist-info → arize_phoenix-4.10.2rc1.dist-info}/RECORD +27 -35
  3. phoenix/server/api/context.py +3 -7
  4. phoenix/server/api/openapi/main.py +18 -2
  5. phoenix/server/api/openapi/schema.py +12 -12
  6. phoenix/server/api/routers/v1/__init__.py +36 -83
  7. phoenix/server/api/routers/v1/dataset_examples.py +102 -123
  8. phoenix/server/api/routers/v1/datasets.py +389 -507
  9. phoenix/server/api/routers/v1/evaluations.py +74 -64
  10. phoenix/server/api/routers/v1/experiment_evaluations.py +67 -91
  11. phoenix/server/api/routers/v1/experiment_runs.py +97 -155
  12. phoenix/server/api/routers/v1/experiments.py +131 -181
  13. phoenix/server/api/routers/v1/spans.py +141 -173
  14. phoenix/server/api/routers/v1/traces.py +113 -128
  15. phoenix/server/api/routers/v1/utils.py +94 -0
  16. phoenix/server/api/types/Span.py +0 -1
  17. phoenix/server/app.py +148 -192
  18. phoenix/server/main.py +0 -3
  19. phoenix/server/static/index.css +6 -0
  20. phoenix/server/static/index.js +8547 -0
  21. phoenix/server/templates/index.html +25 -76
  22. phoenix/server/thread_server.py +2 -2
  23. phoenix/trace/schemas.py +0 -1
  24. phoenix/version.py +1 -1
  25. phoenix/server/openapi/docs.py +0 -221
  26. phoenix/server/static/.vite/manifest.json +0 -78
  27. phoenix/server/static/assets/components-C8sm_r1F.js +0 -1142
  28. phoenix/server/static/assets/index-BEKPzgQs.js +0 -100
  29. phoenix/server/static/assets/pages-bN7juCjh.js +0 -2885
  30. phoenix/server/static/assets/vendor-CUDAPm8e.js +0 -641
  31. phoenix/server/static/assets/vendor-DxkFTwjz.css +0 -1
  32. phoenix/server/static/assets/vendor-arizeai-Do2HOmcL.js +0 -662
  33. phoenix/server/static/assets/vendor-codemirror-CrdxOlMs.js +0 -12
  34. phoenix/server/static/assets/vendor-recharts-PKRvByVe.js +0 -59
  35. phoenix/server/static/assets/vendor-three-DwGkEfCM.js +0 -2998
  36. {arize_phoenix-4.10.2rc0.dist-info → arize_phoenix-4.10.2rc1.dist-info}/WHEEL +0 -0
  37. {arize_phoenix-4.10.2rc0.dist-info → arize_phoenix-4.10.2rc1.dist-info}/licenses/IP_NOTICE +0 -0
  38. {arize_phoenix-4.10.2rc0.dist-info → arize_phoenix-4.10.2rc1.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ import logging
6
6
  import zlib
7
7
  from asyncio import QueueFull
8
8
  from collections import Counter
9
+ from datetime import datetime
9
10
  from enum import Enum
10
11
  from functools import partial
11
12
  from typing import (
@@ -13,6 +14,7 @@ from typing import (
13
14
  Awaitable,
14
15
  Callable,
15
16
  Coroutine,
17
+ Dict,
16
18
  FrozenSet,
17
19
  Iterator,
18
20
  List,
@@ -26,14 +28,17 @@ from typing import (
26
28
 
27
29
  import pandas as pd
28
30
  import pyarrow as pa
31
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
32
+ from fastapi.responses import PlainTextResponse, StreamingResponse
33
+ from pydantic import BaseModel
29
34
  from sqlalchemy import and_, delete, func, select
30
35
  from sqlalchemy.ext.asyncio import AsyncSession
31
- from starlette.background import BackgroundTasks
32
36
  from starlette.concurrency import run_in_threadpool
33
37
  from starlette.datastructures import FormData, UploadFile
34
38
  from starlette.requests import Request
35
- from starlette.responses import JSONResponse, Response
39
+ from starlette.responses import Response
36
40
  from starlette.status import (
41
+ HTTP_200_OK,
37
42
  HTTP_204_NO_CONTENT,
38
43
  HTTP_404_NOT_FOUND,
39
44
  HTTP_409_CONFLICT,
@@ -51,79 +56,59 @@ from phoenix.db.insertion.dataset import (
51
56
  ExampleContent,
52
57
  add_dataset_examples,
53
58
  )
54
- from phoenix.server.api.types.Dataset import Dataset
59
+ from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
55
60
  from phoenix.server.api.types.DatasetExample import DatasetExample
56
- from phoenix.server.api.types.DatasetVersion import DatasetVersion
61
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
57
62
  from phoenix.server.api.types.node import from_global_id_with_expected_type
58
63
  from phoenix.server.api.utils import delete_projects, delete_traces
59
64
 
65
+ from .dataset_examples import router as dataset_examples_router
66
+ from .utils import (
67
+ PaginatedResponseBody,
68
+ ResponseBody,
69
+ add_errors_to_responses,
70
+ add_text_csv_content_to_responses,
71
+ )
72
+
60
73
  logger = logging.getLogger(__name__)
61
74
 
62
- NODE_NAME = "Dataset"
63
-
64
-
65
- async def list_datasets(request: Request) -> Response:
66
- """
67
- summary: List datasets with cursor-based pagination
68
- operationId: listDatasets
69
- tags:
70
- - datasets
71
- parameters:
72
- - in: query
73
- name: cursor
74
- required: false
75
- schema:
76
- type: string
77
- description: Cursor for pagination
78
- - in: query
79
- name: limit
80
- required: false
81
- schema:
82
- type: integer
83
- default: 10
84
- - in: query
85
- name: name
86
- required: false
87
- schema:
88
- type: string
89
- description: match by dataset name
90
- responses:
91
- 200:
92
- description: A paginated list of datasets
93
- content:
94
- application/json:
95
- schema:
96
- type: object
97
- properties:
98
- next_cursor:
99
- type: string
100
- data:
101
- type: array
102
- items:
103
- type: object
104
- properties:
105
- id:
106
- type: string
107
- name:
108
- type: string
109
- description:
110
- type: string
111
- metadata:
112
- type: object
113
- created_at:
114
- type: string
115
- format: date-time
116
- updated_at:
117
- type: string
118
- format: date-time
119
- 403:
120
- description: Forbidden
121
- 404:
122
- description: No datasets found
123
- """
124
- name = request.query_params.get("name")
125
- cursor = request.query_params.get("cursor")
126
- limit = int(request.query_params.get("limit", 10))
75
+ DATASET_NODE_NAME = DatasetNodeType.__name__
76
+ DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
77
+
78
+
79
+ router = APIRouter(tags=["datasets"])
80
+
81
+
82
+ class Dataset(BaseModel):
83
+ id: str
84
+ name: str
85
+ description: Optional[str]
86
+ metadata: Dict[str, Any]
87
+ created_at: datetime
88
+ updated_at: datetime
89
+
90
+
91
+ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
92
+ pass
93
+
94
+
95
+ @router.get(
96
+ "/datasets",
97
+ operation_id="listDatasets",
98
+ summary="List datasets",
99
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
100
+ )
101
+ async def list_datasets(
102
+ request: Request,
103
+ cursor: Optional[str] = Query(
104
+ default=None,
105
+ description="Cursor for pagination",
106
+ ),
107
+ name: Optional[str] = Query(default=None, description="An optional dataset name to filter by"),
108
+ limit: int = Query(
109
+ default=10, description="The max number of datasets to return at a time.", gt=0
110
+ ),
111
+ ) -> ListDatasetsResponseBody:
127
112
  async with request.app.state.db() as session:
128
113
  query = select(models.Dataset).order_by(models.Dataset.id.desc())
129
114
 
@@ -132,8 +117,8 @@ async def list_datasets(request: Request) -> Response:
132
117
  cursor_id = GlobalID.from_id(cursor).node_id
133
118
  query = query.filter(models.Dataset.id <= int(cursor_id))
134
119
  except ValueError:
135
- return Response(
136
- content=f"Invalid cursor format: {cursor}",
120
+ raise HTTPException(
121
+ detail=f"Invalid cursor format: {cursor}",
137
122
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
138
123
  )
139
124
  if name:
@@ -144,67 +129,56 @@ async def list_datasets(request: Request) -> Response:
144
129
  datasets = result.scalars().all()
145
130
 
146
131
  if not datasets:
147
- return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
132
+ return ListDatasetsResponseBody(next_cursor=None, data=[])
148
133
 
149
134
  next_cursor = None
150
135
  if len(datasets) == limit + 1:
151
- next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
136
+ next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
152
137
  datasets = datasets[:-1]
153
138
 
154
139
  data = []
155
140
  for dataset in datasets:
156
141
  data.append(
157
- {
158
- "id": str(GlobalID(NODE_NAME, str(dataset.id))),
159
- "name": dataset.name,
160
- "description": dataset.description,
161
- "metadata": dataset.metadata_,
162
- "created_at": dataset.created_at.isoformat(),
163
- "updated_at": dataset.updated_at.isoformat(),
164
- }
142
+ Dataset(
143
+ id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
144
+ name=dataset.name,
145
+ description=dataset.description,
146
+ metadata=dataset.metadata_,
147
+ created_at=dataset.created_at,
148
+ updated_at=dataset.updated_at,
149
+ )
165
150
  )
166
151
 
167
- return JSONResponse(content={"next_cursor": next_cursor, "data": data})
168
-
169
-
170
- async def delete_dataset_by_id(request: Request) -> Response:
171
- """
172
- summary: Delete dataset by ID
173
- operationId: deleteDatasetById
174
- tags:
175
- - datasets
176
- parameters:
177
- - in: path
178
- name: id
179
- required: true
180
- schema:
181
- type: string
182
- responses:
183
- 204:
184
- description: Success
185
- 403:
186
- description: Forbidden
187
- 404:
188
- description: Dataset not found
189
- 422:
190
- description: Dataset ID is invalid
191
- """
192
- if id_ := request.path_params.get("id"):
152
+ return ListDatasetsResponseBody(next_cursor=next_cursor, data=data)
153
+
154
+
155
+ @router.delete(
156
+ "/datasets/{id}",
157
+ operation_id="deleteDatasetById",
158
+ summary="Delete dataset by ID",
159
+ status_code=HTTP_204_NO_CONTENT,
160
+ responses=add_errors_to_responses(
161
+ [
162
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
163
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
164
+ ]
165
+ ),
166
+ )
167
+ async def delete_dataset(
168
+ request: Request, id: str = Path(description="The ID of the dataset to delete.")
169
+ ) -> None:
170
+ if id:
193
171
  try:
194
172
  dataset_id = from_global_id_with_expected_type(
195
- GlobalID.from_id(id_),
196
- Dataset.__name__,
173
+ GlobalID.from_id(id),
174
+ DATASET_NODE_NAME,
197
175
  )
198
176
  except ValueError:
199
- return Response(
200
- content=f"Invalid Dataset ID: {id_}",
201
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
177
+ raise HTTPException(
178
+ detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
202
179
  )
203
180
  else:
204
- return Response(
205
- content="Missing Dataset ID",
206
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
207
- )
181
+ raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
208
182
  project_names_stmt = get_project_names_for_datasets(dataset_id)
209
183
  eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
210
184
  stmt = (
@@ -214,59 +188,34 @@ async def delete_dataset_by_id(request: Request) -> Response:
214
188
  project_names = await session.scalars(project_names_stmt)
215
189
  eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
216
190
  if (await session.scalar(stmt)) is None:
217
- return Response(content="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
191
+ raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
218
192
  tasks = BackgroundTasks()
219
193
  tasks.add_task(delete_projects, request.app.state.db, *project_names)
220
194
  tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
221
- return Response(status_code=HTTP_204_NO_CONTENT, background=tasks)
222
-
223
-
224
- async def get_dataset_by_id(request: Request) -> Response:
225
- """
226
- summary: Get dataset by ID
227
- operationId: getDatasetById
228
- tags:
229
- - datasets
230
- parameters:
231
- - in: path
232
- name: id
233
- required: true
234
- schema:
235
- type: string
236
- responses:
237
- 200:
238
- description: Success
239
- content:
240
- application/json:
241
- schema:
242
- type: object
243
- properties:
244
- id:
245
- type: string
246
- name:
247
- type: string
248
- description:
249
- type: string
250
- metadata:
251
- type: object
252
- created_at:
253
- type: string
254
- format: date-time
255
- updated_at:
256
- type: string
257
- format: date-time
258
- example_count:
259
- type: integer
260
- 403:
261
- description: Forbidden
262
- 404:
263
- description: Dataset not found
264
- """
265
- dataset_id = GlobalID.from_id(request.path_params["id"])
266
-
267
- if (type_name := dataset_id.type_name) != NODE_NAME:
268
- return Response(
269
- content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
195
+
196
+
197
+ class DatasetWithExampleCount(Dataset):
198
+ example_count: int
199
+
200
+
201
+ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
202
+ pass
203
+
204
+
205
+ @router.get(
206
+ "/datasets/{id}",
207
+ operation_id="getDataset",
208
+ summary="Get dataset by ID",
209
+ responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
210
+ )
211
+ async def get_dataset(
212
+ request: Request, id: str = Path(description="The ID of the dataset")
213
+ ) -> GetDatasetResponseBody:
214
+ dataset_id = GlobalID.from_id(id)
215
+
216
+ if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
217
+ raise HTTPException(
218
+ detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
270
219
  )
271
220
  async with request.app.state.db() as session:
272
221
  result = await session.execute(
@@ -278,97 +227,64 @@ async def get_dataset_by_id(request: Request) -> Response:
278
227
  dataset = dataset_query[0] if dataset_query else None
279
228
  example_count = dataset_query[1] if dataset_query else 0
280
229
  if dataset is None:
281
- return Response(
282
- content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
230
+ raise HTTPException(
231
+ detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
283
232
  )
284
233
 
285
- output_dict = {
286
- "id": str(dataset_id),
287
- "name": dataset.name,
288
- "description": dataset.description,
289
- "metadata": dataset.metadata_,
290
- "created_at": dataset.created_at.isoformat(),
291
- "updated_at": dataset.updated_at.isoformat(),
292
- "example_count": example_count,
293
- }
294
- return JSONResponse(content={"data": output_dict})
295
-
296
-
297
- async def get_dataset_versions(request: Request) -> Response:
298
- """
299
- summary: Get dataset versions (sorted from latest to oldest)
300
- operationId: getDatasetVersionsByDatasetId
301
- tags:
302
- - datasets
303
- parameters:
304
- - in: path
305
- name: id
306
- required: true
307
- description: Dataset ID
308
- schema:
309
- type: string
310
- - in: query
311
- name: cursor
312
- description: Cursor for pagination.
313
- schema:
314
- type: string
315
- - in: query
316
- name: limit
317
- description: Maximum number versions to return.
318
- schema:
319
- type: integer
320
- default: 10
321
- responses:
322
- 200:
323
- description: Success
324
- content:
325
- application/json:
326
- schema:
327
- type: object
328
- properties:
329
- next_cursor:
330
- type: string
331
- data:
332
- type: array
333
- items:
334
- type: object
335
- properties:
336
- version_id:
337
- type: string
338
- description:
339
- type: string
340
- metadata:
341
- type: object
342
- created_at:
343
- type: string
344
- format: date-time
345
- 403:
346
- description: Forbidden
347
- 422:
348
- description: Dataset ID, cursor or limit is invalid.
349
- """
350
- if id_ := request.path_params.get("id"):
234
+ dataset = DatasetWithExampleCount(
235
+ id=str(dataset_id),
236
+ name=dataset.name,
237
+ description=dataset.description,
238
+ metadata=dataset.metadata_,
239
+ created_at=dataset.created_at,
240
+ updated_at=dataset.updated_at,
241
+ example_count=example_count,
242
+ )
243
+ return GetDatasetResponseBody(data=dataset)
244
+
245
+
246
+ class DatasetVersion(BaseModel):
247
+ version_id: str
248
+ description: str
249
+ metadata: Dict[str, Any]
250
+ created_at: datetime
251
+
252
+
253
+ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
254
+ pass
255
+
256
+
257
+ @router.get(
258
+ "/datasets/{id}/versions",
259
+ operation_id="listDatasetVersionsByDatasetId",
260
+ summary="List dataset versions",
261
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
262
+ )
263
+ async def list_dataset_versions(
264
+ request: Request,
265
+ id: str = Path(description="The ID of the dataset"),
266
+ cursor: Optional[str] = Query(
267
+ default=None,
268
+ description="Cursor for pagination",
269
+ ),
270
+ limit: int = Query(
271
+ default=10, description="The max number of dataset versions to return at a time", gt=0
272
+ ),
273
+ ) -> ListDatasetVersionsResponseBody:
274
+ if id:
351
275
  try:
352
276
  dataset_id = from_global_id_with_expected_type(
353
- GlobalID.from_id(id_),
354
- Dataset.__name__,
277
+ GlobalID.from_id(id),
278
+ DATASET_NODE_NAME,
355
279
  )
356
280
  except ValueError:
357
- return Response(
358
- content=f"Invalid Dataset ID: {id_}",
281
+ raise HTTPException(
282
+ detail=f"Invalid Dataset ID: {id}",
359
283
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
360
284
  )
361
285
  else:
362
- return Response(
363
- content="Missing Dataset ID",
364
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
365
- )
366
- try:
367
- limit = int(request.query_params.get("limit", 10))
368
- assert limit > 0
369
- except (ValueError, AssertionError):
370
- return Response(
371
- content="Invalid limit parameter",
286
+ raise HTTPException(
287
+ detail="Missing Dataset ID",
372
288
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
373
289
  )
374
290
  stmt = (
@@ -377,15 +293,14 @@ async def get_dataset_versions(request: Request) -> Response:
377
293
  .order_by(models.DatasetVersion.id.desc())
378
294
  .limit(limit + 1)
379
295
  )
380
- if cursor := request.query_params.get("cursor"):
296
+ if cursor:
381
297
  try:
382
298
  dataset_version_id = from_global_id_with_expected_type(
383
- GlobalID.from_id(cursor),
384
- DatasetVersion.__name__,
299
+ GlobalID.from_id(cursor), DATASET_VERSION_NODE_NAME
385
300
  )
386
301
  except ValueError:
387
- return Response(
388
- content=f"Invalid cursor: {cursor}",
302
+ raise HTTPException(
303
+ detail=f"Invalid cursor: {cursor}",
389
304
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
390
305
  )
391
306
  max_dataset_version_id = (
@@ -396,102 +311,99 @@ async def get_dataset_versions(request: Request) -> Response:
396
311
  stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
397
312
  async with request.app.state.db() as session:
398
313
  data = [
399
- {
400
- "version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
401
- "description": version.description,
402
- "metadata": version.metadata_,
403
- "created_at": version.created_at.isoformat(),
404
- }
314
+ DatasetVersion(
315
+ version_id=str(GlobalID(DATASET_VERSION_NODE_NAME, str(version.id))),
316
+ description=version.description,
317
+ metadata=version.metadata_,
318
+ created_at=version.created_at,
319
+ )
405
320
  async for version in await session.stream_scalars(stmt)
406
321
  ]
407
- next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
408
- return JSONResponse(content={"next_cursor": next_cursor, "data": data})
409
-
410
-
411
- async def post_datasets_upload(request: Request) -> Response:
412
- """
413
- summary: Upload dataset as either JSON or file (CSV or PyArrow)
414
- operationId: uploadDataset
415
- tags:
416
- - datasets
417
- parameters:
418
- - in: query
419
- name: sync
420
- description: If true, fulfill request synchronously and return JSON containing dataset_id
421
- schema:
422
- type: boolean
423
- requestBody:
424
- content:
425
- application/json:
426
- schema:
427
- type: object
428
- required:
429
- - name
430
- - inputs
431
- properties:
432
- action:
433
- type: string
434
- enum: [create, append]
435
- name:
436
- type: string
437
- description:
438
- type: string
439
- inputs:
440
- type: array
441
- items:
442
- type: object
443
- outputs:
444
- type: array
445
- items:
446
- type: object
447
- metadata:
448
- type: array
449
- items:
450
- type: object
451
- multipart/form-data:
452
- schema:
453
- type: object
454
- required:
455
- - name
456
- - input_keys[]
457
- - output_keys[]
458
- - file
459
- properties:
460
- action:
461
- type: string
462
- enum: [create, append]
463
- name:
464
- type: string
465
- description:
466
- type: string
467
- input_keys[]:
468
- type: array
469
- items:
470
- type: string
471
- uniqueItems: true
472
- output_keys[]:
473
- type: array
474
- items:
475
- type: string
476
- uniqueItems: true
477
- metadata_keys[]:
478
- type: array
479
- items:
480
- type: string
481
- uniqueItems: true
482
- file:
483
- type: string
484
- format: binary
485
- responses:
486
- 200:
487
- description: Success
488
- 403:
489
- description: Forbidden
490
- 409:
491
- description: Dataset of the same name already exists
492
- 422:
493
- description: Request body is invalid
494
- """
322
+ next_cursor = data.pop().version_id if len(data) == limit + 1 else None
323
+ return ListDatasetVersionsResponseBody(data=data, next_cursor=next_cursor)
324
+
325
+
326
+ class UploadDatasetData(BaseModel):
327
+ dataset_id: str
328
+
329
+
330
+ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
331
+ pass
332
+
333
+
334
+ @router.post(
335
+ "/datasets/upload",
336
+ operation_id="uploadDataset",
337
+ summary="Upload dataset from JSON, CSV, or PyArrow",
338
+ responses=add_errors_to_responses(
339
+ [
340
+ {
341
+ "status_code": HTTP_409_CONFLICT,
342
+ "description": "Dataset of the same name already exists",
343
+ },
344
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
345
+ ]
346
+ ),
347
+ # FastAPI cannot generate the request body portion of the OpenAPI schema for
348
+ # routes that accept multiple request content types, so we have to provide
349
+ # this part of the schema manually. For context, see
350
+ # https://github.com/tiangolo/fastapi/discussions/7786 and
351
+ # https://github.com/tiangolo/fastapi/issues/990
352
+ openapi_extra={
353
+ "requestBody": {
354
+ "content": {
355
+ "application/json": {
356
+ "schema": {
357
+ "type": "object",
358
+ "required": ["name", "inputs"],
359
+ "properties": {
360
+ "action": {"type": "string", "enum": ["create", "append"]},
361
+ "name": {"type": "string"},
362
+ "description": {"type": "string"},
363
+ "inputs": {"type": "array", "items": {"type": "object"}},
364
+ "outputs": {"type": "array", "items": {"type": "object"}},
365
+ "metadata": {"type": "array", "items": {"type": "object"}},
366
+ },
367
+ }
368
+ },
369
+ "multipart/form-data": {
370
+ "schema": {
371
+ "type": "object",
372
+ "required": ["name", "input_keys[]", "output_keys[]", "file"],
373
+ "properties": {
374
+ "action": {"type": "string", "enum": ["create", "append"]},
375
+ "name": {"type": "string"},
376
+ "description": {"type": "string"},
377
+ "input_keys[]": {
378
+ "type": "array",
379
+ "items": {"type": "string"},
380
+ "uniqueItems": True,
381
+ },
382
+ "output_keys[]": {
383
+ "type": "array",
384
+ "items": {"type": "string"},
385
+ "uniqueItems": True,
386
+ },
387
+ "metadata_keys[]": {
388
+ "type": "array",
389
+ "items": {"type": "string"},
390
+ "uniqueItems": True,
391
+ },
392
+ "file": {"type": "string", "format": "binary"},
393
+ },
394
+ }
395
+ },
396
+ }
397
+ },
398
+ },
399
+ )
400
+ async def upload_dataset(
401
+ request: Request,
402
+ sync: bool = Query(
403
+ default=False,
404
+ description="If true, fulfill request synchronously and return JSON containing dataset_id.",
405
+ ),
406
+ ) -> Optional[UploadDatasetResponseBody]:
495
407
  request_content_type = request.headers["content-type"]
496
408
  examples: Union[Examples, Awaitable[Examples]]
497
409
  if request_content_type.startswith("application/json"):
@@ -500,15 +412,15 @@ async def post_datasets_upload(request: Request) -> Response:
500
412
  _process_json, await request.json()
501
413
  )
502
414
  except ValueError as e:
503
- return Response(
504
- content=str(e),
415
+ raise HTTPException(
416
+ detail=str(e),
505
417
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
506
418
  )
507
419
  if action is DatasetAction.CREATE:
508
420
  async with request.app.state.db() as session:
509
421
  if await _check_table_exists(session, name):
510
- return Response(
511
- content=f"Dataset with the same name already exists: {name=}",
422
+ raise HTTPException(
423
+ detail=f"Dataset with the same name already exists: {name=}",
512
424
  status_code=HTTP_409_CONFLICT,
513
425
  )
514
426
  elif request_content_type.startswith("multipart/form-data"):
@@ -524,15 +436,15 @@ async def post_datasets_upload(request: Request) -> Response:
524
436
  file,
525
437
  ) = await _parse_form_data(form)
526
438
  except ValueError as e:
527
- return Response(
528
- content=str(e),
439
+ raise HTTPException(
440
+ detail=str(e),
529
441
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
530
442
  )
531
443
  if action is DatasetAction.CREATE:
532
444
  async with request.app.state.db() as session:
533
445
  if await _check_table_exists(session, name):
534
- return Response(
535
- content=f"Dataset with the same name already exists: {name=}",
446
+ raise HTTPException(
447
+ detail=f"Dataset with the same name already exists: {name=}",
536
448
  status_code=HTTP_409_CONFLICT,
537
449
  )
538
450
  content = await file.read()
@@ -548,13 +460,13 @@ async def post_datasets_upload(request: Request) -> Response:
548
460
  else:
549
461
  assert_never(file_content_type)
550
462
  except ValueError as e:
551
- return Response(
552
- content=str(e),
463
+ raise HTTPException(
464
+ detail=str(e),
553
465
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
554
466
  )
555
467
  else:
556
- return Response(
557
- content=str("Invalid request Content-Type"),
468
+ raise HTTPException(
469
+ detail="Invalid request Content-Type",
558
470
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
559
471
  )
560
472
  operation = cast(
@@ -567,19 +479,17 @@ async def post_datasets_upload(request: Request) -> Response:
567
479
  description=description,
568
480
  ),
569
481
  )
570
- if request.query_params.get("sync") == "true":
482
+ if sync:
571
483
  async with request.app.state.db() as session:
572
484
  dataset_id = (await operation(session)).dataset_id
573
- return JSONResponse(
574
- content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
575
- )
485
+ return UploadDatasetResponseBody(data=UploadDatasetData(dataset_id=str(dataset_id)))
576
486
  try:
577
487
  request.state.enqueue_operation(operation)
578
488
  except QueueFull:
579
489
  if isinstance(examples, Coroutine):
580
490
  examples.close()
581
- return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
582
- return Response()
491
+ raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
492
+ return None
583
493
 
584
494
 
585
495
  class FileContentType(Enum):
@@ -757,151 +667,125 @@ async def _parse_form_data(
757
667
  )
758
668
 
759
669
 
760
- async def get_dataset_csv(request: Request) -> Response:
761
- """
762
- summary: Download dataset examples as CSV text file
763
- operationId: getDatasetCsv
764
- tags:
765
- - datasets
766
- parameters:
767
- - in: path
768
- name: id
769
- required: true
770
- schema:
771
- type: string
772
- description: Dataset ID
773
- - in: query
774
- name: version_id
775
- schema:
776
- type: string
777
- description: Dataset version ID. If omitted, returns the latest version.
778
- responses:
779
- 200:
780
- description: Success
781
- content:
782
- text/csv:
783
- schema:
784
- type: string
785
- contentMediaType: text/csv
786
- contentEncoding: gzip
787
- 403:
788
- description: Forbidden
789
- 404:
790
- description: Dataset does not exist.
791
- 422:
792
- description: Dataset ID or version ID is invalid.
793
- """
670
+ # including the dataset examples router here ensures the dataset example routes
671
+ # are included in a natural order in the openapi schema and the swagger ui
672
+ #
673
+ # todo: move the dataset examples routes here and remove the dataset_examples
674
+ # sub-module
675
+ router.include_router(dataset_examples_router)
676
+
677
+
678
+ @router.get(
679
+ "/datasets/{id}/csv",
680
+ operation_id="getDatasetCsv",
681
+ summary="Download dataset examples as CSV file",
682
+ response_class=StreamingResponse,
683
+ status_code=HTTP_200_OK,
684
+ responses={
685
+ **add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
686
+ **add_text_csv_content_to_responses(HTTP_200_OK),
687
+ },
688
+ )
689
+ async def get_dataset_csv(
690
+ request: Request,
691
+ response: Response,
692
+ id: str = Path(description="The ID of the dataset"),
693
+ version_id: Optional[str] = Query(
694
+ default=None,
695
+ description=(
696
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
697
+ ),
698
+ ),
699
+ ) -> Response:
794
700
  try:
795
- dataset_name, examples = await _get_db_examples(request)
701
+ async with request.app.state.db() as session:
702
+ dataset_name, examples = await _get_db_examples(
703
+ session=session, id=id, version_id=version_id
704
+ )
796
705
  except ValueError as e:
797
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
706
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
798
707
  content = await run_in_threadpool(_get_content_csv, examples)
799
708
  return Response(
800
709
  content=content,
801
710
  headers={
802
711
  "content-disposition": f'attachment; filename="{dataset_name}.csv"',
803
712
  "content-type": "text/csv",
804
- "content-encoding": "gzip",
805
713
  },
806
714
  )
807
715
 
808
716
 
809
- async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
810
- """
811
- summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
812
- operationId: getDatasetJSONLOpenAIFineTuning
813
- tags:
814
- - datasets
815
- parameters:
816
- - in: path
817
- name: id
818
- required: true
819
- schema:
820
- type: string
821
- description: Dataset ID
822
- - in: query
823
- name: version_id
824
- schema:
825
- type: string
826
- description: Dataset version ID. If omitted, returns the latest version.
827
- responses:
828
- 200:
829
- description: Success
830
- content:
831
- text/plain:
832
- schema:
833
- type: string
834
- contentMediaType: text/plain
835
- contentEncoding: gzip
836
- 403:
837
- description: Forbidden
838
- 404:
839
- description: Dataset does not exist.
840
- 422:
841
- description: Dataset ID or version ID is invalid.
842
- """
717
+ @router.get(
718
+ "/datasets/{id}/jsonl/openai_ft",
719
+ operation_id="getDatasetJSONLOpenAIFineTuning",
720
+ summary="Download dataset examples as OpenAI fine-tuning JSONL file",
721
+ response_class=PlainTextResponse,
722
+ responses=add_errors_to_responses(
723
+ [
724
+ {
725
+ "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
726
+ "description": "Invalid dataset or version ID",
727
+ }
728
+ ]
729
+ ),
730
+ )
731
+ async def get_dataset_jsonl_openai_ft(
732
+ request: Request,
733
+ response: Response,
734
+ id: str = Path(description="The ID of the dataset"),
735
+ version_id: Optional[str] = Query(
736
+ default=None,
737
+ description=(
738
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
739
+ ),
740
+ ),
741
+ ) -> bytes:
843
742
  try:
844
- dataset_name, examples = await _get_db_examples(request)
743
+ async with request.app.state.db() as session:
744
+ dataset_name, examples = await _get_db_examples(
745
+ session=session, id=id, version_id=version_id
746
+ )
845
747
  except ValueError as e:
846
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
748
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
847
749
  content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
848
- return Response(
849
- content=content,
850
- headers={
851
- "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
852
- "content-type": "text/plain",
853
- "content-encoding": "gzip",
854
- },
855
- )
750
+ response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
751
+ return content
856
752
 
857
753
 
858
- async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
859
- """
860
- summary: Download dataset examples as OpenAI Evals JSONL file
861
- operationId: getDatasetJSONLOpenAIEvals
862
- tags:
863
- - datasets
864
- parameters:
865
- - in: path
866
- name: id
867
- required: true
868
- schema:
869
- type: string
870
- description: Dataset ID
871
- - in: query
872
- name: version_id
873
- schema:
874
- type: string
875
- description: Dataset version ID. If omitted, returns the latest version.
876
- responses:
877
- 200:
878
- description: Success
879
- content:
880
- text/plain:
881
- schema:
882
- type: string
883
- contentMediaType: text/plain
884
- contentEncoding: gzip
885
- 403:
886
- description: Forbidden
887
- 404:
888
- description: Dataset does not exist.
889
- 422:
890
- description: Dataset ID or version ID is invalid.
891
- """
754
+ @router.get(
755
+ "/datasets/{id}/jsonl/openai_evals",
756
+ operation_id="getDatasetJSONLOpenAIEvals",
757
+ summary="Download dataset examples as OpenAI evals JSONL file",
758
+ response_class=PlainTextResponse,
759
+ responses=add_errors_to_responses(
760
+ [
761
+ {
762
+ "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
763
+ "description": "Invalid dataset or version ID",
764
+ }
765
+ ]
766
+ ),
767
+ )
768
+ async def get_dataset_jsonl_openai_evals(
769
+ request: Request,
770
+ response: Response,
771
+ id: str = Path(description="The ID of the dataset"),
772
+ version_id: Optional[str] = Query(
773
+ default=None,
774
+ description=(
775
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
776
+ ),
777
+ ),
778
+ ) -> bytes:
892
779
  try:
893
- dataset_name, examples = await _get_db_examples(request)
780
+ async with request.app.state.db() as session:
781
+ dataset_name, examples = await _get_db_examples(
782
+ session=session, id=id, version_id=version_id
783
+ )
894
784
  except ValueError as e:
895
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
785
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
896
786
  content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
897
- return Response(
898
- content=content,
899
- headers={
900
- "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
901
- "content-type": "text/plain",
902
- "content-encoding": "gzip",
903
- },
904
- )
787
+ response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
788
+ return content
905
789
 
906
790
 
907
791
  def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -917,7 +801,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
917
801
  }
918
802
  for ex in examples
919
803
  ]
920
- return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
804
+ return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
921
805
 
922
806
 
923
807
  def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -938,7 +822,7 @@ def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision])
938
822
  ).encode()
939
823
  )
940
824
  records.seek(0)
941
- return gzip.compress(records.read())
825
+ return records.read()
942
826
 
943
827
 
944
828
  def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -965,18 +849,17 @@ def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision
965
849
  ).encode()
966
850
  )
967
851
  records.seek(0)
968
- return gzip.compress(records.read())
852
+ return records.read()
969
853
 
970
854
 
971
- async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
972
- if not (id_ := request.path_params.get("id")):
973
- raise ValueError("Missing Dataset ID")
974
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
855
+ async def _get_db_examples(
856
+ *, session: Any, id: str, version_id: Optional[str]
857
+ ) -> Tuple[str, List[models.DatasetExampleRevision]]:
858
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
975
859
  dataset_version_id: Optional[int] = None
976
- if version_id := request.query_params.get("version_id"):
860
+ if version_id:
977
861
  dataset_version_id = from_global_id_with_expected_type(
978
- GlobalID.from_id(version_id),
979
- DatasetVersion.__name__,
862
+ GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
980
863
  )
981
864
  latest_version = (
982
865
  select(
@@ -1009,13 +892,12 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
1009
892
  .where(models.DatasetExampleRevision.revision_kind != "DELETE")
1010
893
  .order_by(models.DatasetExampleRevision.dataset_example_id)
1011
894
  )
1012
- async with request.app.state.db() as session:
1013
- dataset_name: Optional[str] = await session.scalar(
1014
- select(models.Dataset.name).where(models.Dataset.id == dataset_id)
1015
- )
1016
- if not dataset_name:
1017
- raise ValueError("Dataset does not exist.")
1018
- examples = [r async for r in await session.stream_scalars(stmt)]
895
+ dataset_name: Optional[str] = await session.scalar(
896
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
897
+ )
898
+ if not dataset_name:
899
+ raise ValueError("Dataset does not exist.")
900
+ examples = [r async for r in await session.stream_scalars(stmt)]
1019
901
  return dataset_name, examples
1020
902
 
1021
903