arize-phoenix 0.0.2rc3__py3-none-any.whl → 0.0.2rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arize-phoenix might be problematic. Click here for more details.
- {arize_phoenix-0.0.2rc3.dist-info → arize_phoenix-0.0.2rc5.dist-info}/METADATA +25 -21
- {arize_phoenix-0.0.2rc3.dist-info → arize_phoenix-0.0.2rc5.dist-info}/RECORD +25 -26
- phoenix/__about__.py +1 -1
- phoenix/__init__.py +2 -2
- phoenix/core/embedding_dimension.py +33 -0
- phoenix/datasets/__init__.py +2 -1
- phoenix/datasets/dataset.py +31 -4
- phoenix/{server → datasets}/fixtures.py +47 -10
- phoenix/datasets/validation.py +1 -1
- phoenix/metrics/metrics.py +29 -5
- phoenix/metrics/mixins.py +11 -3
- phoenix/metrics/timeseries.py +11 -7
- phoenix/pointcloud/clustering.py +3 -3
- phoenix/pointcloud/pointcloud.py +9 -7
- phoenix/server/api/input_types/Granularity.py +2 -0
- phoenix/server/api/interceptor.py +28 -0
- phoenix/server/api/types/Dimension.py +23 -33
- phoenix/server/api/types/EmbeddingDimension.py +39 -111
- phoenix/server/api/types/TimeSeries.py +117 -3
- phoenix/server/api/types/UMAPPoints.py +62 -14
- phoenix/server/main.py +3 -3
- phoenix/server/static/index.js +720 -634
- phoenix/session/session.py +48 -6
- phoenix/server/api/types/DataQualityTimeSeries.py +0 -36
- phoenix/server/api/types/DriftTimeSeries.py +0 -10
- {arize_phoenix-0.0.2rc3.dist-info → arize_phoenix-0.0.2rc5.dist-info}/WHEEL +0 -0
- {arize_phoenix-0.0.2rc3.dist-info → arize_phoenix-0.0.2rc5.dist-info}/licenses/LICENSE +0 -0
|
@@ -40,6 +40,8 @@ class Granularity:
|
|
|
40
40
|
def to_timestamps(
|
|
41
41
|
time_range: TimeRange, granularity: Granularity
|
|
42
42
|
) -> Generator[datetime, None, None]:
|
|
43
|
+
if not granularity.sampling_interval_minutes:
|
|
44
|
+
return
|
|
43
45
|
yield from (
|
|
44
46
|
takewhile(
|
|
45
47
|
lambda t: time_range.start < t, # type: ignore
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Interceptor(ABC):
|
|
7
|
+
"""an abstract class making use of the descriptor protocol
|
|
8
|
+
see https://docs.python.org/3/howto/descriptor.html"""
|
|
9
|
+
|
|
10
|
+
private_name: str
|
|
11
|
+
|
|
12
|
+
def __set_name__(self, owner: Any, name: str) -> None:
|
|
13
|
+
self.private_name = "_" + name
|
|
14
|
+
|
|
15
|
+
def __get__(self, instance: Any, owner: Any) -> Any:
|
|
16
|
+
return self if instance is None else getattr(instance, self.private_name)
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def __set__(self, instance: Any, value: Any) -> None:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NoneIfNan(Interceptor):
|
|
24
|
+
"""descriptor that converts NaN and Inf to None because NaN can't be
|
|
25
|
+
serialized to JSON by the graphql object"""
|
|
26
|
+
|
|
27
|
+
def __set__(self, instance: Any, value: float) -> None:
|
|
28
|
+
setattr(instance, self.private_name, value if math.isfinite(value) else None)
|
|
@@ -1,23 +1,18 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from datetime import timedelta
|
|
3
1
|
from typing import List, Optional
|
|
4
2
|
|
|
5
3
|
import strawberry
|
|
6
4
|
from strawberry.types import Info
|
|
7
5
|
|
|
8
6
|
from phoenix.core import Dimension as CoreDimension
|
|
9
|
-
from phoenix.metrics.mixins import UnaryOperator
|
|
10
|
-
from phoenix.metrics.timeseries import timeseries
|
|
11
7
|
from phoenix.server.api.context import Context
|
|
12
8
|
|
|
13
|
-
from ..input_types.Granularity import Granularity
|
|
9
|
+
from ..input_types.Granularity import Granularity
|
|
14
10
|
from ..input_types.TimeRange import TimeRange
|
|
15
|
-
from . import METRICS
|
|
16
11
|
from .DataQualityMetric import DataQualityMetric
|
|
17
|
-
from .DataQualityTimeSeries import DataQualityTimeSeries, to_gql_timeseries
|
|
18
12
|
from .DimensionDataType import DimensionDataType
|
|
19
13
|
from .DimensionType import DimensionType
|
|
20
14
|
from .node import Node
|
|
15
|
+
from .TimeSeries import DataQualityTimeSeries
|
|
21
16
|
|
|
22
17
|
|
|
23
18
|
@strawberry.type
|
|
@@ -32,15 +27,22 @@ class Dimension(Node):
|
|
|
32
27
|
)
|
|
33
28
|
|
|
34
29
|
@strawberry.field
|
|
35
|
-
async def
|
|
36
|
-
self,
|
|
30
|
+
async def data_quality_metric(
|
|
31
|
+
self,
|
|
32
|
+
info: Info[Context, None],
|
|
33
|
+
metric: DataQualityMetric,
|
|
34
|
+
time_range: Optional[TimeRange] = None,
|
|
37
35
|
) -> Optional[float]:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
36
|
+
if len(
|
|
37
|
+
data := DataQualityTimeSeries(
|
|
38
|
+
self.name,
|
|
39
|
+
info.context.model,
|
|
40
|
+
metric,
|
|
41
|
+
time_range,
|
|
42
|
+
).data
|
|
43
|
+
):
|
|
44
|
+
return data.pop().value
|
|
45
|
+
return None
|
|
44
46
|
|
|
45
47
|
@strawberry.field(
|
|
46
48
|
description=(
|
|
@@ -70,24 +72,12 @@ class Dimension(Node):
|
|
|
70
72
|
time_range: TimeRange,
|
|
71
73
|
granularity: Granularity,
|
|
72
74
|
) -> DataQualityTimeSeries:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
return dataset.dataframe.pipe(
|
|
80
|
-
timeseries(
|
|
81
|
-
start_time=time_range.start,
|
|
82
|
-
end_time=time_range.end,
|
|
83
|
-
evaluation_window=timedelta(minutes=granularity.evaluation_window_minutes),
|
|
84
|
-
sampling_interval=timedelta(minutes=granularity.sampling_interval_minutes),
|
|
85
|
-
),
|
|
86
|
-
metrics=(metric_instance,),
|
|
87
|
-
).pipe(
|
|
88
|
-
to_gql_timeseries,
|
|
89
|
-
metric=metric_instance,
|
|
90
|
-
timestamps=to_timestamps(time_range, granularity),
|
|
75
|
+
return DataQualityTimeSeries(
|
|
76
|
+
self.name,
|
|
77
|
+
info.context.model,
|
|
78
|
+
metric,
|
|
79
|
+
time_range,
|
|
80
|
+
granularity,
|
|
91
81
|
)
|
|
92
82
|
|
|
93
83
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
2
|
from datetime import datetime, timedelta
|
|
3
3
|
from itertools import chain
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
@@ -16,26 +16,21 @@ from phoenix.datasets import Dataset
|
|
|
16
16
|
from phoenix.datasets.dataset import DatasetType
|
|
17
17
|
from phoenix.datasets.errors import SchemaError
|
|
18
18
|
from phoenix.datasets.event import EventId
|
|
19
|
-
from phoenix.metrics.
|
|
20
|
-
from phoenix.metrics.mixins import UnaryOperator
|
|
21
|
-
from phoenix.metrics.timeseries import row_interval_from_sorted_time_index, timeseries
|
|
19
|
+
from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
|
|
22
20
|
from phoenix.pointcloud.clustering import Hdbscan
|
|
23
21
|
from phoenix.pointcloud.pointcloud import PointCloud
|
|
24
22
|
from phoenix.pointcloud.projectors import Umap
|
|
25
23
|
from phoenix.server.api.context import Context
|
|
26
24
|
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
27
25
|
|
|
28
|
-
from ..input_types.Granularity import Granularity
|
|
29
|
-
from . import METRICS
|
|
26
|
+
from ..input_types.Granularity import Granularity
|
|
30
27
|
from .DataQualityMetric import DataQualityMetric
|
|
31
|
-
from .DataQualityTimeSeries import DataQualityTimeSeries, to_gql_timeseries
|
|
32
28
|
from .DriftMetric import DriftMetric
|
|
33
|
-
from .DriftTimeSeries import DriftTimeSeries
|
|
34
29
|
from .EmbeddingMetadata import EmbeddingMetadata
|
|
35
30
|
from .EventMetadata import EventMetadata
|
|
36
31
|
from .node import Node
|
|
37
|
-
from .TimeSeries import
|
|
38
|
-
from .UMAPPoints import
|
|
32
|
+
from .TimeSeries import DataQualityTimeSeries, DriftTimeSeries
|
|
33
|
+
from .UMAPPoints import UMAPPoint, UMAPPoints, to_gql_clusters, to_gql_coordinates
|
|
39
34
|
|
|
40
35
|
# Default UMAP hyperparameters
|
|
41
36
|
DEFAULT_N_COMPONENTS = 3
|
|
@@ -43,17 +38,6 @@ DEFAULT_MIN_DIST = 0
|
|
|
43
38
|
DEFAULT_N_NEIGHBORS = 30
|
|
44
39
|
DEFAULT_N_SAMPLES = 500
|
|
45
40
|
|
|
46
|
-
|
|
47
|
-
def to_gql_clusters(clusters: Mapping[EventId, int]) -> List[Cluster]:
|
|
48
|
-
clusteredEvents = defaultdict(list)
|
|
49
|
-
for event_id, cluster_id in clusters.items():
|
|
50
|
-
clusteredEvents[ID(str(cluster_id))].append(ID(str(event_id)))
|
|
51
|
-
return [
|
|
52
|
-
Cluster(id=cluster_id, point_ids=event_ids)
|
|
53
|
-
for cluster_id, event_ids in clusteredEvents.items()
|
|
54
|
-
]
|
|
55
|
-
|
|
56
|
-
|
|
57
41
|
DRIFT_EVAL_WINDOW_NUM_INTERVALS = 72
|
|
58
42
|
EVAL_INTERVAL_LENGTH = timedelta(hours=1)
|
|
59
43
|
|
|
@@ -66,7 +50,10 @@ class EmbeddingDimension(Node):
|
|
|
66
50
|
|
|
67
51
|
@strawberry.field
|
|
68
52
|
def drift_metric(
|
|
69
|
-
self,
|
|
53
|
+
self,
|
|
54
|
+
info: Info[Context, None],
|
|
55
|
+
metric: DriftMetric,
|
|
56
|
+
time_range: Optional[TimeRange] = None,
|
|
70
57
|
) -> Optional[float]:
|
|
71
58
|
"""
|
|
72
59
|
Computes a drift metric between all reference data and the primary data
|
|
@@ -75,29 +62,16 @@ class EmbeddingDimension(Node):
|
|
|
75
62
|
exists, if no primary data exists in the input time range, or if the
|
|
76
63
|
input time range is invalid.
|
|
77
64
|
"""
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
primary_embeddings = _get_embeddings_array_for_time_range(
|
|
89
|
-
dataset=primary_dataset,
|
|
90
|
-
embedding_feature_name=embedding_feature_name,
|
|
91
|
-
start=time_range.start,
|
|
92
|
-
end=time_range.end,
|
|
93
|
-
)
|
|
94
|
-
if primary_embeddings is None:
|
|
95
|
-
return None
|
|
96
|
-
primary_centroid = _compute_mean_vector(primary_embeddings)
|
|
97
|
-
reference_centroid = _compute_mean_vector(reference_embeddings)
|
|
98
|
-
if metric is DriftMetric.euclideanDistance:
|
|
99
|
-
return euclidean_distance(primary_centroid, reference_centroid)
|
|
100
|
-
raise NotImplementedError(f'Metric "{metric}" has not been implemented.')
|
|
65
|
+
if len(
|
|
66
|
+
data := DriftTimeSeries(
|
|
67
|
+
str(info.context.model.primary_dataset.get_embedding_vector_column(self.name).name),
|
|
68
|
+
info.context.model,
|
|
69
|
+
metric,
|
|
70
|
+
time_range,
|
|
71
|
+
).data
|
|
72
|
+
):
|
|
73
|
+
return data.pop().value
|
|
74
|
+
return None
|
|
101
75
|
|
|
102
76
|
@strawberry.field(
|
|
103
77
|
description=(
|
|
@@ -114,37 +88,22 @@ class EmbeddingDimension(Node):
|
|
|
114
88
|
time_range: TimeRange,
|
|
115
89
|
granularity: Granularity,
|
|
116
90
|
) -> DataQualityTimeSeries:
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
timeseries(
|
|
124
|
-
start_time=time_range.start,
|
|
125
|
-
end_time=time_range.end,
|
|
126
|
-
evaluation_window=timedelta(minutes=granularity.evaluation_window_minutes),
|
|
127
|
-
sampling_interval=timedelta(minutes=granularity.sampling_interval_minutes),
|
|
128
|
-
),
|
|
129
|
-
metrics=(metric_instance,),
|
|
130
|
-
).pipe(
|
|
131
|
-
to_gql_timeseries,
|
|
132
|
-
metric=metric_instance,
|
|
133
|
-
timestamps=to_timestamps(time_range, granularity),
|
|
91
|
+
return DataQualityTimeSeries(
|
|
92
|
+
str(info.context.model.primary_dataset.get_embedding_vector_column(self.name).name),
|
|
93
|
+
info.context.model,
|
|
94
|
+
metric,
|
|
95
|
+
time_range,
|
|
96
|
+
granularity,
|
|
134
97
|
)
|
|
135
98
|
|
|
136
99
|
@strawberry.field
|
|
137
100
|
def drift_time_series(
|
|
138
101
|
self,
|
|
139
|
-
metric: DriftMetric,
|
|
140
|
-
time_range: Annotated[
|
|
141
|
-
TimeRange,
|
|
142
|
-
strawberry.argument(
|
|
143
|
-
description="The time range of the primary dataset",
|
|
144
|
-
),
|
|
145
|
-
],
|
|
146
102
|
info: Info[Context, None],
|
|
147
|
-
|
|
103
|
+
metric: DriftMetric,
|
|
104
|
+
time_range: TimeRange,
|
|
105
|
+
granularity: Granularity,
|
|
106
|
+
) -> DriftTimeSeries:
|
|
148
107
|
"""
|
|
149
108
|
Computes a drift time-series between the primary and reference datasets.
|
|
150
109
|
The output drift time-series contains one data point for each whole hour
|
|
@@ -156,46 +115,13 @@ class EmbeddingDimension(Node):
|
|
|
156
115
|
Returns None if no reference dataset exists or if the input time range
|
|
157
116
|
is invalid.
|
|
158
117
|
"""
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
reference_embeddings_column = reference_dataset.get_embedding_vector_column(
|
|
166
|
-
embedding_feature_name
|
|
118
|
+
return DriftTimeSeries(
|
|
119
|
+
str(info.context.model.primary_dataset.get_embedding_vector_column(self.name).name),
|
|
120
|
+
info.context.model,
|
|
121
|
+
metric,
|
|
122
|
+
time_range,
|
|
123
|
+
granularity,
|
|
167
124
|
)
|
|
168
|
-
reference_embeddings = _to_array(reference_embeddings_column)
|
|
169
|
-
reference_centroid = _compute_mean_vector(reference_embeddings)
|
|
170
|
-
time_series_data_points = []
|
|
171
|
-
if metric is DriftMetric.euclideanDistance:
|
|
172
|
-
eval_window_end = time_range.start
|
|
173
|
-
while eval_window_end < time_range.end:
|
|
174
|
-
eval_window_start = (
|
|
175
|
-
eval_window_end - DRIFT_EVAL_WINDOW_NUM_INTERVALS * EVAL_INTERVAL_LENGTH
|
|
176
|
-
)
|
|
177
|
-
primary_embeddings = _get_embeddings_array_for_time_range(
|
|
178
|
-
dataset=primary_dataset,
|
|
179
|
-
embedding_feature_name=embedding_feature_name,
|
|
180
|
-
start=eval_window_start,
|
|
181
|
-
end=eval_window_end,
|
|
182
|
-
)
|
|
183
|
-
distance: Optional[float] = None
|
|
184
|
-
if primary_embeddings is not None:
|
|
185
|
-
primary_centroid = _compute_mean_vector(primary_embeddings)
|
|
186
|
-
distance = euclidean_distance(
|
|
187
|
-
reference_centroid,
|
|
188
|
-
primary_centroid,
|
|
189
|
-
)
|
|
190
|
-
time_series_data_points.append(
|
|
191
|
-
TimeSeriesDataPoint(
|
|
192
|
-
timestamp=eval_window_end,
|
|
193
|
-
value=distance,
|
|
194
|
-
)
|
|
195
|
-
)
|
|
196
|
-
eval_window_end += EVAL_INTERVAL_LENGTH
|
|
197
|
-
return DriftTimeSeries(data=time_series_data_points)
|
|
198
|
-
raise NotImplementedError(f'Metric "{metric}" has not been implemented.')
|
|
199
125
|
|
|
200
126
|
@strawberry.field
|
|
201
127
|
def UMAPPoints(
|
|
@@ -276,7 +202,7 @@ class EmbeddingDimension(Node):
|
|
|
276
202
|
min_dist = DEFAULT_MIN_DIST if min_dist is None else min_dist
|
|
277
203
|
n_neighbors = DEFAULT_N_NEIGHBORS if n_neighbors is None else n_neighbors
|
|
278
204
|
|
|
279
|
-
vectors,
|
|
205
|
+
vectors, cluster_membership = PointCloud(
|
|
280
206
|
dimensionalityReducer=Umap(n_neighbors=n_neighbors, min_dist=min_dist),
|
|
281
207
|
clustersFinder=Hdbscan(),
|
|
282
208
|
).generate(data, n_components=n_components)
|
|
@@ -341,10 +267,12 @@ class EmbeddingDimension(Node):
|
|
|
341
267
|
)
|
|
342
268
|
)
|
|
343
269
|
|
|
270
|
+
has_reference_data = datasets[DatasetType.REFERENCE] is not None
|
|
271
|
+
|
|
344
272
|
return UMAPPoints(
|
|
345
273
|
data=points[DatasetType.PRIMARY],
|
|
346
274
|
reference_data=points[DatasetType.REFERENCE],
|
|
347
|
-
clusters=to_gql_clusters(
|
|
275
|
+
clusters=to_gql_clusters(cluster_membership, has_reference_data=has_reference_data),
|
|
348
276
|
)
|
|
349
277
|
|
|
350
278
|
|
|
@@ -1,9 +1,21 @@
|
|
|
1
|
-
from datetime import datetime
|
|
1
|
+
from datetime import datetime, timedelta
|
|
2
2
|
from functools import total_ordering
|
|
3
|
-
from typing import List, Optional
|
|
3
|
+
from typing import Iterable, List, Optional, Union, cast
|
|
4
4
|
|
|
5
|
+
import pandas as pd
|
|
5
6
|
import strawberry
|
|
6
7
|
|
|
8
|
+
from phoenix.core.model import Model
|
|
9
|
+
from phoenix.metrics import Metric
|
|
10
|
+
from phoenix.metrics.mixins import DriftOperator
|
|
11
|
+
from phoenix.metrics.timeseries import timeseries
|
|
12
|
+
from phoenix.server.api.input_types.Granularity import Granularity, to_timestamps
|
|
13
|
+
from phoenix.server.api.input_types.TimeRange import TimeRange
|
|
14
|
+
from phoenix.server.api.interceptor import NoneIfNan
|
|
15
|
+
from phoenix.server.api.types import METRICS
|
|
16
|
+
from phoenix.server.api.types.DataQualityMetric import DataQualityMetric
|
|
17
|
+
from phoenix.server.api.types.DriftMetric import DriftMetric
|
|
18
|
+
|
|
7
19
|
|
|
8
20
|
@strawberry.type
|
|
9
21
|
@total_ordering
|
|
@@ -14,14 +26,116 @@ class TimeSeriesDataPoint:
|
|
|
14
26
|
timestamp: datetime
|
|
15
27
|
|
|
16
28
|
"""The value of the data point"""
|
|
17
|
-
value: Optional[float]
|
|
29
|
+
value: Optional[float] = strawberry.field(default=NoneIfNan())
|
|
18
30
|
|
|
19
31
|
def __lt__(self, other: "TimeSeriesDataPoint") -> bool:
|
|
20
32
|
return self.timestamp < other.timestamp
|
|
21
33
|
|
|
22
34
|
|
|
35
|
+
def to_gql_datapoints(
|
|
36
|
+
df: pd.DataFrame, metric: Metric, timestamps: Iterable[datetime]
|
|
37
|
+
) -> List[TimeSeriesDataPoint]:
|
|
38
|
+
data = []
|
|
39
|
+
for timestamp in timestamps:
|
|
40
|
+
try:
|
|
41
|
+
row = df.iloc[cast(int, df.index.get_loc(timestamp)), :].to_dict()
|
|
42
|
+
except KeyError:
|
|
43
|
+
row = {}
|
|
44
|
+
data.append(
|
|
45
|
+
TimeSeriesDataPoint(
|
|
46
|
+
timestamp=timestamp,
|
|
47
|
+
value=metric.get_value(row),
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
return sorted(data)
|
|
51
|
+
|
|
52
|
+
|
|
23
53
|
@strawberry.interface
|
|
24
54
|
class TimeSeries:
|
|
25
55
|
"""A collection of data points over time"""
|
|
26
56
|
|
|
27
57
|
data: List[TimeSeriesDataPoint]
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
column_name: str,
|
|
62
|
+
model: Model,
|
|
63
|
+
metric: Union[DriftMetric, DataQualityMetric],
|
|
64
|
+
time_range: Optional[TimeRange] = None,
|
|
65
|
+
granularity: Optional[Granularity] = None,
|
|
66
|
+
):
|
|
67
|
+
if not (metric_cls := METRICS.get(metric.value, None)):
|
|
68
|
+
raise NotImplementedError(f"Metric {metric} is not implemented.")
|
|
69
|
+
dataset = model.primary_dataset
|
|
70
|
+
metric_instance = metric_cls(column_name=column_name)
|
|
71
|
+
if (
|
|
72
|
+
issubclass(metric_cls, DriftOperator)
|
|
73
|
+
and (ref_dataset := model.reference_dataset) is not None
|
|
74
|
+
):
|
|
75
|
+
metric_instance.reference_data = ref_dataset.dataframe
|
|
76
|
+
if time_range is None:
|
|
77
|
+
time_range = TimeRange(
|
|
78
|
+
start=dataset.start_time,
|
|
79
|
+
end=dataset.end_time,
|
|
80
|
+
)
|
|
81
|
+
if granularity is None:
|
|
82
|
+
total_minutes = int((time_range.end - time_range.start).total_seconds()) // 60
|
|
83
|
+
granularity = Granularity(
|
|
84
|
+
evaluation_window_minutes=total_minutes,
|
|
85
|
+
sampling_interval_minutes=total_minutes,
|
|
86
|
+
)
|
|
87
|
+
self.data = dataset.dataframe.pipe(
|
|
88
|
+
timeseries(
|
|
89
|
+
start_time=time_range.start,
|
|
90
|
+
end_time=time_range.end,
|
|
91
|
+
evaluation_window=timedelta(minutes=granularity.evaluation_window_minutes),
|
|
92
|
+
sampling_interval=timedelta(minutes=granularity.sampling_interval_minutes),
|
|
93
|
+
),
|
|
94
|
+
metrics=(metric_instance,),
|
|
95
|
+
).pipe(
|
|
96
|
+
to_gql_datapoints,
|
|
97
|
+
metric=metric_instance,
|
|
98
|
+
timestamps=to_timestamps(time_range, granularity),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@strawberry.type
|
|
103
|
+
class DataQualityTimeSeries(TimeSeries):
|
|
104
|
+
"""A time series of data quality metrics"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
column_name: str,
|
|
109
|
+
model: Model,
|
|
110
|
+
metric: DataQualityMetric,
|
|
111
|
+
time_range: Optional[TimeRange] = None,
|
|
112
|
+
granularity: Optional[Granularity] = None,
|
|
113
|
+
):
|
|
114
|
+
super().__init__(
|
|
115
|
+
column_name,
|
|
116
|
+
model,
|
|
117
|
+
metric,
|
|
118
|
+
time_range,
|
|
119
|
+
granularity,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@strawberry.type
|
|
124
|
+
class DriftTimeSeries(TimeSeries):
|
|
125
|
+
"""A time series of drift metrics"""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
column_name: str,
|
|
130
|
+
model: Model,
|
|
131
|
+
metric: DriftMetric,
|
|
132
|
+
time_range: Optional[TimeRange] = None,
|
|
133
|
+
granularity: Optional[Granularity] = None,
|
|
134
|
+
):
|
|
135
|
+
super().__init__(
|
|
136
|
+
column_name,
|
|
137
|
+
model,
|
|
138
|
+
metric,
|
|
139
|
+
time_range,
|
|
140
|
+
granularity,
|
|
141
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List, Union
|
|
1
|
+
from typing import Dict, List, Optional, Set, Union
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import numpy.typing as npt
|
|
@@ -6,11 +6,70 @@ import strawberry
|
|
|
6
6
|
from strawberry.scalars import ID
|
|
7
7
|
from typing_extensions import TypeAlias
|
|
8
8
|
|
|
9
|
+
from phoenix.core.embedding_dimension import calculate_drift_ratio
|
|
10
|
+
from phoenix.datasets.event import EventId
|
|
11
|
+
from phoenix.server.api.interceptor import NoneIfNan
|
|
12
|
+
|
|
9
13
|
from .EmbeddingMetadata import EmbeddingMetadata
|
|
10
14
|
from .EventMetadata import EventMetadata
|
|
11
15
|
|
|
12
|
-
EventId: TypeAlias = ID
|
|
13
16
|
ClusterId: TypeAlias = ID
|
|
17
|
+
PointId: TypeAlias = ID
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@strawberry.type
|
|
21
|
+
class Cluster:
|
|
22
|
+
"""A grouping of points in a UMAP plot"""
|
|
23
|
+
|
|
24
|
+
"""The ID of the cluster"""
|
|
25
|
+
id: ClusterId
|
|
26
|
+
|
|
27
|
+
"""A list of points that belong to the cluster"""
|
|
28
|
+
point_ids: List[PointId]
|
|
29
|
+
|
|
30
|
+
"""A list of points that belong to the cluster"""
|
|
31
|
+
drift_ratio: Optional[float] = strawberry.field(
|
|
32
|
+
description="ratio of primary points over reference points",
|
|
33
|
+
default=NoneIfNan(),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_gql_clusters(
|
|
38
|
+
cluster_membership: Dict[EventId, int],
|
|
39
|
+
has_reference_data: bool,
|
|
40
|
+
) -> List[Cluster]:
|
|
41
|
+
"""
|
|
42
|
+
Converts a dictionary of event IDs to cluster IDs to a list of clusters for the graphQL response
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
cluster_membership: Dict[EventId, int]
|
|
47
|
+
A dictionary of event IDs to cluster IDs
|
|
48
|
+
has_reference_data: bool
|
|
49
|
+
Whether or not the model has reference data
|
|
50
|
+
Used to determine if drift ratio should be calculated
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
clusters: Dict[int, Set[EventId]] = {}
|
|
54
|
+
for event_id, cluster_id in cluster_membership.items():
|
|
55
|
+
if cluster_id in clusters:
|
|
56
|
+
clusters[cluster_id].add(event_id)
|
|
57
|
+
else:
|
|
58
|
+
clusters[cluster_id] = {event_id}
|
|
59
|
+
|
|
60
|
+
gql_clusters: List[Cluster] = []
|
|
61
|
+
for cluster_id, cluster_events in clusters.items():
|
|
62
|
+
gql_clusters.append(
|
|
63
|
+
Cluster(
|
|
64
|
+
id=ID(str(cluster_id)),
|
|
65
|
+
point_ids=[ID(str(event)) for event in cluster_events],
|
|
66
|
+
drift_ratio=calculate_drift_ratio(cluster_events)
|
|
67
|
+
if has_reference_data
|
|
68
|
+
else float("nan"),
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return gql_clusters
|
|
14
73
|
|
|
15
74
|
|
|
16
75
|
@strawberry.type
|
|
@@ -39,7 +98,7 @@ class UMAPPoint:
|
|
|
39
98
|
"""point and metadata for a UMAP plot"""
|
|
40
99
|
|
|
41
100
|
"""A unique ID for the the point"""
|
|
42
|
-
id:
|
|
101
|
+
id: PointId
|
|
43
102
|
|
|
44
103
|
"""The coordinates of the point. Can be two or three dimensional"""
|
|
45
104
|
coordinates: Union[Point2D, Point3D]
|
|
@@ -51,17 +110,6 @@ class UMAPPoint:
|
|
|
51
110
|
event_metadata: EventMetadata
|
|
52
111
|
|
|
53
112
|
|
|
54
|
-
@strawberry.type
|
|
55
|
-
class Cluster:
|
|
56
|
-
"""A grouping of points in a UMAP plot"""
|
|
57
|
-
|
|
58
|
-
"""The ID of the cluster"""
|
|
59
|
-
id: ClusterId
|
|
60
|
-
|
|
61
|
-
"""A list of points that belong to the cluster"""
|
|
62
|
-
point_ids: List[EventId]
|
|
63
|
-
|
|
64
|
-
|
|
65
113
|
@strawberry.type
|
|
66
114
|
class UMAPPoints:
|
|
67
115
|
data: List[UMAPPoint]
|
phoenix/server/main.py
CHANGED
|
@@ -8,12 +8,12 @@ from typing import Optional
|
|
|
8
8
|
import uvicorn
|
|
9
9
|
|
|
10
10
|
import phoenix.config as config
|
|
11
|
-
from phoenix.
|
|
12
|
-
from phoenix.server.fixtures import (
|
|
11
|
+
from phoenix.datasets.fixtures import (
|
|
13
12
|
FIXTURES,
|
|
14
13
|
download_fixture_if_missing,
|
|
15
14
|
get_dataset_names_from_fixture_name,
|
|
16
15
|
)
|
|
16
|
+
from phoenix.server.app import create_app
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
@@ -66,7 +66,7 @@ if __name__ == "__main__":
|
|
|
66
66
|
primary_dataset_name, reference_dataset_name = get_dataset_names_from_fixture_name(
|
|
67
67
|
fixture_name
|
|
68
68
|
)
|
|
69
|
-
print(f'🌎
|
|
69
|
+
print(f'🌎 Initializing fixture: "{fixture_name}"')
|
|
70
70
|
download_fixture_if_missing(fixture_name)
|
|
71
71
|
|
|
72
72
|
print(f"1️⃣ primary dataset: {primary_dataset_name}")
|