arize-phoenix 4.4.4rc4__py3-none-any.whl → 4.4.4rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

Files changed (52) hide show
  1. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/METADATA +12 -6
  2. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/RECORD +47 -42
  3. phoenix/config.py +21 -0
  4. phoenix/datetime_utils.py +4 -0
  5. phoenix/db/insertion/dataset.py +19 -16
  6. phoenix/db/insertion/evaluation.py +4 -4
  7. phoenix/db/insertion/helpers.py +4 -12
  8. phoenix/db/insertion/span.py +3 -3
  9. phoenix/db/migrations/versions/10460e46d750_datasets.py +2 -2
  10. phoenix/db/models.py +8 -3
  11. phoenix/experiments/__init__.py +6 -0
  12. phoenix/experiments/evaluators/__init__.py +29 -0
  13. phoenix/experiments/evaluators/base.py +153 -0
  14. phoenix/{datasets → experiments}/evaluators/code_evaluators.py +25 -53
  15. phoenix/{datasets → experiments}/evaluators/llm_evaluators.py +62 -31
  16. phoenix/experiments/evaluators/utils.py +189 -0
  17. phoenix/experiments/functions.py +616 -0
  18. phoenix/{datasets → experiments}/tracing.py +19 -0
  19. phoenix/experiments/types.py +722 -0
  20. phoenix/experiments/utils.py +9 -0
  21. phoenix/server/api/context.py +4 -0
  22. phoenix/server/api/dataloaders/__init__.py +4 -0
  23. phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
  24. phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
  25. phoenix/server/api/helpers/dataset_helpers.py +8 -7
  26. phoenix/server/api/input_types/ClearProjectInput.py +15 -0
  27. phoenix/server/api/mutations/project_mutations.py +9 -4
  28. phoenix/server/api/routers/v1/__init__.py +1 -1
  29. phoenix/server/api/routers/v1/dataset_examples.py +10 -10
  30. phoenix/server/api/routers/v1/datasets.py +152 -48
  31. phoenix/server/api/routers/v1/evaluations.py +4 -11
  32. phoenix/server/api/routers/v1/experiment_evaluations.py +23 -23
  33. phoenix/server/api/routers/v1/experiment_runs.py +5 -17
  34. phoenix/server/api/routers/v1/experiments.py +5 -5
  35. phoenix/server/api/routers/v1/spans.py +6 -4
  36. phoenix/server/api/types/Experiment.py +12 -0
  37. phoenix/server/api/types/ExperimentRun.py +1 -1
  38. phoenix/server/api/types/ExperimentRunAnnotation.py +1 -1
  39. phoenix/server/app.py +4 -0
  40. phoenix/server/static/index.js +712 -588
  41. phoenix/session/client.py +321 -28
  42. phoenix/trace/fixtures.py +6 -6
  43. phoenix/utilities/json.py +8 -8
  44. phoenix/version.py +1 -1
  45. phoenix/datasets/__init__.py +0 -0
  46. phoenix/datasets/evaluators/__init__.py +0 -18
  47. phoenix/datasets/evaluators/_utils.py +0 -13
  48. phoenix/datasets/experiments.py +0 -485
  49. phoenix/datasets/types.py +0 -212
  50. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/WHEEL +0 -0
  51. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/licenses/IP_NOTICE +0 -0
  52. {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,9 @@
1
+ from phoenix.config import get_web_base_url
2
+
3
+
4
+ def get_experiment_url(*, dataset_id: str, experiment_id: str) -> str:
5
+ return f"{get_web_base_url()}datasets/{dataset_id}/compare?experimentId={experiment_id}"
6
+
7
+
8
+ def get_dataset_experiments_url(*, dataset_id: str) -> str:
9
+ return f"{get_web_base_url()}datasets/{dataset_id}/experiments"
@@ -11,6 +11,7 @@ from typing_extensions import TypeAlias
11
11
 
12
12
  from phoenix.core.model_schema import Model
13
13
  from phoenix.server.api.dataloaders import (
14
+ AverageExperimentRunLatencyDataLoader,
14
15
  CacheForDataLoaders,
15
16
  DatasetExampleRevisionsDataLoader,
16
17
  DatasetExampleSpansDataLoader,
@@ -20,6 +21,7 @@ from phoenix.server.api.dataloaders import (
20
21
  EvaluationSummaryDataLoader,
21
22
  ExperimentAnnotationSummaryDataLoader,
22
23
  ExperimentErrorRatesDataLoader,
24
+ ExperimentRunCountsDataLoader,
23
25
  ExperimentSequenceNumberDataLoader,
24
26
  LatencyMsQuantileDataLoader,
25
27
  MinStartOrMaxEndTimeDataLoader,
@@ -36,6 +38,7 @@ from phoenix.server.api.dataloaders import (
36
38
 
37
39
  @dataclass
38
40
  class DataLoaders:
41
+ average_experiment_run_latency: AverageExperimentRunLatencyDataLoader
39
42
  dataset_example_revisions: DatasetExampleRevisionsDataLoader
40
43
  dataset_example_spans: DatasetExampleSpansDataLoader
41
44
  document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
@@ -44,6 +47,7 @@ class DataLoaders:
44
47
  evaluation_summaries: EvaluationSummaryDataLoader
45
48
  experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
46
49
  experiment_error_rates: ExperimentErrorRatesDataLoader
50
+ experiment_run_counts: ExperimentRunCountsDataLoader
47
51
  experiment_sequence_number: ExperimentSequenceNumberDataLoader
48
52
  latency_ms_quantile: LatencyMsQuantileDataLoader
49
53
  min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
@@ -8,6 +8,7 @@ from phoenix.db.insertion.evaluation import (
8
8
  )
9
9
  from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
10
10
 
11
+ from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
11
12
  from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
12
13
  from .dataset_example_spans import DatasetExampleSpansDataLoader
13
14
  from .document_evaluation_summaries import (
@@ -19,6 +20,7 @@ from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
19
20
  from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
20
21
  from .experiment_annotation_summaries import ExperimentAnnotationSummaryDataLoader
21
22
  from .experiment_error_rates import ExperimentErrorRatesDataLoader
23
+ from .experiment_run_counts import ExperimentRunCountsDataLoader
22
24
  from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
23
25
  from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
24
26
  from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
@@ -33,6 +35,7 @@ from .trace_row_ids import TraceRowIdsDataLoader
33
35
 
34
36
  __all__ = [
35
37
  "CacheForDataLoaders",
38
+ "AverageExperimentRunLatencyDataLoader",
36
39
  "DatasetExampleRevisionsDataLoader",
37
40
  "DatasetExampleSpansDataLoader",
38
41
  "DocumentEvaluationSummaryDataLoader",
@@ -41,6 +44,7 @@ __all__ = [
41
44
  "EvaluationSummaryDataLoader",
42
45
  "ExperimentAnnotationSummaryDataLoader",
43
46
  "ExperimentErrorRatesDataLoader",
47
+ "ExperimentRunCountsDataLoader",
44
48
  "ExperimentSequenceNumberDataLoader",
45
49
  "LatencyMsQuantileDataLoader",
46
50
  "MinStartOrMaxEndTimeDataLoader",
@@ -0,0 +1,54 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ )
6
+
7
+ from sqlalchemy import func, select
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+ from strawberry.dataloader import DataLoader
10
+ from typing_extensions import TypeAlias
11
+
12
+ from phoenix.db import models
13
+
14
+ ExperimentID: TypeAlias = int
15
+ RunLatency: TypeAlias = float
16
+ Key: TypeAlias = ExperimentID
17
+ Result: TypeAlias = RunLatency
18
+
19
+
20
+ class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
21
+ def __init__(
22
+ self,
23
+ db: Callable[[], AsyncContextManager[AsyncSession]],
24
+ ) -> None:
25
+ super().__init__(load_fn=self._load_fn)
26
+ self._db = db
27
+
28
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
+ experiment_ids = keys
30
+ async with self._db() as session:
31
+ avg_latencies = {
32
+ experiment_id: avg_latency
33
+ async for experiment_id, avg_latency in await session.stream(
34
+ select(
35
+ models.ExperimentRun.experiment_id,
36
+ func.avg(
37
+ func.extract(
38
+ "epoch",
39
+ models.ExperimentRun.end_time,
40
+ )
41
+ - func.extract(
42
+ "epoch",
43
+ models.ExperimentRun.start_time,
44
+ )
45
+ ),
46
+ )
47
+ .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
48
+ .group_by(models.ExperimentRun.experiment_id)
49
+ )
50
+ }
51
+ return [
52
+ avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
53
+ for experiment_id in experiment_ids
54
+ ]
@@ -0,0 +1,42 @@
1
+ from typing import (
2
+ AsyncContextManager,
3
+ Callable,
4
+ List,
5
+ )
6
+
7
+ from sqlalchemy import func, select
8
+ from sqlalchemy.ext.asyncio import AsyncSession
9
+ from strawberry.dataloader import DataLoader
10
+ from typing_extensions import TypeAlias
11
+
12
+ from phoenix.db import models
13
+
14
+ ExperimentID: TypeAlias = int
15
+ RunCount: TypeAlias = int
16
+ Key: TypeAlias = ExperimentID
17
+ Result: TypeAlias = RunCount
18
+
19
+
20
+ class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
21
+ def __init__(
22
+ self,
23
+ db: Callable[[], AsyncContextManager[AsyncSession]],
24
+ ) -> None:
25
+ super().__init__(load_fn=self._load_fn)
26
+ self._db = db
27
+
28
+ async def _load_fn(self, keys: List[Key]) -> List[Result]:
29
+ experiment_ids = keys
30
+ async with self._db() as session:
31
+ run_counts = {
32
+ experiment_id: run_count
33
+ async for experiment_id, run_count in await session.stream(
34
+ select(models.ExperimentRun.experiment_id, func.count())
35
+ .where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
36
+ .group_by(models.ExperimentRun.experiment_id)
37
+ )
38
+ }
39
+ return [
40
+ run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
41
+ for experiment_id in experiment_ids
42
+ ]
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from typing import Any, Dict, Literal, Mapping, Optional, Protocol
2
3
 
3
4
  from openinference.semconv.trace import (
@@ -128,14 +129,14 @@ def _get_generic_io_value(
128
129
  Makes a best-effort attempt to extract the input or output value from a span
129
130
  and returns it as a dictionary.
130
131
  """
131
- if isinstance(io_value, str) and (
132
- mime_type == OpenInferenceMimeTypeValues.TEXT.value or mime_type is None
133
- ):
132
+ if mime_type == OpenInferenceMimeTypeValues.JSON.value:
133
+ parsed_value = json.loads(io_value)
134
+ if isinstance(parsed_value, dict):
135
+ return parsed_value
136
+ else:
137
+ return {kind: parsed_value}
138
+ if isinstance(io_value, str):
134
139
  return {kind: io_value}
135
- if isinstance(io_value, dict) and (
136
- mime_type == OpenInferenceMimeTypeValues.JSON.value or mime_type is None
137
- ):
138
- return io_value
139
140
  return {}
140
141
 
141
142
 
@@ -0,0 +1,15 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ import strawberry
5
+ from strawberry import UNSET
6
+ from strawberry.relay import GlobalID
7
+
8
+
9
+ @strawberry.input
10
+ class ClearProjectInput:
11
+ id: GlobalID
12
+ end_time: Optional[datetime] = strawberry.field(
13
+ default=UNSET,
14
+ description="The time up to which to purge data. Time is right-open /non-inclusive.",
15
+ )
@@ -8,6 +8,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME
8
8
  from phoenix.db import models
9
9
  from phoenix.db.insertion.span import ClearProjectSpansEvent
10
10
  from phoenix.server.api.context import Context
11
+ from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput
11
12
  from phoenix.server.api.mutations.auth import IsAuthenticated
12
13
  from phoenix.server.api.queries import Query
13
14
  from phoenix.server.api.types.node import from_global_id_with_expected_type
@@ -32,11 +33,15 @@ class ProjectMutationMixin:
32
33
  return Query()
33
34
 
34
35
  @strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
35
- async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
36
- project_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
36
+ async def clear_project(self, info: Info[Context, None], input: ClearProjectInput) -> Query:
37
+ project_id = from_global_id_with_expected_type(
38
+ global_id=input.id, expected_type_name="Project"
39
+ )
37
40
  delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
41
+ if input.end_time is not None:
42
+ delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
38
43
  async with info.context.db() as session:
39
44
  await session.execute(delete_statement)
40
- if cache := info.context.cache_for_dataloaders:
41
- cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
45
+ if cache := info.context.cache_for_dataloaders:
46
+ cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
42
47
  return Query()
@@ -80,7 +80,7 @@ V1_ROUTES = [
80
80
  ),
81
81
  Route(
82
82
  "/v1/experiment_evaluations",
83
- experiment_evaluations.create_experiment_evaluation,
83
+ experiment_evaluations.upsert_experiment_evaluation,
84
84
  methods=["POST"],
85
85
  ),
86
86
  ]
@@ -21,7 +21,7 @@ async def list_dataset_examples(request: Request) -> Response:
21
21
  type: string
22
22
  description: Dataset ID
23
23
  - in: query
24
- name: version-id
24
+ name: version_id
25
25
  schema:
26
26
  type: string
27
27
  description: Dataset version ID. If omitted, returns the latest version.
@@ -79,7 +79,7 @@ async def list_dataset_examples(request: Request) -> Response:
79
79
  description: Dataset does not exist.
80
80
  """
81
81
  dataset_id = GlobalID.from_id(request.path_params["id"])
82
- raw_version_id = request.query_params.get("version-id")
82
+ raw_version_id = request.query_params.get("version_id")
83
83
  version_id = GlobalID.from_id(raw_version_id) if raw_version_id else None
84
84
 
85
85
  if (dataset_type := dataset_id.type_name) != "Dataset":
@@ -167,12 +167,12 @@ async def list_dataset_examples(request: Request) -> Response:
167
167
  }
168
168
  async for example, revision in await session.stream(query)
169
169
  ]
170
- return JSONResponse(
171
- {
172
- "data": {
173
- "dataset_id": str(GlobalID("Dataset", str(resolved_dataset_id))),
174
- "version_id": str(GlobalID("DatasetVersion", str(resolved_version_id))),
175
- "examples": examples,
176
- }
170
+ return JSONResponse(
171
+ {
172
+ "data": {
173
+ "dataset_id": str(GlobalID("Dataset", str(resolved_dataset_id))),
174
+ "version_id": str(GlobalID("DatasetVersion", str(resolved_version_id))),
175
+ "examples": examples,
177
176
  }
178
- )
177
+ }
178
+ )
@@ -13,11 +13,12 @@ from typing import (
13
13
  Awaitable,
14
14
  Callable,
15
15
  Coroutine,
16
- Dict,
17
16
  FrozenSet,
18
17
  Iterator,
19
18
  List,
19
+ Mapping,
20
20
  Optional,
21
+ Sequence,
21
22
  Tuple,
22
23
  Union,
23
24
  cast,
@@ -32,8 +33,8 @@ from starlette.datastructures import FormData, UploadFile
32
33
  from starlette.requests import Request
33
34
  from starlette.responses import JSONResponse, Response
34
35
  from starlette.status import (
35
- HTTP_403_FORBIDDEN,
36
36
  HTTP_404_NOT_FOUND,
37
+ HTTP_409_CONFLICT,
37
38
  HTTP_422_UNPROCESSABLE_ENTITY,
38
39
  HTTP_429_TOO_MANY_REQUESTS,
39
40
  )
@@ -44,6 +45,7 @@ from phoenix.db import models
44
45
  from phoenix.db.insertion.dataset import (
45
46
  DatasetAction,
46
47
  DatasetExampleAdditionEvent,
48
+ ExampleContent,
47
49
  add_dataset_examples,
48
50
  )
49
51
  from phoenix.server.api.types.Dataset import Dataset
@@ -231,7 +233,7 @@ async def get_dataset_by_id(request: Request) -> Response:
231
233
  "updated_at": dataset.updated_at.isoformat(),
232
234
  "example_count": example_count,
233
235
  }
234
- return JSONResponse(content=output_dict)
236
+ return JSONResponse(content={"data": output_dict})
235
237
 
236
238
 
237
239
  async def get_dataset_versions(request: Request) -> Response:
@@ -350,7 +352,7 @@ async def get_dataset_versions(request: Request) -> Response:
350
352
 
351
353
  async def post_datasets_upload(request: Request) -> Response:
352
354
  """
353
- summary: Upload CSV or PyArrow file as dataset
355
+ summary: Upload dataset as either JSON or file (CSV or PyArrow)
354
356
  operationId: uploadDataset
355
357
  tags:
356
358
  - datasets
@@ -362,6 +364,32 @@ async def post_datasets_upload(request: Request) -> Response:
362
364
  type: boolean
363
365
  requestBody:
364
366
  content:
367
+ application/json:
368
+ schema:
369
+ type: object
370
+ required:
371
+ - name
372
+ - inputs
373
+ properties:
374
+ action:
375
+ type: string
376
+ enum: [create, append]
377
+ name:
378
+ type: string
379
+ description:
380
+ type: string
381
+ inputs:
382
+ type: array
383
+ items:
384
+ type: object
385
+ outputs:
386
+ type: array
387
+ items:
388
+ type: object
389
+ metadata:
390
+ type: array
391
+ items:
392
+ type: object
365
393
  multipart/form-data:
366
394
  schema:
367
395
  type: object
@@ -401,22 +429,18 @@ async def post_datasets_upload(request: Request) -> Response:
401
429
  description: Success
402
430
  403:
403
431
  description: Forbidden
432
+ 409:
433
+ description: Dataset of the same name already exists
404
434
  422:
405
435
  description: Request body is invalid
406
436
  """
407
- if request.app.state.read_only:
408
- return Response(status_code=HTTP_403_FORBIDDEN)
409
- async with request.form() as form:
437
+ request_content_type = request.headers["content-type"]
438
+ examples: Union[Examples, Awaitable[Examples]]
439
+ if request_content_type.startswith("application/json"):
410
440
  try:
411
- (
412
- action,
413
- name,
414
- description,
415
- input_keys,
416
- output_keys,
417
- metadata_keys,
418
- file,
419
- ) = await _parse_form_data(form)
441
+ examples, action, name, description = await run_in_threadpool(
442
+ _process_json, await request.json()
443
+ )
420
444
  except ValueError as e:
421
445
  return Response(
422
446
  content=str(e),
@@ -426,24 +450,53 @@ async def post_datasets_upload(request: Request) -> Response:
426
450
  async with request.app.state.db() as session:
427
451
  if await _check_table_exists(session, name):
428
452
  return Response(
429
- content=f"Dataset already exists: {name=}",
430
- status_code=HTTP_422_UNPROCESSABLE_ENTITY,
453
+ content=f"Dataset with the same name already exists: {name=}",
454
+ status_code=HTTP_409_CONFLICT,
431
455
  )
432
- content = await file.read()
433
- try:
434
- examples: Union[Examples, Awaitable[Examples]]
435
- content_type = FileContentType(file.content_type)
436
- if content_type is FileContentType.CSV:
437
- encoding = FileContentEncoding(file.headers.get("content-encoding"))
438
- examples, column_headers = await _process_csv(content, encoding)
439
- elif content_type is FileContentType.PYARROW:
440
- examples, column_headers = await _process_pyarrow(content)
441
- else:
442
- assert_never(content_type)
443
- _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
444
- except ValueError as e:
456
+ elif request_content_type.startswith("multipart/form-data"):
457
+ async with request.form() as form:
458
+ try:
459
+ (
460
+ action,
461
+ name,
462
+ description,
463
+ input_keys,
464
+ output_keys,
465
+ metadata_keys,
466
+ file,
467
+ ) = await _parse_form_data(form)
468
+ except ValueError as e:
469
+ return Response(
470
+ content=str(e),
471
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
472
+ )
473
+ if action is DatasetAction.CREATE:
474
+ async with request.app.state.db() as session:
475
+ if await _check_table_exists(session, name):
476
+ return Response(
477
+ content=f"Dataset with the same name already exists: {name=}",
478
+ status_code=HTTP_409_CONFLICT,
479
+ )
480
+ content = await file.read()
481
+ try:
482
+ file_content_type = FileContentType(file.content_type)
483
+ if file_content_type is FileContentType.CSV:
484
+ encoding = FileContentEncoding(file.headers.get("content-encoding"))
485
+ examples = await _process_csv(
486
+ content, encoding, input_keys, output_keys, metadata_keys
487
+ )
488
+ elif file_content_type is FileContentType.PYARROW:
489
+ examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
490
+ else:
491
+ assert_never(file_content_type)
492
+ except ValueError as e:
493
+ return Response(
494
+ content=str(e),
495
+ status_code=HTTP_422_UNPROCESSABLE_ENTITY,
496
+ )
497
+ else:
445
498
  return Response(
446
- content=str(e),
499
+ content=str("Invalid request Content-Type"),
447
500
  status_code=HTTP_422_UNPROCESSABLE_ENTITY,
448
501
  )
449
502
  operation = cast(
@@ -454,9 +507,6 @@ async def post_datasets_upload(request: Request) -> Response:
454
507
  action=action,
455
508
  name=name,
456
509
  description=description,
457
- input_keys=input_keys,
458
- output_keys=output_keys,
459
- metadata_keys=metadata_keys,
460
510
  ),
461
511
  )
462
512
  if request.query_params.get("sync") == "true":
@@ -505,13 +555,46 @@ InputKeys: TypeAlias = FrozenSet[str]
505
555
  OutputKeys: TypeAlias = FrozenSet[str]
506
556
  MetadataKeys: TypeAlias = FrozenSet[str]
507
557
  DatasetId: TypeAlias = int
508
- Examples: TypeAlias = Iterator[Dict[str, Any]]
558
+ Examples: TypeAlias = Iterator[ExampleContent]
559
+
560
+
561
+ def _process_json(
562
+ data: Mapping[str, Any],
563
+ ) -> Tuple[Examples, DatasetAction, Name, Description]:
564
+ name = data.get("name")
565
+ if not name:
566
+ raise ValueError("Dataset name is required")
567
+ description = data.get("description") or ""
568
+ inputs = data.get("inputs")
569
+ if not inputs:
570
+ raise ValueError("input is required")
571
+ if not isinstance(inputs, list) or not _is_all_dict(inputs):
572
+ raise ValueError("Input should be a list containing only dictionary objects")
573
+ outputs, metadata = data.get("outputs"), data.get("metadata")
574
+ for k, v in {"outputs": outputs, "metadata": metadata}.items():
575
+ if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
576
+ raise ValueError(
577
+ f"{k} should be a list of same length as input containing only dictionary objects"
578
+ )
579
+ examples: List[ExampleContent] = []
580
+ for i, obj in enumerate(inputs):
581
+ example = ExampleContent(
582
+ input=obj,
583
+ output=outputs[i] if outputs else {},
584
+ metadata=metadata[i] if metadata else {},
585
+ )
586
+ examples.append(example)
587
+ action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
588
+ return iter(examples), action, name, description
509
589
 
510
590
 
511
591
  async def _process_csv(
512
592
  content: bytes,
513
593
  content_encoding: FileContentEncoding,
514
- ) -> Tuple[Examples, FrozenSet[str]]:
594
+ input_keys: InputKeys,
595
+ output_keys: OutputKeys,
596
+ metadata_keys: MetadataKeys,
597
+ ) -> Examples:
515
598
  if content_encoding is FileContentEncoding.GZIP:
516
599
  content = await run_in_threadpool(gzip.decompress, content)
517
600
  elif content_encoding is FileContentEncoding.DEFLATE:
@@ -525,22 +608,39 @@ async def _process_csv(
525
608
  if freq > 1:
526
609
  raise ValueError(f"Duplicated column header in CSV file: {header}")
527
610
  column_headers = frozenset(reader.fieldnames)
528
- return reader, column_headers
611
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
612
+ return (
613
+ ExampleContent(
614
+ input={k: row.get(k) for k in input_keys},
615
+ output={k: row.get(k) for k in output_keys},
616
+ metadata={k: row.get(k) for k in metadata_keys},
617
+ )
618
+ for row in iter(reader)
619
+ )
529
620
 
530
621
 
531
622
  async def _process_pyarrow(
532
623
  content: bytes,
533
- ) -> Tuple[Awaitable[Examples], FrozenSet[str]]:
624
+ input_keys: InputKeys,
625
+ output_keys: OutputKeys,
626
+ metadata_keys: MetadataKeys,
627
+ ) -> Awaitable[Examples]:
534
628
  try:
535
629
  reader = pa.ipc.open_stream(content)
536
630
  except pa.ArrowInvalid as e:
537
631
  raise ValueError("File is not valid pyarrow") from e
538
632
  column_headers = frozenset(reader.schema.names)
633
+ _check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
634
+
635
+ def get_examples() -> Iterator[ExampleContent]:
636
+ for row in reader.read_pandas().to_dict(orient="records"):
637
+ yield ExampleContent(
638
+ input={k: row.get(k) for k in input_keys},
639
+ output={k: row.get(k) for k in output_keys},
640
+ metadata={k: row.get(k) for k in metadata_keys},
641
+ )
539
642
 
540
- def get_examples() -> Iterator[Dict[str, Any]]:
541
- yield from reader.read_pandas().to_dict(orient="records")
542
-
543
- return run_in_threadpool(get_examples), column_headers
643
+ return run_in_threadpool(get_examples)
544
644
 
545
645
 
546
646
  async def _check_table_exists(session: AsyncSession, name: str) -> bool:
@@ -613,7 +713,7 @@ async def get_dataset_csv(request: Request) -> Response:
613
713
  type: string
614
714
  description: Dataset ID
615
715
  - in: query
616
- name: version
716
+ name: version_id
617
717
  schema:
618
718
  type: string
619
719
  description: Dataset version ID. If omitted, returns the latest version.
@@ -662,7 +762,7 @@ async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
662
762
  type: string
663
763
  description: Dataset ID
664
764
  - in: query
665
- name: version
765
+ name: version_id
666
766
  schema:
667
767
  type: string
668
768
  description: Dataset version ID. If omitted, returns the latest version.
@@ -711,7 +811,7 @@ async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
711
811
  type: string
712
812
  description: Dataset ID
713
813
  - in: query
714
- name: version
814
+ name: version_id
715
815
  schema:
716
816
  type: string
717
817
  description: Dataset version ID. If omitted, returns the latest version.
@@ -815,9 +915,9 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
815
915
  raise ValueError("Missing Dataset ID")
816
916
  dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
817
917
  dataset_version_id: Optional[int] = None
818
- if version := request.query_params.get("version"):
918
+ if version_id := request.query_params.get("version_id"):
819
919
  dataset_version_id = from_global_id_with_expected_type(
820
- GlobalID.from_id(version),
920
+ GlobalID.from_id(version_id),
821
921
  DatasetVersion.__name__,
822
922
  )
823
923
  latest_version = (
@@ -859,3 +959,7 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
859
959
  raise ValueError("Dataset does not exist.")
860
960
  examples = [r async for r in await session.stream_scalars(stmt)]
861
961
  return dataset_name, examples
962
+
963
+
964
+ def _is_all_dict(seq: Sequence[Any]) -> bool:
965
+ return all(map(lambda obj: isinstance(obj, dict), seq))
@@ -45,13 +45,6 @@ async def post_evaluations(request: Request) -> Response:
45
45
  operationId: addEvaluations
46
46
  tags:
47
47
  - private
48
- parameters:
49
- - name: project-name
50
- in: query
51
- schema:
52
- type: string
53
- default: default
54
- description: The project name to add the evaluation to
55
48
  requestBody:
56
49
  required: true
57
50
  content:
@@ -107,7 +100,7 @@ async def get_evaluations(request: Request) -> Response:
107
100
  tags:
108
101
  - private
109
102
  parameters:
110
- - name: project-name
103
+ - name: project_name
111
104
  in: query
112
105
  schema:
113
106
  type: string
@@ -122,9 +115,9 @@ async def get_evaluations(request: Request) -> Response:
122
115
  description: Not found
123
116
  """
124
117
  project_name = (
125
- request.query_params.get("project-name")
126
- # read from headers for backwards compatibility
127
- or request.headers.get("project-name")
118
+ request.query_params.get("project_name")
119
+ or request.query_params.get("project-name") # for backward compatibility
120
+ or request.headers.get("project-name") # read from headers for backwards compatibility
128
121
  or DEFAULT_PROJECT_NAME
129
122
  )
130
123