arize-phoenix 4.5.0__py3-none-any.whl → 4.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/METADATA +16 -8
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/RECORD +122 -58
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +0 -27
- phoenix/config.py +42 -7
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +64 -62
- phoenix/core/model_schema_adapter.py +27 -25
- phoenix/datetime_utils.py +4 -0
- phoenix/db/bulk_inserter.py +54 -14
- phoenix/db/insertion/dataset.py +237 -0
- phoenix/db/insertion/evaluation.py +10 -10
- phoenix/db/insertion/helpers.py +17 -14
- phoenix/db/insertion/span.py +3 -3
- phoenix/db/migrations/types.py +29 -0
- phoenix/db/migrations/versions/10460e46d750_datasets.py +291 -0
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +2 -28
- phoenix/db/models.py +236 -4
- phoenix/experiments/__init__.py +6 -0
- phoenix/experiments/evaluators/__init__.py +29 -0
- phoenix/experiments/evaluators/base.py +153 -0
- phoenix/experiments/evaluators/code_evaluators.py +99 -0
- phoenix/experiments/evaluators/llm_evaluators.py +244 -0
- phoenix/experiments/evaluators/utils.py +186 -0
- phoenix/experiments/functions.py +757 -0
- phoenix/experiments/tracing.py +85 -0
- phoenix/experiments/types.py +753 -0
- phoenix/experiments/utils.py +24 -0
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +20 -0
- phoenix/server/api/dataloaders/__init__.py +20 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +54 -0
- phoenix/server/api/dataloaders/dataset_example_revisions.py +100 -0
- phoenix/server/api/dataloaders/dataset_example_spans.py +43 -0
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +85 -0
- phoenix/server/api/dataloaders/experiment_error_rates.py +43 -0
- phoenix/server/api/dataloaders/experiment_run_counts.py +42 -0
- phoenix/server/api/dataloaders/experiment_sequence_number.py +49 -0
- phoenix/server/api/dataloaders/project_by_name.py +31 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -3
- phoenix/server/api/dataloaders/span_projects.py +33 -0
- phoenix/server/api/dataloaders/trace_row_ids.py +39 -0
- phoenix/server/api/helpers/dataset_helpers.py +179 -0
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +16 -0
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +14 -0
- phoenix/server/api/input_types/ClearProjectInput.py +15 -0
- phoenix/server/api/input_types/CreateDatasetInput.py +12 -0
- phoenix/server/api/input_types/DatasetExampleInput.py +14 -0
- phoenix/server/api/input_types/DatasetSort.py +17 -0
- phoenix/server/api/input_types/DatasetVersionSort.py +16 -0
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +13 -0
- phoenix/server/api/input_types/DeleteDatasetInput.py +7 -0
- phoenix/server/api/input_types/DeleteExperimentsInput.py +9 -0
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +35 -0
- phoenix/server/api/input_types/PatchDatasetInput.py +14 -0
- phoenix/server/api/mutations/__init__.py +13 -0
- phoenix/server/api/mutations/auth.py +11 -0
- phoenix/server/api/mutations/dataset_mutations.py +520 -0
- phoenix/server/api/mutations/experiment_mutations.py +65 -0
- phoenix/server/api/{types/ExportEventsMutation.py → mutations/export_events_mutations.py} +17 -14
- phoenix/server/api/mutations/project_mutations.py +47 -0
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +6 -0
- phoenix/server/api/openapi/schema.py +16 -0
- phoenix/server/api/queries.py +503 -0
- phoenix/server/api/routers/v1/__init__.py +77 -2
- phoenix/server/api/routers/v1/dataset_examples.py +178 -0
- phoenix/server/api/routers/v1/datasets.py +965 -0
- phoenix/server/api/routers/v1/evaluations.py +8 -13
- phoenix/server/api/routers/v1/experiment_evaluations.py +143 -0
- phoenix/server/api/routers/v1/experiment_runs.py +220 -0
- phoenix/server/api/routers/v1/experiments.py +302 -0
- phoenix/server/api/routers/v1/spans.py +9 -5
- phoenix/server/api/routers/v1/traces.py +1 -4
- phoenix/server/api/schema.py +2 -303
- phoenix/server/api/types/AnnotatorKind.py +10 -0
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/CreateDatasetPayload.py +8 -0
- phoenix/server/api/types/Dataset.py +282 -63
- phoenix/server/api/types/DatasetExample.py +85 -0
- phoenix/server/api/types/DatasetExampleRevision.py +34 -0
- phoenix/server/api/types/DatasetVersion.py +14 -0
- phoenix/server/api/types/Dimension.py +30 -29
- phoenix/server/api/types/EmbeddingDimension.py +40 -34
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/types/ExampleRevisionInterface.py +14 -0
- phoenix/server/api/types/Experiment.py +147 -0
- phoenix/server/api/types/ExperimentAnnotationSummary.py +13 -0
- phoenix/server/api/types/ExperimentComparison.py +19 -0
- phoenix/server/api/types/ExperimentRun.py +91 -0
- phoenix/server/api/types/ExperimentRunAnnotation.py +57 -0
- phoenix/server/api/types/Inferences.py +80 -0
- phoenix/server/api/types/InferencesRole.py +23 -0
- phoenix/server/api/types/Model.py +43 -42
- phoenix/server/api/types/Project.py +26 -12
- phoenix/server/api/types/Span.py +79 -2
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +15 -4
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +5 -111
- phoenix/server/api/types/pagination.py +10 -52
- phoenix/server/app.py +103 -49
- phoenix/server/main.py +49 -27
- phoenix/server/openapi/docs.py +3 -0
- phoenix/server/static/index.js +2300 -1294
- phoenix/server/templates/index.html +1 -0
- phoenix/services.py +15 -15
- phoenix/session/client.py +581 -22
- phoenix/session/session.py +47 -37
- phoenix/trace/exporter.py +14 -9
- phoenix/trace/fixtures.py +133 -7
- phoenix/trace/schemas.py +1 -2
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/utilities/json.py +61 -0
- phoenix/utilities/re.py +50 -0
- phoenix/version.py +1 -1
- phoenix/server/api/types/DatasetRole.py +0 -23
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.5.0.dist-info → arize_phoenix-4.6.2.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers.py → helpers/__init__.py} +0 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import strawberry
|
|
2
|
+
from sqlalchemy import delete, select
|
|
3
|
+
from sqlalchemy.orm import load_only
|
|
4
|
+
from strawberry.relay import GlobalID
|
|
5
|
+
from strawberry.types import Info
|
|
6
|
+
|
|
7
|
+
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
8
|
+
from phoenix.db import models
|
|
9
|
+
from phoenix.db.insertion.span import ClearProjectSpansEvent
|
|
10
|
+
from phoenix.server.api.context import Context
|
|
11
|
+
from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput
|
|
12
|
+
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
13
|
+
from phoenix.server.api.queries import Query
|
|
14
|
+
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@strawberry.type
|
|
18
|
+
class ProjectMutationMixin:
|
|
19
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
20
|
+
async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
|
|
21
|
+
node_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project")
|
|
22
|
+
async with info.context.db() as session:
|
|
23
|
+
project = await session.scalar(
|
|
24
|
+
select(models.Project)
|
|
25
|
+
.where(models.Project.id == node_id)
|
|
26
|
+
.options(load_only(models.Project.name))
|
|
27
|
+
)
|
|
28
|
+
if project is None:
|
|
29
|
+
raise ValueError(f"Unknown project: {id}")
|
|
30
|
+
if project.name == DEFAULT_PROJECT_NAME:
|
|
31
|
+
raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
|
|
32
|
+
await session.delete(project)
|
|
33
|
+
return Query()
|
|
34
|
+
|
|
35
|
+
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
36
|
+
async def clear_project(self, info: Info[Context, None], input: ClearProjectInput) -> Query:
|
|
37
|
+
project_id = from_global_id_with_expected_type(
|
|
38
|
+
global_id=input.id, expected_type_name="Project"
|
|
39
|
+
)
|
|
40
|
+
delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
|
|
41
|
+
if input.end_time is not None:
|
|
42
|
+
delete_statement = delete_statement.where(models.Trace.start_time < input.end_time)
|
|
43
|
+
async with info.context.db() as session:
|
|
44
|
+
await session.execute(delete_statement)
|
|
45
|
+
if cache := info.context.cache_for_dataloaders:
|
|
46
|
+
cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
|
|
47
|
+
return Query()
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from starlette.schemas import SchemaGenerator
|
|
4
|
+
|
|
5
|
+
from phoenix.server.api.routers.v1 import V1_ROUTES
|
|
6
|
+
|
|
7
|
+
OPENAPI_SCHEMA_GENERATOR = SchemaGenerator(
|
|
8
|
+
{"openapi": "3.0.0", "info": {"title": "Arize-Phoenix API", "version": "1.0"}}
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_openapi_schema() -> Any:
|
|
13
|
+
"""
|
|
14
|
+
Exports an OpenAPI schema for the Phoenix REST API as a JSON object.
|
|
15
|
+
"""
|
|
16
|
+
return OPENAPI_SCHEMA_GENERATOR.get_schema(V1_ROUTES) # type: ignore
|
|
@@ -0,0 +1,503 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import DefaultDict, Dict, List, Optional, Set, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
6
|
+
import strawberry
|
|
7
|
+
from sqlalchemy import and_, distinct, func, select
|
|
8
|
+
from sqlalchemy.orm import joinedload
|
|
9
|
+
from strawberry import ID, UNSET
|
|
10
|
+
from strawberry.relay import Connection, GlobalID, Node
|
|
11
|
+
from strawberry.types import Info
|
|
12
|
+
from typing_extensions import Annotated, TypeAlias
|
|
13
|
+
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.db.models import (
|
|
16
|
+
DatasetExample as OrmExample,
|
|
17
|
+
)
|
|
18
|
+
from phoenix.db.models import (
|
|
19
|
+
DatasetExampleRevision as OrmRevision,
|
|
20
|
+
)
|
|
21
|
+
from phoenix.db.models import (
|
|
22
|
+
DatasetVersion as OrmVersion,
|
|
23
|
+
)
|
|
24
|
+
from phoenix.db.models import (
|
|
25
|
+
Experiment as OrmExperiment,
|
|
26
|
+
)
|
|
27
|
+
from phoenix.db.models import (
|
|
28
|
+
ExperimentRun as OrmRun,
|
|
29
|
+
)
|
|
30
|
+
from phoenix.db.models import (
|
|
31
|
+
Trace as OrmTrace,
|
|
32
|
+
)
|
|
33
|
+
from phoenix.pointcloud.clustering import Hdbscan
|
|
34
|
+
from phoenix.server.api.context import Context
|
|
35
|
+
from phoenix.server.api.helpers import ensure_list
|
|
36
|
+
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
37
|
+
from phoenix.server.api.input_types.Coordinates import (
|
|
38
|
+
InputCoordinate2D,
|
|
39
|
+
InputCoordinate3D,
|
|
40
|
+
)
|
|
41
|
+
from phoenix.server.api.input_types.DatasetSort import DatasetSort
|
|
42
|
+
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
43
|
+
from phoenix.server.api.types.Dataset import Dataset, to_gql_dataset
|
|
44
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
45
|
+
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
46
|
+
from phoenix.server.api.types.EmbeddingDimension import (
|
|
47
|
+
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
48
|
+
DEFAULT_MIN_CLUSTER_SIZE,
|
|
49
|
+
DEFAULT_MIN_SAMPLES,
|
|
50
|
+
to_gql_embedding_dimension,
|
|
51
|
+
)
|
|
52
|
+
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
53
|
+
from phoenix.server.api.types.Experiment import Experiment
|
|
54
|
+
from phoenix.server.api.types.ExperimentComparison import ExperimentComparison, RunComparisonItem
|
|
55
|
+
from phoenix.server.api.types.ExperimentRun import ExperimentRun, to_gql_experiment_run
|
|
56
|
+
from phoenix.server.api.types.Functionality import Functionality
|
|
57
|
+
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
58
|
+
from phoenix.server.api.types.Model import Model
|
|
59
|
+
from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type
|
|
60
|
+
from phoenix.server.api.types.pagination import (
|
|
61
|
+
ConnectionArgs,
|
|
62
|
+
CursorString,
|
|
63
|
+
connection_from_list,
|
|
64
|
+
)
|
|
65
|
+
from phoenix.server.api.types.Project import Project
|
|
66
|
+
from phoenix.server.api.types.SortDir import SortDir
|
|
67
|
+
from phoenix.server.api.types.Span import Span, to_gql_span
|
|
68
|
+
from phoenix.server.api.types.Trace import Trace
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@strawberry.type
|
|
72
|
+
class Query:
|
|
73
|
+
@strawberry.field
|
|
74
|
+
async def projects(
|
|
75
|
+
self,
|
|
76
|
+
info: Info[Context, None],
|
|
77
|
+
first: Optional[int] = 50,
|
|
78
|
+
last: Optional[int] = UNSET,
|
|
79
|
+
after: Optional[CursorString] = UNSET,
|
|
80
|
+
before: Optional[CursorString] = UNSET,
|
|
81
|
+
) -> Connection[Project]:
|
|
82
|
+
args = ConnectionArgs(
|
|
83
|
+
first=first,
|
|
84
|
+
after=after if isinstance(after, CursorString) else None,
|
|
85
|
+
last=last,
|
|
86
|
+
before=before if isinstance(before, CursorString) else None,
|
|
87
|
+
)
|
|
88
|
+
stmt = (
|
|
89
|
+
select(models.Project)
|
|
90
|
+
.outerjoin(
|
|
91
|
+
models.Experiment,
|
|
92
|
+
models.Project.name == models.Experiment.project_name,
|
|
93
|
+
)
|
|
94
|
+
.where(models.Experiment.project_name.is_(None))
|
|
95
|
+
)
|
|
96
|
+
async with info.context.db() as session:
|
|
97
|
+
projects = await session.stream_scalars(stmt)
|
|
98
|
+
data = [
|
|
99
|
+
Project(
|
|
100
|
+
id_attr=project.id,
|
|
101
|
+
name=project.name,
|
|
102
|
+
gradient_start_color=project.gradient_start_color,
|
|
103
|
+
gradient_end_color=project.gradient_end_color,
|
|
104
|
+
)
|
|
105
|
+
async for project in projects
|
|
106
|
+
]
|
|
107
|
+
return connection_from_list(data=data, args=args)
|
|
108
|
+
|
|
109
|
+
@strawberry.field
|
|
110
|
+
async def datasets(
|
|
111
|
+
self,
|
|
112
|
+
info: Info[Context, None],
|
|
113
|
+
first: Optional[int] = 50,
|
|
114
|
+
last: Optional[int] = UNSET,
|
|
115
|
+
after: Optional[CursorString] = UNSET,
|
|
116
|
+
before: Optional[CursorString] = UNSET,
|
|
117
|
+
sort: Optional[DatasetSort] = UNSET,
|
|
118
|
+
) -> Connection[Dataset]:
|
|
119
|
+
args = ConnectionArgs(
|
|
120
|
+
first=first,
|
|
121
|
+
after=after if isinstance(after, CursorString) else None,
|
|
122
|
+
last=last,
|
|
123
|
+
before=before if isinstance(before, CursorString) else None,
|
|
124
|
+
)
|
|
125
|
+
stmt = select(models.Dataset)
|
|
126
|
+
if sort:
|
|
127
|
+
sort_col = getattr(models.Dataset, sort.col.value)
|
|
128
|
+
stmt = stmt.order_by(sort_col.desc() if sort.dir is SortDir.desc else sort_col.asc())
|
|
129
|
+
async with info.context.db() as session:
|
|
130
|
+
datasets = await session.scalars(stmt)
|
|
131
|
+
return connection_from_list(
|
|
132
|
+
data=[to_gql_dataset(dataset) for dataset in datasets], args=args
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@strawberry.field
|
|
136
|
+
async def compare_experiments(
|
|
137
|
+
self,
|
|
138
|
+
info: Info[Context, None],
|
|
139
|
+
experiment_ids: List[GlobalID],
|
|
140
|
+
) -> List[ExperimentComparison]:
|
|
141
|
+
experiment_ids_ = [
|
|
142
|
+
from_global_id_with_expected_type(experiment_id, OrmExperiment.__name__)
|
|
143
|
+
for experiment_id in experiment_ids
|
|
144
|
+
]
|
|
145
|
+
if len(set(experiment_ids_)) != len(experiment_ids_):
|
|
146
|
+
raise ValueError("Experiment IDs must be unique.")
|
|
147
|
+
|
|
148
|
+
async with info.context.db() as session:
|
|
149
|
+
validation_result = (
|
|
150
|
+
await session.execute(
|
|
151
|
+
select(
|
|
152
|
+
func.count(distinct(OrmVersion.dataset_id)),
|
|
153
|
+
func.max(OrmVersion.dataset_id),
|
|
154
|
+
func.max(OrmVersion.id),
|
|
155
|
+
func.count(OrmExperiment.id),
|
|
156
|
+
)
|
|
157
|
+
.select_from(OrmVersion)
|
|
158
|
+
.join(
|
|
159
|
+
OrmExperiment,
|
|
160
|
+
OrmExperiment.dataset_version_id == OrmVersion.id,
|
|
161
|
+
)
|
|
162
|
+
.where(
|
|
163
|
+
OrmExperiment.id.in_(experiment_ids_),
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
).first()
|
|
167
|
+
if validation_result is None:
|
|
168
|
+
raise ValueError("No experiments could be found for input IDs.")
|
|
169
|
+
|
|
170
|
+
num_datasets, dataset_id, version_id, num_resolved_experiment_ids = validation_result
|
|
171
|
+
if num_datasets != 1:
|
|
172
|
+
raise ValueError("Experiments must belong to the same dataset.")
|
|
173
|
+
if num_resolved_experiment_ids != len(experiment_ids_):
|
|
174
|
+
raise ValueError("Unable to resolve one or more experiment IDs.")
|
|
175
|
+
|
|
176
|
+
revision_ids = (
|
|
177
|
+
select(func.max(OrmRevision.id))
|
|
178
|
+
.join(OrmExample, OrmExample.id == OrmRevision.dataset_example_id)
|
|
179
|
+
.where(
|
|
180
|
+
and_(
|
|
181
|
+
OrmRevision.dataset_version_id <= version_id,
|
|
182
|
+
OrmExample.dataset_id == dataset_id,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
.group_by(OrmRevision.dataset_example_id)
|
|
186
|
+
.scalar_subquery()
|
|
187
|
+
)
|
|
188
|
+
examples = (
|
|
189
|
+
await session.scalars(
|
|
190
|
+
select(OrmExample)
|
|
191
|
+
.join(OrmRevision, OrmExample.id == OrmRevision.dataset_example_id)
|
|
192
|
+
.where(
|
|
193
|
+
and_(
|
|
194
|
+
OrmRevision.id.in_(revision_ids),
|
|
195
|
+
OrmRevision.revision_kind != "DELETE",
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
.order_by(OrmRevision.dataset_example_id.desc())
|
|
199
|
+
)
|
|
200
|
+
).all()
|
|
201
|
+
|
|
202
|
+
ExampleID: TypeAlias = int
|
|
203
|
+
ExperimentID: TypeAlias = int
|
|
204
|
+
runs: DefaultDict[ExampleID, DefaultDict[ExperimentID, List[OrmRun]]] = defaultdict(
|
|
205
|
+
lambda: defaultdict(list)
|
|
206
|
+
)
|
|
207
|
+
async for run in await session.stream_scalars(
|
|
208
|
+
select(OrmRun)
|
|
209
|
+
.where(
|
|
210
|
+
and_(
|
|
211
|
+
OrmRun.dataset_example_id.in_(example.id for example in examples),
|
|
212
|
+
OrmRun.experiment_id.in_(experiment_ids_),
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
.options(joinedload(OrmRun.trace).load_only(OrmTrace.trace_id))
|
|
216
|
+
):
|
|
217
|
+
runs[run.dataset_example_id][run.experiment_id].append(run)
|
|
218
|
+
|
|
219
|
+
experiment_comparisons = []
|
|
220
|
+
for example in examples:
|
|
221
|
+
run_comparison_items = []
|
|
222
|
+
for experiment_id in experiment_ids_:
|
|
223
|
+
run_comparison_items.append(
|
|
224
|
+
RunComparisonItem(
|
|
225
|
+
experiment_id=GlobalID(Experiment.__name__, str(experiment_id)),
|
|
226
|
+
runs=[
|
|
227
|
+
to_gql_experiment_run(run)
|
|
228
|
+
for run in sorted(
|
|
229
|
+
runs[example.id][experiment_id], key=lambda run: run.id
|
|
230
|
+
)
|
|
231
|
+
],
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
experiment_comparisons.append(
|
|
235
|
+
ExperimentComparison(
|
|
236
|
+
example=DatasetExample(
|
|
237
|
+
id_attr=example.id,
|
|
238
|
+
created_at=example.created_at,
|
|
239
|
+
version_id=version_id,
|
|
240
|
+
),
|
|
241
|
+
run_comparison_items=run_comparison_items,
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
return experiment_comparisons
|
|
245
|
+
|
|
246
|
+
@strawberry.field
|
|
247
|
+
async def functionality(self, info: Info[Context, None]) -> "Functionality":
|
|
248
|
+
has_model_inferences = not info.context.model.is_empty
|
|
249
|
+
async with info.context.db() as session:
|
|
250
|
+
has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
|
|
251
|
+
return Functionality(
|
|
252
|
+
model_inferences=has_model_inferences,
|
|
253
|
+
tracing=has_traces,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
@strawberry.field
|
|
257
|
+
def model(self) -> Model:
|
|
258
|
+
return Model()
|
|
259
|
+
|
|
260
|
+
@strawberry.field
|
|
261
|
+
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
|
|
262
|
+
type_name, node_id = from_global_id(id)
|
|
263
|
+
if type_name == "Dimension":
|
|
264
|
+
dimension = info.context.model.scalar_dimensions[node_id]
|
|
265
|
+
return to_gql_dimension(node_id, dimension)
|
|
266
|
+
elif type_name == "EmbeddingDimension":
|
|
267
|
+
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
268
|
+
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
269
|
+
elif type_name == "Project":
|
|
270
|
+
project_stmt = select(
|
|
271
|
+
models.Project.id,
|
|
272
|
+
models.Project.name,
|
|
273
|
+
models.Project.gradient_start_color,
|
|
274
|
+
models.Project.gradient_end_color,
|
|
275
|
+
).where(models.Project.id == node_id)
|
|
276
|
+
async with info.context.db() as session:
|
|
277
|
+
project = (await session.execute(project_stmt)).first()
|
|
278
|
+
if project is None:
|
|
279
|
+
raise ValueError(f"Unknown project: {id}")
|
|
280
|
+
return Project(
|
|
281
|
+
id_attr=project.id,
|
|
282
|
+
name=project.name,
|
|
283
|
+
gradient_start_color=project.gradient_start_color,
|
|
284
|
+
gradient_end_color=project.gradient_end_color,
|
|
285
|
+
)
|
|
286
|
+
elif type_name == "Trace":
|
|
287
|
+
trace_stmt = select(
|
|
288
|
+
models.Trace.id,
|
|
289
|
+
models.Trace.project_rowid,
|
|
290
|
+
).where(models.Trace.id == node_id)
|
|
291
|
+
async with info.context.db() as session:
|
|
292
|
+
trace = (await session.execute(trace_stmt)).first()
|
|
293
|
+
if trace is None:
|
|
294
|
+
raise ValueError(f"Unknown trace: {id}")
|
|
295
|
+
return Trace(
|
|
296
|
+
id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid
|
|
297
|
+
)
|
|
298
|
+
elif type_name == Span.__name__:
|
|
299
|
+
span_stmt = (
|
|
300
|
+
select(models.Span)
|
|
301
|
+
.options(
|
|
302
|
+
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
|
|
303
|
+
)
|
|
304
|
+
.where(models.Span.id == node_id)
|
|
305
|
+
)
|
|
306
|
+
async with info.context.db() as session:
|
|
307
|
+
span = await session.scalar(span_stmt)
|
|
308
|
+
if span is None:
|
|
309
|
+
raise ValueError(f"Unknown span: {id}")
|
|
310
|
+
return to_gql_span(span)
|
|
311
|
+
elif type_name == Dataset.__name__:
|
|
312
|
+
dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
|
|
313
|
+
async with info.context.db() as session:
|
|
314
|
+
if (dataset := await session.scalar(dataset_stmt)) is None:
|
|
315
|
+
raise ValueError(f"Unknown dataset: {id}")
|
|
316
|
+
return to_gql_dataset(dataset)
|
|
317
|
+
elif type_name == DatasetExample.__name__:
|
|
318
|
+
example_id = node_id
|
|
319
|
+
latest_revision_id = (
|
|
320
|
+
select(func.max(models.DatasetExampleRevision.id))
|
|
321
|
+
.where(models.DatasetExampleRevision.dataset_example_id == example_id)
|
|
322
|
+
.scalar_subquery()
|
|
323
|
+
)
|
|
324
|
+
async with info.context.db() as session:
|
|
325
|
+
example = await session.scalar(
|
|
326
|
+
select(models.DatasetExample)
|
|
327
|
+
.join(
|
|
328
|
+
models.DatasetExampleRevision,
|
|
329
|
+
onclause=models.DatasetExampleRevision.dataset_example_id
|
|
330
|
+
== models.DatasetExample.id,
|
|
331
|
+
)
|
|
332
|
+
.where(
|
|
333
|
+
and_(
|
|
334
|
+
models.DatasetExample.id == example_id,
|
|
335
|
+
models.DatasetExampleRevision.id == latest_revision_id,
|
|
336
|
+
models.DatasetExampleRevision.revision_kind != "DELETE",
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
)
|
|
340
|
+
if not example:
|
|
341
|
+
raise ValueError(f"Unknown dataset example: {id}")
|
|
342
|
+
return DatasetExample(
|
|
343
|
+
id_attr=example.id,
|
|
344
|
+
created_at=example.created_at,
|
|
345
|
+
)
|
|
346
|
+
elif type_name == Experiment.__name__:
|
|
347
|
+
async with info.context.db() as session:
|
|
348
|
+
experiment = await session.scalar(
|
|
349
|
+
select(models.Experiment).where(models.Experiment.id == node_id)
|
|
350
|
+
)
|
|
351
|
+
if not experiment:
|
|
352
|
+
raise ValueError(f"Unknown experiment: {id}")
|
|
353
|
+
return Experiment(
|
|
354
|
+
id_attr=experiment.id,
|
|
355
|
+
name=experiment.name,
|
|
356
|
+
project_name=experiment.project_name,
|
|
357
|
+
description=experiment.description,
|
|
358
|
+
created_at=experiment.created_at,
|
|
359
|
+
updated_at=experiment.updated_at,
|
|
360
|
+
metadata=experiment.metadata_,
|
|
361
|
+
)
|
|
362
|
+
elif type_name == ExperimentRun.__name__:
|
|
363
|
+
async with info.context.db() as session:
|
|
364
|
+
if not (
|
|
365
|
+
run := await session.scalar(
|
|
366
|
+
select(models.ExperimentRun)
|
|
367
|
+
.where(models.ExperimentRun.id == node_id)
|
|
368
|
+
.options(
|
|
369
|
+
joinedload(models.ExperimentRun.trace).load_only(models.Trace.trace_id)
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
):
|
|
373
|
+
raise ValueError(f"Unknown experiment run: {id}")
|
|
374
|
+
return to_gql_experiment_run(run)
|
|
375
|
+
raise Exception(f"Unknown node type: {type_name}")
|
|
376
|
+
|
|
377
|
+
@strawberry.field
|
|
378
|
+
def clusters(
|
|
379
|
+
self,
|
|
380
|
+
clusters: List[ClusterInput],
|
|
381
|
+
) -> List[Cluster]:
|
|
382
|
+
clustered_events: Dict[str, Set[ID]] = defaultdict(set)
|
|
383
|
+
for i, cluster in enumerate(clusters):
|
|
384
|
+
clustered_events[cluster.id or str(i)].update(cluster.event_ids)
|
|
385
|
+
return to_gql_clusters(
|
|
386
|
+
clustered_events=clustered_events,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
@strawberry.field
|
|
390
|
+
def hdbscan_clustering(
|
|
391
|
+
self,
|
|
392
|
+
info: Info[Context, None],
|
|
393
|
+
event_ids: Annotated[
|
|
394
|
+
List[ID],
|
|
395
|
+
strawberry.argument(
|
|
396
|
+
description="Event ID of the coordinates",
|
|
397
|
+
),
|
|
398
|
+
],
|
|
399
|
+
coordinates_2d: Annotated[
|
|
400
|
+
Optional[List[InputCoordinate2D]],
|
|
401
|
+
strawberry.argument(
|
|
402
|
+
description="Point coordinates. Must be either 2D or 3D.",
|
|
403
|
+
),
|
|
404
|
+
] = UNSET,
|
|
405
|
+
coordinates_3d: Annotated[
|
|
406
|
+
Optional[List[InputCoordinate3D]],
|
|
407
|
+
strawberry.argument(
|
|
408
|
+
description="Point coordinates. Must be either 2D or 3D.",
|
|
409
|
+
),
|
|
410
|
+
] = UNSET,
|
|
411
|
+
min_cluster_size: Annotated[
|
|
412
|
+
int,
|
|
413
|
+
strawberry.argument(
|
|
414
|
+
description="HDBSCAN minimum cluster size",
|
|
415
|
+
),
|
|
416
|
+
] = DEFAULT_MIN_CLUSTER_SIZE,
|
|
417
|
+
cluster_min_samples: Annotated[
|
|
418
|
+
int,
|
|
419
|
+
strawberry.argument(
|
|
420
|
+
description="HDBSCAN minimum samples",
|
|
421
|
+
),
|
|
422
|
+
] = DEFAULT_MIN_SAMPLES,
|
|
423
|
+
cluster_selection_epsilon: Annotated[
|
|
424
|
+
float,
|
|
425
|
+
strawberry.argument(
|
|
426
|
+
description="HDBSCAN cluster selection epsilon",
|
|
427
|
+
),
|
|
428
|
+
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
429
|
+
) -> List[Cluster]:
|
|
430
|
+
coordinates_3d = ensure_list(coordinates_3d)
|
|
431
|
+
coordinates_2d = ensure_list(coordinates_2d)
|
|
432
|
+
|
|
433
|
+
if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
|
|
434
|
+
raise ValueError("must specify only one of 2D or 3D coordinates")
|
|
435
|
+
|
|
436
|
+
if len(coordinates_3d) > 0:
|
|
437
|
+
coordinates = list(
|
|
438
|
+
map(
|
|
439
|
+
lambda coord: np.array(
|
|
440
|
+
[coord.x, coord.y, coord.z],
|
|
441
|
+
),
|
|
442
|
+
coordinates_3d,
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
coordinates = list(
|
|
447
|
+
map(
|
|
448
|
+
lambda coord: np.array(
|
|
449
|
+
[coord.x, coord.y],
|
|
450
|
+
),
|
|
451
|
+
coordinates_2d,
|
|
452
|
+
)
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
if len(event_ids) != len(coordinates):
|
|
456
|
+
raise ValueError(
|
|
457
|
+
f"length mismatch between "
|
|
458
|
+
f"event_ids ({len(event_ids)}) "
|
|
459
|
+
f"and coordinates ({len(coordinates)})"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if len(event_ids) == 0:
|
|
463
|
+
return []
|
|
464
|
+
|
|
465
|
+
grouped_event_ids: Dict[
|
|
466
|
+
Union[InferencesRole, AncillaryInferencesRole],
|
|
467
|
+
List[ID],
|
|
468
|
+
] = defaultdict(list)
|
|
469
|
+
grouped_coordinates: Dict[
|
|
470
|
+
Union[InferencesRole, AncillaryInferencesRole],
|
|
471
|
+
List[npt.NDArray[np.float64]],
|
|
472
|
+
] = defaultdict(list)
|
|
473
|
+
|
|
474
|
+
for event_id, coordinate in zip(event_ids, coordinates):
|
|
475
|
+
row_id, inferences_role = unpack_event_id(event_id)
|
|
476
|
+
grouped_coordinates[inferences_role].append(coordinate)
|
|
477
|
+
grouped_event_ids[inferences_role].append(create_event_id(row_id, inferences_role))
|
|
478
|
+
|
|
479
|
+
stacked_event_ids = (
|
|
480
|
+
grouped_event_ids[InferencesRole.primary]
|
|
481
|
+
+ grouped_event_ids[InferencesRole.reference]
|
|
482
|
+
+ grouped_event_ids[AncillaryInferencesRole.corpus]
|
|
483
|
+
)
|
|
484
|
+
stacked_coordinates = np.stack(
|
|
485
|
+
grouped_coordinates[InferencesRole.primary]
|
|
486
|
+
+ grouped_coordinates[InferencesRole.reference]
|
|
487
|
+
+ grouped_coordinates[AncillaryInferencesRole.corpus]
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
clusters = Hdbscan(
|
|
491
|
+
min_cluster_size=min_cluster_size,
|
|
492
|
+
min_samples=cluster_min_samples,
|
|
493
|
+
cluster_selection_epsilon=cluster_selection_epsilon,
|
|
494
|
+
).find_clusters(stacked_coordinates)
|
|
495
|
+
|
|
496
|
+
clustered_events = {
|
|
497
|
+
str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
|
|
498
|
+
for i, cluster in enumerate(clusters)
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
return to_gql_clusters(
|
|
502
|
+
clustered_events=clustered_events,
|
|
503
|
+
)
|
|
@@ -1,6 +1,40 @@
|
|
|
1
|
-
from
|
|
1
|
+
from typing import Any, Awaitable, Callable, Mapping, Tuple
|
|
2
|
+
|
|
3
|
+
import wrapt
|
|
4
|
+
from starlette import routing
|
|
5
|
+
from starlette.requests import Request
|
|
6
|
+
from starlette.responses import Response
|
|
7
|
+
from starlette.status import HTTP_403_FORBIDDEN
|
|
8
|
+
|
|
9
|
+
from . import (
|
|
10
|
+
datasets,
|
|
11
|
+
evaluations,
|
|
12
|
+
experiment_evaluations,
|
|
13
|
+
experiment_runs,
|
|
14
|
+
experiments,
|
|
15
|
+
spans,
|
|
16
|
+
traces,
|
|
17
|
+
)
|
|
18
|
+
from .dataset_examples import list_dataset_examples
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@wrapt.decorator # type: ignore
|
|
22
|
+
async def forbid_if_readonly(
|
|
23
|
+
wrapped: Callable[[Request], Awaitable[Response]],
|
|
24
|
+
_: Any,
|
|
25
|
+
args: Tuple[Request],
|
|
26
|
+
kwargs: Mapping[str, Any],
|
|
27
|
+
) -> Response:
|
|
28
|
+
request, *_ = args
|
|
29
|
+
if request.app.state.read_only:
|
|
30
|
+
return Response(status_code=HTTP_403_FORBIDDEN)
|
|
31
|
+
return await wrapped(*args, **kwargs)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Route(routing.Route):
|
|
35
|
+
def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None:
|
|
36
|
+
super().__init__(path, forbid_if_readonly(endpoint), **kwargs)
|
|
2
37
|
|
|
3
|
-
from . import evaluations, spans, traces
|
|
4
38
|
|
|
5
39
|
V1_ROUTES = [
|
|
6
40
|
Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
|
|
@@ -8,4 +42,45 @@ V1_ROUTES = [
|
|
|
8
42
|
Route("/v1/traces", traces.post_traces, methods=["POST"]),
|
|
9
43
|
Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
|
|
10
44
|
Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
|
|
45
|
+
Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
|
|
46
|
+
Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
|
|
47
|
+
Route("/v1/datasets/{id:str}", datasets.get_dataset_by_id, methods=["GET"]),
|
|
48
|
+
Route("/v1/datasets/{id:str}/csv", datasets.get_dataset_csv, methods=["GET"]),
|
|
49
|
+
Route(
|
|
50
|
+
"/v1/datasets/{id:str}/jsonl/openai_ft",
|
|
51
|
+
datasets.get_dataset_jsonl_openai_ft,
|
|
52
|
+
methods=["GET"],
|
|
53
|
+
),
|
|
54
|
+
Route(
|
|
55
|
+
"/v1/datasets/{id:str}/jsonl/openai_evals",
|
|
56
|
+
datasets.get_dataset_jsonl_openai_evals,
|
|
57
|
+
methods=["GET"],
|
|
58
|
+
),
|
|
59
|
+
Route("/v1/datasets/{id:str}/examples", list_dataset_examples, methods=["GET"]),
|
|
60
|
+
Route("/v1/datasets/{id:str}/versions", datasets.get_dataset_versions, methods=["GET"]),
|
|
61
|
+
Route(
|
|
62
|
+
"/v1/datasets/{dataset_id:str}/experiments",
|
|
63
|
+
experiments.create_experiment,
|
|
64
|
+
methods=["POST"],
|
|
65
|
+
),
|
|
66
|
+
Route(
|
|
67
|
+
"/v1/experiments/{experiment_id:str}",
|
|
68
|
+
experiments.read_experiment,
|
|
69
|
+
methods=["GET"],
|
|
70
|
+
),
|
|
71
|
+
Route(
|
|
72
|
+
"/v1/experiments/{experiment_id:str}/runs",
|
|
73
|
+
experiment_runs.create_experiment_run,
|
|
74
|
+
methods=["POST"],
|
|
75
|
+
),
|
|
76
|
+
Route(
|
|
77
|
+
"/v1/experiments/{experiment_id:str}/runs",
|
|
78
|
+
experiment_runs.list_experiment_runs,
|
|
79
|
+
methods=["GET"],
|
|
80
|
+
),
|
|
81
|
+
Route(
|
|
82
|
+
"/v1/experiment_evaluations",
|
|
83
|
+
experiment_evaluations.upsert_experiment_evaluation,
|
|
84
|
+
methods=["POST"],
|
|
85
|
+
),
|
|
11
86
|
]
|