arize-phoenix 4.4.3__py3-none-any.whl → 4.4.4rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/METADATA +4 -4
  2. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/RECORD +108 -55
  3. phoenix/__init__.py +0 -27
  4. phoenix/config.py +21 -7
  5. phoenix/core/model.py +25 -25
  6. phoenix/core/model_schema.py +64 -62
  7. phoenix/core/model_schema_adapter.py +27 -25
  8. phoenix/datasets/__init__.py +0 -0
  9. phoenix/datasets/evaluators.py +275 -0
  10. phoenix/datasets/experiments.py +469 -0
  11. phoenix/datasets/tracing.py +66 -0
  12. phoenix/datasets/types.py +212 -0
  13. phoenix/db/bulk_inserter.py +54 -14
  14. phoenix/db/insertion/dataset.py +234 -0
  15. phoenix/db/insertion/evaluation.py +6 -6
  16. phoenix/db/insertion/helpers.py +13 -2
  17. phoenix/db/migrations/types.py +29 -0
  18. phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
  19. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
  20. phoenix/db/models.py +230 -3
  21. phoenix/inferences/fixtures.py +23 -23
  22. phoenix/inferences/inferences.py +7 -7
  23. phoenix/inferences/validation.py +1 -1
  24. phoenix/server/api/context.py +16 -0
  25. phoenix/server/api/dataloaders/__init__.py +16 -0
  26. phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
  27. phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
  28. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
  29. phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
  30. phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
  31. phoenix/server/api/dataloaders/project_by_name.py +31 -0
  32. phoenix/server/api/dataloaders/span_descendants.py +2 -3
  33. phoenix/server/api/dataloaders/span_projects.py +33 -0
  34. phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
  35. phoenix/server/api/helpers/dataset_helpers.py +178 -0
  36. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
  37. phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
  38. phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
  39. phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
  40. phoenix/server/api/input_types/DatasetSort.py +17 -0
  41. phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
  42. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
  43. phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
  44. phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
  45. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
  46. phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
  47. phoenix/server/api/mutations/__init__.py +13 -0
  48. phoenix/server/api/mutations/auth.py +11 -0
  49. phoenix/server/api/mutations/dataset_mutations.py +520 -0
  50. phoenix/server/api/mutations/experiment_mutations.py +65 -0
  51. phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
  52. phoenix/server/api/mutations/project_mutations.py +42 -0
  53. phoenix/server/api/queries.py +503 -0
  54. phoenix/server/api/routers/v1/__init__.py +77 -2
  55. phoenix/server/api/routers/v1/dataset_examples.py +178 -0
  56. phoenix/server/api/routers/v1/datasets.py +861 -0
  57. phoenix/server/api/routers/v1/evaluations.py +4 -2
  58. phoenix/server/api/routers/v1/experiment_evaluations.py +65 -0
  59. phoenix/server/api/routers/v1/experiment_runs.py +108 -0
  60. phoenix/server/api/routers/v1/experiments.py +174 -0
  61. phoenix/server/api/routers/v1/spans.py +3 -1
  62. phoenix/server/api/routers/v1/traces.py +1 -4
  63. phoenix/server/api/schema.py +2 -303
  64. phoenix/server/api/types/AnnotatorKind.py +10 -0
  65. phoenix/server/api/types/Cluster.py +19 -19
  66. phoenix/server/api/types/CreateDatasetPayload.py +8 -0
  67. phoenix/server/api/types/Dataset.py +282 -63
  68. phoenix/server/api/types/DatasetExample.py +85 -0
  69. phoenix/server/api/types/DatasetExampleRevision.py +34 -0
  70. phoenix/server/api/types/DatasetVersion.py +14 -0
  71. phoenix/server/api/types/Dimension.py +30 -29
  72. phoenix/server/api/types/EmbeddingDimension.py +40 -34
  73. phoenix/server/api/types/Event.py +16 -16
  74. phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
  75. phoenix/server/api/types/Experiment.py +135 -0
  76. phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
  77. phoenix/server/api/types/ExperimentComparison.py +19 -0
  78. phoenix/server/api/types/ExperimentRun.py +91 -0
  79. phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
  80. phoenix/server/api/types/Inferences.py +80 -0
  81. phoenix/server/api/types/InferencesRole.py +23 -0
  82. phoenix/server/api/types/Model.py +43 -42
  83. phoenix/server/api/types/Project.py +26 -12
  84. phoenix/server/api/types/Span.py +78 -2
  85. phoenix/server/api/types/TimeSeries.py +6 -6
  86. phoenix/server/api/types/Trace.py +15 -4
  87. phoenix/server/api/types/UMAPPoints.py +1 -1
  88. phoenix/server/api/types/node.py +5 -111
  89. phoenix/server/api/types/pagination.py +10 -52
  90. phoenix/server/app.py +99 -49
  91. phoenix/server/main.py +49 -27
  92. phoenix/server/openapi/docs.py +3 -0
  93. phoenix/server/static/index.js +2246 -1368
  94. phoenix/server/templates/index.html +1 -0
  95. phoenix/services.py +15 -15
  96. phoenix/session/client.py +316 -21
  97. phoenix/session/session.py +47 -37
  98. phoenix/trace/exporter.py +14 -9
  99. phoenix/trace/fixtures.py +133 -7
  100. phoenix/trace/span_evaluations.py +3 -3
  101. phoenix/trace/trace_dataset.py +6 -6
  102. phoenix/utilities/json.py +61 -0
  103. phoenix/utilities/re.py +50 -0
  104. phoenix/version.py +1 -1
  105. phoenix/server/api/types/DatasetRole.py +0 -23
  106. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/WHEEL +0 -0
  107. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/IP_NOTICE +0 -0
  108. {arize_phoenix-4.4.3.dist-info → arize_phoenix-4.4.4rc0.dist-info}/licenses/LICENSE +0 -0
  109. /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
@@ -0,0 +1,861 @@
1
+ import csv
2
+ import gzip
3
+ import io
4
+ import json
5
+ import logging
6
+ import zlib
7
+ from asyncio import QueueFull
8
+ from collections import Counter
9
+ from enum import Enum
10
+ from functools import partial
11
+ from typing import (
12
+ Any,
13
+ Awaitable,
14
+ Callable,
15
+ Coroutine,
16
+ Dict,
17
+ FrozenSet,
18
+ Iterator,
19
+ List,
20
+ Optional,
21
+ Tuple,
22
+ Union,
23
+ cast,
24
+ )
25
+
26
+ import pandas as pd
27
+ import pyarrow as pa
28
+ from sqlalchemy import and_, func, select
29
+ from sqlalchemy.ext.asyncio import AsyncSession
30
+ from starlette.concurrency import run_in_threadpool
31
+ from starlette.datastructures import FormData, UploadFile
32
+ from starlette.requests import Request
33
+ from starlette.responses import JSONResponse, Response
34
+ from starlette.status import (
35
+ HTTP_403_FORBIDDEN,
36
+ HTTP_404_NOT_FOUND,
37
+ HTTP_422_UNPROCESSABLE_ENTITY,
38
+ HTTP_429_TOO_MANY_REQUESTS,
39
+ )
40
+ from strawberry.relay import GlobalID
41
+ from typing_extensions import TypeAlias, assert_never
42
+
43
+ from phoenix.db import models
44
+ from phoenix.db.insertion.dataset import (
45
+ DatasetAction,
46
+ DatasetExampleAdditionEvent,
47
+ add_dataset_examples,
48
+ )
49
+ from phoenix.server.api.types.Dataset import Dataset
50
+ from phoenix.server.api.types.DatasetExample import DatasetExample
51
+ from phoenix.server.api.types.DatasetVersion import DatasetVersion
52
+ from phoenix.server.api.types.node import from_global_id_with_expected_type
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ NODE_NAME = "Dataset"
57
+
58
+
59
+ async def list_datasets(request: Request) -> Response:
60
+ """
61
+ summary: List datasets with cursor-based pagination
62
+ operationId: listDatasets
63
+ tags:
64
+ - datasets
65
+ parameters:
66
+ - in: query
67
+ name: cursor
68
+ required: false
69
+ schema:
70
+ type: string
71
+ description: Cursor for pagination
72
+ - in: query
73
+ name: limit
74
+ required: false
75
+ schema:
76
+ type: integer
77
+ default: 10
78
+ - in: query
79
+ name: name
80
+ required: false
81
+ schema:
82
+ type: string
83
+ description: match by dataset name
84
+ responses:
85
+ 200:
86
+ description: A paginated list of datasets
87
+ content:
88
+ application/json:
89
+ schema:
90
+ type: object
91
+ properties:
92
+ next_cursor:
93
+ type: string
94
+ data:
95
+ type: array
96
+ items:
97
+ type: object
98
+ properties:
99
+ id:
100
+ type: string
101
+ name:
102
+ type: string
103
+ description:
104
+ type: string
105
+ metadata:
106
+ type: object
107
+ created_at:
108
+ type: string
109
+ format: date-time
110
+ updated_at:
111
+ type: string
112
+ format: date-time
113
+ 403:
114
+ description: Forbidden
115
+ 404:
116
+ description: No datasets found
117
+ """
118
+ name = request.query_params.get("name")
119
+ cursor = request.query_params.get("cursor")
120
+ limit = int(request.query_params.get("limit", 10))
121
+ async with request.app.state.db() as session:
122
+ query = select(models.Dataset).order_by(models.Dataset.id.desc())
123
+
124
+ if cursor:
125
+ try:
126
+ cursor_id = GlobalID.from_id(cursor).node_id
127
+ query = query.filter(models.Dataset.id <= int(cursor_id))
128
+ except ValueError:
129
+ return Response(
130
+ content=f"Invalid cursor format: {cursor}",
131
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
132
+ )
133
+ if name:
134
+ query = query.filter(models.Dataset.name.is_(name))
135
+
136
+ query = query.limit(limit + 1)
137
+ result = await session.execute(query)
138
+ datasets = result.scalars().all()
139
+
140
+ if not datasets:
141
+ return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
142
+
143
+ next_cursor = None
144
+ if len(datasets) == limit + 1:
145
+ next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
146
+ datasets = datasets[:-1]
147
+
148
+ data = []
149
+ for dataset in datasets:
150
+ data.append(
151
+ {
152
+ "id": str(GlobalID(NODE_NAME, str(dataset.id))),
153
+ "name": dataset.name,
154
+ "description": dataset.description,
155
+ "metadata": dataset.metadata_,
156
+ "created_at": dataset.created_at.isoformat(),
157
+ "updated_at": dataset.updated_at.isoformat(),
158
+ }
159
+ )
160
+
161
+ return JSONResponse(content={"next_cursor": next_cursor, "data": data})
162
+
163
+
164
+ async def get_dataset_by_id(request: Request) -> Response:
165
+ """
166
+ summary: Get dataset by ID
167
+ operationId: getDatasetById
168
+ tags:
169
+ - datasets
170
+ parameters:
171
+ - in: path
172
+ name: id
173
+ required: true
174
+ schema:
175
+ type: string
176
+ responses:
177
+ 200:
178
+ description: Success
179
+ content:
180
+ application/json:
181
+ schema:
182
+ type: object
183
+ properties:
184
+ id:
185
+ type: string
186
+ name:
187
+ type: string
188
+ description:
189
+ type: string
190
+ metadata:
191
+ type: object
192
+ created_at:
193
+ type: string
194
+ format: date-time
195
+ updated_at:
196
+ type: string
197
+ format: date-time
198
+ example_count:
199
+ type: integer
200
+ 403:
201
+ description: Forbidden
202
+ 404:
203
+ description: Dataset not found
204
+ """
205
+ dataset_id = GlobalID.from_id(request.path_params["id"])
206
+
207
+ if (type_name := dataset_id.type_name) != NODE_NAME:
208
+ return Response(
209
+ content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
210
+ )
211
+ async with request.app.state.db() as session:
212
+ result = await session.execute(
213
+ select(models.Dataset, models.Dataset.example_count).filter(
214
+ models.Dataset.id == int(dataset_id.node_id)
215
+ )
216
+ )
217
+ dataset_query = result.first()
218
+ dataset = dataset_query[0] if dataset_query else None
219
+ example_count = dataset_query[1] if dataset_query else 0
220
+ if dataset is None:
221
+ return Response(
222
+ content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
223
+ )
224
+
225
+ output_dict = {
226
+ "id": str(dataset_id),
227
+ "name": dataset.name,
228
+ "description": dataset.description,
229
+ "metadata": dataset.metadata_,
230
+ "created_at": dataset.created_at.isoformat(),
231
+ "updated_at": dataset.updated_at.isoformat(),
232
+ "example_count": example_count,
233
+ }
234
+ return JSONResponse(content=output_dict)
235
+
236
+
237
+ async def get_dataset_versions(request: Request) -> Response:
238
+ """
239
+ summary: Get dataset versions (sorted from latest to oldest)
240
+ operationId: getDatasetVersionsByDatasetId
241
+ tags:
242
+ - datasets
243
+ parameters:
244
+ - in: path
245
+ name: id
246
+ required: true
247
+ description: Dataset ID
248
+ schema:
249
+ type: string
250
+ - in: query
251
+ name: cursor
252
+ description: Cursor for pagination.
253
+ schema:
254
+ type: string
255
+ - in: query
256
+ name: limit
257
+ description: Maximum number versions to return.
258
+ schema:
259
+ type: integer
260
+ default: 10
261
+ responses:
262
+ 200:
263
+ description: Success
264
+ content:
265
+ application/json:
266
+ schema:
267
+ type: object
268
+ properties:
269
+ next_cursor:
270
+ type: string
271
+ data:
272
+ type: array
273
+ items:
274
+ type: object
275
+ properties:
276
+ version_id:
277
+ type: string
278
+ description:
279
+ type: string
280
+ metadata:
281
+ type: object
282
+ created_at:
283
+ type: string
284
+ format: date-time
285
+ 403:
286
+ description: Forbidden
287
+ 422:
288
+ description: Dataset ID, cursor or limit is invalid.
289
+ """
290
+ if id_ := request.path_params.get("id"):
291
+ try:
292
+ dataset_id = from_global_id_with_expected_type(
293
+ GlobalID.from_id(id_),
294
+ Dataset.__name__,
295
+ )
296
+ except ValueError:
297
+ return Response(
298
+ content=f"Invalid Dataset ID: {id_}",
299
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
300
+ )
301
+ else:
302
+ return Response(
303
+ content="Missing Dataset ID",
304
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
305
+ )
306
+ try:
307
+ limit = int(request.query_params.get("limit", 10))
308
+ assert limit > 0
309
+ except (ValueError, AssertionError):
310
+ return Response(
311
+ content="Invalid limit parameter",
312
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
313
+ )
314
+ stmt = (
315
+ select(models.DatasetVersion)
316
+ .where(models.DatasetVersion.dataset_id == dataset_id)
317
+ .order_by(models.DatasetVersion.id.desc())
318
+ .limit(limit + 1)
319
+ )
320
+ if cursor := request.query_params.get("cursor"):
321
+ try:
322
+ dataset_version_id = from_global_id_with_expected_type(
323
+ GlobalID.from_id(cursor),
324
+ DatasetVersion.__name__,
325
+ )
326
+ except ValueError:
327
+ return Response(
328
+ content=f"Invalid cursor: {cursor}",
329
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
330
+ )
331
+ max_dataset_version_id = (
332
+ select(models.DatasetVersion.id)
333
+ .where(models.DatasetVersion.id == dataset_version_id)
334
+ .where(models.DatasetVersion.dataset_id == dataset_id)
335
+ ).scalar_subquery()
336
+ stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
337
+ async with request.app.state.db() as session:
338
+ data = [
339
+ {
340
+ "version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
341
+ "description": version.description,
342
+ "metadata": version.metadata_,
343
+ "created_at": version.created_at.isoformat(),
344
+ }
345
+ async for version in await session.stream_scalars(stmt)
346
+ ]
347
+ next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
348
+ return JSONResponse(content={"next_cursor": next_cursor, "data": data})
349
+
350
+
351
+ async def post_datasets_upload(request: Request) -> Response:
352
+ """
353
+ summary: Upload CSV or PyArrow file as dataset
354
+ operationId: uploadDataset
355
+ tags:
356
+ - datasets
357
+ parameters:
358
+ - in: query
359
+ name: sync
360
+ description: If true, fulfill request synchronously and return JSON containing dataset_id
361
+ schema:
362
+ type: boolean
363
+ requestBody:
364
+ content:
365
+ multipart/form-data:
366
+ schema:
367
+ type: object
368
+ required:
369
+ - name
370
+ - input_keys[]
371
+ - output_keys[]
372
+ - file
373
+ properties:
374
+ action:
375
+ type: string
376
+ enum: [create, append]
377
+ name:
378
+ type: string
379
+ description:
380
+ type: string
381
+ input_keys[]:
382
+ type: array
383
+ items:
384
+ type: string
385
+ uniqueItems: true
386
+ output_keys[]:
387
+ type: array
388
+ items:
389
+ type: string
390
+ uniqueItems: true
391
+ metadata_keys[]:
392
+ type: array
393
+ items:
394
+ type: string
395
+ uniqueItems: true
396
+ file:
397
+ type: string
398
+ format: binary
399
+ responses:
400
+ 200:
401
+ description: Success
402
+ 403:
403
+ description: Forbidden
404
+ 422:
405
+ description: Request body is invalid
406
+ """
407
+ if request.app.state.read_only:
408
+ return Response(status_code=HTTP_403_FORBIDDEN)
409
+ async with request.form() as form:
410
+ try:
411
+ (
412
+ action,
413
+ name,
414
+ description,
415
+ input_keys,
416
+ output_keys,
417
+ metadata_keys,
418
+ file,
419
+ ) = await _parse_form_data(form)
420
+ except ValueError as e:
421
+ return Response(
422
+ content=str(e),
423
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
424
+ )
425
+ if action is DatasetAction.CREATE:
426
+ async with request.app.state.db() as session:
427
+ if await _check_table_exists(session, name):
428
+ return Response(
429
+ content=f"Dataset already exists: {name=}",
430
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
431
+ )
432
+ content = await file.read()
433
+ try:
434
+ examples: Union[Examples, Awaitable[Examples]]
435
+ content_type = FileContentType(file.content_type)
436
+ if content_type is FileContentType.CSV:
437
+ encoding = FileContentEncoding(file.headers.get("content-encoding"))
438
+ examples, column_headers = await _process_csv(content, encoding)
439
+ elif content_type is FileContentType.PYARROW:
440
+ examples, column_headers = await _process_pyarrow(content)
441
+ else:
442
+ assert_never(content_type)
443
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
444
+ except ValueError as e:
445
+ return Response(
446
+ content=str(e),
447
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
448
+ )
449
+ operation = cast(
450
+ Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
451
+ partial(
452
+ add_dataset_examples,
453
+ examples=examples,
454
+ action=action,
455
+ name=name,
456
+ description=description,
457
+ input_keys=input_keys,
458
+ output_keys=output_keys,
459
+ metadata_keys=metadata_keys,
460
+ ),
461
+ )
462
+ if request.query_params.get("sync") == "true":
463
+ async with request.app.state.db() as session:
464
+ dataset_id = (await operation(session)).dataset_id
465
+ return JSONResponse(
466
+ content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
467
+ )
468
+ try:
469
+ request.state.enqueue_operation(operation)
470
+ except QueueFull:
471
+ if isinstance(examples, Coroutine):
472
+ examples.close()
473
+ return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
474
+ return Response()
475
+
476
+
477
+ class FileContentType(Enum):
478
+ CSV = "text/csv"
479
+ PYARROW = "application/x-pandas-pyarrow"
480
+
481
+ @classmethod
482
+ def _missing_(cls, v: Any) -> "FileContentType":
483
+ if isinstance(v, str) and v and v.isascii() and not v.islower():
484
+ return cls(v.lower())
485
+ raise ValueError(f"Invalid file content type: {v}")
486
+
487
+
488
+ class FileContentEncoding(Enum):
489
+ NONE = "none"
490
+ GZIP = "gzip"
491
+ DEFLATE = "deflate"
492
+
493
+ @classmethod
494
+ def _missing_(cls, v: Any) -> "FileContentEncoding":
495
+ if v is None:
496
+ return cls("none")
497
+ if isinstance(v, str) and v and v.isascii() and not v.islower():
498
+ return cls(v.lower())
499
+ raise ValueError(f"Invalid file content encoding: {v}")
500
+
501
+
502
+ Name: TypeAlias = str
503
+ Description: TypeAlias = Optional[str]
504
+ InputKeys: TypeAlias = FrozenSet[str]
505
+ OutputKeys: TypeAlias = FrozenSet[str]
506
+ MetadataKeys: TypeAlias = FrozenSet[str]
507
+ DatasetId: TypeAlias = int
508
+ Examples: TypeAlias = Iterator[Dict[str, Any]]
509
+
510
+
511
+ async def _process_csv(
512
+ content: bytes,
513
+ content_encoding: FileContentEncoding,
514
+ ) -> Tuple[Examples, FrozenSet[str]]:
515
+ if content_encoding is FileContentEncoding.GZIP:
516
+ content = await run_in_threadpool(gzip.decompress, content)
517
+ elif content_encoding is FileContentEncoding.DEFLATE:
518
+ content = await run_in_threadpool(zlib.decompress, content)
519
+ elif content_encoding is not FileContentEncoding.NONE:
520
+ assert_never(content_encoding)
521
+ reader = await run_in_threadpool(lambda c: csv.DictReader(io.StringIO(c.decode())), content)
522
+ if reader.fieldnames is None:
523
+ raise ValueError("Missing CSV column header")
524
+ (header, freq), *_ = Counter(reader.fieldnames).most_common(1)
525
+ if freq > 1:
526
+ raise ValueError(f"Duplicated column header in CSV file: {header}")
527
+ column_headers = frozenset(reader.fieldnames)
528
+ return reader, column_headers
529
+
530
+
531
+ async def _process_pyarrow(
532
+ content: bytes,
533
+ ) -> Tuple[Awaitable[Examples], FrozenSet[str]]:
534
+ try:
535
+ reader = pa.ipc.open_stream(content)
536
+ except pa.ArrowInvalid as e:
537
+ raise ValueError("File is not valid pyarrow") from e
538
+ column_headers = frozenset(reader.schema.names)
539
+
540
+ def get_examples() -> Iterator[Dict[str, Any]]:
541
+ yield from reader.read_pandas().to_dict(orient="records")
542
+
543
+ return run_in_threadpool(get_examples), column_headers
544
+
545
+
546
+ async def _check_table_exists(session: AsyncSession, name: str) -> bool:
547
+ return bool(
548
+ await session.scalar(
549
+ select(1).select_from(models.Dataset).where(models.Dataset.name == name)
550
+ )
551
+ )
552
+
553
+
554
+ def _check_keys_exist(
555
+ column_headers: FrozenSet[str],
556
+ input_keys: InputKeys,
557
+ output_keys: OutputKeys,
558
+ metadata_keys: MetadataKeys,
559
+ ) -> None:
560
+ for desc, keys in (
561
+ ("input", input_keys),
562
+ ("output", output_keys),
563
+ ("metadata", metadata_keys),
564
+ ):
565
+ if keys and (diff := keys.difference(column_headers)):
566
+ raise ValueError(f"{desc} keys not found in column headers: {diff}")
567
+
568
+
569
+ async def _parse_form_data(
570
+ form: FormData,
571
+ ) -> Tuple[
572
+ DatasetAction,
573
+ Name,
574
+ Description,
575
+ InputKeys,
576
+ OutputKeys,
577
+ MetadataKeys,
578
+ UploadFile,
579
+ ]:
580
+ name = cast(Optional[str], form.get("name"))
581
+ if not name:
582
+ raise ValueError("Dataset name must not be empty")
583
+ action = DatasetAction(cast(Optional[str], form.get("action")) or "create")
584
+ file = form["file"]
585
+ if not isinstance(file, UploadFile):
586
+ raise ValueError("Malformed file in form data.")
587
+ description = cast(Optional[str], form.get("description")) or file.filename
588
+ input_keys = frozenset(filter(bool, cast(List[str], form.getlist("input_keys[]"))))
589
+ output_keys = frozenset(filter(bool, cast(List[str], form.getlist("output_keys[]"))))
590
+ metadata_keys = frozenset(filter(bool, cast(List[str], form.getlist("metadata_keys[]"))))
591
+ return (
592
+ action,
593
+ name,
594
+ description,
595
+ input_keys,
596
+ output_keys,
597
+ metadata_keys,
598
+ file,
599
+ )
600
+
601
+
602
+ async def get_dataset_csv(request: Request) -> Response:
603
+ """
604
+ summary: Download dataset examples as CSV text file
605
+ operationId: getDatasetCsv
606
+ tags:
607
+ - datasets
608
+ parameters:
609
+ - in: path
610
+ name: id
611
+ required: true
612
+ schema:
613
+ type: string
614
+ description: Dataset ID
615
+ - in: query
616
+ name: version
617
+ schema:
618
+ type: string
619
+ description: Dataset version ID. If omitted, returns the latest version.
620
+ responses:
621
+ 200:
622
+ description: Success
623
+ content:
624
+ text/csv:
625
+ schema:
626
+ type: string
627
+ contentMediaType: text/csv
628
+ contentEncoding: gzip
629
+ 403:
630
+ description: Forbidden
631
+ 404:
632
+ description: Dataset does not exist.
633
+ 422:
634
+ description: Dataset ID or version ID is invalid.
635
+ """
636
+ try:
637
+ dataset_name, examples = await _get_db_examples(request)
638
+ except ValueError as e:
639
+ return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
640
+ content = await run_in_threadpool(_get_content_csv, examples)
641
+ return Response(
642
+ content=content,
643
+ headers={
644
+ "content-disposition": f'attachment; filename="{dataset_name}.csv"',
645
+ "content-type": "text/csv",
646
+ "content-encoding": "gzip",
647
+ },
648
+ )
649
+
650
+
651
+ async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
652
+ """
653
+ summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
654
+ operationId: getDatasetJSONLOpenAIFineTuning
655
+ tags:
656
+ - datasets
657
+ parameters:
658
+ - in: path
659
+ name: id
660
+ required: true
661
+ schema:
662
+ type: string
663
+ description: Dataset ID
664
+ - in: query
665
+ name: version
666
+ schema:
667
+ type: string
668
+ description: Dataset version ID. If omitted, returns the latest version.
669
+ responses:
670
+ 200:
671
+ description: Success
672
+ content:
673
+ text/plain:
674
+ schema:
675
+ type: string
676
+ contentMediaType: text/plain
677
+ contentEncoding: gzip
678
+ 403:
679
+ description: Forbidden
680
+ 404:
681
+ description: Dataset does not exist.
682
+ 422:
683
+ description: Dataset ID or version ID is invalid.
684
+ """
685
+ try:
686
+ dataset_name, examples = await _get_db_examples(request)
687
+ except ValueError as e:
688
+ return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
689
+ content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
690
+ return Response(
691
+ content=content,
692
+ headers={
693
+ "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
694
+ "content-type": "text/plain",
695
+ "content-encoding": "gzip",
696
+ },
697
+ )
698
+
699
+
700
+ async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
701
+ """
702
+ summary: Download dataset examples as OpenAI Evals JSONL file
703
+ operationId: getDatasetJSONLOpenAIEvals
704
+ tags:
705
+ - datasets
706
+ parameters:
707
+ - in: path
708
+ name: id
709
+ required: true
710
+ schema:
711
+ type: string
712
+ description: Dataset ID
713
+ - in: query
714
+ name: version
715
+ schema:
716
+ type: string
717
+ description: Dataset version ID. If omitted, returns the latest version.
718
+ responses:
719
+ 200:
720
+ description: Success
721
+ content:
722
+ text/plain:
723
+ schema:
724
+ type: string
725
+ contentMediaType: text/plain
726
+ contentEncoding: gzip
727
+ 403:
728
+ description: Forbidden
729
+ 404:
730
+ description: Dataset does not exist.
731
+ 422:
732
+ description: Dataset ID or version ID is invalid.
733
+ """
734
+ try:
735
+ dataset_name, examples = await _get_db_examples(request)
736
+ except ValueError as e:
737
+ return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
738
+ content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
739
+ return Response(
740
+ content=content,
741
+ headers={
742
+ "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
743
+ "content-type": "text/plain",
744
+ "content-encoding": "gzip",
745
+ },
746
+ )
747
+
748
+
749
+ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
750
+ records = [
751
+ {
752
+ "example_id": GlobalID(
753
+ type_name=DatasetExample.__name__,
754
+ node_id=str(ex.dataset_example_id),
755
+ ),
756
+ **{f"input_{k}": v for k, v in ex.input.items()},
757
+ **{f"output_{k}": v for k, v in ex.output.items()},
758
+ **{f"metadata_{k}": v for k, v in ex.metadata_.items()},
759
+ }
760
+ for ex in examples
761
+ ]
762
+ return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
763
+
764
+
765
+ def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
766
+ records = io.BytesIO()
767
+ for ex in examples:
768
+ records.write(
769
+ (
770
+ json.dumps(
771
+ {
772
+ "messages": (
773
+ ims if isinstance(ims := ex.input.get("messages"), list) else []
774
+ )
775
+ + (oms if isinstance(oms := ex.output.get("messages"), list) else [])
776
+ },
777
+ ensure_ascii=False,
778
+ )
779
+ + "\n"
780
+ ).encode()
781
+ )
782
+ records.seek(0)
783
+ return gzip.compress(records.read())
784
+
785
+
786
+ def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
787
+ records = io.BytesIO()
788
+ for ex in examples:
789
+ records.write(
790
+ (
791
+ json.dumps(
792
+ {
793
+ "messages": ims
794
+ if isinstance(ims := ex.input.get("messages"), list)
795
+ else [],
796
+ "ideal": (
797
+ ideal if isinstance(ideal := last_message.get("content"), str) else ""
798
+ )
799
+ if isinstance(oms := ex.output.get("messages"), list)
800
+ and oms
801
+ and hasattr(last_message := oms[-1], "get")
802
+ else "",
803
+ },
804
+ ensure_ascii=False,
805
+ )
806
+ + "\n"
807
+ ).encode()
808
+ )
809
+ records.seek(0)
810
+ return gzip.compress(records.read())
811
+
812
+
813
+ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
814
+ if not (id_ := request.path_params.get("id")):
815
+ raise ValueError("Missing Dataset ID")
816
+ dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
817
+ dataset_version_id: Optional[int] = None
818
+ if version := request.query_params.get("version"):
819
+ dataset_version_id = from_global_id_with_expected_type(
820
+ GlobalID.from_id(version),
821
+ DatasetVersion.__name__,
822
+ )
823
+ latest_version = (
824
+ select(
825
+ models.DatasetExampleRevision.dataset_example_id,
826
+ func.max(models.DatasetExampleRevision.dataset_version_id).label("dataset_version_id"),
827
+ )
828
+ .group_by(models.DatasetExampleRevision.dataset_example_id)
829
+ .join(models.DatasetExample)
830
+ .where(models.DatasetExample.dataset_id == dataset_id)
831
+ )
832
+ if dataset_version_id is not None:
833
+ max_dataset_version_id = (
834
+ select(models.DatasetVersion.id)
835
+ .where(models.DatasetVersion.id == dataset_version_id)
836
+ .where(models.DatasetVersion.dataset_id == dataset_id)
837
+ ).scalar_subquery()
838
+ latest_version = latest_version.where(
839
+ models.DatasetExampleRevision.dataset_version_id <= max_dataset_version_id
840
+ )
841
+ subq = latest_version.subquery("latest_version")
842
+ stmt = (
843
+ select(models.DatasetExampleRevision)
844
+ .join(
845
+ subq,
846
+ onclause=and_(
847
+ models.DatasetExampleRevision.dataset_example_id == subq.c.dataset_example_id,
848
+ models.DatasetExampleRevision.dataset_version_id == subq.c.dataset_version_id,
849
+ ),
850
+ )
851
+ .where(models.DatasetExampleRevision.revision_kind != "DELETE")
852
+ .order_by(models.DatasetExampleRevision.dataset_example_id)
853
+ )
854
+ async with request.app.state.db() as session:
855
+ dataset_name: Optional[str] = await session.scalar(
856
+ select(models.Dataset.name).where(models.Dataset.id == dataset_id)
857
+ )
858
+ if not dataset_name:
859
+ raise ValueError("Dataset does not exist.")
860
+ examples = [r async for r in await session.stream_scalars(stmt)]
861
+ return dataset_name, examples