arize-phoenix 4.4.4rc4__py3-none-any.whl → 4.4.4rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/METADATA +12 -6
- {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/RECORD +47 -42
- phoenix/config.py +21 -0
- phoenix/datetime_utils.py +4 -0
- phoenix/db/insertion/dataset.py +19 -16
- phoenix/db/insertion/evaluation.py +4 -4
- phoenix/db/insertion/helpers.py +4 -12
- phoenix/db/insertion/span.py +3 -3
- phoenix/db/migrations/versions/10460e46d750_datasets.py +2 -2
- phoenix/db/models.py +8 -3
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +153 -0
- phoenix/{datasets → experiments}/evaluators/code_evaluators.py +25 -53
- phoenix/{datasets → experiments}/evaluators/llm_evaluators.py +62 -31
- phoenix/experiments/evaluators/utils.py +189 -0
- phoenix/experiments/functions.py +616 -0
- phoenix/{datasets → experiments}/tracing.py +19 -0
- phoenix/experiments/types.py +722 -0
- phoenix/experiments/utils.py +9 -0
- phoenix/server/api/context.py +4 -0
- phoenix/server/api/dataloaders/__init__.py +4 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
- phoenix/server/api/helpers/dataset_helpers.py +8 -7
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/mutations/project_mutations.py +9 -4
- phoenix/server/api/routers/v1/__init__.py +1 -1
- phoenix/server/api/routers/v1/dataset_examples.py +10 -10
- phoenix/server/api/routers/v1/datasets.py +152 -48
- phoenix/server/api/routers/v1/evaluations.py +4 -11
- phoenix/server/api/routers/v1/experiment_evaluations.py +23 -23
- phoenix/server/api/routers/v1/experiment_runs.py +5 -17
- phoenix/server/api/routers/v1/experiments.py +5 -5
- phoenix/server/api/routers/v1/spans.py +6 -4
- phoenix/server/api/types/Experiment.py +12 -0
- phoenix/server/api/types/ExperimentRun.py +1 -1
- phoenix/server/api/types/ExperimentRunAnnotation.py +1 -1
- phoenix/server/app.py +4 -0
- phoenix/server/static/index.js +712 -588
- phoenix/session/client.py +321 -28
- phoenix/trace/fixtures.py +6 -6
- phoenix/utilities/json.py +8 -8
- phoenix/version.py +1 -1
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators/__init__.py +0 -18
- phoenix/datasets/evaluators/_utils.py +0 -13
- phoenix/datasets/experiments.py +0 -485
- phoenix/datasets/types.py +0 -212
- {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.4rc4.dist-info → arize_phoenix-4.4.4rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from phoenix.config import get_web_base_url
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_experiment_url(*, dataset_id: str, experiment_id: str) -> str:
|
|
5
|
+
return f"{get_web_base_url()}datasets/{dataset_id}/compare?experimentId={experiment_id}"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_dataset_experiments_url(*, dataset_id: str) -> str:
|
|
9
|
+
return f"{get_web_base_url()}datasets/{dataset_id}/experiments"
|
phoenix/server/api/context.py
CHANGED
|
@@ -11,6 +11,7 @@ from typing_extensions import TypeAlias
|
|
|
11
11
|
|
|
12
12
|
from phoenix.core.model_schema import Model
|
|
13
13
|
from phoenix.server.api.dataloaders import (
|
|
14
|
+
AverageExperimentRunLatencyDataLoader,
|
|
14
15
|
CacheForDataLoaders,
|
|
15
16
|
DatasetExampleRevisionsDataLoader,
|
|
16
17
|
DatasetExampleSpansDataLoader,
|
|
@@ -20,6 +21,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
20
21
|
EvaluationSummaryDataLoader,
|
|
21
22
|
ExperimentAnnotationSummaryDataLoader,
|
|
22
23
|
ExperimentErrorRatesDataLoader,
|
|
24
|
+
ExperimentRunCountsDataLoader,
|
|
23
25
|
ExperimentSequenceNumberDataLoader,
|
|
24
26
|
LatencyMsQuantileDataLoader,
|
|
25
27
|
MinStartOrMaxEndTimeDataLoader,
|
|
@@ -36,6 +38,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
36
38
|
|
|
37
39
|
@dataclass
|
|
38
40
|
class DataLoaders:
|
|
41
|
+
average_experiment_run_latency: AverageExperimentRunLatencyDataLoader
|
|
39
42
|
dataset_example_revisions: DatasetExampleRevisionsDataLoader
|
|
40
43
|
dataset_example_spans: DatasetExampleSpansDataLoader
|
|
41
44
|
document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
|
|
@@ -44,6 +47,7 @@ class DataLoaders:
|
|
|
44
47
|
evaluation_summaries: EvaluationSummaryDataLoader
|
|
45
48
|
experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
|
|
46
49
|
experiment_error_rates: ExperimentErrorRatesDataLoader
|
|
50
|
+
experiment_run_counts: ExperimentRunCountsDataLoader
|
|
47
51
|
experiment_sequence_number: ExperimentSequenceNumberDataLoader
|
|
48
52
|
latency_ms_quantile: LatencyMsQuantileDataLoader
|
|
49
53
|
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
|
|
@@ -8,6 +8,7 @@ from phoenix.db.insertion.evaluation import (
|
|
|
8
8
|
)
|
|
9
9
|
from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent
|
|
10
10
|
|
|
11
|
+
from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
|
|
11
12
|
from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
|
|
12
13
|
from .dataset_example_spans import DatasetExampleSpansDataLoader
|
|
13
14
|
from .document_evaluation_summaries import (
|
|
@@ -19,6 +20,7 @@ from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
|
|
|
19
20
|
from .evaluation_summaries import EvaluationSummaryCache, EvaluationSummaryDataLoader
|
|
20
21
|
from .experiment_annotation_summaries import ExperimentAnnotationSummaryDataLoader
|
|
21
22
|
from .experiment_error_rates import ExperimentErrorRatesDataLoader
|
|
23
|
+
from .experiment_run_counts import ExperimentRunCountsDataLoader
|
|
22
24
|
from .experiment_sequence_number import ExperimentSequenceNumberDataLoader
|
|
23
25
|
from .latency_ms_quantile import LatencyMsQuantileCache, LatencyMsQuantileDataLoader
|
|
24
26
|
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
|
|
@@ -33,6 +35,7 @@ from .trace_row_ids import TraceRowIdsDataLoader
|
|
|
33
35
|
|
|
34
36
|
__all__ = [
|
|
35
37
|
"CacheForDataLoaders",
|
|
38
|
+
"AverageExperimentRunLatencyDataLoader",
|
|
36
39
|
"DatasetExampleRevisionsDataLoader",
|
|
37
40
|
"DatasetExampleSpansDataLoader",
|
|
38
41
|
"DocumentEvaluationSummaryDataLoader",
|
|
@@ -41,6 +44,7 @@ __all__ = [
|
|
|
41
44
|
"EvaluationSummaryDataLoader",
|
|
42
45
|
"ExperimentAnnotationSummaryDataLoader",
|
|
43
46
|
"ExperimentErrorRatesDataLoader",
|
|
47
|
+
"ExperimentRunCountsDataLoader",
|
|
44
48
|
"ExperimentSequenceNumberDataLoader",
|
|
45
49
|
"LatencyMsQuantileDataLoader",
|
|
46
50
|
"MinStartOrMaxEndTimeDataLoader",
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
AsyncContextManager,
|
|
3
|
+
Callable,
|
|
4
|
+
List,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import func, select
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
from strawberry.dataloader import DataLoader
|
|
10
|
+
from typing_extensions import TypeAlias
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
|
|
14
|
+
ExperimentID: TypeAlias = int
|
|
15
|
+
RunLatency: TypeAlias = float
|
|
16
|
+
Key: TypeAlias = ExperimentID
|
|
17
|
+
Result: TypeAlias = RunLatency
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AverageExperimentRunLatencyDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
24
|
+
) -> None:
|
|
25
|
+
super().__init__(load_fn=self._load_fn)
|
|
26
|
+
self._db = db
|
|
27
|
+
|
|
28
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
29
|
+
experiment_ids = keys
|
|
30
|
+
async with self._db() as session:
|
|
31
|
+
avg_latencies = {
|
|
32
|
+
experiment_id: avg_latency
|
|
33
|
+
async for experiment_id, avg_latency in await session.stream(
|
|
34
|
+
select(
|
|
35
|
+
models.ExperimentRun.experiment_id,
|
|
36
|
+
func.avg(
|
|
37
|
+
func.extract(
|
|
38
|
+
"epoch",
|
|
39
|
+
models.ExperimentRun.end_time,
|
|
40
|
+
)
|
|
41
|
+
- func.extract(
|
|
42
|
+
"epoch",
|
|
43
|
+
models.ExperimentRun.start_time,
|
|
44
|
+
)
|
|
45
|
+
),
|
|
46
|
+
)
|
|
47
|
+
.where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
|
|
48
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
49
|
+
)
|
|
50
|
+
}
|
|
51
|
+
return [
|
|
52
|
+
avg_latencies.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
|
|
53
|
+
for experiment_id in experiment_ids
|
|
54
|
+
]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
AsyncContextManager,
|
|
3
|
+
Callable,
|
|
4
|
+
List,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import func, select
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
from strawberry.dataloader import DataLoader
|
|
10
|
+
from typing_extensions import TypeAlias
|
|
11
|
+
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
|
|
14
|
+
ExperimentID: TypeAlias = int
|
|
15
|
+
RunCount: TypeAlias = int
|
|
16
|
+
Key: TypeAlias = ExperimentID
|
|
17
|
+
Result: TypeAlias = RunCount
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ExperimentRunCountsDataLoader(DataLoader[Key, Result]):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
db: Callable[[], AsyncContextManager[AsyncSession]],
|
|
24
|
+
) -> None:
|
|
25
|
+
super().__init__(load_fn=self._load_fn)
|
|
26
|
+
self._db = db
|
|
27
|
+
|
|
28
|
+
async def _load_fn(self, keys: List[Key]) -> List[Result]:
|
|
29
|
+
experiment_ids = keys
|
|
30
|
+
async with self._db() as session:
|
|
31
|
+
run_counts = {
|
|
32
|
+
experiment_id: run_count
|
|
33
|
+
async for experiment_id, run_count in await session.stream(
|
|
34
|
+
select(models.ExperimentRun.experiment_id, func.count())
|
|
35
|
+
.where(models.ExperimentRun.experiment_id.in_(set(experiment_ids)))
|
|
36
|
+
.group_by(models.ExperimentRun.experiment_id)
|
|
37
|
+
)
|
|
38
|
+
}
|
|
39
|
+
return [
|
|
40
|
+
run_counts.get(experiment_id, ValueError(f"Unknown experiment: {experiment_id}"))
|
|
41
|
+
for experiment_id in experiment_ids
|
|
42
|
+
]
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from typing import Any, Dict, Literal, Mapping, Optional, Protocol
|
|
2
3
|
|
|
3
4
|
from openinference.semconv.trace import (
|
|
@@ -128,14 +129,14 @@ def _get_generic_io_value(
|
|
|
128
129
|
Makes a best-effort attempt to extract the input or output value from a span
|
|
129
130
|
and returns it as a dictionary.
|
|
130
131
|
"""
|
|
131
|
-
if
|
|
132
|
-
|
|
133
|
-
|
|
132
|
+
if mime_type == OpenInferenceMimeTypeValues.JSON.value:
|
|
133
|
+
parsed_value = json.loads(io_value)
|
|
134
|
+
if isinstance(parsed_value, dict):
|
|
135
|
+
return parsed_value
|
|
136
|
+
else:
|
|
137
|
+
return {kind: parsed_value}
|
|
138
|
+
if isinstance(io_value, str):
|
|
134
139
|
return {kind: io_value}
|
|
135
|
-
if isinstance(io_value, dict) and (
|
|
136
|
-
mime_type == OpenInferenceMimeTypeValues.JSON.value or mime_type is None
|
|
137
|
-
):
|
|
138
|
-
return io_value
|
|
139
140
|
return {}
|
|
140
141
|
|
|
141
142
|
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry import UNSET
|
|
6
|
+
from strawberry.relay import GlobalID
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.input
|
|
10
|
+
class ClearProjectInput:
|
|
11
|
+
id: GlobalID
|
|
12
|
+
end_time: Optional[datetime] = strawberry.field(
|
|
13
|
+
default=UNSET,
|
|
14
|
+
description="The time up to which to purge data. Time is right-open /non-inclusive.",
|
|
15
|
+
)
|
|
@@ -8,6 +8,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME
|
|
|
8
8
|
from phoenix.db import models
|
|
9
9
|
from phoenix.db.insertion.span import ClearProjectSpansEvent
|
|
10
10
|
from phoenix.server.api.context import Context
|
|
11
|
+
from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput
|
|
11
12
|
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
12
13
|
from phoenix.server.api.queries import Query
|
|
13
14
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
@@ -32,11 +33,15 @@ class ProjectMutationMixin:
|
|
|
32
33
|
return Query()
|
|
33
34
|
|
|
34
35
|
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
35
|
-
async def clear_project(self, info: Info[Context, None],
|
|
36
|
-
project_id = from_global_id_with_expected_type(
|
|
36
|
+
async def clear_project(self, info: Info[Context, None], input: ClearProjectInput) -> Query:
|
|
37
|
+
project_id = from_global_id_with_expected_type(
|
|
38
|
+
global_id=input.id, expected_type_name="Project"
|
|
39
|
+
)
|
|
37
40
|
delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
|
|
41
|
+
if input.end_time is not None:
|
|
42
|
+
delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
|
|
38
43
|
async with info.context.db() as session:
|
|
39
44
|
await session.execute(delete_statement)
|
|
40
|
-
|
|
41
|
-
|
|
45
|
+
if cache := info.context.cache_for_dataloaders:
|
|
46
|
+
cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
|
|
42
47
|
return Query()
|
|
@@ -21,7 +21,7 @@ async def list_dataset_examples(request: Request) -> Response:
|
|
|
21
21
|
type: string
|
|
22
22
|
description: Dataset ID
|
|
23
23
|
- in: query
|
|
24
|
-
name:
|
|
24
|
+
name: version_id
|
|
25
25
|
schema:
|
|
26
26
|
type: string
|
|
27
27
|
description: Dataset version ID. If omitted, returns the latest version.
|
|
@@ -79,7 +79,7 @@ async def list_dataset_examples(request: Request) -> Response:
|
|
|
79
79
|
description: Dataset does not exist.
|
|
80
80
|
"""
|
|
81
81
|
dataset_id = GlobalID.from_id(request.path_params["id"])
|
|
82
|
-
raw_version_id = request.query_params.get("
|
|
82
|
+
raw_version_id = request.query_params.get("version_id")
|
|
83
83
|
version_id = GlobalID.from_id(raw_version_id) if raw_version_id else None
|
|
84
84
|
|
|
85
85
|
if (dataset_type := dataset_id.type_name) != "Dataset":
|
|
@@ -167,12 +167,12 @@ async def list_dataset_examples(request: Request) -> Response:
|
|
|
167
167
|
}
|
|
168
168
|
async for example, revision in await session.stream(query)
|
|
169
169
|
]
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
}
|
|
170
|
+
return JSONResponse(
|
|
171
|
+
{
|
|
172
|
+
"data": {
|
|
173
|
+
"dataset_id": str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
174
|
+
"version_id": str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
175
|
+
"examples": examples,
|
|
177
176
|
}
|
|
178
|
-
|
|
177
|
+
}
|
|
178
|
+
)
|
|
@@ -13,11 +13,12 @@ from typing import (
|
|
|
13
13
|
Awaitable,
|
|
14
14
|
Callable,
|
|
15
15
|
Coroutine,
|
|
16
|
-
Dict,
|
|
17
16
|
FrozenSet,
|
|
18
17
|
Iterator,
|
|
19
18
|
List,
|
|
19
|
+
Mapping,
|
|
20
20
|
Optional,
|
|
21
|
+
Sequence,
|
|
21
22
|
Tuple,
|
|
22
23
|
Union,
|
|
23
24
|
cast,
|
|
@@ -32,8 +33,8 @@ from starlette.datastructures import FormData, UploadFile
|
|
|
32
33
|
from starlette.requests import Request
|
|
33
34
|
from starlette.responses import JSONResponse, Response
|
|
34
35
|
from starlette.status import (
|
|
35
|
-
HTTP_403_FORBIDDEN,
|
|
36
36
|
HTTP_404_NOT_FOUND,
|
|
37
|
+
HTTP_409_CONFLICT,
|
|
37
38
|
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
38
39
|
HTTP_429_TOO_MANY_REQUESTS,
|
|
39
40
|
)
|
|
@@ -44,6 +45,7 @@ from phoenix.db import models
|
|
|
44
45
|
from phoenix.db.insertion.dataset import (
|
|
45
46
|
DatasetAction,
|
|
46
47
|
DatasetExampleAdditionEvent,
|
|
48
|
+
ExampleContent,
|
|
47
49
|
add_dataset_examples,
|
|
48
50
|
)
|
|
49
51
|
from phoenix.server.api.types.Dataset import Dataset
|
|
@@ -231,7 +233,7 @@ async def get_dataset_by_id(request: Request) -> Response:
|
|
|
231
233
|
"updated_at": dataset.updated_at.isoformat(),
|
|
232
234
|
"example_count": example_count,
|
|
233
235
|
}
|
|
234
|
-
return JSONResponse(content=output_dict)
|
|
236
|
+
return JSONResponse(content={"data": output_dict})
|
|
235
237
|
|
|
236
238
|
|
|
237
239
|
async def get_dataset_versions(request: Request) -> Response:
|
|
@@ -350,7 +352,7 @@ async def get_dataset_versions(request: Request) -> Response:
|
|
|
350
352
|
|
|
351
353
|
async def post_datasets_upload(request: Request) -> Response:
|
|
352
354
|
"""
|
|
353
|
-
summary: Upload
|
|
355
|
+
summary: Upload dataset as either JSON or file (CSV or PyArrow)
|
|
354
356
|
operationId: uploadDataset
|
|
355
357
|
tags:
|
|
356
358
|
- datasets
|
|
@@ -362,6 +364,32 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
362
364
|
type: boolean
|
|
363
365
|
requestBody:
|
|
364
366
|
content:
|
|
367
|
+
application/json:
|
|
368
|
+
schema:
|
|
369
|
+
type: object
|
|
370
|
+
required:
|
|
371
|
+
- name
|
|
372
|
+
- inputs
|
|
373
|
+
properties:
|
|
374
|
+
action:
|
|
375
|
+
type: string
|
|
376
|
+
enum: [create, append]
|
|
377
|
+
name:
|
|
378
|
+
type: string
|
|
379
|
+
description:
|
|
380
|
+
type: string
|
|
381
|
+
inputs:
|
|
382
|
+
type: array
|
|
383
|
+
items:
|
|
384
|
+
type: object
|
|
385
|
+
outputs:
|
|
386
|
+
type: array
|
|
387
|
+
items:
|
|
388
|
+
type: object
|
|
389
|
+
metadata:
|
|
390
|
+
type: array
|
|
391
|
+
items:
|
|
392
|
+
type: object
|
|
365
393
|
multipart/form-data:
|
|
366
394
|
schema:
|
|
367
395
|
type: object
|
|
@@ -401,22 +429,18 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
401
429
|
description: Success
|
|
402
430
|
403:
|
|
403
431
|
description: Forbidden
|
|
432
|
+
409:
|
|
433
|
+
description: Dataset of the same name already exists
|
|
404
434
|
422:
|
|
405
435
|
description: Request body is invalid
|
|
406
436
|
"""
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
437
|
+
request_content_type = request.headers["content-type"]
|
|
438
|
+
examples: Union[Examples, Awaitable[Examples]]
|
|
439
|
+
if request_content_type.startswith("application/json"):
|
|
410
440
|
try:
|
|
411
|
-
(
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
description,
|
|
415
|
-
input_keys,
|
|
416
|
-
output_keys,
|
|
417
|
-
metadata_keys,
|
|
418
|
-
file,
|
|
419
|
-
) = await _parse_form_data(form)
|
|
441
|
+
examples, action, name, description = await run_in_threadpool(
|
|
442
|
+
_process_json, await request.json()
|
|
443
|
+
)
|
|
420
444
|
except ValueError as e:
|
|
421
445
|
return Response(
|
|
422
446
|
content=str(e),
|
|
@@ -426,24 +450,53 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
426
450
|
async with request.app.state.db() as session:
|
|
427
451
|
if await _check_table_exists(session, name):
|
|
428
452
|
return Response(
|
|
429
|
-
content=f"Dataset already exists: {name=}",
|
|
430
|
-
status_code=
|
|
453
|
+
content=f"Dataset with the same name already exists: {name=}",
|
|
454
|
+
status_code=HTTP_409_CONFLICT,
|
|
431
455
|
)
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
456
|
+
elif request_content_type.startswith("multipart/form-data"):
|
|
457
|
+
async with request.form() as form:
|
|
458
|
+
try:
|
|
459
|
+
(
|
|
460
|
+
action,
|
|
461
|
+
name,
|
|
462
|
+
description,
|
|
463
|
+
input_keys,
|
|
464
|
+
output_keys,
|
|
465
|
+
metadata_keys,
|
|
466
|
+
file,
|
|
467
|
+
) = await _parse_form_data(form)
|
|
468
|
+
except ValueError as e:
|
|
469
|
+
return Response(
|
|
470
|
+
content=str(e),
|
|
471
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
472
|
+
)
|
|
473
|
+
if action is DatasetAction.CREATE:
|
|
474
|
+
async with request.app.state.db() as session:
|
|
475
|
+
if await _check_table_exists(session, name):
|
|
476
|
+
return Response(
|
|
477
|
+
content=f"Dataset with the same name already exists: {name=}",
|
|
478
|
+
status_code=HTTP_409_CONFLICT,
|
|
479
|
+
)
|
|
480
|
+
content = await file.read()
|
|
481
|
+
try:
|
|
482
|
+
file_content_type = FileContentType(file.content_type)
|
|
483
|
+
if file_content_type is FileContentType.CSV:
|
|
484
|
+
encoding = FileContentEncoding(file.headers.get("content-encoding"))
|
|
485
|
+
examples = await _process_csv(
|
|
486
|
+
content, encoding, input_keys, output_keys, metadata_keys
|
|
487
|
+
)
|
|
488
|
+
elif file_content_type is FileContentType.PYARROW:
|
|
489
|
+
examples = await _process_pyarrow(content, input_keys, output_keys, metadata_keys)
|
|
490
|
+
else:
|
|
491
|
+
assert_never(file_content_type)
|
|
492
|
+
except ValueError as e:
|
|
493
|
+
return Response(
|
|
494
|
+
content=str(e),
|
|
495
|
+
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
496
|
+
)
|
|
497
|
+
else:
|
|
445
498
|
return Response(
|
|
446
|
-
content=str(
|
|
499
|
+
content=str("Invalid request Content-Type"),
|
|
447
500
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
448
501
|
)
|
|
449
502
|
operation = cast(
|
|
@@ -454,9 +507,6 @@ async def post_datasets_upload(request: Request) -> Response:
|
|
|
454
507
|
action=action,
|
|
455
508
|
name=name,
|
|
456
509
|
description=description,
|
|
457
|
-
input_keys=input_keys,
|
|
458
|
-
output_keys=output_keys,
|
|
459
|
-
metadata_keys=metadata_keys,
|
|
460
510
|
),
|
|
461
511
|
)
|
|
462
512
|
if request.query_params.get("sync") == "true":
|
|
@@ -505,13 +555,46 @@ InputKeys: TypeAlias = FrozenSet[str]
|
|
|
505
555
|
OutputKeys: TypeAlias = FrozenSet[str]
|
|
506
556
|
MetadataKeys: TypeAlias = FrozenSet[str]
|
|
507
557
|
DatasetId: TypeAlias = int
|
|
508
|
-
Examples: TypeAlias = Iterator[
|
|
558
|
+
Examples: TypeAlias = Iterator[ExampleContent]
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def _process_json(
|
|
562
|
+
data: Mapping[str, Any],
|
|
563
|
+
) -> Tuple[Examples, DatasetAction, Name, Description]:
|
|
564
|
+
name = data.get("name")
|
|
565
|
+
if not name:
|
|
566
|
+
raise ValueError("Dataset name is required")
|
|
567
|
+
description = data.get("description") or ""
|
|
568
|
+
inputs = data.get("inputs")
|
|
569
|
+
if not inputs:
|
|
570
|
+
raise ValueError("input is required")
|
|
571
|
+
if not isinstance(inputs, list) or not _is_all_dict(inputs):
|
|
572
|
+
raise ValueError("Input should be a list containing only dictionary objects")
|
|
573
|
+
outputs, metadata = data.get("outputs"), data.get("metadata")
|
|
574
|
+
for k, v in {"outputs": outputs, "metadata": metadata}.items():
|
|
575
|
+
if v and not (isinstance(v, list) and len(v) == len(inputs) and _is_all_dict(v)):
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"{k} should be a list of same length as input containing only dictionary objects"
|
|
578
|
+
)
|
|
579
|
+
examples: List[ExampleContent] = []
|
|
580
|
+
for i, obj in enumerate(inputs):
|
|
581
|
+
example = ExampleContent(
|
|
582
|
+
input=obj,
|
|
583
|
+
output=outputs[i] if outputs else {},
|
|
584
|
+
metadata=metadata[i] if metadata else {},
|
|
585
|
+
)
|
|
586
|
+
examples.append(example)
|
|
587
|
+
action = DatasetAction(cast(Optional[str], data.get("action")) or "create")
|
|
588
|
+
return iter(examples), action, name, description
|
|
509
589
|
|
|
510
590
|
|
|
511
591
|
async def _process_csv(
|
|
512
592
|
content: bytes,
|
|
513
593
|
content_encoding: FileContentEncoding,
|
|
514
|
-
|
|
594
|
+
input_keys: InputKeys,
|
|
595
|
+
output_keys: OutputKeys,
|
|
596
|
+
metadata_keys: MetadataKeys,
|
|
597
|
+
) -> Examples:
|
|
515
598
|
if content_encoding is FileContentEncoding.GZIP:
|
|
516
599
|
content = await run_in_threadpool(gzip.decompress, content)
|
|
517
600
|
elif content_encoding is FileContentEncoding.DEFLATE:
|
|
@@ -525,22 +608,39 @@ async def _process_csv(
|
|
|
525
608
|
if freq > 1:
|
|
526
609
|
raise ValueError(f"Duplicated column header in CSV file: {header}")
|
|
527
610
|
column_headers = frozenset(reader.fieldnames)
|
|
528
|
-
|
|
611
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
612
|
+
return (
|
|
613
|
+
ExampleContent(
|
|
614
|
+
input={k: row.get(k) for k in input_keys},
|
|
615
|
+
output={k: row.get(k) for k in output_keys},
|
|
616
|
+
metadata={k: row.get(k) for k in metadata_keys},
|
|
617
|
+
)
|
|
618
|
+
for row in iter(reader)
|
|
619
|
+
)
|
|
529
620
|
|
|
530
621
|
|
|
531
622
|
async def _process_pyarrow(
|
|
532
623
|
content: bytes,
|
|
533
|
-
|
|
624
|
+
input_keys: InputKeys,
|
|
625
|
+
output_keys: OutputKeys,
|
|
626
|
+
metadata_keys: MetadataKeys,
|
|
627
|
+
) -> Awaitable[Examples]:
|
|
534
628
|
try:
|
|
535
629
|
reader = pa.ipc.open_stream(content)
|
|
536
630
|
except pa.ArrowInvalid as e:
|
|
537
631
|
raise ValueError("File is not valid pyarrow") from e
|
|
538
632
|
column_headers = frozenset(reader.schema.names)
|
|
633
|
+
_check_keys_exist(column_headers, input_keys, output_keys, metadata_keys)
|
|
634
|
+
|
|
635
|
+
def get_examples() -> Iterator[ExampleContent]:
|
|
636
|
+
for row in reader.read_pandas().to_dict(orient="records"):
|
|
637
|
+
yield ExampleContent(
|
|
638
|
+
input={k: row.get(k) for k in input_keys},
|
|
639
|
+
output={k: row.get(k) for k in output_keys},
|
|
640
|
+
metadata={k: row.get(k) for k in metadata_keys},
|
|
641
|
+
)
|
|
539
642
|
|
|
540
|
-
|
|
541
|
-
yield from reader.read_pandas().to_dict(orient="records")
|
|
542
|
-
|
|
543
|
-
return run_in_threadpool(get_examples), column_headers
|
|
643
|
+
return run_in_threadpool(get_examples)
|
|
544
644
|
|
|
545
645
|
|
|
546
646
|
async def _check_table_exists(session: AsyncSession, name: str) -> bool:
|
|
@@ -613,7 +713,7 @@ async def get_dataset_csv(request: Request) -> Response:
|
|
|
613
713
|
type: string
|
|
614
714
|
description: Dataset ID
|
|
615
715
|
- in: query
|
|
616
|
-
name:
|
|
716
|
+
name: version_id
|
|
617
717
|
schema:
|
|
618
718
|
type: string
|
|
619
719
|
description: Dataset version ID. If omitted, returns the latest version.
|
|
@@ -662,7 +762,7 @@ async def get_dataset_jsonl_openai_ft(request: Request) -> Response:
|
|
|
662
762
|
type: string
|
|
663
763
|
description: Dataset ID
|
|
664
764
|
- in: query
|
|
665
|
-
name:
|
|
765
|
+
name: version_id
|
|
666
766
|
schema:
|
|
667
767
|
type: string
|
|
668
768
|
description: Dataset version ID. If omitted, returns the latest version.
|
|
@@ -711,7 +811,7 @@ async def get_dataset_jsonl_openai_evals(request: Request) -> Response:
|
|
|
711
811
|
type: string
|
|
712
812
|
description: Dataset ID
|
|
713
813
|
- in: query
|
|
714
|
-
name:
|
|
814
|
+
name: version_id
|
|
715
815
|
schema:
|
|
716
816
|
type: string
|
|
717
817
|
description: Dataset version ID. If omitted, returns the latest version.
|
|
@@ -815,9 +915,9 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
|
|
|
815
915
|
raise ValueError("Missing Dataset ID")
|
|
816
916
|
dataset_id = from_global_id_with_expected_type(GlobalID.from_id(id_), Dataset.__name__)
|
|
817
917
|
dataset_version_id: Optional[int] = None
|
|
818
|
-
if
|
|
918
|
+
if version_id := request.query_params.get("version_id"):
|
|
819
919
|
dataset_version_id = from_global_id_with_expected_type(
|
|
820
|
-
GlobalID.from_id(
|
|
920
|
+
GlobalID.from_id(version_id),
|
|
821
921
|
DatasetVersion.__name__,
|
|
822
922
|
)
|
|
823
923
|
latest_version = (
|
|
@@ -859,3 +959,7 @@ async def _get_db_examples(request: Request) -> Tuple[str, List[models.DatasetEx
|
|
|
859
959
|
raise ValueError("Dataset does not exist.")
|
|
860
960
|
examples = [r async for r in await session.stream_scalars(stmt)]
|
|
861
961
|
return dataset_name, examples
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def _is_all_dict(seq: Sequence[Any]) -> bool:
|
|
965
|
+
return all(map(lambda obj: isinstance(obj, dict), seq))
|
|
@@ -45,13 +45,6 @@ async def post_evaluations(request: Request) -> Response:
|
|
|
45
45
|
operationId: addEvaluations
|
|
46
46
|
tags:
|
|
47
47
|
- private
|
|
48
|
-
parameters:
|
|
49
|
-
- name: project-name
|
|
50
|
-
in: query
|
|
51
|
-
schema:
|
|
52
|
-
type: string
|
|
53
|
-
default: default
|
|
54
|
-
description: The project name to add the evaluation to
|
|
55
48
|
requestBody:
|
|
56
49
|
required: true
|
|
57
50
|
content:
|
|
@@ -107,7 +100,7 @@ async def get_evaluations(request: Request) -> Response:
|
|
|
107
100
|
tags:
|
|
108
101
|
- private
|
|
109
102
|
parameters:
|
|
110
|
-
- name:
|
|
103
|
+
- name: project_name
|
|
111
104
|
in: query
|
|
112
105
|
schema:
|
|
113
106
|
type: string
|
|
@@ -122,9 +115,9 @@ async def get_evaluations(request: Request) -> Response:
|
|
|
122
115
|
description: Not found
|
|
123
116
|
"""
|
|
124
117
|
project_name = (
|
|
125
|
-
request.query_params.get("
|
|
126
|
-
#
|
|
127
|
-
or request.headers.get("project-name")
|
|
118
|
+
request.query_params.get("project_name")
|
|
119
|
+
or request.query_params.get("project-name") # for backward compatibility
|
|
120
|
+
or request.headers.get("project-name") # read from headers for backwards compatibility
|
|
128
121
|
or DEFAULT_PROJECT_NAME
|
|
129
122
|
)
|
|
130
123
|
|