arize-phoenix 4.14.1__py3-none-any.whl → 4.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (85) hide show
  1. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/METADATA +5 -3
  2. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/RECORD +81 -71
  3. phoenix/db/bulk_inserter.py +131 -5
  4. phoenix/db/engines.py +2 -1
  5. phoenix/db/helpers.py +23 -1
  6. phoenix/db/insertion/constants.py +2 -0
  7. phoenix/db/insertion/document_annotation.py +157 -0
  8. phoenix/db/insertion/helpers.py +13 -0
  9. phoenix/db/insertion/span_annotation.py +144 -0
  10. phoenix/db/insertion/trace_annotation.py +144 -0
  11. phoenix/db/insertion/types.py +261 -0
  12. phoenix/experiments/functions.py +3 -2
  13. phoenix/experiments/types.py +3 -3
  14. phoenix/server/api/context.py +7 -9
  15. phoenix/server/api/dataloaders/__init__.py +2 -0
  16. phoenix/server/api/dataloaders/average_experiment_run_latency.py +3 -3
  17. phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
  18. phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
  19. phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
  20. phoenix/server/api/dataloaders/document_evaluations.py +2 -4
  21. phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
  22. phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
  23. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
  24. phoenix/server/api/dataloaders/experiment_error_rates.py +2 -4
  25. phoenix/server/api/dataloaders/experiment_run_counts.py +2 -4
  26. phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
  27. phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
  28. phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
  29. phoenix/server/api/dataloaders/project_by_name.py +3 -3
  30. phoenix/server/api/dataloaders/record_counts.py +2 -4
  31. phoenix/server/api/dataloaders/span_annotations.py +2 -4
  32. phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
  33. phoenix/server/api/dataloaders/span_descendants.py +2 -4
  34. phoenix/server/api/dataloaders/span_evaluations.py +2 -4
  35. phoenix/server/api/dataloaders/span_projects.py +3 -3
  36. phoenix/server/api/dataloaders/token_counts.py +2 -4
  37. phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
  38. phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
  39. phoenix/server/api/input_types/SpanAnnotationSort.py +17 -0
  40. phoenix/server/api/input_types/TraceAnnotationSort.py +17 -0
  41. phoenix/server/api/mutations/span_annotations_mutations.py +8 -3
  42. phoenix/server/api/mutations/trace_annotations_mutations.py +8 -3
  43. phoenix/server/api/openapi/main.py +18 -2
  44. phoenix/server/api/openapi/schema.py +12 -12
  45. phoenix/server/api/routers/v1/__init__.py +36 -83
  46. phoenix/server/api/routers/v1/datasets.py +515 -509
  47. phoenix/server/api/routers/v1/evaluations.py +164 -73
  48. phoenix/server/api/routers/v1/experiment_evaluations.py +68 -91
  49. phoenix/server/api/routers/v1/experiment_runs.py +98 -155
  50. phoenix/server/api/routers/v1/experiments.py +132 -181
  51. phoenix/server/api/routers/v1/pydantic_compat.py +78 -0
  52. phoenix/server/api/routers/v1/spans.py +164 -203
  53. phoenix/server/api/routers/v1/traces.py +134 -159
  54. phoenix/server/api/routers/v1/utils.py +95 -0
  55. phoenix/server/api/types/Span.py +27 -3
  56. phoenix/server/api/types/Trace.py +21 -4
  57. phoenix/server/api/utils.py +4 -4
  58. phoenix/server/app.py +172 -192
  59. phoenix/server/grpc_server.py +2 -2
  60. phoenix/server/main.py +5 -9
  61. phoenix/server/static/.vite/manifest.json +31 -31
  62. phoenix/server/static/assets/components-Ci5kMOk5.js +1175 -0
  63. phoenix/server/static/assets/{index-CQgXRwU0.js → index-BQG5WVX7.js} +2 -2
  64. phoenix/server/static/assets/{pages-hdjlFZhO.js → pages-BrevprVW.js} +451 -275
  65. phoenix/server/static/assets/{vendor-DPvSDRn3.js → vendor-CP0b0YG0.js} +2 -2
  66. phoenix/server/static/assets/{vendor-arizeai-CkvPT67c.js → vendor-arizeai-DTbiPGp6.js} +27 -27
  67. phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
  68. phoenix/server/static/assets/{vendor-recharts-5jlNaZuF.js → vendor-recharts-A0DA1O99.js} +1 -1
  69. phoenix/server/thread_server.py +2 -2
  70. phoenix/server/types.py +18 -0
  71. phoenix/session/client.py +5 -3
  72. phoenix/session/session.py +2 -2
  73. phoenix/trace/dsl/filter.py +2 -6
  74. phoenix/trace/fixtures.py +17 -23
  75. phoenix/trace/utils.py +23 -0
  76. phoenix/utilities/client.py +116 -0
  77. phoenix/utilities/project.py +1 -1
  78. phoenix/version.py +1 -1
  79. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  80. phoenix/server/openapi/docs.py +0 -221
  81. phoenix/server/static/assets/components-DeS0YEmv.js +0 -1142
  82. phoenix/server/static/assets/vendor-codemirror-Cqwpwlua.js +0 -12
  83. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/WHEEL +0 -0
  84. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/IP_NOTICE +0 -0
  85. {arize_phoenix-4.14.1.dist-info → arize_phoenix-4.16.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ import logging
6
6
  import zlib
7
7
  from asyncio import QueueFull
8
8
  from collections import Counter
9
+ from datetime import datetime
9
10
  from enum import Enum
10
11
  from functools import partial
11
12
  from typing import (
@@ -13,6 +14,7 @@ from typing import (
13
14
  Awaitable,
14
15
  Callable,
15
16
  Coroutine,
17
+ Dict,
16
18
  FrozenSet,
17
19
  Iterator,
18
20
  List,
@@ -26,14 +28,16 @@ from typing import (
26
28
 
27
29
  import pandas as pd
28
30
  import pyarrow as pa
31
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
32
+ from fastapi.responses import PlainTextResponse, StreamingResponse
29
33
  from sqlalchemy import and_, delete, func, select
30
34
  from sqlalchemy.ext.asyncio import AsyncSession
31
- from starlette.background import BackgroundTasks
32
35
  from starlette.concurrency import run_in_threadpool
33
36
  from starlette.datastructures import FormData, UploadFile
34
37
  from starlette.requests import Request
35
- from starlette.responses import JSONResponse, Response
38
+ from starlette.responses import Response
36
39
  from starlette.status import (
40
+ HTTP_200_OK,
37
41
  HTTP_204_NO_CONTENT,
38
42
  HTTP_404_NOT_FOUND,
39
43
  HTTP_409_CONFLICT,
@@ -51,79 +55,59 @@ from phoenix.db.insertion.dataset import (
51
55
  ExampleContent,
52
56
  add_dataset_examples,
53
57
  )
54
- from phoenix.server.api.types.Dataset import Dataset
55
- from phoenix.server.api.types.DatasetExample import DatasetExample
56
- from phoenix.server.api.types.DatasetVersion import DatasetVersion
58
+ from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
59
+ from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
60
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
57
61
  from phoenix.server.api.types.node import from_global_id_with_expected_type
58
62
  from phoenix.server.api.utils import delete_projects, delete_traces
59
63
 
64
+ from .pydantic_compat import V1RoutesBaseModel
65
+ from .utils import (
66
+ PaginatedResponseBody,
67
+ ResponseBody,
68
+ add_errors_to_responses,
69
+ add_text_csv_content_to_responses,
70
+ )
71
+
60
72
  logger = logging.getLogger(__name__)
61
73
 
62
- NODE_NAME = "Dataset"
63
-
64
-
65
- async def list_datasets(request: Request) -> Response:
66
- """
67
- summary: List datasets with cursor-based pagination
68
- operationId: listDatasets
69
- tags:
70
- - datasets
71
- parameters:
72
- - in: query
73
- name: cursor
74
- required: false
75
- schema:
76
- type: string
77
- description: Cursor for pagination
78
- - in: query
79
- name: limit
80
- required: false
81
- schema:
82
- type: integer
83
- default: 10
84
- - in: query
85
- name: name
86
- required: false
87
- schema:
88
- type: string
89
- description: match by dataset name
90
- responses:
91
- 200:
92
- description: A paginated list of datasets
93
- content:
94
- application/json:
95
- schema:
96
- type: object
97
- properties:
98
- next_cursor:
99
- type: string
100
- data:
101
- type: array
102
- items:
103
- type: object
104
- properties:
105
- id:
106
- type: string
107
- name:
108
- type: string
109
- description:
110
- type: string
111
- metadata:
112
- type: object
113
- created_at:
114
- type: string
115
- format: date-time
116
- updated_at:
117
- type: string
118
- format: date-time
119
- 403:
120
- description: Forbidden
121
- 404:
122
- description: No datasets found
123
- """
124
- name = request.query_params.get("name")
125
- cursor = request.query_params.get("cursor")
126
- limit = int(request.query_params.get("limit", 10))
74
+ DATASET_NODE_NAME = DatasetNodeType.__name__
75
+ DATASET_VERSION_NODE_NAME = DatasetVersionNodeType.__name__
76
+
77
+
78
+ router = APIRouter(tags=["datasets"])
79
+
80
+
81
+ class Dataset(V1RoutesBaseModel):
82
+ id: str
83
+ name: str
84
+ description: Optional[str]
85
+ metadata: Dict[str, Any]
86
+ created_at: datetime
87
+ updated_at: datetime
88
+
89
+
90
+ class ListDatasetsResponseBody(PaginatedResponseBody[Dataset]):
91
+ pass
92
+
93
+
94
+ @router.get(
95
+ "/datasets",
96
+ operation_id="listDatasets",
97
+ summary="List datasets",
98
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
99
+ )
100
+ async def list_datasets(
101
+ request: Request,
102
+ cursor: Optional[str] = Query(
103
+ default=None,
104
+ description="Cursor for pagination",
105
+ ),
106
+ name: Optional[str] = Query(default=None, description="An optional dataset name to filter by"),
107
+ limit: int = Query(
108
+ default=10, description="The max number of datasets to return at a time.", gt=0
109
+ ),
110
+ ) -> ListDatasetsResponseBody:
127
111
  async with request.app.state.db() as session:
128
112
  query = select(models.Dataset).order_by(models.Dataset.id.desc())
129
113
 
@@ -132,79 +116,68 @@ async def list_datasets(request: Request) -> Response:
132
116
  cursor_id = GlobalID.from_id(cursor).node_id
133
117
  query = query.filter(models.Dataset.id <= int(cursor_id))
134
118
  except ValueError:
135
- return Response(
136
- content=f"Invalid cursor format: {cursor}",
119
+ raise HTTPException(
120
+ detail=f"Invalid cursor format: {cursor}",
137
121
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
138
122
  )
139
123
  if name:
140
- query = query.filter(models.Dataset.name.is_(name))
124
+ query = query.filter(models.Dataset.name == name)
141
125
 
142
126
  query = query.limit(limit + 1)
143
127
  result = await session.execute(query)
144
128
  datasets = result.scalars().all()
145
129
 
146
130
  if not datasets:
147
- return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
131
+ return ListDatasetsResponseBody(next_cursor=None, data=[])
148
132
 
149
133
  next_cursor = None
150
134
  if len(datasets) == limit + 1:
151
- next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
135
+ next_cursor = str(GlobalID(DATASET_NODE_NAME, str(datasets[-1].id)))
152
136
  datasets = datasets[:-1]
153
137
 
154
138
  data = []
155
139
  for dataset in datasets:
156
140
  data.append(
157
- {
158
- "id": str(GlobalID(NODE_NAME, str(dataset.id))),
159
- "name": dataset.name,
160
- "description": dataset.description,
161
- "metadata": dataset.metadata_,
162
- "created_at": dataset.created_at.isoformat(),
163
- "updated_at": dataset.updated_at.isoformat(),
164
- }
141
+ Dataset(
142
+ id=str(GlobalID(DATASET_NODE_NAME, str(dataset.id))),
143
+ name=dataset.name,
144
+ description=dataset.description,
145
+ metadata=dataset.metadata_,
146
+ created_at=dataset.created_at,
147
+ updated_at=dataset.updated_at,
148
+ )
165
149
  )
166
150
 
167
- return JSONResponse(content={"next_cursor": next_cursor, "data": data})
168
-
169
-
170
- async def delete_dataset_by_id(request: Request) -> Response:
171
- """
172
- summary: Delete dataset by ID
173
- operationId: deleteDatasetById
174
- tags:
175
- - datasets
176
- parameters:
177
- - in: path
178
- name: id
179
- required: true
180
- schema:
181
- type: string
182
- responses:
183
- 204:
184
- description: Success
185
- 403:
186
- description: Forbidden
187
- 404:
188
- description: Dataset not found
189
- 422:
190
- description: Dataset ID is invalid
191
- """
192
- if id_ := request.path_params.get("id"):
151
+ return ListDatasetsResponseBody(next_cursor=next_cursor, data=data)
152
+
153
+
154
+ @router.delete(
155
+ "/datasets/{id}",
156
+ operation_id="deleteDatasetById",
157
+ summary="Delete dataset by ID",
158
+ status_code=HTTP_204_NO_CONTENT,
159
+ responses=add_errors_to_responses(
160
+ [
161
+ {"status_code": HTTP_404_NOT_FOUND, "description": "Dataset not found"},
162
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid dataset ID"},
163
+ ]
164
+ ),
165
+ )
166
+ async def delete_dataset(
167
+ request: Request, id: str = Path(description="The ID of the dataset to delete.")
168
+ ) -> None:
169
+ if id:
193
170
  try:
194
171
  dataset_id = from_global_id_with_expected_type(
195
- GlobalID.from_id(id_),
196
- Dataset.__name__,
172
+ GlobalID.from_id(id),
173
+ DATASET_NODE_NAME,
197
174
  )
198
175
  except ValueError:
199
- return Response(
200
- content=f"Invalid Dataset ID: {id_}",
201
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
176
+ raise HTTPException(
177
+ detail=f"Invalid Dataset ID: {id}", status_code=HTTP_422_UNPROCESSABLE_ENTITY
202
178
  )
203
179
  else:
204
- return Response(
205
- content="Missing Dataset ID",
206
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
207
- )
180
+ raise HTTPException(detail="Missing Dataset ID", status_code=HTTP_422_UNPROCESSABLE_ENTITY)
208
181
  project_names_stmt = get_project_names_for_datasets(dataset_id)
209
182
  eval_trace_ids_stmt = get_eval_trace_ids_for_datasets(dataset_id)
210
183
  stmt = (
@@ -214,59 +187,34 @@ async def delete_dataset_by_id(request: Request) -> Response:
214
187
  project_names = await session.scalars(project_names_stmt)
215
188
  eval_trace_ids = await session.scalars(eval_trace_ids_stmt)
216
189
  if (await session.scalar(stmt)) is None:
217
- return Response(content="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
190
+ raise HTTPException(detail="Dataset does not exist", status_code=HTTP_404_NOT_FOUND)
218
191
  tasks = BackgroundTasks()
219
192
  tasks.add_task(delete_projects, request.app.state.db, *project_names)
220
193
  tasks.add_task(delete_traces, request.app.state.db, *eval_trace_ids)
221
- return Response(status_code=HTTP_204_NO_CONTENT, background=tasks)
222
-
223
-
224
- async def get_dataset_by_id(request: Request) -> Response:
225
- """
226
- summary: Get dataset by ID
227
- operationId: getDatasetById
228
- tags:
229
- - datasets
230
- parameters:
231
- - in: path
232
- name: id
233
- required: true
234
- schema:
235
- type: string
236
- responses:
237
- 200:
238
- description: Success
239
- content:
240
- application/json:
241
- schema:
242
- type: object
243
- properties:
244
- id:
245
- type: string
246
- name:
247
- type: string
248
- description:
249
- type: string
250
- metadata:
251
- type: object
252
- created_at:
253
- type: string
254
- format: date-time
255
- updated_at:
256
- type: string
257
- format: date-time
258
- example_count:
259
- type: integer
260
- 403:
261
- description: Forbidden
262
- 404:
263
- description: Dataset not found
264
- """
265
- dataset_id = GlobalID.from_id(request.path_params["id"])
266
-
267
- if (type_name := dataset_id.type_name) != NODE_NAME:
268
- return Response(
269
- content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
194
+
195
+
196
+ class DatasetWithExampleCount(Dataset):
197
+ example_count: int
198
+
199
+
200
+ class GetDatasetResponseBody(ResponseBody[DatasetWithExampleCount]):
201
+ pass
202
+
203
+
204
+ @router.get(
205
+ "/datasets/{id}",
206
+ operation_id="getDataset",
207
+ summary="Get dataset by ID",
208
+ responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
209
+ )
210
+ async def get_dataset(
211
+ request: Request, id: str = Path(description="The ID of the dataset")
212
+ ) -> GetDatasetResponseBody:
213
+ dataset_id = GlobalID.from_id(id)
214
+
215
+ if (type_name := dataset_id.type_name) != DATASET_NODE_NAME:
216
+ raise HTTPException(
217
+ detail=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
270
218
  )
271
219
  async with request.app.state.db() as session:
272
220
  result = await session.execute(
@@ -278,97 +226,64 @@ async def get_dataset_by_id(request: Request) -> Response:
278
226
  dataset = dataset_query[0] if dataset_query else None
279
227
  example_count = dataset_query[1] if dataset_query else 0
280
228
  if dataset is None:
281
- return Response(
282
- content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
229
+ raise HTTPException(
230
+ detail=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
283
231
  )
284
232
 
285
- output_dict = {
286
- "id": str(dataset_id),
287
- "name": dataset.name,
288
- "description": dataset.description,
289
- "metadata": dataset.metadata_,
290
- "created_at": dataset.created_at.isoformat(),
291
- "updated_at": dataset.updated_at.isoformat(),
292
- "example_count": example_count,
293
- }
294
- return JSONResponse(content={"data": output_dict})
295
-
296
-
297
- async def get_dataset_versions(request: Request) -> Response:
298
- """
299
- summary: Get dataset versions (sorted from latest to oldest)
300
- operationId: getDatasetVersionsByDatasetId
301
- tags:
302
- - datasets
303
- parameters:
304
- - in: path
305
- name: id
306
- required: true
307
- description: Dataset ID
308
- schema:
309
- type: string
310
- - in: query
311
- name: cursor
312
- description: Cursor for pagination.
313
- schema:
314
- type: string
315
- - in: query
316
- name: limit
317
- description: Maximum number versions to return.
318
- schema:
319
- type: integer
320
- default: 10
321
- responses:
322
- 200:
323
- description: Success
324
- content:
325
- application/json:
326
- schema:
327
- type: object
328
- properties:
329
- next_cursor:
330
- type: string
331
- data:
332
- type: array
333
- items:
334
- type: object
335
- properties:
336
- version_id:
337
- type: string
338
- description:
339
- type: string
340
- metadata:
341
- type: object
342
- created_at:
343
- type: string
344
- format: date-time
345
- 403:
346
- description: Forbidden
347
- 422:
348
- description: Dataset ID, cursor or limit is invalid.
349
- """
350
- if id_ := request.path_params.get("id"):
233
+ dataset = DatasetWithExampleCount(
234
+ id=str(dataset_id),
235
+ name=dataset.name,
236
+ description=dataset.description,
237
+ metadata=dataset.metadata_,
238
+ created_at=dataset.created_at,
239
+ updated_at=dataset.updated_at,
240
+ example_count=example_count,
241
+ )
242
+ return GetDatasetResponseBody(data=dataset)
243
+
244
+
245
+ class DatasetVersion(V1RoutesBaseModel):
246
+ version_id: str
247
+ description: Optional[str]
248
+ metadata: Dict[str, Any]
249
+ created_at: datetime
250
+
251
+
252
+ class ListDatasetVersionsResponseBody(PaginatedResponseBody[DatasetVersion]):
253
+ pass
254
+
255
+
256
+ @router.get(
257
+ "/datasets/{id}/versions",
258
+ operation_id="listDatasetVersionsByDatasetId",
259
+ summary="List dataset versions",
260
+ responses=add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
261
+ )
262
+ async def list_dataset_versions(
263
+ request: Request,
264
+ id: str = Path(description="The ID of the dataset"),
265
+ cursor: Optional[str] = Query(
266
+ default=None,
267
+ description="Cursor for pagination",
268
+ ),
269
+ limit: int = Query(
270
+ default=10, description="The max number of dataset versions to return at a time", gt=0
271
+ ),
272
+ ) -> ListDatasetVersionsResponseBody:
273
+ if id:
351
274
  try:
352
275
  dataset_id = from_global_id_with_expected_type(
353
- GlobalID.from_id(id_),
354
- Dataset.__name__,
276
+ GlobalID.from_id(id),
277
+ DATASET_NODE_NAME,
355
278
  )
356
279
  except ValueError:
357
- return Response(
358
- content=f"Invalid Dataset ID: {id_}",
280
+ raise HTTPException(
281
+ detail=f"Invalid Dataset ID: {id}",
359
282
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
360
283
  )
361
284
  else:
362
- return Response(
363
- content="Missing Dataset ID",
364
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
365
- )
366
- try:
367
- limit = int(request.query_params.get("limit", 10))
368
- assert limit > 0
369
- except (ValueError, AssertionError):
370
- return Response(
371
- content="Invalid limit parameter",
285
+ raise HTTPException(
286
+ detail="Missing Dataset ID",
372
287
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
373
288
  )
374
289
  stmt = (
@@ -377,15 +292,14 @@ async def get_dataset_versions(request: Request) -> Response:
377
292
  .order_by(models.DatasetVersion.id.desc())
378
293
  .limit(limit + 1)
379
294
  )
380
- if cursor := request.query_params.get("cursor"):
295
+ if cursor:
381
296
  try:
382
297
  dataset_version_id = from_global_id_with_expected_type(
383
- GlobalID.from_id(cursor),
384
- DatasetVersion.__name__,
298
+ GlobalID.from_id(cursor), DATASET_VERSION_NODE_NAME
385
299
  )
386
300
  except ValueError:
387
- return Response(
388
- content=f"Invalid cursor: {cursor}",
301
+ raise HTTPException(
302
+ detail=f"Invalid cursor: {cursor}",
389
303
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
390
304
  )
391
305
  max_dataset_version_id = (
@@ -396,102 +310,99 @@ async def get_dataset_versions(request: Request) -> Response:
396
310
  stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
397
311
  async with request.app.state.db() as session:
398
312
  data = [
399
- {
400
- "version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
401
- "description": version.description,
402
- "metadata": version.metadata_,
403
- "created_at": version.created_at.isoformat(),
404
- }
313
+ DatasetVersion(
314
+ version_id=str(GlobalID(DATASET_VERSION_NODE_NAME, str(version.id))),
315
+ description=version.description,
316
+ metadata=version.metadata_,
317
+ created_at=version.created_at,
318
+ )
405
319
  async for version in await session.stream_scalars(stmt)
406
320
  ]
407
- next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
408
- return JSONResponse(content={"next_cursor": next_cursor, "data": data})
409
-
410
-
411
- async def post_datasets_upload(request: Request) -> Response:
412
- """
413
- summary: Upload dataset as either JSON or file (CSV or PyArrow)
414
- operationId: uploadDataset
415
- tags:
416
- - datasets
417
- parameters:
418
- - in: query
419
- name: sync
420
- description: If true, fulfill request synchronously and return JSON containing dataset_id
421
- schema:
422
- type: boolean
423
- requestBody:
424
- content:
425
- application/json:
426
- schema:
427
- type: object
428
- required:
429
- - name
430
- - inputs
431
- properties:
432
- action:
433
- type: string
434
- enum: [create, append]
435
- name:
436
- type: string
437
- description:
438
- type: string
439
- inputs:
440
- type: array
441
- items:
442
- type: object
443
- outputs:
444
- type: array
445
- items:
446
- type: object
447
- metadata:
448
- type: array
449
- items:
450
- type: object
451
- multipart/form-data:
452
- schema:
453
- type: object
454
- required:
455
- - name
456
- - input_keys[]
457
- - output_keys[]
458
- - file
459
- properties:
460
- action:
461
- type: string
462
- enum: [create, append]
463
- name:
464
- type: string
465
- description:
466
- type: string
467
- input_keys[]:
468
- type: array
469
- items:
470
- type: string
471
- uniqueItems: true
472
- output_keys[]:
473
- type: array
474
- items:
475
- type: string
476
- uniqueItems: true
477
- metadata_keys[]:
478
- type: array
479
- items:
480
- type: string
481
- uniqueItems: true
482
- file:
483
- type: string
484
- format: binary
485
- responses:
486
- 200:
487
- description: Success
488
- 403:
489
- description: Forbidden
490
- 409:
491
- description: Dataset of the same name already exists
492
- 422:
493
- description: Request body is invalid
494
- """
321
+ next_cursor = data.pop().version_id if len(data) == limit + 1 else None
322
+ return ListDatasetVersionsResponseBody(data=data, next_cursor=next_cursor)
323
+
324
+
325
+ class UploadDatasetData(V1RoutesBaseModel):
326
+ dataset_id: str
327
+
328
+
329
+ class UploadDatasetResponseBody(ResponseBody[UploadDatasetData]):
330
+ pass
331
+
332
+
333
+ @router.post(
334
+ "/datasets/upload",
335
+ operation_id="uploadDataset",
336
+ summary="Upload dataset from JSON, CSV, or PyArrow",
337
+ responses=add_errors_to_responses(
338
+ [
339
+ {
340
+ "status_code": HTTP_409_CONFLICT,
341
+ "description": "Dataset of the same name already exists",
342
+ },
343
+ {"status_code": HTTP_422_UNPROCESSABLE_ENTITY, "description": "Invalid request body"},
344
+ ]
345
+ ),
346
+ # FastAPI cannot generate the request body portion of the OpenAPI schema for
347
+ # routes that accept multiple request content types, so we have to provide
348
+ # this part of the schema manually. For context, see
349
+ # https://github.com/tiangolo/fastapi/discussions/7786 and
350
+ # https://github.com/tiangolo/fastapi/issues/990
351
+ openapi_extra={
352
+ "requestBody": {
353
+ "content": {
354
+ "application/json": {
355
+ "schema": {
356
+ "type": "object",
357
+ "required": ["name", "inputs"],
358
+ "properties": {
359
+ "action": {"type": "string", "enum": ["create", "append"]},
360
+ "name": {"type": "string"},
361
+ "description": {"type": "string"},
362
+ "inputs": {"type": "array", "items": {"type": "object"}},
363
+ "outputs": {"type": "array", "items": {"type": "object"}},
364
+ "metadata": {"type": "array", "items": {"type": "object"}},
365
+ },
366
+ }
367
+ },
368
+ "multipart/form-data": {
369
+ "schema": {
370
+ "type": "object",
371
+ "required": ["name", "input_keys[]", "output_keys[]", "file"],
372
+ "properties": {
373
+ "action": {"type": "string", "enum": ["create", "append"]},
374
+ "name": {"type": "string"},
375
+ "description": {"type": "string"},
376
+ "input_keys[]": {
377
+ "type": "array",
378
+ "items": {"type": "string"},
379
+ "uniqueItems": True,
380
+ },
381
+ "output_keys[]": {
382
+ "type": "array",
383
+ "items": {"type": "string"},
384
+ "uniqueItems": True,
385
+ },
386
+ "metadata_keys[]": {
387
+ "type": "array",
388
+ "items": {"type": "string"},
389
+ "uniqueItems": True,
390
+ },
391
+ "file": {"type": "string", "format": "binary"},
392
+ },
393
+ }
394
+ },
395
+ }
396
+ },
397
+ },
398
+ )
399
+ async def upload_dataset(
400
+ request: Request,
401
+ sync: bool = Query(
402
+ default=False,
403
+ description="If true, fulfill request synchronously and return JSON containing dataset_id.",
404
+ ),
405
+ ) -> Optional[UploadDatasetResponseBody]:
495
406
  request_content_type = request.headers["content-type"]
496
407
  examples: Union[Examples, Awaitable[Examples]]
497
408
  if request_content_type.startswith("application/json"):
@@ -500,15 +411,15 @@ async def post_datasets_upload(request: Request) -> Response:
500
411
  _process_json, await request.json()
501
412
  )
502
413
  except ValueError as e:
503
- return Response(
504
- content=str(e),
414
+ raise HTTPException(
415
+ detail=str(e),
505
416
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
506
417
  )
507
418
  if action is DatasetAction.CREATE:
508
419
  async with request.app.state.db() as session:
509
420
  if await _check_table_exists(session, name):
510
- return Response(
511
- content=f"Dataset with the same name already exists: {name=}",
421
+ raise HTTPException(
422
+ detail=f"Dataset with the same name already exists: {name=}",
512
423
  status_code=HTTP_409_CONFLICT,
513
424
  )
514
425
  elif request_content_type.startswith("multipart/form-data"):
@@ -524,15 +435,15 @@ async def post_datasets_upload(request: Request) -> Response:
524
435
  file,
525
436
  ) = await _parse_form_data(form)
526
437
  except ValueError as e:
527
- return Response(
528
- content=str(e),
438
+ raise HTTPException(
439
+ detail=str(e),
529
440
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
530
441
  )
531
442
  if action is DatasetAction.CREATE:
532
443
  async with request.app.state.db() as session:
533
444
  if await _check_table_exists(session, name):
534
- return Response(
535
- content=f"Dataset with the same name already exists: {name=}",
445
+ raise HTTPException(
446
+ detail=f"Dataset with the same name already exists: {name=}",
536
447
  status_code=HTTP_409_CONFLICT,
537
448
  )
538
449
  content = await file.read()
@@ -548,13 +459,13 @@ async def post_datasets_upload(request: Request) -> Response:
548
459
  else:
549
460
  assert_never(file_content_type)
550
461
  except ValueError as e:
551
- return Response(
552
- content=str(e),
462
+ raise HTTPException(
463
+ detail=str(e),
553
464
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
554
465
  )
555
466
  else:
556
- return Response(
557
- content=str("Invalid request Content-Type"),
467
+ raise HTTPException(
468
+ detail="Invalid request Content-Type",
558
469
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
559
470
  )
560
471
  operation = cast(
@@ -567,19 +478,19 @@ async def post_datasets_upload(request: Request) -> Response:
567
478
  description=description,
568
479
  ),
569
480
  )
570
- if request.query_params.get("sync") == "true":
481
+ if sync:
571
482
  async with request.app.state.db() as session:
572
483
  dataset_id = (await operation(session)).dataset_id
573
- return JSONResponse(
574
- content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
484
+ return UploadDatasetResponseBody(
485
+ data=UploadDatasetData(dataset_id=str(GlobalID(Dataset.__name__, str(dataset_id))))
575
486
  )
576
487
  try:
577
488
  request.state.enqueue_operation(operation)
578
489
  except QueueFull:
579
490
  if isinstance(examples, Coroutine):
580
491
  examples.close()
581
- return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
582
- return Response()
492
+ raise HTTPException(detail="Too many requests.", status_code=HTTP_429_TOO_MANY_REQUESTS)
493
+ return None
583
494
 
584
495
 
585
496
  class FileContentType(Enum):
@@ -757,158 +668,255 @@ async def _parse_form_data(
757
668
  )
758
669
 
759
670
 
760
- async def get_dataset_csv(request: Request) -> Response:
761
- """
762
- summary: Download dataset examples as CSV text file
763
- operationId: getDatasetCsv
764
- tags:
765
- - datasets
766
- parameters:
767
- - in: path
768
- name: id
769
- required: true
770
- schema:
771
- type: string
772
- description: Dataset ID
773
- - in: query
774
- name: version_id
775
- schema:
776
- type: string
777
- description: Dataset version ID. If omitted, returns the latest version.
778
- responses:
779
- 200:
780
- description: Success
781
- content:
782
- text/csv:
783
- schema:
784
- type: string
785
- contentMediaType: text/csv
786
- contentEncoding: gzip
787
- 403:
788
- description: Forbidden
789
- 404:
790
- description: Dataset does not exist.
791
- 422:
792
- description: Dataset ID or version ID is invalid.
793
- """
671
+ class DatasetExample(V1RoutesBaseModel):
672
+ id: str
673
+ input: Dict[str, Any]
674
+ output: Dict[str, Any]
675
+ metadata: Dict[str, Any]
676
+ updated_at: datetime
677
+
678
+
679
+ class ListDatasetExamplesData(V1RoutesBaseModel):
680
+ dataset_id: str
681
+ version_id: str
682
+ examples: List[DatasetExample]
683
+
684
+
685
+ class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
686
+ pass
687
+
688
+
689
+ @router.get(
690
+ "/datasets/{id}/examples",
691
+ operation_id="getDatasetExamples",
692
+ summary="Get examples from a dataset",
693
+ responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
694
+ )
695
+ async def get_dataset_examples(
696
+ request: Request,
697
+ id: str = Path(description="The ID of the dataset"),
698
+ version_id: Optional[str] = Query(
699
+ default=None,
700
+ description=(
701
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
702
+ ),
703
+ ),
704
+ ) -> ListDatasetExamplesResponseBody:
705
+ dataset_gid = GlobalID.from_id(id)
706
+ version_gid = GlobalID.from_id(version_id) if version_id else None
707
+
708
+ if (dataset_type := dataset_gid.type_name) != "Dataset":
709
+ raise HTTPException(
710
+ detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
711
+ )
712
+
713
+ if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
714
+ raise HTTPException(
715
+ detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
716
+ )
717
+
718
+ async with request.app.state.db() as session:
719
+ if (
720
+ resolved_dataset_id := await session.scalar(
721
+ select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id))
722
+ )
723
+ ) is None:
724
+ raise HTTPException(
725
+ detail=f"No dataset with id {dataset_gid} can be found.",
726
+ status_code=HTTP_404_NOT_FOUND,
727
+ )
728
+
729
+ # Subquery to find the maximum created_at for each dataset_example_id
730
+ # timestamp tiebreaks are resolved by the largest id
731
+ partial_subquery = select(
732
+ func.max(models.DatasetExampleRevision.id).label("max_id"),
733
+ ).group_by(models.DatasetExampleRevision.dataset_example_id)
734
+
735
+ if version_gid:
736
+ if (
737
+ resolved_version_id := await session.scalar(
738
+ select(models.DatasetVersion.id).where(
739
+ and_(
740
+ models.DatasetVersion.dataset_id == resolved_dataset_id,
741
+ models.DatasetVersion.id == int(version_gid.node_id),
742
+ )
743
+ )
744
+ )
745
+ ) is None:
746
+ raise HTTPException(
747
+ detail=f"No dataset version with id {version_id} can be found.",
748
+ status_code=HTTP_404_NOT_FOUND,
749
+ )
750
+ # if a version_id is provided, filter the subquery to only include revisions from that
751
+ partial_subquery = partial_subquery.filter(
752
+ models.DatasetExampleRevision.dataset_version_id <= resolved_version_id
753
+ )
754
+ else:
755
+ if (
756
+ resolved_version_id := await session.scalar(
757
+ select(func.max(models.DatasetVersion.id)).where(
758
+ models.DatasetVersion.dataset_id == resolved_dataset_id
759
+ )
760
+ )
761
+ ) is None:
762
+ raise HTTPException(
763
+ detail="Dataset has no versions.",
764
+ status_code=HTTP_404_NOT_FOUND,
765
+ )
766
+
767
+ subquery = partial_subquery.subquery()
768
+ # Query for the most recent example revisions that are not deleted
769
+ query = (
770
+ select(models.DatasetExample, models.DatasetExampleRevision)
771
+ .join(
772
+ models.DatasetExampleRevision,
773
+ models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
774
+ )
775
+ .join(
776
+ subquery,
777
+ (subquery.c.max_id == models.DatasetExampleRevision.id),
778
+ )
779
+ .filter(models.DatasetExample.dataset_id == resolved_dataset_id)
780
+ .filter(models.DatasetExampleRevision.revision_kind != "DELETE")
781
+ .order_by(models.DatasetExample.id.asc())
782
+ )
783
+ examples = [
784
+ DatasetExample(
785
+ id=str(GlobalID("DatasetExample", str(example.id))),
786
+ input=revision.input,
787
+ output=revision.output,
788
+ metadata=revision.metadata_,
789
+ updated_at=revision.created_at,
790
+ )
791
+ async for example, revision in await session.stream(query)
792
+ ]
793
+ return ListDatasetExamplesResponseBody(
794
+ data=ListDatasetExamplesData(
795
+ dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
796
+ version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
797
+ examples=examples,
798
+ )
799
+ )
800
+
801
+
802
+ @router.get(
803
+ "/datasets/{id}/csv",
804
+ operation_id="getDatasetCsv",
805
+ summary="Download dataset examples as CSV file",
806
+ response_class=StreamingResponse,
807
+ status_code=HTTP_200_OK,
808
+ responses={
809
+ **add_errors_to_responses([HTTP_422_UNPROCESSABLE_ENTITY]),
810
+ **add_text_csv_content_to_responses(HTTP_200_OK),
811
+ },
812
+ )
813
+ async def get_dataset_csv(
814
+ request: Request,
815
+ response: Response,
816
+ id: str = Path(description="The ID of the dataset"),
817
+ version_id: Optional[str] = Query(
818
+ default=None,
819
+ description=(
820
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
821
+ ),
822
+ ),
823
+ ) -> Response:
794
824
  try:
795
- dataset_name, examples = await _get_db_examples(request)
825
+ async with request.app.state.db() as session:
826
+ dataset_name, examples = await _get_db_examples(
827
+ session=session, id=id, version_id=version_id
828
+ )
796
829
  except ValueError as e:
797
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
830
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
798
831
  content = await run_in_threadpool(_get_content_csv, examples)
799
832
  return Response(
800
833
  content=content,
801
834
  headers={
802
835
  "content-disposition": f'attachment; filename="{dataset_name}.csv"',
803
836
  "content-type": "text/csv",
804
- "content-encoding": "gzip",
805
837
  },
806
838
  )
807
839
 
808
840
 
809
- async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
810
- """
811
- summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
812
- operationId: getDatasetJSONLOpenAIFineTuning
813
- tags:
814
- - datasets
815
- parameters:
816
- - in: path
817
- name: id
818
- required: true
819
- schema:
820
- type: string
821
- description: Dataset ID
822
- - in: query
823
- name: version_id
824
- schema:
825
- type: string
826
- description: Dataset version ID. If omitted, returns the latest version.
827
- responses:
828
- 200:
829
- description: Success
830
- content:
831
- text/plain:
832
- schema:
833
- type: string
834
- contentMediaType: text/plain
835
- contentEncoding: gzip
836
- 403:
837
- description: Forbidden
838
- 404:
839
- description: Dataset does not exist.
840
- 422:
841
- description: Dataset ID or version ID is invalid.
842
- """
841
+ @router.get(
842
+ "/datasets/{id}/jsonl/openai_ft",
843
+ operation_id="getDatasetJSONLOpenAIFineTuning",
844
+ summary="Download dataset examples as OpenAI fine-tuning JSONL file",
845
+ response_class=PlainTextResponse,
846
+ responses=add_errors_to_responses(
847
+ [
848
+ {
849
+ "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
850
+ "description": "Invalid dataset or version ID",
851
+ }
852
+ ]
853
+ ),
854
+ )
855
+ async def get_dataset_jsonl_openai_ft(
856
+ request: Request,
857
+ response: Response,
858
+ id: str = Path(description="The ID of the dataset"),
859
+ version_id: Optional[str] = Query(
860
+ default=None,
861
+ description=(
862
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
863
+ ),
864
+ ),
865
+ ) -> bytes:
843
866
  try:
844
- dataset_name, examples = await _get_db_examples(request)
867
+ async with request.app.state.db() as session:
868
+ dataset_name, examples = await _get_db_examples(
869
+ session=session, id=id, version_id=version_id
870
+ )
845
871
  except ValueError as e:
846
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
872
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
847
873
  content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
848
- return Response(
849
- content=content,
850
- headers={
851
- "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
852
- "content-type": "text/plain",
853
- "content-encoding": "gzip",
854
- },
855
- )
874
+ response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
875
+ return content
856
876
 
857
877
 
858
- async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
859
- """
860
- summary: Download dataset examples as OpenAI Evals JSONL file
861
- operationId: getDatasetJSONLOpenAIEvals
862
- tags:
863
- - datasets
864
- parameters:
865
- - in: path
866
- name: id
867
- required: true
868
- schema:
869
- type: string
870
- description: Dataset ID
871
- - in: query
872
- name: version_id
873
- schema:
874
- type: string
875
- description: Dataset version ID. If omitted, returns the latest version.
876
- responses:
877
- 200:
878
- description: Success
879
- content:
880
- text/plain:
881
- schema:
882
- type: string
883
- contentMediaType: text/plain
884
- contentEncoding: gzip
885
- 403:
886
- description: Forbidden
887
- 404:
888
- description: Dataset does not exist.
889
- 422:
890
- description: Dataset ID or version ID is invalid.
891
- """
878
+ @router.get(
879
+ "/datasets/{id}/jsonl/openai_evals",
880
+ operation_id="getDatasetJSONLOpenAIEvals",
881
+ summary="Download dataset examples as OpenAI evals JSONL file",
882
+ response_class=PlainTextResponse,
883
+ responses=add_errors_to_responses(
884
+ [
885
+ {
886
+ "status_code": HTTP_422_UNPROCESSABLE_ENTITY,
887
+ "description": "Invalid dataset or version ID",
888
+ }
889
+ ]
890
+ ),
891
+ )
892
+ async def get_dataset_jsonl_openai_evals(
893
+ request: Request,
894
+ response: Response,
895
+ id: str = Path(description="The ID of the dataset"),
896
+ version_id: Optional[str] = Query(
897
+ default=None,
898
+ description=(
899
+ "The ID of the dataset version " "(if omitted, returns data from the latest version)"
900
+ ),
901
+ ),
902
+ ) -> bytes:
892
903
  try:
893
- dataset_name, examples = await _get_db_examples(request)
904
+ async with request.app.state.db() as session:
905
+ dataset_name, examples = await _get_db_examples(
906
+ session=session, id=id, version_id=version_id
907
+ )
894
908
  except ValueError as e:
895
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
909
+ raise HTTPException(detail=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
896
910
  content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
897
- return Response(
898
- content=content,
899
- headers={
900
- "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
901
- "content-type": "text/plain",
902
- "content-encoding": "gzip",
903
- },
904
- )
911
+ response.headers["content-disposition"] = f'attachment; filename="{dataset_name}.jsonl"'
912
+ return content
905
913
 
906
914
 
907
915
  def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
908
916
  records = [
909
917
  {
910
918
  "example_id": GlobalID(
911
- type_name=DatasetExample.__name__,
919
+ type_name=DatasetExampleNodeType.__name__,
912
920
  node_id=str(ex.dataset_example_id),
913
921
  ),
914
922
  **{f"input_{k}": v for k, v in ex.input.items()},
@@ -917,7 +925,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
917
925
  }
918
926
  for ex in examples
919
927
  ]
920
- return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
928
+ return str(pd.DataFrame.from_records(records).to_csv(index=False)).encode()
921
929
 
922
930
 
923
931
  def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -938,7 +946,7 @@ def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision])
938
946
  ).encode()
939
947
  )
940
948
  records.seek(0)
941
- return gzip.compress(records.read())
949
+ return records.read()
942
950
 
943
951
 
944
952
  def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
@@ -965,18 +973,17 @@ def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision
965
973
  ).encode()
966
974
  )
967
975
  records.seek(0)
968
- return gzip.compress(records.read())
976
+ return records.read()
969
977
 
970
978
 
971
- async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
972
- if not (id_ := request.path_params.get("id")):
973
- raise ValueError("Missing Dataset ID")
974
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
979
+ async def _get_db_examples(
980
+ *, session: Any, id: str, version_id: Optional[str]
981
+ ) -> Tuple[str, List[models.DatasetExampleRevision]]:
982
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id), DATASET_NODE_NAME)
975
983
  dataset_version_id: Optional[int] = None
976
- if version_id := request.query_params.get("version_id"):
984
+ if version_id:
977
985
  dataset_version_id = from_global_id_with_expected_type(
978
- GlobalID.from_id(version_id),
979
- DatasetVersion.__name__,
986
+ GlobalID.from_id(version_id), DATASET_VERSION_NODE_NAME
980
987
  )
981
988
  latest_version = (
982
989
  select(
@@ -1009,13 +1016,12 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
1009
1016
  .where(models.DatasetExampleRevision.revision_kind != "DELETE")
1010
1017
  .order_by(models.DatasetExampleRevision.dataset_example_id)
1011
1018
  )
1012
- async with request.app.state.db() as session:
1013
- dataset_name: Optional[str] = await session.scalar(
1014
- select(models.Dataset.name).where(models.Dataset.id == dataset_id)
1015
- )
1016
- if not dataset_name:
1017
- raise ValueError("Dataset does not exist.")
1018
- examples = [r async for r in await session.stream_scalars(stmt)]
1019
+ dataset_name: Optional[str] = await session.scalar(
1020
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
1021
+ )
1022
+ if not dataset_name:
1023
+ raise ValueError("Dataset does not exist.")
1024
+ examples = [r async for r in await session.stream_scalars(stmt)]
1019
1025
  return dataset_name, examples
1020
1026
 
1021
1027