arize-phoenix 4.5.0__py3-none-any.whl → 4.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/METADATA +16 -8
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/RECORD +122 -58
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +0 -27
- phoenix/config.py +42 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +64 -62
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datetime_utils.py +4 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +237 -0
- phoenix/db/insertion/evaluation.py +10 -10
- phoenix/db/insertion/helpers.py +17 -14
- phoenix/db/insertion/span.py +3 -3
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +236 -4
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +153 -0
- phoenix/experiments/evaluators/code_evaluators.py +99 -0
- phoenix/experiments/evaluators/llm_evaluators.py +244 -0
- phoenix/experiments/evaluators/utils.py +186 -0
- phoenix/experiments/functions.py +757 -0
- phoenix/experiments/tracing.py +85 -0
- phoenix/experiments/types.py +753 -0
- phoenix/experiments/utils.py +24 -0
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +20 -0
- phoenix/server/api/dataloaders/__init__.py +20 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +179 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +47 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +6 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +965 -0
- phoenix/server/api/routers/v1/evaluations.py +8 -13
- phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
- phoenix/server/api/routers/v1/experiment_runs.py +220 -0
- phoenix/server/api/routers/v1/experiments.py +302 -0
- phoenix/server/api/routers/v1/spans.py +9 -5
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +147 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Span.py +79 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +103 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2300 -1294
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +581 -22
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/schemas.py +1 -2
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
|
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
|
|
|
9
9
|
from phoenix.server.api.context import Context
|
|
10
10
|
from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
|
|
11
11
|
from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
|
|
12
|
-
from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
|
|
13
12
|
from phoenix.server.api.types.DatasetValues import DatasetValues
|
|
14
13
|
from phoenix.server.api.types.Event import unpack_event_id
|
|
14
|
+
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@strawberry.type
|
|
@@ -36,8 +36,8 @@ class Cluster:
|
|
|
36
36
|
"""
|
|
37
37
|
Calculates the drift score of the cluster. The score will be a value
|
|
38
38
|
representing the balance of points between the primary and the reference
|
|
39
|
-
|
|
40
|
-
reference), with 0 being an even balance between the two
|
|
39
|
+
inferences, and will be on a scale between 1 (all primary) and -1 (all
|
|
40
|
+
reference), with 0 being an even balance between the two inference sets.
|
|
41
41
|
|
|
42
42
|
Returns
|
|
43
43
|
-------
|
|
@@ -47,8 +47,8 @@ class Cluster:
|
|
|
47
47
|
if model[REFERENCE].empty:
|
|
48
48
|
return None
|
|
49
49
|
count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
|
|
50
|
-
primary_count = count_by_role[
|
|
51
|
-
reference_count = count_by_role[
|
|
50
|
+
primary_count = count_by_role[InferencesRole.primary]
|
|
51
|
+
reference_count = count_by_role[InferencesRole.reference]
|
|
52
52
|
return (
|
|
53
53
|
None
|
|
54
54
|
if not (denominator := (primary_count + reference_count))
|
|
@@ -76,8 +76,8 @@ class Cluster:
|
|
|
76
76
|
if corpus is None or corpus[PRIMARY].empty:
|
|
77
77
|
return None
|
|
78
78
|
count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
|
|
79
|
-
primary_count = count_by_role[
|
|
80
|
-
corpus_count = count_by_role[
|
|
79
|
+
primary_count = count_by_role[InferencesRole.primary]
|
|
80
|
+
corpus_count = count_by_role[AncillaryInferencesRole.corpus]
|
|
81
81
|
return (
|
|
82
82
|
None
|
|
83
83
|
if not (denominator := (primary_count + corpus_count))
|
|
@@ -94,19 +94,19 @@ class Cluster:
|
|
|
94
94
|
metric: DataQualityMetricInput,
|
|
95
95
|
) -> DatasetValues:
|
|
96
96
|
model = info.context.model
|
|
97
|
-
row_ids: Dict[
|
|
98
|
-
for row_id,
|
|
99
|
-
if not isinstance(
|
|
97
|
+
row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
|
|
98
|
+
for row_id, inferences_role in map(unpack_event_id, self.event_ids):
|
|
99
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
100
100
|
continue
|
|
101
|
-
row_ids[
|
|
101
|
+
row_ids[inferences_role].append(row_id)
|
|
102
102
|
return DatasetValues(
|
|
103
103
|
primary_value=metric.metric_instance(
|
|
104
104
|
model[PRIMARY],
|
|
105
|
-
subset_rows=row_ids[
|
|
105
|
+
subset_rows=row_ids[InferencesRole.primary],
|
|
106
106
|
),
|
|
107
107
|
reference_value=metric.metric_instance(
|
|
108
108
|
model[REFERENCE],
|
|
109
|
-
subset_rows=row_ids[
|
|
109
|
+
subset_rows=row_ids[InferencesRole.reference],
|
|
110
110
|
),
|
|
111
111
|
)
|
|
112
112
|
|
|
@@ -120,20 +120,20 @@ class Cluster:
|
|
|
120
120
|
metric: PerformanceMetricInput,
|
|
121
121
|
) -> DatasetValues:
|
|
122
122
|
model = info.context.model
|
|
123
|
-
row_ids: Dict[
|
|
124
|
-
for row_id,
|
|
125
|
-
if not isinstance(
|
|
123
|
+
row_ids: Dict[InferencesRole, List[int]] = defaultdict(list)
|
|
124
|
+
for row_id, inferences_role in map(unpack_event_id, self.event_ids):
|
|
125
|
+
if not isinstance(inferences_role, InferencesRole):
|
|
126
126
|
continue
|
|
127
|
-
row_ids[
|
|
127
|
+
row_ids[inferences_role].append(row_id)
|
|
128
128
|
metric_instance = metric.metric_instance(model)
|
|
129
129
|
return DatasetValues(
|
|
130
130
|
primary_value=metric_instance(
|
|
131
131
|
model[PRIMARY],
|
|
132
|
-
subset_rows=row_ids[
|
|
132
|
+
subset_rows=row_ids[InferencesRole.primary],
|
|
133
133
|
),
|
|
134
134
|
reference_value=metric_instance(
|
|
135
135
|
model[REFERENCE],
|
|
136
|
-
subset_rows=row_ids[
|
|
136
|
+
subset_rows=row_ids[InferencesRole.reference],
|
|
137
137
|
),
|
|
138
138
|
)
|
|
139
139
|
|
|
@@ -1,80 +1,299 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import AsyncIterable, List, Optional, Tuple, cast
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from sqlalchemy import and_, func, select
|
|
6
|
+
from sqlalchemy.sql.functions import count
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.scalars import JSON
|
|
10
|
+
from strawberry.types import Info
|
|
7
11
|
|
|
8
|
-
|
|
9
|
-
from phoenix.
|
|
10
|
-
|
|
11
|
-
from
|
|
12
|
-
from .
|
|
13
|
-
from .
|
|
14
|
-
from .
|
|
12
|
+
from phoenix.db import models
|
|
13
|
+
from phoenix.server.api.context import Context
|
|
14
|
+
from phoenix.server.api.input_types.DatasetVersionSort import DatasetVersionSort
|
|
15
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
16
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
17
|
+
from phoenix.server.api.types.Experiment import Experiment, to_gql_experiment
|
|
18
|
+
from phoenix.server.api.types.ExperimentAnnotationSummary import ExperimentAnnotationSummary
|
|
19
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
20
|
+
from phoenix.server.api.types.pagination import (
|
|
21
|
+
ConnectionArgs,
|
|
22
|
+
CursorString,
|
|
23
|
+
connection_from_list,
|
|
24
|
+
)
|
|
25
|
+
from phoenix.server.api.types.SortDir import SortDir
|
|
15
26
|
|
|
16
27
|
|
|
17
28
|
@strawberry.type
|
|
18
|
-
class Dataset:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# type ignored here to get around the following: https://github.com/strawberry-graphql/strawberry/issues/1929
|
|
27
|
-
@strawberry.field(description="Returns a human friendly name for the dataset.") # type: ignore
|
|
28
|
-
def name(self) -> str:
|
|
29
|
-
return self.dataset.display_name
|
|
29
|
+
class Dataset(Node):
|
|
30
|
+
id_attr: NodeID[int]
|
|
31
|
+
name: str
|
|
32
|
+
description: Optional[str]
|
|
33
|
+
metadata: JSON
|
|
34
|
+
created_at: datetime
|
|
35
|
+
updated_at: datetime
|
|
30
36
|
|
|
31
37
|
@strawberry.field
|
|
32
|
-
def
|
|
38
|
+
async def versions(
|
|
33
39
|
self,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
events = self.dataset[row_ids[self.dataset_role]]
|
|
47
|
-
requested_gql_dimensions = _get_requested_features_and_tags(
|
|
48
|
-
core_dimensions=self.model.scalar_dimensions,
|
|
49
|
-
requested_dimension_names=set(dim.name for dim in dimensions)
|
|
50
|
-
if isinstance(dimensions, list)
|
|
51
|
-
else None,
|
|
40
|
+
info: Info[Context, None],
|
|
41
|
+
first: Optional[int] = 50,
|
|
42
|
+
last: Optional[int] = UNSET,
|
|
43
|
+
after: Optional[CursorString] = UNSET,
|
|
44
|
+
before: Optional[CursorString] = UNSET,
|
|
45
|
+
sort: Optional[DatasetVersionSort] = UNSET,
|
|
46
|
+
) -> Connection[DatasetVersion]:
|
|
47
|
+
args = ConnectionArgs(
|
|
48
|
+
first=first,
|
|
49
|
+
after=after if isinstance(after, CursorString) else None,
|
|
50
|
+
last=last,
|
|
51
|
+
before=before if isinstance(before, CursorString) else None,
|
|
52
52
|
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
53
|
+
async with info.context.db() as session:
|
|
54
|
+
stmt = select(models.DatasetVersion).filter_by(dataset_id=self.id_attr)
|
|
55
|
+
if sort:
|
|
56
|
+
# For now assume the the column names match 1:1 with the enum values
|
|
57
|
+
sort_col = getattr(models.DatasetVersion, sort.col.value)
|
|
58
|
+
if sort.dir is SortDir.desc:
|
|
59
|
+
stmt = stmt.order_by(sort_col.desc(), models.DatasetVersion.id.desc())
|
|
60
|
+
else:
|
|
61
|
+
stmt = stmt.order_by(sort_col.asc(), models.DatasetVersion.id.asc())
|
|
62
|
+
else:
|
|
63
|
+
stmt = stmt.order_by(models.DatasetVersion.created_at.desc())
|
|
64
|
+
versions = await session.scalars(stmt)
|
|
65
|
+
data = [
|
|
66
|
+
DatasetVersion(
|
|
67
|
+
id_attr=version.id,
|
|
68
|
+
description=version.description,
|
|
69
|
+
metadata=version.metadata_,
|
|
70
|
+
created_at=version.created_at,
|
|
59
71
|
)
|
|
60
|
-
for
|
|
72
|
+
for version in versions
|
|
61
73
|
]
|
|
74
|
+
return connection_from_list(data=data, args=args)
|
|
75
|
+
|
|
76
|
+
@strawberry.field(
|
|
77
|
+
description="Number of examples in a specific version if version is specified, or in the "
|
|
78
|
+
"latest version if version is not specified."
|
|
79
|
+
) # type: ignore
|
|
80
|
+
async def example_count(
|
|
81
|
+
self,
|
|
82
|
+
info: Info[Context, None],
|
|
83
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
84
|
+
) -> int:
|
|
85
|
+
dataset_id = self.id_attr
|
|
86
|
+
version_id = (
|
|
87
|
+
from_global_id_with_expected_type(
|
|
88
|
+
global_id=dataset_version_id,
|
|
89
|
+
expected_type_name=DatasetVersion.__name__,
|
|
90
|
+
)
|
|
91
|
+
if dataset_version_id
|
|
92
|
+
else None
|
|
93
|
+
)
|
|
94
|
+
revision_ids = (
|
|
95
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
96
|
+
.join(models.DatasetExample)
|
|
97
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
98
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
99
|
+
)
|
|
100
|
+
if version_id:
|
|
101
|
+
version_id_subquery = (
|
|
102
|
+
select(models.DatasetVersion.id)
|
|
103
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
104
|
+
.where(models.DatasetVersion.id == version_id)
|
|
105
|
+
.scalar_subquery()
|
|
106
|
+
)
|
|
107
|
+
revision_ids = revision_ids.where(
|
|
108
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
109
|
+
)
|
|
110
|
+
stmt = (
|
|
111
|
+
select(count(models.DatasetExampleRevision.id))
|
|
112
|
+
.where(models.DatasetExampleRevision.id.in_(revision_ids))
|
|
113
|
+
.where(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
114
|
+
)
|
|
115
|
+
async with info.context.db() as session:
|
|
116
|
+
return (await session.scalar(stmt)) or 0
|
|
117
|
+
|
|
118
|
+
@strawberry.field
|
|
119
|
+
async def examples(
|
|
120
|
+
self,
|
|
121
|
+
info: Info[Context, None],
|
|
122
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
123
|
+
first: Optional[int] = 50,
|
|
124
|
+
last: Optional[int] = UNSET,
|
|
125
|
+
after: Optional[CursorString] = UNSET,
|
|
126
|
+
before: Optional[CursorString] = UNSET,
|
|
127
|
+
) -> Connection[DatasetExample]:
|
|
128
|
+
args = ConnectionArgs(
|
|
129
|
+
first=first,
|
|
130
|
+
after=after if isinstance(after, CursorString) else None,
|
|
131
|
+
last=last,
|
|
132
|
+
before=before if isinstance(before, CursorString) else None,
|
|
133
|
+
)
|
|
134
|
+
dataset_id = self.id_attr
|
|
135
|
+
version_id = (
|
|
136
|
+
from_global_id_with_expected_type(
|
|
137
|
+
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
138
|
+
)
|
|
139
|
+
if dataset_version_id
|
|
140
|
+
else None
|
|
141
|
+
)
|
|
142
|
+
revision_ids = (
|
|
143
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
144
|
+
.join(models.DatasetExample)
|
|
145
|
+
.where(models.DatasetExample.dataset_id == dataset_id)
|
|
146
|
+
.group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
147
|
+
)
|
|
148
|
+
if version_id:
|
|
149
|
+
version_id_subquery = (
|
|
150
|
+
select(models.DatasetVersion.id)
|
|
151
|
+
.where(models.DatasetVersion.dataset_id == dataset_id)
|
|
152
|
+
.where(models.DatasetVersion.id == version_id)
|
|
153
|
+
.scalar_subquery()
|
|
154
|
+
)
|
|
155
|
+
revision_ids = revision_ids.where(
|
|
156
|
+
models.DatasetExampleRevision.dataset_version_id <= version_id_subquery
|
|
157
|
+
)
|
|
158
|
+
query = (
|
|
159
|
+
select(models.DatasetExample)
|
|
160
|
+
.join(
|
|
161
|
+
models.DatasetExampleRevision,
|
|
162
|
+
onclause=models.DatasetExample.id
|
|
163
|
+
== models.DatasetExampleRevision.dataset_example_id,
|
|
164
|
+
)
|
|
165
|
+
.where(
|
|
166
|
+
and_(
|
|
167
|
+
models.DatasetExampleRevision.id.in_(revision_ids),
|
|
168
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
169
|
+
)
|
|
170
|
+
)
|
|
171
|
+
.order_by(models.DatasetExampleRevision.dataset_example_id.desc())
|
|
172
|
+
)
|
|
173
|
+
async with info.context.db() as session:
|
|
174
|
+
dataset_examples = [
|
|
175
|
+
DatasetExample(
|
|
176
|
+
id_attr=example.id,
|
|
177
|
+
version_id=version_id,
|
|
178
|
+
created_at=example.created_at,
|
|
179
|
+
)
|
|
180
|
+
async for example in await session.stream_scalars(query)
|
|
181
|
+
]
|
|
182
|
+
return connection_from_list(data=dataset_examples, args=args)
|
|
183
|
+
|
|
184
|
+
@strawberry.field(
|
|
185
|
+
description="Number of experiments for a specific version if version is specified, "
|
|
186
|
+
"or for all versions if version is not specified."
|
|
187
|
+
) # type: ignore
|
|
188
|
+
async def experiment_count(
|
|
189
|
+
self,
|
|
190
|
+
info: Info[Context, None],
|
|
191
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
192
|
+
) -> int:
|
|
193
|
+
stmt = select(count(models.Experiment.id)).where(
|
|
194
|
+
models.Experiment.dataset_id == self.id_attr
|
|
195
|
+
)
|
|
196
|
+
version_id = (
|
|
197
|
+
from_global_id_with_expected_type(
|
|
198
|
+
global_id=dataset_version_id,
|
|
199
|
+
expected_type_name=DatasetVersion.__name__,
|
|
200
|
+
)
|
|
201
|
+
if dataset_version_id
|
|
202
|
+
else None
|
|
203
|
+
)
|
|
204
|
+
if version_id is not None:
|
|
205
|
+
stmt = stmt.where(models.Experiment.dataset_version_id == version_id)
|
|
206
|
+
async with info.context.db() as session:
|
|
207
|
+
return (await session.scalar(stmt)) or 0
|
|
208
|
+
|
|
209
|
+
@strawberry.field
|
|
210
|
+
async def experiments(
|
|
211
|
+
self,
|
|
212
|
+
info: Info[Context, None],
|
|
213
|
+
first: Optional[int] = 50,
|
|
214
|
+
last: Optional[int] = UNSET,
|
|
215
|
+
after: Optional[CursorString] = UNSET,
|
|
216
|
+
before: Optional[CursorString] = UNSET,
|
|
217
|
+
) -> Connection[Experiment]:
|
|
218
|
+
args = ConnectionArgs(
|
|
219
|
+
first=first,
|
|
220
|
+
after=after if isinstance(after, CursorString) else None,
|
|
221
|
+
last=last,
|
|
222
|
+
before=before if isinstance(before, CursorString) else None,
|
|
223
|
+
)
|
|
224
|
+
dataset_id = self.id_attr
|
|
225
|
+
row_number = func.row_number().over(order_by=models.Experiment.id).label("row_number")
|
|
226
|
+
query = (
|
|
227
|
+
select(models.Experiment, row_number)
|
|
228
|
+
.where(models.Experiment.dataset_id == dataset_id)
|
|
229
|
+
.order_by(models.Experiment.id.desc())
|
|
230
|
+
)
|
|
231
|
+
async with info.context.db() as session:
|
|
232
|
+
experiments = [
|
|
233
|
+
to_gql_experiment(experiment, sequence_number)
|
|
234
|
+
async for experiment, sequence_number in cast(
|
|
235
|
+
AsyncIterable[Tuple[models.Experiment, int]],
|
|
236
|
+
await session.stream(query),
|
|
237
|
+
)
|
|
238
|
+
]
|
|
239
|
+
return connection_from_list(data=experiments, args=args)
|
|
240
|
+
|
|
241
|
+
@strawberry.field
|
|
242
|
+
async def experiment_annotation_summaries(
|
|
243
|
+
self, info: Info[Context, None]
|
|
244
|
+
) -> List[ExperimentAnnotationSummary]:
|
|
245
|
+
dataset_id = self.id_attr
|
|
246
|
+
query = (
|
|
247
|
+
select(
|
|
248
|
+
models.ExperimentRunAnnotation.name,
|
|
249
|
+
func.min(models.ExperimentRunAnnotation.score),
|
|
250
|
+
func.max(models.ExperimentRunAnnotation.score),
|
|
251
|
+
func.avg(models.ExperimentRunAnnotation.score),
|
|
252
|
+
func.count(),
|
|
253
|
+
func.count(models.ExperimentRunAnnotation.error),
|
|
254
|
+
)
|
|
255
|
+
.join(
|
|
256
|
+
models.ExperimentRun,
|
|
257
|
+
models.ExperimentRunAnnotation.experiment_run_id == models.ExperimentRun.id,
|
|
258
|
+
)
|
|
259
|
+
.join(
|
|
260
|
+
models.Experiment,
|
|
261
|
+
models.ExperimentRun.experiment_id == models.Experiment.id,
|
|
262
|
+
)
|
|
263
|
+
.where(models.Experiment.dataset_id == dataset_id)
|
|
264
|
+
.group_by(models.ExperimentRunAnnotation.name)
|
|
265
|
+
.order_by(models.ExperimentRunAnnotation.name)
|
|
266
|
+
)
|
|
267
|
+
async with info.context.db() as session:
|
|
268
|
+
return [
|
|
269
|
+
ExperimentAnnotationSummary(
|
|
270
|
+
annotation_name=annotation_name,
|
|
271
|
+
min_score=min_score,
|
|
272
|
+
max_score=max_score,
|
|
273
|
+
mean_score=mean_score,
|
|
274
|
+
count=count_,
|
|
275
|
+
error_count=error_count,
|
|
276
|
+
)
|
|
277
|
+
async for (
|
|
278
|
+
annotation_name,
|
|
279
|
+
min_score,
|
|
280
|
+
max_score,
|
|
281
|
+
mean_score,
|
|
282
|
+
count_,
|
|
283
|
+
error_count,
|
|
284
|
+
) in await session.stream(query)
|
|
285
|
+
]
|
|
62
286
|
|
|
63
287
|
|
|
64
|
-
def
|
|
65
|
-
core_dimensions: Iterable[ScalarDimension],
|
|
66
|
-
requested_dimension_names: Optional[Set[str]] = UNSET,
|
|
67
|
-
) -> List[Dimension]:
|
|
288
|
+
def to_gql_dataset(dataset: models.Dataset) -> Dataset:
|
|
68
289
|
"""
|
|
69
|
-
|
|
70
|
-
dimensions are explicitly requested, returns all features and tags.
|
|
290
|
+
Converts an ORM dataset to a GraphQL dataset.
|
|
71
291
|
"""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
return requested_features_and_tags
|
|
292
|
+
return Dataset(
|
|
293
|
+
id_attr=dataset.id,
|
|
294
|
+
name=dataset.name,
|
|
295
|
+
description=dataset.description,
|
|
296
|
+
metadata=dataset.metadata_,
|
|
297
|
+
created_at=dataset.created_at,
|
|
298
|
+
updated_at=dataset.updated_at,
|
|
299
|
+
)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from sqlalchemy import select
|
|
6
|
+
from sqlalchemy.orm import joinedload
|
|
7
|
+
from strawberry import UNSET
|
|
8
|
+
from strawberry.relay.types import Connection, GlobalID, Node, NodeID
|
|
9
|
+
from strawberry.types import Info
|
|
10
|
+
|
|
11
|
+
from phoenix.db import models
|
|
12
|
+
from phoenix.server.api.context import Context
|
|
13
|
+
from phoenix.server.api.types.DatasetExampleRevision import DatasetExampleRevision
|
|
14
|
+
from phoenix.server.api.types.DatasetVersion import DatasetVersion
|
|
15
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
16
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
17
|
+
from phoenix.server.api.types.pagination import (
|
|
18
|
+
ConnectionArgs,
|
|
19
|
+
CursorString,
|
|
20
|
+
connection_from_list,
|
|
21
|
+
)
|
|
22
|
+
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@strawberry.type
|
|
26
|
+
class DatasetExample(Node):
|
|
27
|
+
id_attr: NodeID[int]
|
|
28
|
+
created_at: datetime
|
|
29
|
+
version_id: strawberry.Private[Optional[int]] = None
|
|
30
|
+
|
|
31
|
+
@strawberry.field
|
|
32
|
+
async def revision(
|
|
33
|
+
self,
|
|
34
|
+
info: Info[Context, None],
|
|
35
|
+
dataset_version_id: Optional[GlobalID] = UNSET,
|
|
36
|
+
) -> DatasetExampleRevision:
|
|
37
|
+
example_id = self.id_attr
|
|
38
|
+
version_id: Optional[int] = None
|
|
39
|
+
if dataset_version_id:
|
|
40
|
+
version_id = from_global_id_with_expected_type(
|
|
41
|
+
global_id=dataset_version_id, expected_type_name=DatasetVersion.__name__
|
|
42
|
+
)
|
|
43
|
+
elif self.version_id is not None:
|
|
44
|
+
version_id = self.version_id
|
|
45
|
+
return await info.context.data_loaders.dataset_example_revisions.load(
|
|
46
|
+
(example_id, version_id)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@strawberry.field
|
|
50
|
+
async def span(
|
|
51
|
+
self,
|
|
52
|
+
info: Info[Context, None],
|
|
53
|
+
) -> Optional[Span]:
|
|
54
|
+
return (
|
|
55
|
+
to_gql_span(span)
|
|
56
|
+
if (span := await info.context.data_loaders.dataset_example_spans.load(self.id_attr))
|
|
57
|
+
else None
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@strawberry.field
|
|
61
|
+
async def experiment_runs(
|
|
62
|
+
self,
|
|
63
|
+
info: Info[Context, None],
|
|
64
|
+
first: Optional[int] = 50,
|
|
65
|
+
last: Optional[int] = UNSET,
|
|
66
|
+
after: Optional[CursorString] = UNSET,
|
|
67
|
+
before: Optional[CursorString] = UNSET,
|
|
68
|
+
) -> Connection[ExperimentRun]:
|
|
69
|
+
args = ConnectionArgs(
|
|
70
|
+
first=first,
|
|
71
|
+
after=after if isinstance(after, CursorString) else None,
|
|
72
|
+
last=last,
|
|
73
|
+
before=before if isinstance(before, CursorString) else None,
|
|
74
|
+
)
|
|
75
|
+
example_id = self.id_attr
|
|
76
|
+
query = (
|
|
77
|
+
select(models.ExperimentRun)
|
|
78
|
+
.options(joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id))
|
|
79
|
+
.join(models.Experiment, models.Experiment.id == models.ExperimentRun.experiment_id)
|
|
80
|
+
.where(models.ExperimentRun.dataset_example_id == example_id)
|
|
81
|
+
.order_by(models.Experiment.id.desc())
|
|
82
|
+
)
|
|
83
|
+
async with info.context.db() as session:
|
|
84
|
+
runs = (await session.scalars(query)).all()
|
|
85
|
+
return connection_from_list([to_gql_experiment_run(run) for run in runs], args)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
|
|
6
|
+
from phoenix.db import models
|
|
7
|
+
from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@strawberry.enum
|
|
11
|
+
class RevisionKind(Enum):
|
|
12
|
+
CREATE = "CREATE"
|
|
13
|
+
PATCH = "PATCH"
|
|
14
|
+
DELETE = "DELETE"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class DatasetExampleRevision(ExampleRevision):
|
|
19
|
+
"""
|
|
20
|
+
Represents a revision (i.e., update or alteration) of a dataset example.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
revision_kind: RevisionKind
|
|
24
|
+
created_at: datetime
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_orm_revision(cls, revision: models.DatasetExampleRevision) -> "DatasetExampleRevision":
|
|
28
|
+
return cls(
|
|
29
|
+
input=revision.input,
|
|
30
|
+
output=revision.output,
|
|
31
|
+
metadata=revision.metadata_,
|
|
32
|
+
revision_kind=RevisionKind(revision.revision_kind),
|
|
33
|
+
created_at=revision.created_at,
|
|
34
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import strawberry
|
|
5
|
+
from strawberry.relay import Node, NodeID
|
|
6
|
+
from strawberry.scalars import JSON
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@strawberry.type
|
|
10
|
+
class DatasetVersion(Node):
|
|
11
|
+
id_attr: NodeID[int]
|
|
12
|
+
description: Optional[str]
|
|
13
|
+
metadata: JSON
|
|
14
|
+
created_at: datetime
|