arize 8.0.0a9__py3-none-any.whl → 8.0.0a10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/models/client.py +8 -11
- arize/models/surrogate_explainer/__init__.py +0 -0
- arize/models/surrogate_explainer/mimic.py +164 -0
- arize/version.py +1 -1
- {arize-8.0.0a9.dist-info → arize-8.0.0a10.dist-info}/METADATA +3 -1
- {arize-8.0.0a9.dist-info → arize-8.0.0a10.dist-info}/RECORD +8 -6
- {arize-8.0.0a9.dist-info → arize-8.0.0a10.dist-info}/WHEEL +0 -0
- {arize-8.0.0a9.dist-info → arize-8.0.0a10.dist-info}/licenses/LICENSE.md +0 -0
arize/models/client.py
CHANGED
|
@@ -80,6 +80,11 @@ _BATCH_DEPS = (
|
|
|
80
80
|
"tqdm",
|
|
81
81
|
)
|
|
82
82
|
_BATCH_EXTRA = "ml-batch"
|
|
83
|
+
_MIMIC_DEPS = (
|
|
84
|
+
"interpret_community.mimic",
|
|
85
|
+
"sklearn.preprocessing",
|
|
86
|
+
)
|
|
87
|
+
_MIMIC_EXTRA = "mimic-explainer"
|
|
83
88
|
|
|
84
89
|
|
|
85
90
|
class MLModelsClient:
|
|
@@ -116,7 +121,6 @@ class MLModelsClient:
|
|
|
116
121
|
timeout: float | None = None,
|
|
117
122
|
) -> cf.Future:
|
|
118
123
|
require(_STREAM_EXTRA, _STREAM_DEPS)
|
|
119
|
-
|
|
120
124
|
from arize._generated.protocol.rec import public_pb2 as pb2
|
|
121
125
|
from arize.utils.proto import (
|
|
122
126
|
get_pb_dictionary,
|
|
@@ -545,17 +549,10 @@ class MLModelsClient:
|
|
|
545
549
|
dataframe = dataframe.astype(cat_str_map)
|
|
546
550
|
|
|
547
551
|
if surrogate_explainability:
|
|
548
|
-
|
|
552
|
+
require(_MIMIC_EXTRA, _MIMIC_DEPS)
|
|
553
|
+
from arize.models.surrogate_explainer.mimic import Mimic
|
|
549
554
|
|
|
550
|
-
|
|
551
|
-
# WARNING: MIMIC EXPLAINER IS NOT DONE
|
|
552
|
-
from arize.pandas.surrogate_explainer.mimic import Mimic
|
|
553
|
-
except ImportError:
|
|
554
|
-
raise ImportError(
|
|
555
|
-
"To enable surrogate explainability, "
|
|
556
|
-
"the arize module must be installed with the MimicExplainer option: pip "
|
|
557
|
-
"install 'arize[MimicExplainer]'."
|
|
558
|
-
) from None
|
|
555
|
+
logger.debug("Running surrogate_explainability.")
|
|
559
556
|
if schema.shap_values_column_names:
|
|
560
557
|
logger.info(
|
|
561
558
|
"surrogate_explainability=True has no effect "
|
|
File without changes
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import string
|
|
5
|
+
from dataclasses import replace
|
|
6
|
+
from typing import TYPE_CHECKING, Callable, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from interpret_community.mimic.mimic_explainer import (
|
|
11
|
+
LGBMExplainableModel,
|
|
12
|
+
MimicExplainer,
|
|
13
|
+
)
|
|
14
|
+
from sklearn.preprocessing import LabelEncoder
|
|
15
|
+
|
|
16
|
+
from arize.types import (
|
|
17
|
+
CATEGORICAL_MODEL_TYPES,
|
|
18
|
+
NUMERIC_MODEL_TYPES,
|
|
19
|
+
ModelTypes,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from arize.types import Schema
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Mimic:
|
|
27
|
+
_testing = False
|
|
28
|
+
|
|
29
|
+
def __init__(self, X: pd.DataFrame, model_func: Callable):
|
|
30
|
+
self.explainer = MimicExplainer(
|
|
31
|
+
model_func,
|
|
32
|
+
X,
|
|
33
|
+
LGBMExplainableModel,
|
|
34
|
+
augment_data=False,
|
|
35
|
+
is_function=True,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def explain(self, X: pd.DataFrame) -> pd.DataFrame:
|
|
39
|
+
return pd.DataFrame(
|
|
40
|
+
self.explainer.explain_local(X).local_importance_values,
|
|
41
|
+
columns=X.columns,
|
|
42
|
+
index=X.index,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def augment(
|
|
47
|
+
df: pd.DataFrame, schema: Schema, model_type: ModelTypes
|
|
48
|
+
) -> Tuple[pd.DataFrame, Schema]:
|
|
49
|
+
features = schema.feature_column_names
|
|
50
|
+
X = df[features]
|
|
51
|
+
|
|
52
|
+
if X.shape[1] == 0:
|
|
53
|
+
return df, schema
|
|
54
|
+
|
|
55
|
+
if model_type in CATEGORICAL_MODEL_TYPES:
|
|
56
|
+
if not schema.prediction_score_column_name:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"To calculate surrogate explainability, "
|
|
59
|
+
f"prediction_score_column_name must be specified in schema for {model_type}."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
y_col_name = schema.prediction_score_column_name
|
|
63
|
+
y = df[y_col_name].to_numpy()
|
|
64
|
+
|
|
65
|
+
_min, _max = np.min(y), np.max(y)
|
|
66
|
+
if not 0 <= _min <= 1 or not 0 <= _max <= 1:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"To calculate surrogate explainability for {model_type}, "
|
|
69
|
+
f"prediction scores must be between 0 and 1, but current "
|
|
70
|
+
f"prediction scores range from {_min} to {_max}."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# model func requires 1 positional argument
|
|
74
|
+
def model_func(_): # type: ignore
|
|
75
|
+
return np.column_stack((1 - y, y))
|
|
76
|
+
|
|
77
|
+
elif model_type in NUMERIC_MODEL_TYPES:
|
|
78
|
+
y_col_name = schema.prediction_label_column_name
|
|
79
|
+
if schema.prediction_score_column_name is not None:
|
|
80
|
+
y_col_name = schema.prediction_score_column_name
|
|
81
|
+
y = df[y_col_name].to_numpy()
|
|
82
|
+
|
|
83
|
+
_finite_count = np.isfinite(y).sum()
|
|
84
|
+
if len(y) - _finite_count:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"To calculate surrogate explainability for {model_type}, "
|
|
87
|
+
f"predictions must not contain NaN or infinite values, but "
|
|
88
|
+
f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# model func requires 1 positional argument
|
|
92
|
+
def model_func(_): # type: ignore
|
|
93
|
+
return y
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Surrogate explainability is not supported for the specified "
|
|
98
|
+
f"model type {model_type}."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Column name mapping between features and feature importance values.
|
|
102
|
+
# This is used to augment the schema.
|
|
103
|
+
col_map = {
|
|
104
|
+
ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
|
|
105
|
+
for ft in features
|
|
106
|
+
}
|
|
107
|
+
aug_schema = replace(schema, shap_values_column_names=col_map)
|
|
108
|
+
|
|
109
|
+
# Limit the total number of "cells" to 20M, unless it results in too few or
|
|
110
|
+
# too many rows. This is done to keep the runtime low. Records not sampled
|
|
111
|
+
# have feature importance values set to 0.
|
|
112
|
+
samp_size = min(
|
|
113
|
+
len(X), min(100_000, max(1_000, 20_000_000 // X.shape[1]))
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if samp_size < len(X):
|
|
117
|
+
_mask = np.zeros(len(X), dtype=int)
|
|
118
|
+
_mask[:samp_size] = 1
|
|
119
|
+
np.random.shuffle(_mask)
|
|
120
|
+
_mask = _mask.astype(bool)
|
|
121
|
+
X = X[_mask]
|
|
122
|
+
y = y[_mask]
|
|
123
|
+
|
|
124
|
+
# Replace all pd.NA values with np.nan values
|
|
125
|
+
for col in X.columns:
|
|
126
|
+
if X[col].isna().any():
|
|
127
|
+
X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
|
|
128
|
+
|
|
129
|
+
# Apply integer encoding to non-numeric columns.
|
|
130
|
+
# Currently training and explaining detasets are the same, but
|
|
131
|
+
# this can be changed in the future. The student model can be
|
|
132
|
+
# fitted on a much larger dataset since it takes a lot less time.
|
|
133
|
+
X = pd.concat(
|
|
134
|
+
[
|
|
135
|
+
X.select_dtypes(exclude=[object, "string"]),
|
|
136
|
+
pd.DataFrame(
|
|
137
|
+
{
|
|
138
|
+
name: LabelEncoder().fit_transform(data)
|
|
139
|
+
for name, data in X.select_dtypes(
|
|
140
|
+
include=[object, "string"]
|
|
141
|
+
).items()
|
|
142
|
+
},
|
|
143
|
+
index=X.index,
|
|
144
|
+
),
|
|
145
|
+
],
|
|
146
|
+
axis=1,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
aug_df = pd.concat(
|
|
150
|
+
[
|
|
151
|
+
df,
|
|
152
|
+
Mimic(X, model_func).explain(X).rename(col_map, axis=1),
|
|
153
|
+
],
|
|
154
|
+
axis=1,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Fill null with zero so they're not counted as missing records by server
|
|
158
|
+
if not Mimic._testing:
|
|
159
|
+
aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True)
|
|
160
|
+
|
|
161
|
+
return (
|
|
162
|
+
aug_df,
|
|
163
|
+
aug_schema,
|
|
164
|
+
)
|
arize/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "8.0.
|
|
1
|
+
__version__ = "8.0.0a10"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: arize
|
|
3
|
-
Version: 8.0.
|
|
3
|
+
Version: 8.0.0a10
|
|
4
4
|
Summary: A helper library to interact with Arize AI APIs
|
|
5
5
|
Project-URL: Homepage, https://arize.com
|
|
6
6
|
Project-URL: Documentation, https://docs.arize.com/arize
|
|
@@ -30,6 +30,8 @@ Requires-Dist: numpy>=2.0.0
|
|
|
30
30
|
Provides-Extra: dev
|
|
31
31
|
Requires-Dist: pytest==8.4.2; extra == 'dev'
|
|
32
32
|
Requires-Dist: ruff==0.13.2; extra == 'dev'
|
|
33
|
+
Provides-Extra: mimic-explainer
|
|
34
|
+
Requires-Dist: interpret-community[mimic]<1,>=0.22.0; extra == 'mimic-explainer'
|
|
33
35
|
Provides-Extra: ml-batch
|
|
34
36
|
Requires-Dist: pandas<3,>=1.0.0; extra == 'ml-batch'
|
|
35
37
|
Requires-Dist: protobuf<6,>=4.21.0; extra == 'ml-batch'
|
|
@@ -4,7 +4,7 @@ arize/client.py,sha256=0LtZU3WeEatGd1QgQsMrJOuI-tFmzM3y1AfO74BLJys,5716
|
|
|
4
4
|
arize/config.py,sha256=iynVEZhrOPdTNJTQ_KQmwKOPiwL0LfEP8AUIDYW86Xw,5801
|
|
5
5
|
arize/logging.py,sha256=2vwdta2-kR78GeBFGK2vpk51rQ2d06HoKzuARI9qFQk,7317
|
|
6
6
|
arize/types.py,sha256=z1yg5-brmTD4kVHDmmTVkYke53JpusXXeOOpdQw7rYg,69508
|
|
7
|
-
arize/version.py,sha256=
|
|
7
|
+
arize/version.py,sha256=Wv8B6KxzS2ThGtkzs_13OkvwSugf5HITHYMQsGk1gjg,25
|
|
8
8
|
arize/_exporter/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
arize/_exporter/client.py,sha256=eAxJX1sUfdpLrtaQ0ynMTd5jI37JOp9fbl3NWp4WFEA,15216
|
|
10
10
|
arize/_exporter/validation.py,sha256=6ROu5p7uaolxQ93lO_Eiwv9NVw_uyi3E5T--C5Klo5Q,1021
|
|
@@ -71,11 +71,13 @@ arize/experiments/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
|
|
|
71
71
|
arize/experiments/client.py,sha256=2fDq0fr_h6Knn_9zgDAlAhSUCKUrKozGLOQRTInCr4c,344
|
|
72
72
|
arize/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
73
73
|
arize/models/bounded_executor.py,sha256=o-PJsDAXQdiJ9dc-jzGCHMhT0-QBY9bvl4Ckn1017Eo,1131
|
|
74
|
-
arize/models/client.py,sha256=
|
|
74
|
+
arize/models/client.py,sha256=ZHxGYmCKP5ZX001qVNQc96QoclP4jvYVkLW11Xfqo2M,31199
|
|
75
75
|
arize/models/stream_validation.py,sha256=PtmqWgRdCxVtTNkHHEHIM1S6ECbYLA1vuQQFBw_t3Lw,7118
|
|
76
76
|
arize/models/batch_validation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
77
|
arize/models/batch_validation/errors.py,sha256=__I8l25zf4kGv6qgiwEm9LzGNgqmMSM8Fb88pBtyMxE,39990
|
|
78
78
|
arize/models/batch_validation/validator.py,sha256=acnGcMt-pETmPJUfYj5tIzIBvmBhWoXoWmDYi_Gkq6Y,146910
|
|
79
|
+
arize/models/surrogate_explainer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
80
|
+
arize/models/surrogate_explainer/mimic.py,sha256=MsMfhU9IhQJWm0kK6jpFkcTW6kw5IGJE3Kv94oOzMo0,5517
|
|
79
81
|
arize/spans/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
80
82
|
arize/spans/client.py,sha256=5yODUaSqxH-dLAenjRZBKbpsK7XgewZKwJpXzHWPNf0,47248
|
|
81
83
|
arize/spans/columns.py,sha256=BbB11jF4YHYfjrKbSd1r3K2F0AGA8KULTj1W3e2rwhM,12912
|
|
@@ -107,7 +109,7 @@ arize/utils/arrow.py,sha256=4In1gQc0i4Rb8zuwI0w-Hv-10wiItu5opqqGrJ8tSzo,5277
|
|
|
107
109
|
arize/utils/casting.py,sha256=KUrPUQN6qJEVe39nxbr0T-0GjAJLHjf4xWuzV71QezI,12468
|
|
108
110
|
arize/utils/dataframe.py,sha256=I0FloPgNiqlKga32tMOvTE70598QA8Hhrgf-6zjYMAM,1120
|
|
109
111
|
arize/utils/proto.py,sha256=9vLo53INYjdF78ffjm3E48jFwK6LbPD2FfKei7VaDy8,35477
|
|
110
|
-
arize-8.0.
|
|
111
|
-
arize-8.0.
|
|
112
|
-
arize-8.0.
|
|
113
|
-
arize-8.0.
|
|
112
|
+
arize-8.0.0a10.dist-info/METADATA,sha256=9u9UPm9jOeZp9pxLo9R5mDYvrACrOzbPET51mNyyXQU,12567
|
|
113
|
+
arize-8.0.0a10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
114
|
+
arize-8.0.0a10.dist-info/licenses/LICENSE.md,sha256=8vLN8Gms62NCBorxIv9MUvuK7myueb6_-dhXHPmm4H0,1479
|
|
115
|
+
arize-8.0.0a10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|