arize-phoenix 4.4.4rc5__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/METADATA +5 -5
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/RECORD +56 -117
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/WHEEL +1 -1
- phoenix/__init__.py +27 -0
- phoenix/config.py +7 -21
- phoenix/core/model.py +25 -25
- phoenix/core/model_schema.py +62 -64
- phoenix/core/model_schema_adapter.py +25 -27
- phoenix/db/bulk_inserter.py +14 -54
- phoenix/db/insertion/evaluation.py +6 -6
- phoenix/db/insertion/helpers.py +2 -13
- phoenix/db/migrations/versions/cf03bd6bae1d_init.py +28 -2
- phoenix/db/models.py +4 -236
- phoenix/inferences/fixtures.py +23 -23
- phoenix/inferences/inferences.py +7 -7
- phoenix/inferences/validation.py +1 -1
- phoenix/server/api/context.py +0 -18
- phoenix/server/api/dataloaders/__init__.py +0 -18
- phoenix/server/api/dataloaders/span_descendants.py +3 -2
- phoenix/server/api/routers/v1/__init__.py +2 -77
- phoenix/server/api/routers/v1/evaluations.py +2 -4
- phoenix/server/api/routers/v1/spans.py +1 -3
- phoenix/server/api/routers/v1/traces.py +4 -1
- phoenix/server/api/schema.py +303 -2
- phoenix/server/api/types/Cluster.py +19 -19
- phoenix/server/api/types/Dataset.py +63 -282
- phoenix/server/api/types/DatasetRole.py +23 -0
- phoenix/server/api/types/Dimension.py +29 -30
- phoenix/server/api/types/EmbeddingDimension.py +34 -40
- phoenix/server/api/types/Event.py +16 -16
- phoenix/server/api/{mutations/export_events_mutations.py → types/ExportEventsMutation.py} +14 -17
- phoenix/server/api/types/Model.py +42 -43
- phoenix/server/api/types/Project.py +12 -26
- phoenix/server/api/types/Span.py +2 -79
- phoenix/server/api/types/TimeSeries.py +6 -6
- phoenix/server/api/types/Trace.py +4 -15
- phoenix/server/api/types/UMAPPoints.py +1 -1
- phoenix/server/api/types/node.py +111 -5
- phoenix/server/api/types/pagination.py +52 -10
- phoenix/server/app.py +49 -101
- phoenix/server/main.py +27 -49
- phoenix/server/openapi/docs.py +0 -3
- phoenix/server/static/index.js +2595 -3523
- phoenix/server/templates/index.html +0 -1
- phoenix/services.py +15 -15
- phoenix/session/client.py +21 -438
- phoenix/session/session.py +37 -47
- phoenix/trace/exporter.py +9 -14
- phoenix/trace/fixtures.py +7 -133
- phoenix/trace/schemas.py +2 -1
- phoenix/trace/span_evaluations.py +3 -3
- phoenix/trace/trace_dataset.py +6 -6
- phoenix/version.py +1 -1
- phoenix/datasets/__init__.py +0 -0
- phoenix/datasets/evaluators/__init__.py +0 -18
- phoenix/datasets/evaluators/code_evaluators.py +0 -99
- phoenix/datasets/evaluators/llm_evaluators.py +0 -244
- phoenix/datasets/evaluators/utils.py +0 -292
- phoenix/datasets/experiments.py +0 -550
- phoenix/datasets/tracing.py +0 -85
- phoenix/datasets/types.py +0 -178
- phoenix/db/insertion/dataset.py +0 -237
- phoenix/db/migrations/types.py +0 -29
- phoenix/db/migrations/versions/10460e46d750_datasets.py +0 -291
- phoenix/server/api/dataloaders/dataset_example_revisions.py +0 -100
- phoenix/server/api/dataloaders/dataset_example_spans.py +0 -43
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +0 -85
- phoenix/server/api/dataloaders/experiment_error_rates.py +0 -43
- phoenix/server/api/dataloaders/experiment_run_counts.py +0 -42
- phoenix/server/api/dataloaders/experiment_sequence_number.py +0 -49
- phoenix/server/api/dataloaders/project_by_name.py +0 -31
- phoenix/server/api/dataloaders/span_projects.py +0 -33
- phoenix/server/api/dataloaders/trace_row_ids.py +0 -39
- phoenix/server/api/helpers/dataset_helpers.py +0 -179
- phoenix/server/api/input_types/AddExamplesToDatasetInput.py +0 -16
- phoenix/server/api/input_types/AddSpansToDatasetInput.py +0 -14
- phoenix/server/api/input_types/ClearProjectInput.py +0 -15
- phoenix/server/api/input_types/CreateDatasetInput.py +0 -12
- phoenix/server/api/input_types/DatasetExampleInput.py +0 -14
- phoenix/server/api/input_types/DatasetSort.py +0 -17
- phoenix/server/api/input_types/DatasetVersionSort.py +0 -16
- phoenix/server/api/input_types/DeleteDatasetExamplesInput.py +0 -13
- phoenix/server/api/input_types/DeleteDatasetInput.py +0 -7
- phoenix/server/api/input_types/DeleteExperimentsInput.py +0 -9
- phoenix/server/api/input_types/PatchDatasetExamplesInput.py +0 -35
- phoenix/server/api/input_types/PatchDatasetInput.py +0 -14
- phoenix/server/api/mutations/__init__.py +0 -13
- phoenix/server/api/mutations/auth.py +0 -11
- phoenix/server/api/mutations/dataset_mutations.py +0 -520
- phoenix/server/api/mutations/experiment_mutations.py +0 -65
- phoenix/server/api/mutations/project_mutations.py +0 -47
- phoenix/server/api/openapi/__init__.py +0 -0
- phoenix/server/api/openapi/main.py +0 -6
- phoenix/server/api/openapi/schema.py +0 -16
- phoenix/server/api/queries.py +0 -503
- phoenix/server/api/routers/v1/dataset_examples.py +0 -178
- phoenix/server/api/routers/v1/datasets.py +0 -965
- phoenix/server/api/routers/v1/experiment_evaluations.py +0 -66
- phoenix/server/api/routers/v1/experiment_runs.py +0 -108
- phoenix/server/api/routers/v1/experiments.py +0 -174
- phoenix/server/api/types/AnnotatorKind.py +0 -10
- phoenix/server/api/types/CreateDatasetPayload.py +0 -8
- phoenix/server/api/types/DatasetExample.py +0 -85
- phoenix/server/api/types/DatasetExampleRevision.py +0 -34
- phoenix/server/api/types/DatasetVersion.py +0 -14
- phoenix/server/api/types/ExampleRevisionInterface.py +0 -14
- phoenix/server/api/types/Experiment.py +0 -140
- phoenix/server/api/types/ExperimentAnnotationSummary.py +0 -13
- phoenix/server/api/types/ExperimentComparison.py +0 -19
- phoenix/server/api/types/ExperimentRun.py +0 -91
- phoenix/server/api/types/ExperimentRunAnnotation.py +0 -57
- phoenix/server/api/types/Inferences.py +0 -80
- phoenix/server/api/types/InferencesRole.py +0 -23
- phoenix/utilities/json.py +0 -61
- phoenix/utilities/re.py +0 -50
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.4.4rc5.dist-info → arize_phoenix-4.5.0.dist-info}/licenses/LICENSE +0 -0
- /phoenix/server/api/{helpers/__init__.py → helpers.py} +0 -0
|
@@ -1,40 +1,6 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
import wrapt
|
|
4
|
-
from starlette import routing
|
|
5
|
-
from starlette.requests import Request
|
|
6
|
-
from starlette.responses import Response
|
|
7
|
-
from starlette.status import HTTP_403_FORBIDDEN
|
|
8
|
-
|
|
9
|
-
from . import (
|
|
10
|
-
datasets,
|
|
11
|
-
evaluations,
|
|
12
|
-
experiment_evaluations,
|
|
13
|
-
experiment_runs,
|
|
14
|
-
experiments,
|
|
15
|
-
spans,
|
|
16
|
-
traces,
|
|
17
|
-
)
|
|
18
|
-
from .dataset_examples import list_dataset_examples
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@wrapt.decorator # type: ignore
|
|
22
|
-
async def forbid_if_readonly(
|
|
23
|
-
wrapped: Callable[[Request], Awaitable[Response]],
|
|
24
|
-
_: Any,
|
|
25
|
-
args: Tuple[Request],
|
|
26
|
-
kwargs: Mapping[str, Any],
|
|
27
|
-
) -> Response:
|
|
28
|
-
request, *_ = args
|
|
29
|
-
if request.app.state.read_only:
|
|
30
|
-
return Response(status_code=HTTP_403_FORBIDDEN)
|
|
31
|
-
return await wrapped(*args, **kwargs)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class Route(routing.Route):
|
|
35
|
-
def __init__(self, path: str, endpoint: Callable[..., Any], **kwargs: Any) -> None:
|
|
36
|
-
super().__init__(path, forbid_if_readonly(endpoint), **kwargs)
|
|
1
|
+
from starlette.routing import Route
|
|
37
2
|
|
|
3
|
+
from . import evaluations, spans, traces
|
|
38
4
|
|
|
39
5
|
V1_ROUTES = [
|
|
40
6
|
Route("/v1/evaluations", evaluations.post_evaluations, methods=["POST"]),
|
|
@@ -42,45 +8,4 @@ V1_ROUTES = [
|
|
|
42
8
|
Route("/v1/traces", traces.post_traces, methods=["POST"]),
|
|
43
9
|
Route("/v1/spans", spans.query_spans_handler, methods=["POST"]),
|
|
44
10
|
Route("/v1/spans", spans.get_spans_handler, methods=["GET"]),
|
|
45
|
-
Route("/v1/datasets/upload", datasets.post_datasets_upload, methods=["POST"]),
|
|
46
|
-
Route("/v1/datasets", datasets.list_datasets, methods=["GET"]),
|
|
47
|
-
Route("/v1/datasets/{id:str}", datasets.get_dataset_by_id, methods=["GET"]),
|
|
48
|
-
Route("/v1/datasets/{id:str}/csv", datasets.get_dataset_csv, methods=["GET"]),
|
|
49
|
-
Route(
|
|
50
|
-
"/v1/datasets/{id:str}/jsonl/openai_ft",
|
|
51
|
-
datasets.get_dataset_jsonl_openai_ft,
|
|
52
|
-
methods=["GET"],
|
|
53
|
-
),
|
|
54
|
-
Route(
|
|
55
|
-
"/v1/datasets/{id:str}/jsonl/openai_evals",
|
|
56
|
-
datasets.get_dataset_jsonl_openai_evals,
|
|
57
|
-
methods=["GET"],
|
|
58
|
-
),
|
|
59
|
-
Route("/v1/datasets/{id:str}/examples", list_dataset_examples, methods=["GET"]),
|
|
60
|
-
Route("/v1/datasets/{id:str}/versions", datasets.get_dataset_versions, methods=["GET"]),
|
|
61
|
-
Route(
|
|
62
|
-
"/v1/datasets/{dataset_id:str}/experiments",
|
|
63
|
-
experiments.create_experiment,
|
|
64
|
-
methods=["POST"],
|
|
65
|
-
),
|
|
66
|
-
Route(
|
|
67
|
-
"/v1/experiments/{experiment_id:str}",
|
|
68
|
-
experiments.read_experiment,
|
|
69
|
-
methods=["GET"],
|
|
70
|
-
),
|
|
71
|
-
Route(
|
|
72
|
-
"/v1/experiments/{experiment_id:str}/runs",
|
|
73
|
-
experiment_runs.create_experiment_run,
|
|
74
|
-
methods=["POST"],
|
|
75
|
-
),
|
|
76
|
-
Route(
|
|
77
|
-
"/v1/experiments/{experiment_id:str}/runs",
|
|
78
|
-
experiment_runs.list_experiment_runs,
|
|
79
|
-
methods=["GET"],
|
|
80
|
-
),
|
|
81
|
-
Route(
|
|
82
|
-
"/v1/experiment_evaluations",
|
|
83
|
-
experiment_evaluations.create_experiment_evaluation,
|
|
84
|
-
methods=["POST"],
|
|
85
|
-
),
|
|
86
11
|
]
|
|
@@ -44,7 +44,7 @@ async def post_evaluations(request: Request) -> Response:
|
|
|
44
44
|
summary: Add evaluations to a span, trace, or document
|
|
45
45
|
operationId: addEvaluations
|
|
46
46
|
tags:
|
|
47
|
-
-
|
|
47
|
+
- evaluations
|
|
48
48
|
parameters:
|
|
49
49
|
- name: project-name
|
|
50
50
|
in: query
|
|
@@ -105,7 +105,7 @@ async def get_evaluations(request: Request) -> Response:
|
|
|
105
105
|
summary: Get evaluations from Phoenix
|
|
106
106
|
operationId: getEvaluation
|
|
107
107
|
tags:
|
|
108
|
-
-
|
|
108
|
+
- evaluations
|
|
109
109
|
parameters:
|
|
110
110
|
- name: project-name
|
|
111
111
|
in: query
|
|
@@ -116,8 +116,6 @@ async def get_evaluations(request: Request) -> Response:
|
|
|
116
116
|
responses:
|
|
117
117
|
200:
|
|
118
118
|
description: Success
|
|
119
|
-
403:
|
|
120
|
-
description: Forbidden
|
|
121
119
|
404:
|
|
122
120
|
description: Not found
|
|
123
121
|
"""
|
|
@@ -19,7 +19,7 @@ async def query_spans_handler(request: Request) -> Response:
|
|
|
19
19
|
summary: Query spans using query DSL
|
|
20
20
|
operationId: querySpans
|
|
21
21
|
tags:
|
|
22
|
-
-
|
|
22
|
+
- spans
|
|
23
23
|
parameters:
|
|
24
24
|
- name: project-name
|
|
25
25
|
in: query
|
|
@@ -68,8 +68,6 @@ async def query_spans_handler(request: Request) -> Response:
|
|
|
68
68
|
responses:
|
|
69
69
|
200:
|
|
70
70
|
description: Success
|
|
71
|
-
403:
|
|
72
|
-
description: Forbidden
|
|
73
71
|
404:
|
|
74
72
|
description: Not found
|
|
75
73
|
422:
|
|
@@ -11,6 +11,7 @@ from starlette.datastructures import State
|
|
|
11
11
|
from starlette.requests import Request
|
|
12
12
|
from starlette.responses import Response
|
|
13
13
|
from starlette.status import (
|
|
14
|
+
HTTP_403_FORBIDDEN,
|
|
14
15
|
HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
|
15
16
|
HTTP_422_UNPROCESSABLE_ENTITY,
|
|
16
17
|
)
|
|
@@ -24,7 +25,7 @@ async def post_traces(request: Request) -> Response:
|
|
|
24
25
|
summary: Send traces to Phoenix
|
|
25
26
|
operationId: addTraces
|
|
26
27
|
tags:
|
|
27
|
-
-
|
|
28
|
+
- traces
|
|
28
29
|
requestBody:
|
|
29
30
|
required: true
|
|
30
31
|
content:
|
|
@@ -42,6 +43,8 @@ async def post_traces(request: Request) -> Response:
|
|
|
42
43
|
422:
|
|
43
44
|
description: Request body is invalid
|
|
44
45
|
"""
|
|
46
|
+
if request.app.state.read_only:
|
|
47
|
+
return Response(status_code=HTTP_403_FORBIDDEN)
|
|
45
48
|
content_type = request.headers.get("content-type")
|
|
46
49
|
if content_type != "application/x-protobuf":
|
|
47
50
|
return Response(
|
phoenix/server/api/schema.py
CHANGED
|
@@ -1,7 +1,308 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Dict, List, Optional, Set, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import numpy.typing as npt
|
|
1
6
|
import strawberry
|
|
7
|
+
from sqlalchemy import delete, select
|
|
8
|
+
from sqlalchemy.orm import contains_eager, load_only
|
|
9
|
+
from strawberry import ID, UNSET
|
|
10
|
+
from strawberry.types import Info
|
|
11
|
+
from typing_extensions import Annotated
|
|
12
|
+
|
|
13
|
+
from phoenix.config import DEFAULT_PROJECT_NAME
|
|
14
|
+
from phoenix.db import models
|
|
15
|
+
from phoenix.db.insertion.span import ClearProjectSpansEvent
|
|
16
|
+
from phoenix.pointcloud.clustering import Hdbscan
|
|
17
|
+
from phoenix.server.api.context import Context
|
|
18
|
+
from phoenix.server.api.helpers import ensure_list
|
|
19
|
+
from phoenix.server.api.input_types.ClusterInput import ClusterInput
|
|
20
|
+
from phoenix.server.api.input_types.Coordinates import (
|
|
21
|
+
InputCoordinate2D,
|
|
22
|
+
InputCoordinate3D,
|
|
23
|
+
)
|
|
24
|
+
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
|
|
25
|
+
from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
|
|
26
|
+
from phoenix.server.api.types.Dimension import to_gql_dimension
|
|
27
|
+
from phoenix.server.api.types.EmbeddingDimension import (
|
|
28
|
+
DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
29
|
+
DEFAULT_MIN_CLUSTER_SIZE,
|
|
30
|
+
DEFAULT_MIN_SAMPLES,
|
|
31
|
+
to_gql_embedding_dimension,
|
|
32
|
+
)
|
|
33
|
+
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
|
|
34
|
+
from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
|
|
35
|
+
from phoenix.server.api.types.Functionality import Functionality
|
|
36
|
+
from phoenix.server.api.types.Model import Model
|
|
37
|
+
from phoenix.server.api.types.node import (
|
|
38
|
+
GlobalID,
|
|
39
|
+
Node,
|
|
40
|
+
from_global_id,
|
|
41
|
+
from_global_id_with_expected_type,
|
|
42
|
+
)
|
|
43
|
+
from phoenix.server.api.types.pagination import (
|
|
44
|
+
Connection,
|
|
45
|
+
ConnectionArgs,
|
|
46
|
+
CursorString,
|
|
47
|
+
connection_from_list,
|
|
48
|
+
)
|
|
49
|
+
from phoenix.server.api.types.Project import Project
|
|
50
|
+
from phoenix.server.api.types.Span import to_gql_span
|
|
51
|
+
from phoenix.server.api.types.Trace import Trace
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@strawberry.type
|
|
55
|
+
class Query:
|
|
56
|
+
@strawberry.field
|
|
57
|
+
async def projects(
|
|
58
|
+
self,
|
|
59
|
+
info: Info[Context, None],
|
|
60
|
+
first: Optional[int] = 50,
|
|
61
|
+
last: Optional[int] = UNSET,
|
|
62
|
+
after: Optional[CursorString] = UNSET,
|
|
63
|
+
before: Optional[CursorString] = UNSET,
|
|
64
|
+
) -> Connection[Project]:
|
|
65
|
+
args = ConnectionArgs(
|
|
66
|
+
first=first,
|
|
67
|
+
after=after if isinstance(after, CursorString) else None,
|
|
68
|
+
last=last,
|
|
69
|
+
before=before if isinstance(before, CursorString) else None,
|
|
70
|
+
)
|
|
71
|
+
async with info.context.db() as session:
|
|
72
|
+
projects = await session.scalars(select(models.Project))
|
|
73
|
+
data = [
|
|
74
|
+
Project(
|
|
75
|
+
id_attr=project.id,
|
|
76
|
+
name=project.name,
|
|
77
|
+
gradient_start_color=project.gradient_start_color,
|
|
78
|
+
gradient_end_color=project.gradient_end_color,
|
|
79
|
+
)
|
|
80
|
+
for project in projects
|
|
81
|
+
]
|
|
82
|
+
return connection_from_list(data=data, args=args)
|
|
83
|
+
|
|
84
|
+
@strawberry.field
|
|
85
|
+
async def functionality(self, info: Info[Context, None]) -> "Functionality":
|
|
86
|
+
has_model_inferences = not info.context.model.is_empty
|
|
87
|
+
async with info.context.db() as session:
|
|
88
|
+
has_traces = (await session.scalar(select(models.Trace).limit(1))) is not None
|
|
89
|
+
return Functionality(
|
|
90
|
+
model_inferences=has_model_inferences,
|
|
91
|
+
tracing=has_traces,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@strawberry.field
|
|
95
|
+
def model(self) -> Model:
|
|
96
|
+
return Model()
|
|
97
|
+
|
|
98
|
+
@strawberry.field
|
|
99
|
+
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
|
|
100
|
+
type_name, node_id = from_global_id(str(id))
|
|
101
|
+
if type_name == "Dimension":
|
|
102
|
+
dimension = info.context.model.scalar_dimensions[node_id]
|
|
103
|
+
return to_gql_dimension(node_id, dimension)
|
|
104
|
+
elif type_name == "EmbeddingDimension":
|
|
105
|
+
embedding_dimension = info.context.model.embedding_dimensions[node_id]
|
|
106
|
+
return to_gql_embedding_dimension(node_id, embedding_dimension)
|
|
107
|
+
elif type_name == "Project":
|
|
108
|
+
project_stmt = select(
|
|
109
|
+
models.Project.id,
|
|
110
|
+
models.Project.name,
|
|
111
|
+
models.Project.gradient_start_color,
|
|
112
|
+
models.Project.gradient_end_color,
|
|
113
|
+
).where(models.Project.id == node_id)
|
|
114
|
+
async with info.context.db() as session:
|
|
115
|
+
project = (await session.execute(project_stmt)).first()
|
|
116
|
+
if project is None:
|
|
117
|
+
raise ValueError(f"Unknown project: {id}")
|
|
118
|
+
return Project(
|
|
119
|
+
id_attr=project.id,
|
|
120
|
+
name=project.name,
|
|
121
|
+
gradient_start_color=project.gradient_start_color,
|
|
122
|
+
gradient_end_color=project.gradient_end_color,
|
|
123
|
+
)
|
|
124
|
+
elif type_name == "Trace":
|
|
125
|
+
trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
|
|
126
|
+
async with info.context.db() as session:
|
|
127
|
+
id_attr = await session.scalar(trace_stmt)
|
|
128
|
+
if id_attr is None:
|
|
129
|
+
raise ValueError(f"Unknown trace: {id}")
|
|
130
|
+
return Trace(id_attr=id_attr)
|
|
131
|
+
elif type_name == "Span":
|
|
132
|
+
span_stmt = (
|
|
133
|
+
select(models.Span)
|
|
134
|
+
.join(models.Trace)
|
|
135
|
+
.options(contains_eager(models.Span.trace))
|
|
136
|
+
.where(models.Span.id == node_id)
|
|
137
|
+
)
|
|
138
|
+
async with info.context.db() as session:
|
|
139
|
+
span = await session.scalar(span_stmt)
|
|
140
|
+
if span is None:
|
|
141
|
+
raise ValueError(f"Unknown span: {id}")
|
|
142
|
+
return to_gql_span(span)
|
|
143
|
+
raise Exception(f"Unknown node type: {type_name}")
|
|
144
|
+
|
|
145
|
+
@strawberry.field
|
|
146
|
+
def clusters(
|
|
147
|
+
self,
|
|
148
|
+
clusters: List[ClusterInput],
|
|
149
|
+
) -> List[Cluster]:
|
|
150
|
+
clustered_events: Dict[str, Set[ID]] = defaultdict(set)
|
|
151
|
+
for i, cluster in enumerate(clusters):
|
|
152
|
+
clustered_events[cluster.id or str(i)].update(cluster.event_ids)
|
|
153
|
+
return to_gql_clusters(
|
|
154
|
+
clustered_events=clustered_events,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@strawberry.field
|
|
158
|
+
def hdbscan_clustering(
|
|
159
|
+
self,
|
|
160
|
+
info: Info[Context, None],
|
|
161
|
+
event_ids: Annotated[
|
|
162
|
+
List[ID],
|
|
163
|
+
strawberry.argument(
|
|
164
|
+
description="Event ID of the coordinates",
|
|
165
|
+
),
|
|
166
|
+
],
|
|
167
|
+
coordinates_2d: Annotated[
|
|
168
|
+
Optional[List[InputCoordinate2D]],
|
|
169
|
+
strawberry.argument(
|
|
170
|
+
description="Point coordinates. Must be either 2D or 3D.",
|
|
171
|
+
),
|
|
172
|
+
] = UNSET,
|
|
173
|
+
coordinates_3d: Annotated[
|
|
174
|
+
Optional[List[InputCoordinate3D]],
|
|
175
|
+
strawberry.argument(
|
|
176
|
+
description="Point coordinates. Must be either 2D or 3D.",
|
|
177
|
+
),
|
|
178
|
+
] = UNSET,
|
|
179
|
+
min_cluster_size: Annotated[
|
|
180
|
+
int,
|
|
181
|
+
strawberry.argument(
|
|
182
|
+
description="HDBSCAN minimum cluster size",
|
|
183
|
+
),
|
|
184
|
+
] = DEFAULT_MIN_CLUSTER_SIZE,
|
|
185
|
+
cluster_min_samples: Annotated[
|
|
186
|
+
int,
|
|
187
|
+
strawberry.argument(
|
|
188
|
+
description="HDBSCAN minimum samples",
|
|
189
|
+
),
|
|
190
|
+
] = DEFAULT_MIN_SAMPLES,
|
|
191
|
+
cluster_selection_epsilon: Annotated[
|
|
192
|
+
float,
|
|
193
|
+
strawberry.argument(
|
|
194
|
+
description="HDBSCAN cluster selection epsilon",
|
|
195
|
+
),
|
|
196
|
+
] = DEFAULT_CLUSTER_SELECTION_EPSILON,
|
|
197
|
+
) -> List[Cluster]:
|
|
198
|
+
coordinates_3d = ensure_list(coordinates_3d)
|
|
199
|
+
coordinates_2d = ensure_list(coordinates_2d)
|
|
200
|
+
|
|
201
|
+
if len(coordinates_3d) > 0 and len(coordinates_2d) > 0:
|
|
202
|
+
raise ValueError("must specify only one of 2D or 3D coordinates")
|
|
203
|
+
|
|
204
|
+
if len(coordinates_3d) > 0:
|
|
205
|
+
coordinates = list(
|
|
206
|
+
map(
|
|
207
|
+
lambda coord: np.array(
|
|
208
|
+
[coord.x, coord.y, coord.z],
|
|
209
|
+
),
|
|
210
|
+
coordinates_3d,
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
else:
|
|
214
|
+
coordinates = list(
|
|
215
|
+
map(
|
|
216
|
+
lambda coord: np.array(
|
|
217
|
+
[coord.x, coord.y],
|
|
218
|
+
),
|
|
219
|
+
coordinates_2d,
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if len(event_ids) != len(coordinates):
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"length mismatch between "
|
|
226
|
+
f"event_ids ({len(event_ids)}) "
|
|
227
|
+
f"and coordinates ({len(coordinates)})"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
if len(event_ids) == 0:
|
|
231
|
+
return []
|
|
232
|
+
|
|
233
|
+
grouped_event_ids: Dict[
|
|
234
|
+
Union[DatasetRole, AncillaryDatasetRole],
|
|
235
|
+
List[ID],
|
|
236
|
+
] = defaultdict(list)
|
|
237
|
+
grouped_coordinates: Dict[
|
|
238
|
+
Union[DatasetRole, AncillaryDatasetRole],
|
|
239
|
+
List[npt.NDArray[np.float64]],
|
|
240
|
+
] = defaultdict(list)
|
|
241
|
+
|
|
242
|
+
for event_id, coordinate in zip(event_ids, coordinates):
|
|
243
|
+
row_id, dataset_role = unpack_event_id(event_id)
|
|
244
|
+
grouped_coordinates[dataset_role].append(coordinate)
|
|
245
|
+
grouped_event_ids[dataset_role].append(create_event_id(row_id, dataset_role))
|
|
246
|
+
|
|
247
|
+
stacked_event_ids = (
|
|
248
|
+
grouped_event_ids[DatasetRole.primary]
|
|
249
|
+
+ grouped_event_ids[DatasetRole.reference]
|
|
250
|
+
+ grouped_event_ids[AncillaryDatasetRole.corpus]
|
|
251
|
+
)
|
|
252
|
+
stacked_coordinates = np.stack(
|
|
253
|
+
grouped_coordinates[DatasetRole.primary]
|
|
254
|
+
+ grouped_coordinates[DatasetRole.reference]
|
|
255
|
+
+ grouped_coordinates[AncillaryDatasetRole.corpus]
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
clusters = Hdbscan(
|
|
259
|
+
min_cluster_size=min_cluster_size,
|
|
260
|
+
min_samples=cluster_min_samples,
|
|
261
|
+
cluster_selection_epsilon=cluster_selection_epsilon,
|
|
262
|
+
).find_clusters(stacked_coordinates)
|
|
263
|
+
|
|
264
|
+
clustered_events = {
|
|
265
|
+
str(i): {stacked_event_ids[row_idx] for row_idx in cluster}
|
|
266
|
+
for i, cluster in enumerate(clusters)
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
return to_gql_clusters(
|
|
270
|
+
clustered_events=clustered_events,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@strawberry.type
|
|
275
|
+
class Mutation(ExportEventsMutation):
|
|
276
|
+
@strawberry.mutation
|
|
277
|
+
async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
|
|
278
|
+
if info.context.read_only:
|
|
279
|
+
return Query()
|
|
280
|
+
node_id = from_global_id_with_expected_type(str(id), "Project")
|
|
281
|
+
async with info.context.db() as session:
|
|
282
|
+
project = await session.scalar(
|
|
283
|
+
select(models.Project)
|
|
284
|
+
.where(models.Project.id == node_id)
|
|
285
|
+
.options(load_only(models.Project.name))
|
|
286
|
+
)
|
|
287
|
+
if project is None:
|
|
288
|
+
raise ValueError(f"Unknown project: {id}")
|
|
289
|
+
if project.name == DEFAULT_PROJECT_NAME:
|
|
290
|
+
raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
|
|
291
|
+
await session.delete(project)
|
|
292
|
+
return Query()
|
|
293
|
+
|
|
294
|
+
@strawberry.mutation
|
|
295
|
+
async def clear_project(self, info: Info[Context, None], id: GlobalID) -> Query:
|
|
296
|
+
if info.context.read_only:
|
|
297
|
+
return Query()
|
|
298
|
+
project_id = from_global_id_with_expected_type(str(id), "Project")
|
|
299
|
+
delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id)
|
|
300
|
+
async with info.context.db() as session:
|
|
301
|
+
await session.execute(delete_statement)
|
|
302
|
+
if cache := info.context.cache_for_dataloaders:
|
|
303
|
+
cache.invalidate(ClearProjectSpansEvent(project_rowid=project_id))
|
|
304
|
+
return Query()
|
|
2
305
|
|
|
3
|
-
from phoenix.server.api.mutations import Mutation
|
|
4
|
-
from phoenix.server.api.queries import Query
|
|
5
306
|
|
|
6
307
|
# This is the schema for generating `schema.graphql`.
|
|
7
308
|
# See https://strawberry.rocks/docs/guides/schema-export
|
|
@@ -9,9 +9,9 @@ from phoenix.core.model_schema import PRIMARY, REFERENCE
|
|
|
9
9
|
from phoenix.server.api.context import Context
|
|
10
10
|
from phoenix.server.api.input_types.DataQualityMetricInput import DataQualityMetricInput
|
|
11
11
|
from phoenix.server.api.input_types.PerformanceMetricInput import PerformanceMetricInput
|
|
12
|
+
from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
|
|
12
13
|
from phoenix.server.api.types.DatasetValues import DatasetValues
|
|
13
14
|
from phoenix.server.api.types.Event import unpack_event_id
|
|
14
|
-
from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@strawberry.type
|
|
@@ -36,8 +36,8 @@ class Cluster:
|
|
|
36
36
|
"""
|
|
37
37
|
Calculates the drift score of the cluster. The score will be a value
|
|
38
38
|
representing the balance of points between the primary and the reference
|
|
39
|
-
|
|
40
|
-
reference), with 0 being an even balance between the two
|
|
39
|
+
datasets, and will be on a scale between 1 (all primary) and -1 (all
|
|
40
|
+
reference), with 0 being an even balance between the two datasets.
|
|
41
41
|
|
|
42
42
|
Returns
|
|
43
43
|
-------
|
|
@@ -47,8 +47,8 @@ class Cluster:
|
|
|
47
47
|
if model[REFERENCE].empty:
|
|
48
48
|
return None
|
|
49
49
|
count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
|
|
50
|
-
primary_count = count_by_role[
|
|
51
|
-
reference_count = count_by_role[
|
|
50
|
+
primary_count = count_by_role[DatasetRole.primary]
|
|
51
|
+
reference_count = count_by_role[DatasetRole.reference]
|
|
52
52
|
return (
|
|
53
53
|
None
|
|
54
54
|
if not (denominator := (primary_count + reference_count))
|
|
@@ -76,8 +76,8 @@ class Cluster:
|
|
|
76
76
|
if corpus is None or corpus[PRIMARY].empty:
|
|
77
77
|
return None
|
|
78
78
|
count_by_role = Counter(unpack_event_id(event_id)[1] for event_id in self.event_ids)
|
|
79
|
-
primary_count = count_by_role[
|
|
80
|
-
corpus_count = count_by_role[
|
|
79
|
+
primary_count = count_by_role[DatasetRole.primary]
|
|
80
|
+
corpus_count = count_by_role[AncillaryDatasetRole.corpus]
|
|
81
81
|
return (
|
|
82
82
|
None
|
|
83
83
|
if not (denominator := (primary_count + corpus_count))
|
|
@@ -94,19 +94,19 @@ class Cluster:
|
|
|
94
94
|
metric: DataQualityMetricInput,
|
|
95
95
|
) -> DatasetValues:
|
|
96
96
|
model = info.context.model
|
|
97
|
-
row_ids: Dict[
|
|
98
|
-
for row_id,
|
|
99
|
-
if not isinstance(
|
|
97
|
+
row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
|
|
98
|
+
for row_id, dataset_role in map(unpack_event_id, self.event_ids):
|
|
99
|
+
if not isinstance(dataset_role, DatasetRole):
|
|
100
100
|
continue
|
|
101
|
-
row_ids[
|
|
101
|
+
row_ids[dataset_role].append(row_id)
|
|
102
102
|
return DatasetValues(
|
|
103
103
|
primary_value=metric.metric_instance(
|
|
104
104
|
model[PRIMARY],
|
|
105
|
-
subset_rows=row_ids[
|
|
105
|
+
subset_rows=row_ids[DatasetRole.primary],
|
|
106
106
|
),
|
|
107
107
|
reference_value=metric.metric_instance(
|
|
108
108
|
model[REFERENCE],
|
|
109
|
-
subset_rows=row_ids[
|
|
109
|
+
subset_rows=row_ids[DatasetRole.reference],
|
|
110
110
|
),
|
|
111
111
|
)
|
|
112
112
|
|
|
@@ -120,20 +120,20 @@ class Cluster:
|
|
|
120
120
|
metric: PerformanceMetricInput,
|
|
121
121
|
) -> DatasetValues:
|
|
122
122
|
model = info.context.model
|
|
123
|
-
row_ids: Dict[
|
|
124
|
-
for row_id,
|
|
125
|
-
if not isinstance(
|
|
123
|
+
row_ids: Dict[DatasetRole, List[int]] = defaultdict(list)
|
|
124
|
+
for row_id, dataset_role in map(unpack_event_id, self.event_ids):
|
|
125
|
+
if not isinstance(dataset_role, DatasetRole):
|
|
126
126
|
continue
|
|
127
|
-
row_ids[
|
|
127
|
+
row_ids[dataset_role].append(row_id)
|
|
128
128
|
metric_instance = metric.metric_instance(model)
|
|
129
129
|
return DatasetValues(
|
|
130
130
|
primary_value=metric_instance(
|
|
131
131
|
model[PRIMARY],
|
|
132
|
-
subset_rows=row_ids[
|
|
132
|
+
subset_rows=row_ids[DatasetRole.primary],
|
|
133
133
|
),
|
|
134
134
|
reference_value=metric_instance(
|
|
135
135
|
model[REFERENCE],
|
|
136
|
-
subset_rows=row_ids[
|
|
136
|
+
subset_rows=row_ids[DatasetRole.reference],
|
|
137
137
|
),
|
|
138
138
|
)
|
|
139
139
|
|