arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (118) hide show
  1. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
  2. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
  3. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
  4. phoenix/__init__.py +27 -0
  5. phoenix/config.py +7 -21
  6. phoenix/core/model.py +25 -25
  7. phoenix/core/model_schema.py +62 -64
  8. phoenix/core/model_schema_adapter.py +25 -27
  9. phoenix/db/bulk_inserter.py +14 -54
  10. phoenix/db/insertion/evaluation.py +6 -6
  11. phoenix/db/insertion/helpers.py +2 -13
  12. phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
  13. phoenix/db/models.py +4 -236
  14. phoenix/inferences/fixtures.py +23 -23
  15. phoenix/inferences/inferences.py +7 -7
  16. phoenix/inferences/validation.py +1 -1
  17. phoenix/server/api/context.py +0 -18
  18. phoenix/server/api/dataloaders/__init__.py +0 -18
  19. phoenix/server/api/dataloaders/span_descendants.py +3 -2
  20. phoenix/server/api/routers/v1/__init__.py +2 -77
  21. phoenix/server/api/routers/v1/evaluations.py +2 -4
  22. phoenix/server/api/routers/v1/spans.py +1 -3
  23. phoenix/server/api/routers/v1/traces.py +4 -1
  24. phoenix/server/api/schema.py +303 -2
  25. phoenix/server/api/types/Cluster.py +19 -19
  26. phoenix/server/api/types/Dataset.py +63 -282
  27. phoenix/server/api/types/DatasetRole.py +23 -0
  28. phoenix/server/api/types/Dimension.py +29 -30
  29. phoenix/server/api/types/EmbeddingDimension.py +34 -40
  30. phoenix/server/api/types/Event.py +16 -16
  31. phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
  32. phoenix/server/api/types/Model.py +42 -43
  33. phoenix/server/api/types/Project.py +12 -26
  34. phoenix/server/api/types/Span.py +2 -79
  35. phoenix/server/api/types/TimeSeries.py +6 -6
  36. phoenix/server/api/types/Trace.py +4 -15
  37. phoenix/server/api/types/UMAPPoints.py +1 -1
  38. phoenix/server/api/types/node.py +111 -5
  39. phoenix/server/api/types/pagination.py +52 -10
  40. phoenix/server/app.py +49 -101
  41. phoenix/server/main.py +27 -49
  42. phoenix/server/openapi/docs.py +0 -3
  43. phoenix/server/static/index.js +2595 -3523
  44. phoenix/server/templates/index.html +0 -1
  45. phoenix/services.py +15 -15
  46. phoenix/session/client.py +21 -438
  47. phoenix/session/session.py +37 -47
  48. phoenix/trace/exporter.py +9 -14
  49. phoenix/trace/fixtures.py +7 -133
  50. phoenix/trace/schemas.py +2 -1
  51. phoenix/trace/span_evaluations.py +3 -3
  52. phoenix/trace/trace_dataset.py +6 -6
  53. phoenix/version.py +1 -1
  54. phoenix/datasets/__init__.py +0 -0
  55. phoenix/datasets/evaluators/__init__.py +0 -18
  56. phoenix/datasets/evaluators/code_evaluators.py +0 -99
  57. phoenix/datasets/evaluators/llm_evaluators.py +0 -244
  58. phoenix/datasets/evaluators/utils.py +0 -292
  59. phoenix/datasets/experiments.py +0 -550
  60. phoenix/datasets/tracing.py +0 -85
  61. phoenix/datasets/types.py +0 -178
  62. phoenix/db/insertion/dataset.py +0 -237
  63. phoenix/db/migrations/types.py +0 -29
  64. phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
  65. phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
  66. phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
  67. phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
  68. phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
  69. phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
  70. phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
  71. phoenix/server/api/dataloaders/project_by_name.py +0 -31
  72. phoenix/server/api/dataloaders/span_projects.py +0 -33
  73. phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
  74. phoenix/server/api/helpers/dataset_helpers.py +0 -179
  75. phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
  76. phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
  77. phoenix/server/api/input_types/ClearProjectInput.py +0 -15
  78. phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
  79. phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
  80. phoenix/server/api/input_types/DatasetSort.py +0 -17
  81. phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
  82. phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
  83. phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
  84. phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
  85. phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
  86. phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
  87. phoenix/server/api/mutations/__init__.py +0 -13
  88. phoenix/server/api/mutations/auth.py +0 -11
  89. phoenix/server/api/mutations/dataset_mutations.py +0 -520
  90. phoenix/server/api/mutations/experiment_mutations.py +0 -65
  91. phoenix/server/api/mutations/project_mutations.py +0 -47
  92. phoenix/server/api/openapi/__init__.py +0 -0
  93. phoenix/server/api/openapi/main.py +0 -6
  94. phoenix/server/api/openapi/schema.py +0 -16
  95. phoenix/server/api/queries.py +0 -503
  96. phoenix/server/api/routers/v1/dataset_examples.py +0 -178
  97. phoenix/server/api/routers/v1/datasets.py +0 -965
  98. phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
  99. phoenix/server/api/routers/v1/experiment_runs.py +0 -108
  100. phoenix/server/api/routers/v1/experiments.py +0 -174
  101. phoenix/server/api/types/AnnotatorKind.py +0 -10
  102. phoenix/server/api/types/CreateDatasetPayload.py +0 -8
  103. phoenix/server/api/types/DatasetExample.py +0 -85
  104. phoenix/server/api/types/DatasetExampleRevision.py +0 -34
  105. phoenix/server/api/types/DatasetVersion.py +0 -14
  106. phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
  107. phoenix/server/api/types/Experiment.py +0 -140
  108. phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
  109. phoenix/server/api/types/ExperimentComparison.py +0 -19
  110. phoenix/server/api/types/ExperimentRun.py +0 -91
  111. phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
  112. phoenix/server/api/types/Inferences.py +0 -80
  113. phoenix/server/api/types/InferencesRole.py +0 -23
  114. phoenix/utilities/json.py +0 -61
  115. phoenix/utilities/re.py +0 -50
  116. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
  117. {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
  118. /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
@@ -1,965 +0,0 @@
1
- import csv
2
- import gzip
3
- import io
4
- import json
5
- import logging
6
- import zlib
7
- from asyncio import QueueFull
8
- from collections import Counter
9
- from enum import Enum
10
- from functools import partial
11
- from typing import (
12
- Any,
13
- Awaitable,
14
- Callable,
15
- Coroutine,
16
- FrozenSet,
17
- Iterator,
18
- List,
19
- Mapping,
20
- Optional,
21
- Sequence,
22
- Tuple,
23
- Union,
24
- cast,
25
- )
26
-
27
- import pandas as pd
28
- import pyarrow as pa
29
- from sqlalchemy import and_, func, select
30
- from sqlalchemy.ext.asyncio import AsyncSession
31
- from starlette.concurrency import run_in_threadpool
32
- from starlette.datastructures import FormData, UploadFile
33
- from starlette.requests import Request
34
- from starlette.responses import JSONResponse, Response
35
- from starlette.status import (
36
- HTTP_404_NOT_FOUND,
37
- HTTP_409_CONFLICT,
38
- HTTP_422_UNPROCESSABLE_ENTITY,
39
- HTTP_429_TOO_MANY_REQUESTS,
40
- )
41
- from strawberry.relay import GlobalID
42
- from typing_extensions import TypeAlias, assert_never
43
-
44
- from phoenix.db import models
45
- from phoenix.db.insertion.dataset import (
46
- DatasetAction,
47
- DatasetExampleAdditionEvent,
48
- ExampleContent,
49
- add_dataset_examples,
50
- )
51
- from phoenix.server.api.types.Dataset import Dataset
52
- from phoenix.server.api.types.DatasetExample import DatasetExample
53
- from phoenix.server.api.types.DatasetVersion import DatasetVersion
54
- from phoenix.server.api.types.node import from_global_id_with_expected_type
55
-
56
- logger = logging.getLogger(__name__)
57
-
58
- NODE_NAME = "Dataset"
59
-
60
-
61
- async def list_datasets(request: Request) -> Response:
62
- """
63
- summary: List datasets with cursor-based pagination
64
- operationId: listDatasets
65
- tags:
66
- - datasets
67
- parameters:
68
- - in: query
69
- name: cursor
70
- required: false
71
- schema:
72
- type: string
73
- description: Cursor for pagination
74
- - in: query
75
- name: limit
76
- required: false
77
- schema:
78
- type: integer
79
- default: 10
80
- - in: query
81
- name: name
82
- required: false
83
- schema:
84
- type: string
85
- description: match by dataset name
86
- responses:
87
- 200:
88
- description: A paginated list of datasets
89
- content:
90
- application/json:
91
- schema:
92
- type: object
93
- properties:
94
- next_cursor:
95
- type: string
96
- data:
97
- type: array
98
- items:
99
- type: object
100
- properties:
101
- id:
102
- type: string
103
- name:
104
- type: string
105
- description:
106
- type: string
107
- metadata:
108
- type: object
109
- created_at:
110
- type: string
111
- format: date-time
112
- updated_at:
113
- type: string
114
- format: date-time
115
- 403:
116
- description: Forbidden
117
- 404:
118
- description: No datasets found
119
- """
120
- name = request.query_params.get("name")
121
- cursor = request.query_params.get("cursor")
122
- limit = int(request.query_params.get("limit", 10))
123
- async with request.app.state.db() as session:
124
- query = select(models.Dataset).order_by(models.Dataset.id.desc())
125
-
126
- if cursor:
127
- try:
128
- cursor_id = GlobalID.from_id(cursor).node_id
129
- query = query.filter(models.Dataset.id <= int(cursor_id))
130
- except ValueError:
131
- return Response(
132
- content=f"Invalid cursor format: {cursor}",
133
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
134
- )
135
- if name:
136
- query = query.filter(models.Dataset.name.is_(name))
137
-
138
- query = query.limit(limit + 1)
139
- result = await session.execute(query)
140
- datasets = result.scalars().all()
141
-
142
- if not datasets:
143
- return JSONResponse(content={"next_cursor": None, "data": []}, status_code=200)
144
-
145
- next_cursor = None
146
- if len(datasets) == limit + 1:
147
- next_cursor = str(GlobalID(NODE_NAME, str(datasets[-1].id)))
148
- datasets = datasets[:-1]
149
-
150
- data = []
151
- for dataset in datasets:
152
- data.append(
153
- {
154
- "id": str(GlobalID(NODE_NAME, str(dataset.id))),
155
- "name": dataset.name,
156
- "description": dataset.description,
157
- "metadata": dataset.metadata_,
158
- "created_at": dataset.created_at.isoformat(),
159
- "updated_at": dataset.updated_at.isoformat(),
160
- }
161
- )
162
-
163
- return JSONResponse(content={"next_cursor": next_cursor, "data": data})
164
-
165
-
166
- async def get_dataset_by_id(request: Request) -> Response:
167
- """
168
- summary: Get dataset by ID
169
- operationId: getDatasetById
170
- tags:
171
- - datasets
172
- parameters:
173
- - in: path
174
- name: id
175
- required: true
176
- schema:
177
- type: string
178
- responses:
179
- 200:
180
- description: Success
181
- content:
182
- application/json:
183
- schema:
184
- type: object
185
- properties:
186
- id:
187
- type: string
188
- name:
189
- type: string
190
- description:
191
- type: string
192
- metadata:
193
- type: object
194
- created_at:
195
- type: string
196
- format: date-time
197
- updated_at:
198
- type: string
199
- format: date-time
200
- example_count:
201
- type: integer
202
- 403:
203
- description: Forbidden
204
- 404:
205
- description: Dataset not found
206
- """
207
- dataset_id = GlobalID.from_id(request.path_params["id"])
208
-
209
- if (type_name := dataset_id.type_name) != NODE_NAME:
210
- return Response(
211
- content=f"ID {dataset_id} refers to a f{type_name}", status_code=HTTP_404_NOT_FOUND
212
- )
213
- async with request.app.state.db() as session:
214
- result = await session.execute(
215
- select(models.Dataset, models.Dataset.example_count).filter(
216
- models.Dataset.id == int(dataset_id.node_id)
217
- )
218
- )
219
- dataset_query = result.first()
220
- dataset = dataset_query[0] if dataset_query else None
221
- example_count = dataset_query[1] if dataset_query else 0
222
- if dataset is None:
223
- return Response(
224
- content=f"Dataset with ID {dataset_id} not found", status_code=HTTP_404_NOT_FOUND
225
- )
226
-
227
- output_dict = {
228
- "id": str(dataset_id),
229
- "name": dataset.name,
230
- "description": dataset.description,
231
- "metadata": dataset.metadata_,
232
- "created_at": dataset.created_at.isoformat(),
233
- "updated_at": dataset.updated_at.isoformat(),
234
- "example_count": example_count,
235
- }
236
- return JSONResponse(content=output_dict)
237
-
238
-
239
- async def get_dataset_versions(request: Request) -> Response:
240
- """
241
- summary: Get dataset versions (sorted from latest to oldest)
242
- operationId: getDatasetVersionsByDatasetId
243
- tags:
244
- - datasets
245
- parameters:
246
- - in: path
247
- name: id
248
- required: true
249
- description: Dataset ID
250
- schema:
251
- type: string
252
- - in: query
253
- name: cursor
254
- description: Cursor for pagination.
255
- schema:
256
- type: string
257
- - in: query
258
- name: limit
259
- description: Maximum number versions to return.
260
- schema:
261
- type: integer
262
- default: 10
263
- responses:
264
- 200:
265
- description: Success
266
- content:
267
- application/json:
268
- schema:
269
- type: object
270
- properties:
271
- next_cursor:
272
- type: string
273
- data:
274
- type: array
275
- items:
276
- type: object
277
- properties:
278
- version_id:
279
- type: string
280
- description:
281
- type: string
282
- metadata:
283
- type: object
284
- created_at:
285
- type: string
286
- format: date-time
287
- 403:
288
- description: Forbidden
289
- 422:
290
- description: Dataset ID, cursor or limit is invalid.
291
- """
292
- if id_ := request.path_params.get("id"):
293
- try:
294
- dataset_id = from_global_id_with_expected_type(
295
- GlobalID.from_id(id_),
296
- Dataset.__name__,
297
- )
298
- except ValueError:
299
- return Response(
300
- content=f"Invalid Dataset ID: {id_}",
301
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
302
- )
303
- else:
304
- return Response(
305
- content="Missing Dataset ID",
306
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
307
- )
308
- try:
309
- limit = int(request.query_params.get("limit", 10))
310
- assert limit > 0
311
- except (ValueError, AssertionError):
312
- return Response(
313
- content="Invalid limit parameter",
314
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
315
- )
316
- stmt = (
317
- select(models.DatasetVersion)
318
- .where(models.DatasetVersion.dataset_id == dataset_id)
319
- .order_by(models.DatasetVersion.id.desc())
320
- .limit(limit + 1)
321
- )
322
- if cursor := request.query_params.get("cursor"):
323
- try:
324
- dataset_version_id = from_global_id_with_expected_type(
325
- GlobalID.from_id(cursor),
326
- DatasetVersion.__name__,
327
- )
328
- except ValueError:
329
- return Response(
330
- content=f"Invalid cursor: {cursor}",
331
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
332
- )
333
- max_dataset_version_id = (
334
- select(models.DatasetVersion.id)
335
- .where(models.DatasetVersion.id == dataset_version_id)
336
- .where(models.DatasetVersion.dataset_id == dataset_id)
337
- ).scalar_subquery()
338
- stmt = stmt.filter(models.DatasetVersion.id <= max_dataset_version_id)
339
- async with request.app.state.db() as session:
340
- data = [
341
- {
342
- "version_id": str(GlobalID(DatasetVersion.__name__, str(version.id))),
343
- "description": version.description,
344
- "metadata": version.metadata_,
345
- "created_at": version.created_at.isoformat(),
346
- }
347
- async for version in await session.stream_scalars(stmt)
348
- ]
349
- next_cursor = data.pop()["version_id"] if len(data) == limit + 1 else None
350
- return JSONResponse(content={"next_cursor": next_cursor, "data": data})
351
-
352
-
353
- async def post_datasets_upload(request: Request) -> Response:
354
- """
355
- summary: Upload dataset as either JSON or file (CSV or PyArrow)
356
- operationId: uploadDataset
357
- tags:
358
- - datasets
359
- parameters:
360
- - in: query
361
- name: sync
362
- description: If true, fulfill request synchronously and return JSON containing dataset_id
363
- schema:
364
- type: boolean
365
- requestBody:
366
- content:
367
- application/json:
368
- schema:
369
- type: object
370
- required:
371
- - name
372
- - inputs
373
- properties:
374
- action:
375
- type: string
376
- enum: [create, append]
377
- name:
378
- type: string
379
- description:
380
- type: string
381
- inputs:
382
- type: array
383
- items:
384
- type: object
385
- outputs:
386
- type: array
387
- items:
388
- type: object
389
- metadata:
390
- type: array
391
- items:
392
- type: object
393
- multipart/form-data:
394
- schema:
395
- type: object
396
- required:
397
- - name
398
- - input_keys[]
399
- - output_keys[]
400
- - file
401
- properties:
402
- action:
403
- type: string
404
- enum: [create, append]
405
- name:
406
- type: string
407
- description:
408
- type: string
409
- input_keys[]:
410
- type: array
411
- items:
412
- type: string
413
- uniqueItems: true
414
- output_keys[]:
415
- type: array
416
- items:
417
- type: string
418
- uniqueItems: true
419
- metadata_keys[]:
420
- type: array
421
- items:
422
- type: string
423
- uniqueItems: true
424
- file:
425
- type: string
426
- format: binary
427
- responses:
428
- 200:
429
- description: Success
430
- 403:
431
- description: Forbidden
432
- 409:
433
- description: Dataset of the same name already exists
434
- 422:
435
- description: Request body is invalid
436
- """
437
- request_content_type = request.headers["content-type"]
438
- examples: Union[Examples, Awaitable[Examples]]
439
- if request_content_type.startswith("application/json"):
440
- try:
441
- examples, action, name, description = await run_in_threadpool(
442
- _process_json, await request.json()
443
- )
444
- except ValueError as e:
445
- return Response(
446
- content=str(e),
447
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
448
- )
449
- if action is DatasetAction.CREATE:
450
- async with request.app.state.db() as session:
451
- if await _check_table_exists(session, name):
452
- return Response(
453
- content=f"Dataset with the same name already exists: {name=}",
454
- status_code=HTTP_409_CONFLICT,
455
- )
456
- elif request_content_type.startswith("multipart/form-data"):
457
- async with request.form() as form:
458
- try:
459
- (
460
- action,
461
- name,
462
- description,
463
- input_keys,
464
- output_keys,
465
- metadata_keys,
466
- file,
467
- ) = await _parse_form_data(form)
468
- except ValueError as e:
469
- return Response(
470
- content=str(e),
471
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
472
- )
473
- if action is DatasetAction.CREATE:
474
- async with request.app.state.db() as session:
475
- if await _check_table_exists(session, name):
476
- return Response(
477
- content=f"Dataset with the same name already exists: {name=}",
478
- status_code=HTTP_409_CONFLICT,
479
- )
480
- content = await file.read()
481
- try:
482
- file_content_type = FileContentType(file.content_type)
483
- if file_content_type is FileContentType.CSV:
484
- encoding = FileContentEncoding(file.headers.get("content-encoding"))
485
- examples = await _process_csv(
486
- content, encoding, input_keys, output_keys, metadata_keys
487
- )
488
- elif file_content_type is FileContentType.PYARROW:
489
- examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
490
- else:
491
- assert_never(file_content_type)
492
- except ValueError as e:
493
- return Response(
494
- content=str(e),
495
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
496
- )
497
- else:
498
- return Response(
499
- content=str("Invalid request Content-Type"),
500
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
501
- )
502
- operation = cast(
503
- Callable[[AsyncSession], Awaitable[DatasetExampleAdditionEvent]],
504
- partial(
505
- add_dataset_examples,
506
- examples=examples,
507
- action=action,
508
- name=name,
509
- description=description,
510
- ),
511
- )
512
- if request.query_params.get("sync") == "true":
513
- async with request.app.state.db() as session:
514
- dataset_id = (await operation(session)).dataset_id
515
- return JSONResponse(
516
- content={"data": {"dataset_id": str(GlobalID(Dataset.__name__, str(dataset_id)))}}
517
- )
518
- try:
519
- request.state.enqueue_operation(operation)
520
- except QueueFull:
521
- if isinstance(examples, Coroutine):
522
- examples.close()
523
- return Response(status_code=HTTP_429_TOO_MANY_REQUESTS)
524
- return Response()
525
-
526
-
527
- class FileContentType(Enum):
528
- CSV = "text/csv"
529
- PYARROW = "application/x-pandas-pyarrow"
530
-
531
- @classmethod
532
- def _missing_(cls, v: Any) -> "FileContentType":
533
- if isinstance(v, str) and v and v.isascii() and not v.islower():
534
- return cls(v.lower())
535
- raise ValueError(f"Invalid file content type: {v}")
536
-
537
-
538
- class FileContentEncoding(Enum):
539
- NONE = "none"
540
- GZIP = "gzip"
541
- DEFLATE = "deflate"
542
-
543
- @classmethod
544
- def _missing_(cls, v: Any) -> "FileContentEncoding":
545
- if v is None:
546
- return cls("none")
547
- if isinstance(v, str) and v and v.isascii() and not v.islower():
548
- return cls(v.lower())
549
- raise ValueError(f"Invalid file content encoding: {v}")
550
-
551
-
552
- Name: TypeAlias = str
553
- Description: TypeAlias = Optional[str]
554
- InputKeys: TypeAlias = FrozenSet[str]
555
- OutputKeys: TypeAlias = FrozenSet[str]
556
- MetadataKeys: TypeAlias = FrozenSet[str]
557
- DatasetId: TypeAlias = int
558
- Examples: TypeAlias = Iterator[ExampleContent]
559
-
560
-
561
- def _process_json(
562
- data: Mapping[str, Any],
563
- ) -> Tuple[Examples, DatasetAction, Name, Description]:
564
- name = data.get("name")
565
- if not name:
566
- raise ValueError("Dataset name is required")
567
- description = data.get("description") or ""
568
- inputs = data.get("inputs")
569
- if not inputs:
570
- raise ValueError("input is required")
571
- if not isinstance(inputs, list) or not _is_all_dict(inputs):
572
- raise ValueError("Input should be a list containing only dictionary objects")
573
- outputs, metadata = data.get("outputs"), data.get("metadata")
574
- for k, v in {"outputs": outputs, "metadata": metadata}.items():
575
- if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
576
- raise ValueError(
577
- f"{k} should be a list of same length as input containing only dictionary objects"
578
- )
579
- examples: List[ExampleContent] = []
580
- for i, obj in enumerate(inputs):
581
- example = ExampleContent(
582
- input=obj,
583
- output=outputs[i] if outputs else {},
584
- metadata=metadata[i] if metadata else {},
585
- )
586
- examples.append(example)
587
- action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
588
- return iter(examples), action, name, description
589
-
590
-
591
- async def _process_csv(
592
- content: bytes,
593
- content_encoding: FileContentEncoding,
594
- input_keys: InputKeys,
595
- output_keys: OutputKeys,
596
- metadata_keys: MetadataKeys,
597
- ) -> Examples:
598
- if content_encoding is FileContentEncoding.GZIP:
599
- content = await run_in_threadpool(gzip.decompress, content)
600
- elif content_encoding is FileContentEncoding.DEFLATE:
601
- content = await run_in_threadpool(zlib.decompress, content)
602
- elif content_encoding is not FileContentEncoding.NONE:
603
- assert_never(content_encoding)
604
- reader = await run_in_threadpool(lambda c: csv.DictReader(io.StringIO(c.decode())), content)
605
- if reader.fieldnames is None:
606
- raise ValueError("Missing CSV column header")
607
- (header, freq), *_ = Counter(reader.fieldnames).most_common(1)
608
- if freq > 1:
609
- raise ValueError(f"Duplicated column header in CSV file: {header}")
610
- column_headers = frozenset(reader.fieldnames)
611
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
612
- return (
613
- ExampleContent(
614
- input={k: row.get(k) for k in input_keys},
615
- output={k: row.get(k) for k in output_keys},
616
- metadata={k: row.get(k) for k in metadata_keys},
617
- )
618
- for row in iter(reader)
619
- )
620
-
621
-
622
- async def _process_pyarrow(
623
- content: bytes,
624
- input_keys: InputKeys,
625
- output_keys: OutputKeys,
626
- metadata_keys: MetadataKeys,
627
- ) -> Awaitable[Examples]:
628
- try:
629
- reader = pa.ipc.open_stream(content)
630
- except pa.ArrowInvalid as e:
631
- raise ValueError("File is not valid pyarrow") from e
632
- column_headers = frozenset(reader.schema.names)
633
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
634
-
635
- def get_examples() -> Iterator[ExampleContent]:
636
- for row in reader.read_pandas().to_dict(orient="records"):
637
- yield ExampleContent(
638
- input={k: row.get(k) for k in input_keys},
639
- output={k: row.get(k) for k in output_keys},
640
- metadata={k: row.get(k) for k in metadata_keys},
641
- )
642
-
643
- return run_in_threadpool(get_examples)
644
-
645
-
646
- async def _check_table_exists(session: AsyncSession, name: str) -> bool:
647
- return bool(
648
- await session.scalar(
649
- select(1).select_from(models.Dataset).where(models.Dataset.name == name)
650
- )
651
- )
652
-
653
-
654
- def _check_keys_exist(
655
- column_headers: FrozenSet[str],
656
- input_keys: InputKeys,
657
- output_keys: OutputKeys,
658
- metadata_keys: MetadataKeys,
659
- ) -> None:
660
- for desc, keys in (
661
- ("input", input_keys),
662
- ("output", output_keys),
663
- ("metadata", metadata_keys),
664
- ):
665
- if keys and (diff := keys.difference(column_headers)):
666
- raise ValueError(f"{desc} keys not found in column headers: {diff}")
667
-
668
-
669
- async def _parse_form_data(
670
- form: FormData,
671
- ) -> Tuple[
672
- DatasetAction,
673
- Name,
674
- Description,
675
- InputKeys,
676
- OutputKeys,
677
- MetadataKeys,
678
- UploadFile,
679
- ]:
680
- name = cast(Optional[str], form.get("name"))
681
- if not name:
682
- raise ValueError("Dataset name must not be empty")
683
- action = DatasetAction(cast(Optional[str], form.get("action")) or "create")
684
- file = form["file"]
685
- if not isinstance(file, UploadFile):
686
- raise ValueError("Malformed file in form data.")
687
- description = cast(Optional[str], form.get("description")) or file.filename
688
- input_keys = frozenset(filter(bool, cast(List[str], form.getlist("input_keys[]"))))
689
- output_keys = frozenset(filter(bool, cast(List[str], form.getlist("output_keys[]"))))
690
- metadata_keys = frozenset(filter(bool, cast(List[str], form.getlist("metadata_keys[]"))))
691
- return (
692
- action,
693
- name,
694
- description,
695
- input_keys,
696
- output_keys,
697
- metadata_keys,
698
- file,
699
- )
700
-
701
-
702
- async def get_dataset_csv(request: Request) -> Response:
703
- """
704
- summary: Download dataset examples as CSV text file
705
- operationId: getDatasetCsv
706
- tags:
707
- - datasets
708
- parameters:
709
- - in: path
710
- name: id
711
- required: true
712
- schema:
713
- type: string
714
- description: Dataset ID
715
- - in: query
716
- name: version
717
- schema:
718
- type: string
719
- description: Dataset version ID. If omitted, returns the latest version.
720
- responses:
721
- 200:
722
- description: Success
723
- content:
724
- text/csv:
725
- schema:
726
- type: string
727
- contentMediaType: text/csv
728
- contentEncoding: gzip
729
- 403:
730
- description: Forbidden
731
- 404:
732
- description: Dataset does not exist.
733
- 422:
734
- description: Dataset ID or version ID is invalid.
735
- """
736
- try:
737
- dataset_name, examples = await _get_db_examples(request)
738
- except ValueError as e:
739
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
740
- content = await run_in_threadpool(_get_content_csv, examples)
741
- return Response(
742
- content=content,
743
- headers={
744
- "content-disposition": f'attachment; filename="{dataset_name}.csv"',
745
- "content-type": "text/csv",
746
- "content-encoding": "gzip",
747
- },
748
- )
749
-
750
-
751
- async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
752
- """
753
- summary: Download dataset examples as OpenAI Fine-Tuning JSONL file
754
- operationId: getDatasetJSONLOpenAIFineTuning
755
- tags:
756
- - datasets
757
- parameters:
758
- - in: path
759
- name: id
760
- required: true
761
- schema:
762
- type: string
763
- description: Dataset ID
764
- - in: query
765
- name: version
766
- schema:
767
- type: string
768
- description: Dataset version ID. If omitted, returns the latest version.
769
- responses:
770
- 200:
771
- description: Success
772
- content:
773
- text/plain:
774
- schema:
775
- type: string
776
- contentMediaType: text/plain
777
- contentEncoding: gzip
778
- 403:
779
- description: Forbidden
780
- 404:
781
- description: Dataset does not exist.
782
- 422:
783
- description: Dataset ID or version ID is invalid.
784
- """
785
- try:
786
- dataset_name, examples = await _get_db_examples(request)
787
- except ValueError as e:
788
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
789
- content = await run_in_threadpool(_get_content_jsonl_openai_ft, examples)
790
- return Response(
791
- content=content,
792
- headers={
793
- "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
794
- "content-type": "text/plain",
795
- "content-encoding": "gzip",
796
- },
797
- )
798
-
799
-
800
- async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
801
- """
802
- summary: Download dataset examples as OpenAI Evals JSONL file
803
- operationId: getDatasetJSONLOpenAIEvals
804
- tags:
805
- - datasets
806
- parameters:
807
- - in: path
808
- name: id
809
- required: true
810
- schema:
811
- type: string
812
- description: Dataset ID
813
- - in: query
814
- name: version
815
- schema:
816
- type: string
817
- description: Dataset version ID. If omitted, returns the latest version.
818
- responses:
819
- 200:
820
- description: Success
821
- content:
822
- text/plain:
823
- schema:
824
- type: string
825
- contentMediaType: text/plain
826
- contentEncoding: gzip
827
- 403:
828
- description: Forbidden
829
- 404:
830
- description: Dataset does not exist.
831
- 422:
832
- description: Dataset ID or version ID is invalid.
833
- """
834
- try:
835
- dataset_name, examples = await _get_db_examples(request)
836
- except ValueError as e:
837
- return Response(content=str(e), status_code=HTTP_422_UNPROCESSABLE_ENTITY)
838
- content = await run_in_threadpool(_get_content_jsonl_openai_evals, examples)
839
- return Response(
840
- content=content,
841
- headers={
842
- "content-disposition": f'attachment; filename="{dataset_name}.jsonl"',
843
- "content-type": "text/plain",
844
- "content-encoding": "gzip",
845
- },
846
- )
847
-
848
-
849
- def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
850
- records = [
851
- {
852
- "example_id": GlobalID(
853
- type_name=DatasetExample.__name__,
854
- node_id=str(ex.dataset_example_id),
855
- ),
856
- **{f"input_{k}": v for k, v in ex.input.items()},
857
- **{f"output_{k}": v for k, v in ex.output.items()},
858
- **{f"metadata_{k}": v for k, v in ex.metadata_.items()},
859
- }
860
- for ex in examples
861
- ]
862
- return gzip.compress(pd.DataFrame.from_records(records).to_csv(index=False).encode())
863
-
864
-
865
- def _get_content_jsonl_openai_ft(examples: List[models.DatasetExampleRevision]) -> bytes:
866
- records = io.BytesIO()
867
- for ex in examples:
868
- records.write(
869
- (
870
- json.dumps(
871
- {
872
- "messages": (
873
- ims if isinstance(ims := ex.input.get("messages"), list) else []
874
- )
875
- + (oms if isinstance(oms := ex.output.get("messages"), list) else [])
876
- },
877
- ensure_ascii=False,
878
- )
879
- + "\n"
880
- ).encode()
881
- )
882
- records.seek(0)
883
- return gzip.compress(records.read())
884
-
885
-
886
- def _get_content_jsonl_openai_evals(examples: List[models.DatasetExampleRevision]) -> bytes:
887
- records = io.BytesIO()
888
- for ex in examples:
889
- records.write(
890
- (
891
- json.dumps(
892
- {
893
- "messages": ims
894
- if isinstance(ims := ex.input.get("messages"), list)
895
- else [],
896
- "ideal": (
897
- ideal if isinstance(ideal := last_message.get("content"), str) else ""
898
- )
899
- if isinstance(oms := ex.output.get("messages"), list)
900
- and oms
901
- and hasattr(last_message := oms[-1], "get")
902
- else "",
903
- },
904
- ensure_ascii=False,
905
- )
906
- + "\n"
907
- ).encode()
908
- )
909
- records.seek(0)
910
- return gzip.compress(records.read())
911
-
912
-
913
- async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetExampleRevision]]:
914
- if not (id_ := request.path_params.get("id")):
915
- raise ValueError("Missing Dataset ID")
916
- dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
917
- dataset_version_id: Optional[int] = None
918
- if version := request.query_params.get("version"):
919
- dataset_version_id = from_global_id_with_expected_type(
920
- GlobalID.from_id(version),
921
- DatasetVersion.__name__,
922
- )
923
- latest_version = (
924
- select(
925
- models.DatasetExampleRevision.dataset_example_id,
926
- func.max(models.DatasetExampleRevision.dataset_version_id).label("dataset_version_id"),
927
- )
928
- .group_by(models.DatasetExampleRevision.dataset_example_id)
929
- .join(models.DatasetExample)
930
- .where(models.DatasetExample.dataset_id == dataset_id)
931
- )
932
- if dataset_version_id is not None:
933
- max_dataset_version_id = (
934
- select(models.DatasetVersion.id)
935
- .where(models.DatasetVersion.id == dataset_version_id)
936
- .where(models.DatasetVersion.dataset_id == dataset_id)
937
- ).scalar_subquery()
938
- latest_version = latest_version.where(
939
- models.DatasetExampleRevision.dataset_version_id <= max_dataset_version_id
940
- )
941
- subq = latest_version.subquery("latest_version")
942
- stmt = (
943
- select(models.DatasetExampleRevision)
944
- .join(
945
- subq,
946
- onclause=and_(
947
- models.DatasetExampleRevision.dataset_example_id == subq.c.dataset_example_id,
948
- models.DatasetExampleRevision.dataset_version_id == subq.c.dataset_version_id,
949
- ),
950
- )
951
- .where(models.DatasetExampleRevision.revision_kind != "DELETE")
952
- .order_by(models.DatasetExampleRevision.dataset_example_id)
953
- )
954
- async with request.app.state.db() as session:
955
- dataset_name: Optional[str] = await session.scalar(
956
- select(models.Dataset.name).where(models.Dataset.id == dataset_id)
957
- )
958
- if not dataset_name:
959
- raise ValueError("Dataset does not exist.")
960
- examples = [r async for r in await session.stream_scalars(stmt)]
961
- return dataset_name, examples
962
-
963
-
964
- def _is_all_dict(seq: Sequence[Any]) -> bool:
965
- return all(map(lambda obj: isinstance(obj, dict), seq))