arize-phoenix 4.12.1rc1__py3-none-any.whl → 4.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/METADATA +10 -6
- {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/RECORD +70 -68
- phoenix/db/bulk_inserter.py +5 -4
- phoenix/db/engines.py +2 -1
- phoenix/experiments/evaluators/base.py +4 -0
- phoenix/experiments/evaluators/code_evaluators.py +80 -0
- phoenix/experiments/evaluators/llm_evaluators.py +77 -1
- phoenix/experiments/evaluators/utils.py +70 -21
- phoenix/experiments/functions.py +17 -16
- phoenix/server/api/context.py +5 -3
- phoenix/server/api/dataloaders/__init__.py +2 -0
- phoenix/server/api/dataloaders/average_experiment_run_latency.py +25 -25
- phoenix/server/api/dataloaders/dataset_example_revisions.py +2 -4
- phoenix/server/api/dataloaders/dataset_example_spans.py +2 -4
- phoenix/server/api/dataloaders/document_evaluation_summaries.py +2 -4
- phoenix/server/api/dataloaders/document_evaluations.py +2 -4
- phoenix/server/api/dataloaders/document_retrieval_metrics.py +2 -4
- phoenix/server/api/dataloaders/evaluation_summaries.py +2 -4
- phoenix/server/api/dataloaders/experiment_annotation_summaries.py +2 -4
- phoenix/server/api/dataloaders/experiment_error_rates.py +32 -14
- phoenix/server/api/dataloaders/experiment_run_counts.py +20 -9
- phoenix/server/api/dataloaders/experiment_sequence_number.py +2 -4
- phoenix/server/api/dataloaders/latency_ms_quantile.py +2 -3
- phoenix/server/api/dataloaders/min_start_or_max_end_times.py +2 -4
- phoenix/server/api/dataloaders/project_by_name.py +3 -3
- phoenix/server/api/dataloaders/record_counts.py +2 -4
- phoenix/server/api/dataloaders/span_annotations.py +2 -4
- phoenix/server/api/dataloaders/span_dataset_examples.py +36 -0
- phoenix/server/api/dataloaders/span_descendants.py +2 -4
- phoenix/server/api/dataloaders/span_evaluations.py +2 -4
- phoenix/server/api/dataloaders/span_projects.py +3 -3
- phoenix/server/api/dataloaders/token_counts.py +2 -4
- phoenix/server/api/dataloaders/trace_evaluations.py +2 -4
- phoenix/server/api/dataloaders/trace_row_ids.py +2 -4
- phoenix/server/api/input_types/{CreateSpanAnnotationsInput.py → CreateSpanAnnotationInput.py} +4 -2
- phoenix/server/api/input_types/{CreateTraceAnnotationsInput.py → CreateTraceAnnotationInput.py} +4 -2
- phoenix/server/api/input_types/{PatchAnnotationsInput.py → PatchAnnotationInput.py} +4 -2
- phoenix/server/api/mutations/span_annotations_mutations.py +20 -9
- phoenix/server/api/mutations/trace_annotations_mutations.py +20 -9
- phoenix/server/api/routers/v1/datasets.py +132 -10
- phoenix/server/api/routers/v1/evaluations.py +3 -5
- phoenix/server/api/routers/v1/experiments.py +1 -1
- phoenix/server/api/types/Experiment.py +2 -2
- phoenix/server/api/types/Inferences.py +1 -2
- phoenix/server/api/types/Model.py +1 -2
- phoenix/server/api/types/Span.py +5 -0
- phoenix/server/api/utils.py +4 -4
- phoenix/server/app.py +21 -18
- phoenix/server/grpc_server.py +2 -2
- phoenix/server/main.py +5 -9
- phoenix/server/static/.vite/manifest.json +31 -31
- phoenix/server/static/assets/{components-C8sm_r1F.js → components-kGgeFkHp.js} +150 -110
- phoenix/server/static/assets/index-BctFO6S7.js +100 -0
- phoenix/server/static/assets/{pages-bN7juCjh.js → pages-DabDCmVd.js} +432 -255
- phoenix/server/static/assets/{vendor-CUDAPm8e.js → vendor-CP0b0YG0.js} +2 -2
- phoenix/server/static/assets/{vendor-arizeai-Do2HOmcL.js → vendor-arizeai-B5Hti8OB.js} +27 -27
- phoenix/server/static/assets/vendor-codemirror-DtdPDzrv.js +15 -0
- phoenix/server/static/assets/{vendor-recharts-PKRvByVe.js → vendor-recharts-A0DA1O99.js} +1 -1
- phoenix/server/types.py +18 -0
- phoenix/session/client.py +9 -6
- phoenix/session/session.py +2 -2
- phoenix/trace/dsl/filter.py +40 -25
- phoenix/trace/fixtures.py +17 -23
- phoenix/trace/utils.py +23 -0
- phoenix/utilities/client.py +116 -0
- phoenix/utilities/project.py +1 -1
- phoenix/version.py +1 -1
- phoenix/server/api/routers/v1/dataset_examples.py +0 -157
- phoenix/server/static/assets/index-BEKPzgQs.js +0 -100
- phoenix/server/static/assets/vendor-codemirror-CrdxOlMs.js +0 -12
- {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/WHEEL +0 -0
- {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/licenses/IP_NOTICE +0 -0
- {arize_phoenix-4.12.1rc1.dist-info → arize_phoenix-4.15.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,10 +7,11 @@ from strawberry.types import Info
|
|
|
7
7
|
|
|
8
8
|
from phoenix.db import models
|
|
9
9
|
from phoenix.server.api.context import Context
|
|
10
|
-
from phoenix.server.api.input_types.
|
|
10
|
+
from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTraceAnnotationInput
|
|
11
11
|
from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput
|
|
12
|
-
from phoenix.server.api.input_types.
|
|
12
|
+
from phoenix.server.api.input_types.PatchAnnotationInput import PatchAnnotationInput
|
|
13
13
|
from phoenix.server.api.mutations.auth import IsAuthenticated
|
|
14
|
+
from phoenix.server.api.queries import Query
|
|
14
15
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
15
16
|
from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation
|
|
16
17
|
|
|
@@ -18,13 +19,14 @@ from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_tra
|
|
|
18
19
|
@strawberry.type
|
|
19
20
|
class TraceAnnotationMutationPayload:
|
|
20
21
|
trace_annotations: List[TraceAnnotation]
|
|
22
|
+
query: Query
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
@strawberry.type
|
|
24
26
|
class TraceAnnotationMutationMixin:
|
|
25
27
|
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
26
28
|
async def create_trace_annotations(
|
|
27
|
-
self, info: Info[Context, None], input: List[
|
|
29
|
+
self, info: Info[Context, None], input: List[CreateTraceAnnotationInput]
|
|
28
30
|
) -> TraceAnnotationMutationPayload:
|
|
29
31
|
inserted_annotations: Sequence[models.TraceAnnotation] = []
|
|
30
32
|
async with info.context.db() as session:
|
|
@@ -35,7 +37,7 @@ class TraceAnnotationMutationMixin:
|
|
|
35
37
|
label=annotation.label,
|
|
36
38
|
score=annotation.score,
|
|
37
39
|
explanation=annotation.explanation,
|
|
38
|
-
annotator_kind=annotation.annotator_kind,
|
|
40
|
+
annotator_kind=annotation.annotator_kind.value,
|
|
39
41
|
metadata_=annotation.metadata,
|
|
40
42
|
)
|
|
41
43
|
for annotation in input
|
|
@@ -49,12 +51,13 @@ class TraceAnnotationMutationMixin:
|
|
|
49
51
|
return TraceAnnotationMutationPayload(
|
|
50
52
|
trace_annotations=[
|
|
51
53
|
to_gql_trace_annotation(annotation) for annotation in inserted_annotations
|
|
52
|
-
]
|
|
54
|
+
],
|
|
55
|
+
query=Query(),
|
|
53
56
|
)
|
|
54
57
|
|
|
55
58
|
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
56
59
|
async def patch_trace_annotations(
|
|
57
|
-
self, info: Info[Context, None], input: List[
|
|
60
|
+
self, info: Info[Context, None], input: List[PatchAnnotationInput]
|
|
58
61
|
) -> TraceAnnotationMutationPayload:
|
|
59
62
|
patched_annotations = []
|
|
60
63
|
async with info.context.db() as session:
|
|
@@ -66,7 +69,13 @@ class TraceAnnotationMutationMixin:
|
|
|
66
69
|
column.key: patch_value
|
|
67
70
|
for column, patch_value, column_is_nullable in (
|
|
68
71
|
(models.TraceAnnotation.name, annotation.name, False),
|
|
69
|
-
(
|
|
72
|
+
(
|
|
73
|
+
models.TraceAnnotation.annotator_kind,
|
|
74
|
+
annotation.annotator_kind.value
|
|
75
|
+
if annotation.annotator_kind is not None
|
|
76
|
+
else None,
|
|
77
|
+
False,
|
|
78
|
+
),
|
|
70
79
|
(models.TraceAnnotation.label, annotation.label, True),
|
|
71
80
|
(models.TraceAnnotation.score, annotation.score, True),
|
|
72
81
|
(models.TraceAnnotation.explanation, annotation.explanation, True),
|
|
@@ -83,7 +92,7 @@ class TraceAnnotationMutationMixin:
|
|
|
83
92
|
if trace_annotation:
|
|
84
93
|
patched_annotations.append(to_gql_trace_annotation(trace_annotation))
|
|
85
94
|
|
|
86
|
-
return TraceAnnotationMutationPayload(trace_annotations=patched_annotations)
|
|
95
|
+
return TraceAnnotationMutationPayload(trace_annotations=patched_annotations, query=Query())
|
|
87
96
|
|
|
88
97
|
@strawberry.mutation(permission_classes=[IsAuthenticated]) # type: ignore
|
|
89
98
|
async def delete_trace_annotations(
|
|
@@ -105,4 +114,6 @@ class TraceAnnotationMutationMixin:
|
|
|
105
114
|
deleted_annotations_gql = [
|
|
106
115
|
to_gql_trace_annotation(annotation) for annotation in deleted_annotations
|
|
107
116
|
]
|
|
108
|
-
return TraceAnnotationMutationPayload(
|
|
117
|
+
return TraceAnnotationMutationPayload(
|
|
118
|
+
trace_annotations=deleted_annotations_gql, query=Query()
|
|
119
|
+
)
|
|
@@ -56,12 +56,11 @@ from phoenix.db.insertion.dataset import (
|
|
|
56
56
|
add_dataset_examples,
|
|
57
57
|
)
|
|
58
58
|
from phoenix.server.api.types.Dataset import Dataset as DatasetNodeType
|
|
59
|
-
from phoenix.server.api.types.DatasetExample import DatasetExample
|
|
59
|
+
from phoenix.server.api.types.DatasetExample import DatasetExample as DatasetExampleNodeType
|
|
60
60
|
from phoenix.server.api.types.DatasetVersion import DatasetVersion as DatasetVersionNodeType
|
|
61
61
|
from phoenix.server.api.types.node import from_global_id_with_expected_type
|
|
62
62
|
from phoenix.server.api.utils import delete_projects, delete_traces
|
|
63
63
|
|
|
64
|
-
from .dataset_examples import router as dataset_examples_router
|
|
65
64
|
from .pydantic_compat import V1RoutesBaseModel
|
|
66
65
|
from .utils import (
|
|
67
66
|
PaginatedResponseBody,
|
|
@@ -122,7 +121,7 @@ async def list_datasets(
|
|
|
122
121
|
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
|
|
123
122
|
)
|
|
124
123
|
if name:
|
|
125
|
-
query = query.filter(models.Dataset.name
|
|
124
|
+
query = query.filter(models.Dataset.name == name)
|
|
126
125
|
|
|
127
126
|
query = query.limit(limit + 1)
|
|
128
127
|
result = await session.execute(query)
|
|
@@ -669,12 +668,135 @@ async def _parse_form_data(
|
|
|
669
668
|
)
|
|
670
669
|
|
|
671
670
|
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
671
|
+
class DatasetExample(V1RoutesBaseModel):
|
|
672
|
+
id: str
|
|
673
|
+
input: Dict[str, Any]
|
|
674
|
+
output: Dict[str, Any]
|
|
675
|
+
metadata: Dict[str, Any]
|
|
676
|
+
updated_at: datetime
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
class ListDatasetExamplesData(V1RoutesBaseModel):
|
|
680
|
+
dataset_id: str
|
|
681
|
+
version_id: str
|
|
682
|
+
examples: List[DatasetExample]
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
class ListDatasetExamplesResponseBody(ResponseBody[ListDatasetExamplesData]):
|
|
686
|
+
pass
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
@router.get(
|
|
690
|
+
"/datasets/{id}/examples",
|
|
691
|
+
operation_id="getDatasetExamples",
|
|
692
|
+
summary="Get examples from a dataset",
|
|
693
|
+
responses=add_errors_to_responses([HTTP_404_NOT_FOUND]),
|
|
694
|
+
)
|
|
695
|
+
async def get_dataset_examples(
|
|
696
|
+
request: Request,
|
|
697
|
+
id: str = Path(description="The ID of the dataset"),
|
|
698
|
+
version_id: Optional[str] = Query(
|
|
699
|
+
default=None,
|
|
700
|
+
description=(
|
|
701
|
+
"The ID of the dataset version " "(if omitted, returns data from the latest version)"
|
|
702
|
+
),
|
|
703
|
+
),
|
|
704
|
+
) -> ListDatasetExamplesResponseBody:
|
|
705
|
+
dataset_gid = GlobalID.from_id(id)
|
|
706
|
+
version_gid = GlobalID.from_id(version_id) if version_id else None
|
|
707
|
+
|
|
708
|
+
if (dataset_type := dataset_gid.type_name) != "Dataset":
|
|
709
|
+
raise HTTPException(
|
|
710
|
+
detail=f"ID {dataset_gid} refers to a {dataset_type}", status_code=HTTP_404_NOT_FOUND
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
if version_gid and (version_type := version_gid.type_name) != "DatasetVersion":
|
|
714
|
+
raise HTTPException(
|
|
715
|
+
detail=f"ID {version_gid} refers to a {version_type}", status_code=HTTP_404_NOT_FOUND
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
async with request.app.state.db() as session:
|
|
719
|
+
if (
|
|
720
|
+
resolved_dataset_id := await session.scalar(
|
|
721
|
+
select(models.Dataset.id).where(models.Dataset.id == int(dataset_gid.node_id))
|
|
722
|
+
)
|
|
723
|
+
) is None:
|
|
724
|
+
raise HTTPException(
|
|
725
|
+
detail=f"No dataset with id {dataset_gid} can be found.",
|
|
726
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Subquery to find the maximum created_at for each dataset_example_id
|
|
730
|
+
# timestamp tiebreaks are resolved by the largest id
|
|
731
|
+
partial_subquery = select(
|
|
732
|
+
func.max(models.DatasetExampleRevision.id).label("max_id"),
|
|
733
|
+
).group_by(models.DatasetExampleRevision.dataset_example_id)
|
|
734
|
+
|
|
735
|
+
if version_gid:
|
|
736
|
+
if (
|
|
737
|
+
resolved_version_id := await session.scalar(
|
|
738
|
+
select(models.DatasetVersion.id).where(
|
|
739
|
+
and_(
|
|
740
|
+
models.DatasetVersion.dataset_id == resolved_dataset_id,
|
|
741
|
+
models.DatasetVersion.id == int(version_gid.node_id),
|
|
742
|
+
)
|
|
743
|
+
)
|
|
744
|
+
)
|
|
745
|
+
) is None:
|
|
746
|
+
raise HTTPException(
|
|
747
|
+
detail=f"No dataset version with id {version_id} can be found.",
|
|
748
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
749
|
+
)
|
|
750
|
+
# if a version_id is provided, filter the subquery to only include revisions from that
|
|
751
|
+
partial_subquery = partial_subquery.filter(
|
|
752
|
+
models.DatasetExampleRevision.dataset_version_id <= resolved_version_id
|
|
753
|
+
)
|
|
754
|
+
else:
|
|
755
|
+
if (
|
|
756
|
+
resolved_version_id := await session.scalar(
|
|
757
|
+
select(func.max(models.DatasetVersion.id)).where(
|
|
758
|
+
models.DatasetVersion.dataset_id == resolved_dataset_id
|
|
759
|
+
)
|
|
760
|
+
)
|
|
761
|
+
) is None:
|
|
762
|
+
raise HTTPException(
|
|
763
|
+
detail="Dataset has no versions.",
|
|
764
|
+
status_code=HTTP_404_NOT_FOUND,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
subquery = partial_subquery.subquery()
|
|
768
|
+
# Query for the most recent example revisions that are not deleted
|
|
769
|
+
query = (
|
|
770
|
+
select(models.DatasetExample, models.DatasetExampleRevision)
|
|
771
|
+
.join(
|
|
772
|
+
models.DatasetExampleRevision,
|
|
773
|
+
models.DatasetExample.id == models.DatasetExampleRevision.dataset_example_id,
|
|
774
|
+
)
|
|
775
|
+
.join(
|
|
776
|
+
subquery,
|
|
777
|
+
(subquery.c.max_id == models.DatasetExampleRevision.id),
|
|
778
|
+
)
|
|
779
|
+
.filter(models.DatasetExample.dataset_id == resolved_dataset_id)
|
|
780
|
+
.filter(models.DatasetExampleRevision.revision_kind != "DELETE")
|
|
781
|
+
.order_by(models.DatasetExample.id.asc())
|
|
782
|
+
)
|
|
783
|
+
examples = [
|
|
784
|
+
DatasetExample(
|
|
785
|
+
id=str(GlobalID("DatasetExample", str(example.id))),
|
|
786
|
+
input=revision.input,
|
|
787
|
+
output=revision.output,
|
|
788
|
+
metadata=revision.metadata_,
|
|
789
|
+
updated_at=revision.created_at,
|
|
790
|
+
)
|
|
791
|
+
async for example, revision in await session.stream(query)
|
|
792
|
+
]
|
|
793
|
+
return ListDatasetExamplesResponseBody(
|
|
794
|
+
data=ListDatasetExamplesData(
|
|
795
|
+
dataset_id=str(GlobalID("Dataset", str(resolved_dataset_id))),
|
|
796
|
+
version_id=str(GlobalID("DatasetVersion", str(resolved_version_id))),
|
|
797
|
+
examples=examples,
|
|
798
|
+
)
|
|
799
|
+
)
|
|
678
800
|
|
|
679
801
|
|
|
680
802
|
@router.get(
|
|
@@ -794,7 +916,7 @@ def _get_content_csv(examples: List[models.DatasetExampleRevision]) -> bytes:
|
|
|
794
916
|
records = [
|
|
795
917
|
{
|
|
796
918
|
"example_id": GlobalID(
|
|
797
|
-
type_name=
|
|
919
|
+
type_name=DatasetExampleNodeType.__name__,
|
|
798
920
|
node_id=str(ex.dataset_example_id),
|
|
799
921
|
),
|
|
800
922
|
**{f"input_{k}": v for k, v in ex.input.items()},
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import gzip
|
|
2
2
|
from itertools import chain
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import Iterator, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import pandas as pd
|
|
6
6
|
import pyarrow as pa
|
|
@@ -9,9 +9,6 @@ from google.protobuf.message import DecodeError
|
|
|
9
9
|
from pandas import DataFrame
|
|
10
10
|
from sqlalchemy import select
|
|
11
11
|
from sqlalchemy.engine import Connectable
|
|
12
|
-
from sqlalchemy.ext.asyncio import (
|
|
13
|
-
AsyncSession,
|
|
14
|
-
)
|
|
15
12
|
from starlette.background import BackgroundTask
|
|
16
13
|
from starlette.datastructures import State
|
|
17
14
|
from starlette.requests import Request
|
|
@@ -29,6 +26,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME
|
|
|
29
26
|
from phoenix.db import models
|
|
30
27
|
from phoenix.exceptions import PhoenixEvaluationNameIsMissing
|
|
31
28
|
from phoenix.server.api.routers.utils import table_to_bytes
|
|
29
|
+
from phoenix.server.types import DbSessionFactory
|
|
32
30
|
from phoenix.session.evaluation import encode_evaluations
|
|
33
31
|
from phoenix.trace.span_evaluations import (
|
|
34
32
|
DocumentEvaluations,
|
|
@@ -128,7 +126,7 @@ async def get_evaluations(
|
|
|
128
126
|
or DEFAULT_PROJECT_NAME
|
|
129
127
|
)
|
|
130
128
|
|
|
131
|
-
db:
|
|
129
|
+
db: DbSessionFactory = request.app.state.db
|
|
132
130
|
async with db() as session:
|
|
133
131
|
connection = await session.connection()
|
|
134
132
|
trace_evals_dataframe = await connection.run_sync(
|
|
@@ -110,7 +110,7 @@ async def create_experiment(
|
|
|
110
110
|
)
|
|
111
111
|
except ValueError:
|
|
112
112
|
raise HTTPException(
|
|
113
|
-
detail="DatasetVersion with ID {
|
|
113
|
+
detail=f"DatasetVersion with ID {dataset_version_globalid_str} does not exist",
|
|
114
114
|
status_code=HTTP_404_NOT_FOUND,
|
|
115
115
|
)
|
|
116
116
|
|
|
@@ -104,11 +104,11 @@ class Experiment(Node):
|
|
|
104
104
|
return await info.context.data_loaders.experiment_error_rates.load(self.id_attr)
|
|
105
105
|
|
|
106
106
|
@strawberry.field
|
|
107
|
-
async def average_run_latency_ms(self, info: Info[Context, None]) -> float:
|
|
107
|
+
async def average_run_latency_ms(self, info: Info[Context, None]) -> Optional[float]:
|
|
108
108
|
latency_seconds = await info.context.data_loaders.average_experiment_run_latency.load(
|
|
109
109
|
self.id_attr
|
|
110
110
|
)
|
|
111
|
-
return latency_seconds * 1000
|
|
111
|
+
return latency_seconds * 1000 if latency_seconds is not None else None
|
|
112
112
|
|
|
113
113
|
@strawberry.field
|
|
114
114
|
async def project(self, info: Info[Context, None]) -> Optional[Project]:
|
|
@@ -2,8 +2,7 @@ from datetime import datetime
|
|
|
2
2
|
from typing import Iterable, List, Optional, Set, Union
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
-
from strawberry
|
|
6
|
-
from strawberry.unset import UNSET
|
|
5
|
+
from strawberry import ID, UNSET
|
|
7
6
|
|
|
8
7
|
import phoenix.core.model_schema as ms
|
|
9
8
|
from phoenix.core.model_schema import FEATURE, TAG, ScalarDimension
|
|
@@ -2,9 +2,8 @@ import asyncio
|
|
|
2
2
|
from typing import List, Optional
|
|
3
3
|
|
|
4
4
|
import strawberry
|
|
5
|
+
from strawberry import UNSET, Info
|
|
5
6
|
from strawberry.relay import Connection
|
|
6
|
-
from strawberry.types import Info
|
|
7
|
-
from strawberry.unset import UNSET
|
|
8
7
|
from typing_extensions import Annotated
|
|
9
8
|
|
|
10
9
|
from phoenix.config import get_exported_files
|
phoenix/server/api/types/Span.py
CHANGED
|
@@ -258,6 +258,11 @@ class Span(Node):
|
|
|
258
258
|
project = await info.context.data_loaders.span_projects.load(span_id)
|
|
259
259
|
return to_gql_project(project)
|
|
260
260
|
|
|
261
|
+
@strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
|
|
262
|
+
async def contained_in_dataset(self, info: Info[Context, None]) -> bool:
|
|
263
|
+
examples = await info.context.data_loaders.span_dataset_examples.load(self.id_attr)
|
|
264
|
+
return bool(examples)
|
|
265
|
+
|
|
261
266
|
|
|
262
267
|
def to_gql_span(span: models.Span) -> Span:
|
|
263
268
|
events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
|
phoenix/server/api/utils.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import List
|
|
2
2
|
|
|
3
3
|
from sqlalchemy import delete
|
|
4
|
-
from sqlalchemy.ext.asyncio import AsyncSession
|
|
5
4
|
|
|
6
5
|
from phoenix.db import models
|
|
6
|
+
from phoenix.server.types import DbSessionFactory
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
async def delete_projects(
|
|
10
|
-
db:
|
|
10
|
+
db: DbSessionFactory,
|
|
11
11
|
*project_names: str,
|
|
12
12
|
) -> List[int]:
|
|
13
13
|
if not project_names:
|
|
@@ -22,7 +22,7 @@ async def delete_projects(
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
async def delete_traces(
|
|
25
|
-
db:
|
|
25
|
+
db: DbSessionFactory,
|
|
26
26
|
*trace_ids: str,
|
|
27
27
|
) -> List[int]:
|
|
28
28
|
if not trace_ids:
|
phoenix/server/app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import contextlib
|
|
2
3
|
import json
|
|
3
4
|
import logging
|
|
@@ -74,6 +75,7 @@ from phoenix.server.api.dataloaders import (
|
|
|
74
75
|
ProjectByNameDataLoader,
|
|
75
76
|
RecordCountDataLoader,
|
|
76
77
|
SpanAnnotationsDataLoader,
|
|
78
|
+
SpanDatasetExamplesDataLoader,
|
|
77
79
|
SpanDescendantsDataLoader,
|
|
78
80
|
SpanEvaluationsDataLoader,
|
|
79
81
|
SpanProjectsDataLoader,
|
|
@@ -86,7 +88,9 @@ from phoenix.server.api.routers.v1 import router as v1_router
|
|
|
86
88
|
from phoenix.server.api.schema import schema
|
|
87
89
|
from phoenix.server.grpc_server import GrpcServer
|
|
88
90
|
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
|
|
91
|
+
from phoenix.server.types import DbSessionFactory
|
|
89
92
|
from phoenix.trace.schemas import Span
|
|
93
|
+
from phoenix.utilities.client import PHOENIX_SERVER_VERSION_HEADER
|
|
90
94
|
|
|
91
95
|
if TYPE_CHECKING:
|
|
92
96
|
from opentelemetry.trace import TracerProvider
|
|
@@ -167,9 +171,11 @@ class HeadersMiddleware(BaseHTTPMiddleware):
|
|
|
167
171
|
request: Request,
|
|
168
172
|
call_next: RequestResponseEndpoint,
|
|
169
173
|
) -> Response:
|
|
174
|
+
from phoenix import __version__ as phoenix_version
|
|
175
|
+
|
|
170
176
|
response = await call_next(request)
|
|
171
177
|
response.headers["x-colab-notebook-cache-control"] = "no-cache"
|
|
172
|
-
response.headers[
|
|
178
|
+
response.headers[PHOENIX_SERVER_VERSION_HEADER] = phoenix_version
|
|
173
179
|
return response
|
|
174
180
|
|
|
175
181
|
|
|
@@ -193,19 +199,25 @@ async def version() -> PlainTextResponse:
|
|
|
193
199
|
return PlainTextResponse(f"{phoenix.__version__}")
|
|
194
200
|
|
|
195
201
|
|
|
202
|
+
DB_MUTEX: Optional[asyncio.Lock] = None
|
|
203
|
+
|
|
204
|
+
|
|
196
205
|
def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]:
|
|
197
206
|
Session = async_sessionmaker(engine, expire_on_commit=False)
|
|
198
207
|
|
|
199
208
|
@contextlib.asynccontextmanager
|
|
200
209
|
async def factory() -> AsyncIterator[AsyncSession]:
|
|
201
|
-
async with
|
|
202
|
-
|
|
210
|
+
async with contextlib.AsyncExitStack() as stack:
|
|
211
|
+
if DB_MUTEX:
|
|
212
|
+
await stack.enter_async_context(DB_MUTEX)
|
|
213
|
+
yield await stack.enter_async_context(Session.begin())
|
|
203
214
|
|
|
204
215
|
return factory
|
|
205
216
|
|
|
206
217
|
|
|
207
218
|
def _lifespan(
|
|
208
219
|
*,
|
|
220
|
+
dialect: SupportedSQLDialect,
|
|
209
221
|
bulk_inserter: BulkInserter,
|
|
210
222
|
tracer_provider: Optional["TracerProvider"] = None,
|
|
211
223
|
enable_prometheus: bool = False,
|
|
@@ -214,6 +226,8 @@ def _lifespan(
|
|
|
214
226
|
) -> StatefulLifespan[FastAPI]:
|
|
215
227
|
@contextlib.asynccontextmanager
|
|
216
228
|
async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]:
|
|
229
|
+
global DB_MUTEX
|
|
230
|
+
DB_MUTEX = asyncio.Lock() if dialect is SupportedSQLDialect.SQLITE else None
|
|
217
231
|
async with bulk_inserter as (
|
|
218
232
|
queue_span,
|
|
219
233
|
queue_evaluation,
|
|
@@ -243,7 +257,7 @@ async def check_healthz(_: Request) -> PlainTextResponse:
|
|
|
243
257
|
def create_graphql_router(
|
|
244
258
|
*,
|
|
245
259
|
schema: BaseSchema,
|
|
246
|
-
db:
|
|
260
|
+
db: DbSessionFactory,
|
|
247
261
|
model: Model,
|
|
248
262
|
export_path: Path,
|
|
249
263
|
corpus: Optional[Model] = None,
|
|
@@ -297,6 +311,7 @@ def create_graphql_router(
|
|
|
297
311
|
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
|
|
298
312
|
),
|
|
299
313
|
span_annotations=SpanAnnotationsDataLoader(db),
|
|
314
|
+
span_dataset_examples=SpanDatasetExamplesDataLoader(db),
|
|
300
315
|
span_descendants=SpanDescendantsDataLoader(db),
|
|
301
316
|
span_evaluations=SpanEvaluationsDataLoader(db),
|
|
302
317
|
span_projects=SpanProjectsDataLoader(db),
|
|
@@ -321,19 +336,6 @@ def create_graphql_router(
|
|
|
321
336
|
)
|
|
322
337
|
|
|
323
338
|
|
|
324
|
-
class SessionFactory:
|
|
325
|
-
def __init__(
|
|
326
|
-
self,
|
|
327
|
-
session_factory: Callable[[], AsyncContextManager[AsyncSession]],
|
|
328
|
-
dialect: str,
|
|
329
|
-
):
|
|
330
|
-
self.session_factory = session_factory
|
|
331
|
-
self.dialect = SupportedSQLDialect(dialect)
|
|
332
|
-
|
|
333
|
-
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
|
334
|
-
return self.session_factory()
|
|
335
|
-
|
|
336
|
-
|
|
337
339
|
def create_engine_and_run_migrations(
|
|
338
340
|
database_url: str,
|
|
339
341
|
) -> AsyncEngine:
|
|
@@ -382,7 +384,7 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
|
|
|
382
384
|
|
|
383
385
|
|
|
384
386
|
def create_app(
|
|
385
|
-
db:
|
|
387
|
+
db: DbSessionFactory,
|
|
386
388
|
export_path: Path,
|
|
387
389
|
model: Model,
|
|
388
390
|
umap_params: UMAPParameters,
|
|
@@ -463,6 +465,7 @@ def create_app(
|
|
|
463
465
|
title="Arize-Phoenix REST API",
|
|
464
466
|
version=REST_API_VERSION,
|
|
465
467
|
lifespan=_lifespan(
|
|
468
|
+
dialect=db.dialect,
|
|
466
469
|
read_only=read_only,
|
|
467
470
|
bulk_inserter=bulk_inserter,
|
|
468
471
|
tracer_provider=tracer_provider,
|
phoenix/server/grpc_server.py
CHANGED
|
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|
|
23
23
|
ProjectName: TypeAlias = str
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class Servicer(TraceServiceServicer):
|
|
26
|
+
class Servicer(TraceServiceServicer): # type:ignore
|
|
27
27
|
def __init__(
|
|
28
28
|
self,
|
|
29
29
|
callback: Callable[[Span, ProjectName], Awaitable[None]],
|
|
@@ -78,7 +78,7 @@ class GrpcServer:
|
|
|
78
78
|
interceptors=interceptors,
|
|
79
79
|
)
|
|
80
80
|
server.add_insecure_port(f"[::]:{get_env_grpc_port()}")
|
|
81
|
-
add_TraceServiceServicer_to_server(Servicer(self._callback), server)
|
|
81
|
+
add_TraceServiceServicer_to_server(Servicer(self._callback), server)
|
|
82
82
|
await server.start()
|
|
83
83
|
self._server = server
|
|
84
84
|
|
phoenix/server/main.py
CHANGED
|
@@ -33,25 +33,23 @@ from phoenix.pointcloud.umap_parameters import (
|
|
|
33
33
|
UMAPParameters,
|
|
34
34
|
)
|
|
35
35
|
from phoenix.server.app import (
|
|
36
|
-
SessionFactory,
|
|
37
36
|
_db,
|
|
38
37
|
create_app,
|
|
39
38
|
create_engine_and_run_migrations,
|
|
40
39
|
instrument_engine_if_enabled,
|
|
41
40
|
)
|
|
41
|
+
from phoenix.server.types import DbSessionFactory
|
|
42
42
|
from phoenix.settings import Settings
|
|
43
43
|
from phoenix.trace.fixtures import (
|
|
44
44
|
TRACES_FIXTURES,
|
|
45
|
-
download_traces_fixture,
|
|
46
45
|
get_dataset_fixtures,
|
|
47
46
|
get_evals_from_fixture,
|
|
48
|
-
|
|
47
|
+
load_example_traces,
|
|
49
48
|
reset_fixture_span_ids_and_timestamps,
|
|
50
49
|
send_dataset_fixtures,
|
|
51
50
|
)
|
|
52
51
|
from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp
|
|
53
52
|
from phoenix.trace.schemas import Span
|
|
54
|
-
from phoenix.trace.span_json_decoder import json_string_to_span
|
|
55
53
|
|
|
56
54
|
logger = logging.getLogger(__name__)
|
|
57
55
|
|
|
@@ -221,10 +219,8 @@ if __name__ == "__main__":
|
|
|
221
219
|
(
|
|
222
220
|
# Apply `encode` here because legacy jsonl files contains UUIDs as strings.
|
|
223
221
|
# `encode` removes the hyphens in the UUIDs.
|
|
224
|
-
decode_otlp_span(encode_span_to_otlp(
|
|
225
|
-
for
|
|
226
|
-
get_trace_fixture_by_name(trace_dataset_name)
|
|
227
|
-
)
|
|
222
|
+
decode_otlp_span(encode_span_to_otlp(span))
|
|
223
|
+
for span in load_example_traces(trace_dataset_name).to_spans()
|
|
228
224
|
),
|
|
229
225
|
get_evals_from_fixture(trace_dataset_name),
|
|
230
226
|
)
|
|
@@ -250,7 +246,7 @@ if __name__ == "__main__":
|
|
|
250
246
|
working_dir = get_working_dir().resolve()
|
|
251
247
|
engine = create_engine_and_run_migrations(db_connection_str)
|
|
252
248
|
instrumentation_cleanups = instrument_engine_if_enabled(engine)
|
|
253
|
-
factory =
|
|
249
|
+
factory = DbSessionFactory(db=_db(engine), dialect=engine.dialect.name)
|
|
254
250
|
app = create_app(
|
|
255
251
|
db=factory,
|
|
256
252
|
export_path=export_path,
|