arize-phoenix 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arize-phoenix might be problematic. Click here for more details.

@@ -2,17 +2,21 @@ from typing import List, Optional
2
2
 
3
3
  import strawberry
4
4
  from strawberry.types import Info
5
+ from typing_extensions import Annotated
5
6
 
6
- from phoenix.core.model_schema import REFERENCE, ScalarDimension
7
+ from phoenix.core.model_schema import PRIMARY, REFERENCE, ScalarDimension
7
8
 
8
9
  from ..context import Context
9
10
  from ..input_types.Granularity import Granularity
10
11
  from ..input_types.TimeRange import TimeRange
11
12
  from .DataQualityMetric import DataQualityMetric
13
+ from .DatasetRole import DatasetRole
12
14
  from .DimensionDataType import DimensionDataType
15
+ from .DimensionShape import DimensionShape
13
16
  from .DimensionType import DimensionType
14
17
  from .node import Node
15
18
  from .ScalarDriftMetricEnum import ScalarDriftMetric
19
+ from .Segments import DatasetValues, Segments
16
20
  from .TimeSeries import (
17
21
  DataQualityTimeSeries,
18
22
  DriftTimeSeries,
@@ -31,6 +35,9 @@ class Dimension(Node):
31
35
  dataType: DimensionDataType = strawberry.field(
32
36
  description="The data type of the column. Categorical or numeric."
33
37
  )
38
+ shape: DimensionShape = strawberry.field(
39
+ description="Whether the dimension data is continuous or discrete."
40
+ )
34
41
  dimension: strawberry.Private[ScalarDimension]
35
42
 
36
43
  @strawberry.field
@@ -50,8 +57,17 @@ class Dimension(Node):
50
57
  model = info.context.model
51
58
  if model[REFERENCE].empty:
52
59
  return None
53
- time_range, granularity = ensure_timeseries_parameters(model, time_range)
54
- data = get_drift_timeseries_data(self.dimension, metric, time_range, granularity)
60
+ dataset = model[PRIMARY]
61
+ time_range, granularity = ensure_timeseries_parameters(
62
+ dataset,
63
+ time_range,
64
+ )
65
+ data = get_drift_timeseries_data(
66
+ self.dimension,
67
+ metric,
68
+ time_range,
69
+ granularity,
70
+ )
55
71
  return data[0].value if len(data) else None
56
72
 
57
73
  @strawberry.field
@@ -60,9 +76,27 @@ class Dimension(Node):
60
76
  info: Info[Context, None],
61
77
  metric: DataQualityMetric,
62
78
  time_range: Optional[TimeRange] = None,
79
+ dataset_role: Annotated[
80
+ Optional[DatasetRole],
81
+ strawberry.argument(
82
+ description="The dataset (primary or reference) to query",
83
+ ),
84
+ ] = DatasetRole.primary,
63
85
  ) -> Optional[float]:
64
- time_range, granularity = ensure_timeseries_parameters(info.context.model, time_range)
65
- data = get_data_quality_timeseries_data(self.dimension, metric, time_range, granularity)
86
+ if dataset_role is None:
87
+ dataset_role = DatasetRole.primary
88
+ dataset = info.context.model[dataset_role.value]
89
+ time_range, granularity = ensure_timeseries_parameters(
90
+ dataset,
91
+ time_range,
92
+ )
93
+ data = get_data_quality_timeseries_data(
94
+ self.dimension,
95
+ metric,
96
+ time_range,
97
+ granularity,
98
+ dataset_role,
99
+ )
66
100
  return data[0].value if len(data) else None
67
101
 
68
102
  @strawberry.field(
@@ -89,17 +123,34 @@ class Dimension(Node):
89
123
  metric: DataQualityMetric,
90
124
  time_range: TimeRange,
91
125
  granularity: Granularity,
126
+ dataset_role: Annotated[
127
+ Optional[DatasetRole],
128
+ strawberry.argument(
129
+ description="The dataset (primary or reference) to query",
130
+ ),
131
+ ] = DatasetRole.primary,
92
132
  ) -> DataQualityTimeSeries:
133
+ if dataset_role is None:
134
+ dataset_role = DatasetRole.primary
135
+ dataset = info.context.model[dataset_role.value]
93
136
  time_range, granularity = ensure_timeseries_parameters(
94
- info.context.model, time_range, granularity
137
+ dataset,
138
+ time_range,
139
+ granularity,
95
140
  )
96
141
  return DataQualityTimeSeries(
97
- data=get_data_quality_timeseries_data(self.dimension, metric, time_range, granularity)
142
+ data=get_data_quality_timeseries_data(
143
+ self.dimension,
144
+ metric,
145
+ time_range,
146
+ granularity,
147
+ dataset_role,
148
+ )
98
149
  )
99
150
 
100
151
  @strawberry.field(
101
152
  description=(
102
- "Returns the time series of the specified metric for data within a time range. Data"
153
+ "The time series of the specified metric for data within a time range. Data"
103
154
  " points are generated starting at the end time and are separated by the sampling"
104
155
  " interval. Each data point is labeled by the end instant and contains data from their"
105
156
  " respective evaluation windows."
@@ -115,11 +166,32 @@ class Dimension(Node):
115
166
  model = info.context.model
116
167
  if model[REFERENCE].empty:
117
168
  return DriftTimeSeries(data=[])
118
- time_range, granularity = ensure_timeseries_parameters(model, time_range, granularity)
169
+ dataset = model[PRIMARY]
170
+ time_range, granularity = ensure_timeseries_parameters(
171
+ dataset,
172
+ time_range,
173
+ granularity,
174
+ )
119
175
  return DriftTimeSeries(
120
- data=get_drift_timeseries_data(self.dimension, metric, time_range, granularity)
176
+ data=get_drift_timeseries_data(
177
+ self.dimension,
178
+ metric,
179
+ time_range,
180
+ granularity,
181
+ )
121
182
  )
122
183
 
184
+ @strawberry.field(
185
+ description="Returns the segments across both datasets and returns the counts per segment",
186
+ ) # type: ignore
187
+ def segments_comparison(
188
+ self,
189
+ primary_time_range: Optional[TimeRange] = strawberry.UNSET,
190
+ ) -> Segments:
191
+ # TODO: Implement binning across primary and reference
192
+
193
+ return Segments(segments=[], total_counts=DatasetValues(primary_value=0, reference_value=0))
194
+
123
195
 
124
196
  def to_gql_dimension(id_attr: int, dimension: ScalarDimension) -> Dimension:
125
197
  """
@@ -128,7 +200,8 @@ def to_gql_dimension(id_attr: int, dimension: ScalarDimension) -> Dimension:
128
200
  return Dimension(
129
201
  id_attr=id_attr,
130
202
  name=dimension.name,
131
- type=DimensionType.from_(dimension),
132
- dataType=DimensionDataType.from_(dimension),
203
+ type=DimensionType.from_dimension(dimension),
204
+ dataType=DimensionDataType.from_dimension(dimension),
133
205
  dimension=dimension,
206
+ shape=DimensionShape.from_dimension(dimension),
134
207
  )
@@ -11,8 +11,8 @@ class DimensionDataType(Enum):
11
11
  numeric = "numeric"
12
12
 
13
13
  @classmethod
14
- def from_(cls, dim: Dimension) -> "DimensionDataType":
15
- data_type = dim.data_type
14
+ def from_dimension(cls, dimension: Dimension) -> "DimensionDataType":
15
+ data_type = dimension.data_type
16
16
  if data_type in (CONTINUOUS,):
17
17
  return cls.numeric
18
18
  return cls.categorical
@@ -0,0 +1,21 @@
1
+ from enum import Enum
2
+
3
+ import strawberry
4
+
5
+ from phoenix.core.model_schema import CONTINUOUS, Dimension
6
+
7
+
8
+ @strawberry.enum
9
+ class DimensionShape(Enum):
10
+ continuous = "continuous"
11
+ discrete = "discrete"
12
+
13
+ @classmethod
14
+ def from_dimension(cls, dim: Dimension) -> "DimensionShape":
15
+ data_type = dim.data_type
16
+ if data_type in (CONTINUOUS,):
17
+ return cls.continuous
18
+
19
+ # For now we assume all non-continuous data is discrete
20
+ # E.g. floats are the only dimension data type that is continuous
21
+ return cls.discrete
@@ -21,7 +21,7 @@ class DimensionType(Enum):
21
21
  actual = "actual"
22
22
 
23
23
  @classmethod
24
- def from_(cls, dim: Dimension) -> "DimensionType":
24
+ def from_dimension(cls, dim: Dimension) -> "DimensionType":
25
25
  role = dim.role
26
26
  if role in (FEATURE,):
27
27
  return cls.feature
@@ -20,7 +20,6 @@ from phoenix.core.model_schema import (
20
20
  PRIMARY,
21
21
  REFERENCE,
22
22
  Dataset,
23
- DatasetRole,
24
23
  EventId,
25
24
  )
26
25
  from phoenix.metrics.timeseries import row_interval_from_sorted_time_index
@@ -29,6 +28,7 @@ from phoenix.pointcloud.pointcloud import PointCloud
29
28
  from phoenix.pointcloud.projectors import Umap
30
29
  from phoenix.server.api.context import Context
31
30
  from phoenix.server.api.input_types.TimeRange import TimeRange
31
+ from phoenix.server.api.types.DatasetRole import DatasetRole
32
32
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
33
33
 
34
34
  from ..input_types.Granularity import Granularity
@@ -83,8 +83,17 @@ class EmbeddingDimension(Node):
83
83
  model = info.context.model
84
84
  if model[REFERENCE].empty:
85
85
  return None
86
- time_range, granularity = ensure_timeseries_parameters(model, time_range)
87
- data = get_drift_timeseries_data(self.dimension, metric, time_range, granularity)
86
+ dataset = model[PRIMARY]
87
+ time_range, granularity = ensure_timeseries_parameters(
88
+ dataset,
89
+ time_range,
90
+ )
91
+ data = get_drift_timeseries_data(
92
+ self.dimension,
93
+ metric,
94
+ time_range,
95
+ granularity,
96
+ )
88
97
  return data[0].value if len(data) else None
89
98
 
90
99
  @strawberry.field(
@@ -101,12 +110,29 @@ class EmbeddingDimension(Node):
101
110
  metric: DataQualityMetric,
102
111
  time_range: TimeRange,
103
112
  granularity: Granularity,
113
+ dataset_role: Annotated[
114
+ Optional[DatasetRole],
115
+ strawberry.argument(
116
+ description="The dataset (primary or reference) to query",
117
+ ),
118
+ ] = DatasetRole.primary,
104
119
  ) -> DataQualityTimeSeries:
120
+ if dataset_role is None:
121
+ dataset_role = DatasetRole.primary
122
+ dataset = info.context.model[dataset_role.value]
105
123
  time_range, granularity = ensure_timeseries_parameters(
106
- info.context.model, time_range, granularity
124
+ dataset,
125
+ time_range,
126
+ granularity,
107
127
  )
108
128
  return DataQualityTimeSeries(
109
- data=get_data_quality_timeseries_data(self.dimension, metric, time_range, granularity)
129
+ data=get_data_quality_timeseries_data(
130
+ self.dimension,
131
+ metric,
132
+ time_range,
133
+ granularity,
134
+ dataset_role,
135
+ )
110
136
  )
111
137
 
112
138
  @strawberry.field(
@@ -129,9 +155,19 @@ class EmbeddingDimension(Node):
129
155
  model = info.context.model
130
156
  if model[REFERENCE].empty:
131
157
  return DriftTimeSeries(data=[])
132
- time_range, granularity = ensure_timeseries_parameters(model, time_range, granularity)
158
+ dataset = model[PRIMARY]
159
+ time_range, granularity = ensure_timeseries_parameters(
160
+ dataset,
161
+ time_range,
162
+ granularity,
163
+ )
133
164
  return DriftTimeSeries(
134
- data=get_drift_timeseries_data(self.dimension, metric, time_range, granularity)
165
+ data=get_drift_timeseries_data(
166
+ self.dimension,
167
+ metric,
168
+ time_range,
169
+ granularity,
170
+ )
135
171
  )
136
172
 
137
173
  @strawberry.field
@@ -226,7 +262,7 @@ class EmbeddingDimension(Node):
226
262
  ),
227
263
  ).generate(data, n_components=n_components)
228
264
 
229
- points: Dict[DatasetRole, List[UMAPPoint]] = defaultdict(list)
265
+ points: Dict[ms.DatasetRole, List[UMAPPoint]] = defaultdict(list)
230
266
  for event_id, vector in vectors.items():
231
267
  row_id = event_id.row_id
232
268
  dataset_id = event_id.dataset_id
@@ -0,0 +1,10 @@
1
+ import strawberry
2
+
3
+
4
+ @strawberry.type
5
+ class NumericRange:
6
+ """A numeric range to denote a bin or domain"""
7
+
8
+ start: float
9
+ end: float
10
+ # TODO consider denoting right open or closed
@@ -0,0 +1,44 @@
1
+ from typing import List, Optional
2
+
3
+ import strawberry
4
+
5
+ from .NumericRange import NumericRange
6
+
7
+
8
+ @strawberry.type
9
+ class NominalBin:
10
+ """A bin that contains a discrete value"""
11
+
12
+ name: str
13
+
14
+
15
+ @strawberry.type
16
+ class IntervalBin:
17
+ """A bin that contains a discrete value"""
18
+
19
+ # TODO figure out the empty case
20
+ range: NumericRange
21
+
22
+
23
+ @strawberry.type
24
+ class DatasetValues:
25
+ """Numeric values per dataset role"""
26
+
27
+ primary_value: Optional[float]
28
+ reference_value: Optional[float]
29
+
30
+
31
+ @strawberry.type
32
+ class Segment:
33
+ """A segment of the parent's data, split out using a heuristic"""
34
+
35
+ bin: strawberry.union("Bin", types=(NominalBin, IntervalBin)) # type: ignore
36
+ counts: DatasetValues
37
+ # TODO add support for a "z" metric list
38
+ # values: List[Optional[float]]
39
+
40
+
41
+ @strawberry.type
42
+ class Segments:
43
+ segments: List[Segment]
44
+ total_counts: DatasetValues
@@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Tuple, Union, cast
5
5
  import pandas as pd
6
6
  import strawberry
7
7
 
8
- from phoenix.core.model_schema import CONTINUOUS, PRIMARY, REFERENCE, Dimension, Model
8
+ from phoenix.core.model_schema import CONTINUOUS, REFERENCE, Dataset, Dimension
9
9
  from phoenix.metrics import Metric, binning
10
10
  from phoenix.metrics.mixins import DriftOperator
11
11
  from phoenix.metrics.timeseries import timeseries
@@ -14,6 +14,7 @@ from phoenix.server.api.input_types.TimeRange import TimeRange
14
14
  from phoenix.server.api.interceptor import NoneIfNan
15
15
  from phoenix.server.api.types import METRICS
16
16
  from phoenix.server.api.types.DataQualityMetric import DataQualityMetric
17
+ from phoenix.server.api.types.DatasetRole import DatasetRole
17
18
  from phoenix.server.api.types.ScalarDriftMetricEnum import ScalarDriftMetric
18
19
  from phoenix.server.api.types.VectorDriftMetricEnum import VectorDriftMetric
19
20
 
@@ -63,10 +64,11 @@ def _get_timeseries_data(
63
64
  metric: Union[ScalarDriftMetric, VectorDriftMetric, DataQualityMetric],
64
65
  time_range: TimeRange,
65
66
  granularity: Granularity,
67
+ dataset_role: DatasetRole,
66
68
  ) -> List[TimeSeriesDataPoint]:
67
69
  if not (metric_cls := METRICS.get(metric.value, None)):
68
70
  raise NotImplementedError(f"Metric {metric} is not implemented.")
69
- data = dimension[PRIMARY]
71
+ data = dimension[dataset_role.value]
70
72
  metric_instance = metric_cls(operand_column_name=dimension.name)
71
73
  if issubclass(metric_cls, DriftOperator):
72
74
  ref_data = dimension[REFERENCE]
@@ -105,8 +107,15 @@ def get_data_quality_timeseries_data(
105
107
  metric: DataQualityMetric,
106
108
  time_range: TimeRange,
107
109
  granularity: Granularity,
110
+ dataset_role: DatasetRole,
108
111
  ) -> List[TimeSeriesDataPoint]:
109
- return _get_timeseries_data(dimension, metric, time_range, granularity)
112
+ return _get_timeseries_data(
113
+ dimension,
114
+ metric,
115
+ time_range,
116
+ granularity,
117
+ dataset_role,
118
+ )
110
119
 
111
120
 
112
121
  @strawberry.type
@@ -120,16 +129,22 @@ def get_drift_timeseries_data(
120
129
  time_range: TimeRange,
121
130
  granularity: Granularity,
122
131
  ) -> List[TimeSeriesDataPoint]:
123
- return _get_timeseries_data(dimension, metric, time_range, granularity)
132
+ return _get_timeseries_data(
133
+ dimension,
134
+ metric,
135
+ time_range,
136
+ granularity,
137
+ DatasetRole.primary,
138
+ )
124
139
 
125
140
 
126
141
  def ensure_timeseries_parameters(
127
- model: Model,
142
+ dataset: Dataset,
128
143
  time_range: Optional[TimeRange] = None,
129
144
  granularity: Optional[Granularity] = None,
130
145
  ) -> Tuple[TimeRange, Granularity]:
131
146
  if time_range is None:
132
- start, end = model[PRIMARY].time_range
147
+ start, end = dataset.time_range
133
148
  time_range = TimeRange(start=start, end=end)
134
149
  if granularity is None:
135
150
  total_minutes = int((time_range.end - time_range.start).total_seconds()) // 60
phoenix/server/main.py CHANGED
@@ -11,7 +11,7 @@ import uvicorn
11
11
  import phoenix.config as config
12
12
  from phoenix.core.model_schema_adapter import create_model_from_datasets
13
13
  from phoenix.datasets import Dataset
14
- from phoenix.datasets.fixtures import FIXTURES, download_fixture_if_missing
14
+ from phoenix.datasets.fixtures import FIXTURES, get_datasets
15
15
  from phoenix.server.app import create_app
16
16
 
17
17
  logger = logging.getLogger(__name__)
@@ -48,6 +48,7 @@ if __name__ == "__main__":
48
48
  parser = ArgumentParser()
49
49
  parser.add_argument("--export_path")
50
50
  parser.add_argument("--port", type=int, default=config.PORT)
51
+ parser.add_argument("--no-internet", action="store_true")
51
52
  parser.add_argument("--debug", action="store_false") # TODO: Disable before public launch
52
53
  subparsers = parser.add_subparsers(dest="command", required=True)
53
54
  datasets_parser = subparsers.add_parser("datasets")
@@ -70,7 +71,10 @@ if __name__ == "__main__":
70
71
  else:
71
72
  fixture_name = args.fixture
72
73
  primary_only = args.primary_only
73
- primary_dataset, reference_dataset = download_fixture_if_missing(fixture_name)
74
+ primary_dataset, reference_dataset = get_datasets(
75
+ fixture_name,
76
+ args.no_internet,
77
+ )
74
78
  if primary_only:
75
79
  reference_dataset_name = None
76
80
  reference_dataset = None