arize 8.0.0a9__py3-none-any.whl → 8.0.0a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arize/models/client.py CHANGED
@@ -80,6 +80,11 @@ _BATCH_DEPS = (
80
80
  "tqdm",
81
81
  )
82
82
  _BATCH_EXTRA = "ml-batch"
83
+ _MIMIC_DEPS = (
84
+ "interpret_community.mimic",
85
+ "sklearn.preprocessing",
86
+ )
87
+ _MIMIC_EXTRA = "mimic-explainer"
83
88
 
84
89
 
85
90
  class MLModelsClient:
@@ -116,7 +121,6 @@ class MLModelsClient:
116
121
  timeout: float | None = None,
117
122
  ) -> cf.Future:
118
123
  require(_STREAM_EXTRA, _STREAM_DEPS)
119
-
120
124
  from arize._generated.protocol.rec import public_pb2 as pb2
121
125
  from arize.utils.proto import (
122
126
  get_pb_dictionary,
@@ -545,17 +549,10 @@ class MLModelsClient:
545
549
  dataframe = dataframe.astype(cat_str_map)
546
550
 
547
551
  if surrogate_explainability:
548
- logger.debug("Running surrogate_explainability.")
552
+ require(_MIMIC_EXTRA, _MIMIC_DEPS)
553
+ from arize.models.surrogate_explainer.mimic import Mimic
549
554
 
550
- try:
551
- # WARNING: MIMIC EXPLAINER IS NOT DONE
552
- from arize.pandas.surrogate_explainer.mimic import Mimic
553
- except ImportError:
554
- raise ImportError(
555
- "To enable surrogate explainability, "
556
- "the arize module must be installed with the MimicExplainer option: pip "
557
- "install 'arize[MimicExplainer]'."
558
- ) from None
555
+ logger.debug("Running surrogate_explainability.")
559
556
  if schema.shap_values_column_names:
560
557
  logger.info(
561
558
  "surrogate_explainability=True has no effect "
File without changes
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ import string
5
+ from dataclasses import replace
6
+ from typing import TYPE_CHECKING, Callable, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from interpret_community.mimic.mimic_explainer import (
11
+ LGBMExplainableModel,
12
+ MimicExplainer,
13
+ )
14
+ from sklearn.preprocessing import LabelEncoder
15
+
16
+ from arize.types import (
17
+ CATEGORICAL_MODEL_TYPES,
18
+ NUMERIC_MODEL_TYPES,
19
+ ModelTypes,
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from arize.types import Schema
24
+
25
+
26
+ class Mimic:
27
+ _testing = False
28
+
29
+ def __init__(self, X: pd.DataFrame, model_func: Callable):
30
+ self.explainer = MimicExplainer(
31
+ model_func,
32
+ X,
33
+ LGBMExplainableModel,
34
+ augment_data=False,
35
+ is_function=True,
36
+ )
37
+
38
+ def explain(self, X: pd.DataFrame) -> pd.DataFrame:
39
+ return pd.DataFrame(
40
+ self.explainer.explain_local(X).local_importance_values,
41
+ columns=X.columns,
42
+ index=X.index,
43
+ )
44
+
45
+ @staticmethod
46
+ def augment(
47
+ df: pd.DataFrame, schema: Schema, model_type: ModelTypes
48
+ ) -> Tuple[pd.DataFrame, Schema]:
49
+ features = schema.feature_column_names
50
+ X = df[features]
51
+
52
+ if X.shape[1] == 0:
53
+ return df, schema
54
+
55
+ if model_type in CATEGORICAL_MODEL_TYPES:
56
+ if not schema.prediction_score_column_name:
57
+ raise ValueError(
58
+ "To calculate surrogate explainability, "
59
+ f"prediction_score_column_name must be specified in schema for {model_type}."
60
+ )
61
+
62
+ y_col_name = schema.prediction_score_column_name
63
+ y = df[y_col_name].to_numpy()
64
+
65
+ _min, _max = np.min(y), np.max(y)
66
+ if not 0 <= _min <= 1 or not 0 <= _max <= 1:
67
+ raise ValueError(
68
+ f"To calculate surrogate explainability for {model_type}, "
69
+ f"prediction scores must be between 0 and 1, but current "
70
+ f"prediction scores range from {_min} to {_max}."
71
+ )
72
+
73
+ # model func requires 1 positional argument
74
+ def model_func(_): # type: ignore
75
+ return np.column_stack((1 - y, y))
76
+
77
+ elif model_type in NUMERIC_MODEL_TYPES:
78
+ y_col_name = schema.prediction_label_column_name
79
+ if schema.prediction_score_column_name is not None:
80
+ y_col_name = schema.prediction_score_column_name
81
+ y = df[y_col_name].to_numpy()
82
+
83
+ _finite_count = np.isfinite(y).sum()
84
+ if len(y) - _finite_count:
85
+ raise ValueError(
86
+ f"To calculate surrogate explainability for {model_type}, "
87
+ f"predictions must not contain NaN or infinite values, but "
88
+ f"{len(y) - _finite_count} NaN or infinite value(s) are found in {y_col_name}."
89
+ )
90
+
91
+ # model func requires 1 positional argument
92
+ def model_func(_): # type: ignore
93
+ return y
94
+
95
+ else:
96
+ raise ValueError(
97
+ "Surrogate explainability is not supported for the specified "
98
+ f"model type {model_type}."
99
+ )
100
+
101
+ # Column name mapping between features and feature importance values.
102
+ # This is used to augment the schema.
103
+ col_map = {
104
+ ft: f"{''.join(random.choices(string.ascii_letters, k=8))}"
105
+ for ft in features
106
+ }
107
+ aug_schema = replace(schema, shap_values_column_names=col_map)
108
+
109
+ # Limit the total number of "cells" to 20M, unless it results in too few or
110
+ # too many rows. This is done to keep the runtime low. Records not sampled
111
+ # have feature importance values set to 0.
112
+ samp_size = min(
113
+ len(X), min(100_000, max(1_000, 20_000_000 // X.shape[1]))
114
+ )
115
+
116
+ if samp_size < len(X):
117
+ _mask = np.zeros(len(X), dtype=int)
118
+ _mask[:samp_size] = 1
119
+ np.random.shuffle(_mask)
120
+ _mask = _mask.astype(bool)
121
+ X = X[_mask]
122
+ y = y[_mask]
123
+
124
+ # Replace all pd.NA values with np.nan values
125
+ for col in X.columns:
126
+ if X[col].isna().any():
127
+ X[col] = X[col].astype(object).where(~X[col].isna(), np.nan)
128
+
129
+ # Apply integer encoding to non-numeric columns.
130
+ # Currently training and explaining detasets are the same, but
131
+ # this can be changed in the future. The student model can be
132
+ # fitted on a much larger dataset since it takes a lot less time.
133
+ X = pd.concat(
134
+ [
135
+ X.select_dtypes(exclude=[object, "string"]),
136
+ pd.DataFrame(
137
+ {
138
+ name: LabelEncoder().fit_transform(data)
139
+ for name, data in X.select_dtypes(
140
+ include=[object, "string"]
141
+ ).items()
142
+ },
143
+ index=X.index,
144
+ ),
145
+ ],
146
+ axis=1,
147
+ )
148
+
149
+ aug_df = pd.concat(
150
+ [
151
+ df,
152
+ Mimic(X, model_func).explain(X).rename(col_map, axis=1),
153
+ ],
154
+ axis=1,
155
+ )
156
+
157
+ # Fill null with zero so they're not counted as missing records by server
158
+ if not Mimic._testing:
159
+ aug_df.fillna({c: 0 for c in col_map.values()}, inplace=True)
160
+
161
+ return (
162
+ aug_df,
163
+ aug_schema,
164
+ )
arize/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "8.0.0a9"
1
+ __version__ = "8.0.0a10"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: arize
3
- Version: 8.0.0a9
3
+ Version: 8.0.0a10
4
4
  Summary: A helper library to interact with Arize AI APIs
5
5
  Project-URL: Homepage, https://arize.com
6
6
  Project-URL: Documentation, https://docs.arize.com/arize
@@ -30,6 +30,8 @@ Requires-Dist: numpy>=2.0.0
30
30
  Provides-Extra: dev
31
31
  Requires-Dist: pytest==8.4.2; extra == 'dev'
32
32
  Requires-Dist: ruff==0.13.2; extra == 'dev'
33
+ Provides-Extra: mimic-explainer
34
+ Requires-Dist: interpret-community[mimic]<1,>=0.22.0; extra == 'mimic-explainer'
33
35
  Provides-Extra: ml-batch
34
36
  Requires-Dist: pandas<3,>=1.0.0; extra == 'ml-batch'
35
37
  Requires-Dist: protobuf<6,>=4.21.0; extra == 'ml-batch'
@@ -4,7 +4,7 @@ arize/client.py,sha256=0LtZU3WeEatGd1QgQsMrJOuI-tFmzM3y1AfO74BLJys,5716
4
4
  arize/config.py,sha256=iynVEZhrOPdTNJTQ_KQmwKOPiwL0LfEP8AUIDYW86Xw,5801
5
5
  arize/logging.py,sha256=2vwdta2-kR78GeBFGK2vpk51rQ2d06HoKzuARI9qFQk,7317
6
6
  arize/types.py,sha256=z1yg5-brmTD4kVHDmmTVkYke53JpusXXeOOpdQw7rYg,69508
7
- arize/version.py,sha256=BDT_fj8aJ4pvQMjTUrdbEd6vY71GkAguJScCzty97gA,24
7
+ arize/version.py,sha256=Wv8B6KxzS2ThGtkzs_13OkvwSugf5HITHYMQsGk1gjg,25
8
8
  arize/_exporter/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  arize/_exporter/client.py,sha256=eAxJX1sUfdpLrtaQ0ynMTd5jI37JOp9fbl3NWp4WFEA,15216
10
10
  arize/_exporter/validation.py,sha256=6ROu5p7uaolxQ93lO_Eiwv9NVw_uyi3E5T--C5Klo5Q,1021
@@ -71,11 +71,13 @@ arize/experiments/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
71
71
  arize/experiments/client.py,sha256=2fDq0fr_h6Knn_9zgDAlAhSUCKUrKozGLOQRTInCr4c,344
72
72
  arize/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
73
73
  arize/models/bounded_executor.py,sha256=o-PJsDAXQdiJ9dc-jzGCHMhT0-QBY9bvl4Ckn1017Eo,1131
74
- arize/models/client.py,sha256=aYXPv5Pq2Va2_aEEptw6-iD5zDEFV4UJz2bPnXvvIHw,31419
74
+ arize/models/client.py,sha256=ZHxGYmCKP5ZX001qVNQc96QoclP4jvYVkLW11Xfqo2M,31199
75
75
  arize/models/stream_validation.py,sha256=PtmqWgRdCxVtTNkHHEHIM1S6ECbYLA1vuQQFBw_t3Lw,7118
76
76
  arize/models/batch_validation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
77
  arize/models/batch_validation/errors.py,sha256=__I8l25zf4kGv6qgiwEm9LzGNgqmMSM8Fb88pBtyMxE,39990
78
78
  arize/models/batch_validation/validator.py,sha256=acnGcMt-pETmPJUfYj5tIzIBvmBhWoXoWmDYi_Gkq6Y,146910
79
+ arize/models/surrogate_explainer/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
+ arize/models/surrogate_explainer/mimic.py,sha256=MsMfhU9IhQJWm0kK6jpFkcTW6kw5IGJE3Kv94oOzMo0,5517
79
81
  arize/spans/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
80
82
  arize/spans/client.py,sha256=5yODUaSqxH-dLAenjRZBKbpsK7XgewZKwJpXzHWPNf0,47248
81
83
  arize/spans/columns.py,sha256=BbB11jF4YHYfjrKbSd1r3K2F0AGA8KULTj1W3e2rwhM,12912
@@ -107,7 +109,7 @@ arize/utils/arrow.py,sha256=4In1gQc0i4Rb8zuwI0w-Hv-10wiItu5opqqGrJ8tSzo,5277
107
109
  arize/utils/casting.py,sha256=KUrPUQN6qJEVe39nxbr0T-0GjAJLHjf4xWuzV71QezI,12468
108
110
  arize/utils/dataframe.py,sha256=I0FloPgNiqlKga32tMOvTE70598QA8Hhrgf-6zjYMAM,1120
109
111
  arize/utils/proto.py,sha256=9vLo53INYjdF78ffjm3E48jFwK6LbPD2FfKei7VaDy8,35477
110
- arize-8.0.0a9.dist-info/METADATA,sha256=WGsj9jvMjlNY9Xy4c9GpIT6mkJ1551ITCUYEhhQESck,12453
111
- arize-8.0.0a9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
112
- arize-8.0.0a9.dist-info/licenses/LICENSE.md,sha256=8vLN8Gms62NCBorxIv9MUvuK7myueb6_-dhXHPmm4H0,1479
113
- arize-8.0.0a9.dist-info/RECORD,,
112
+ arize-8.0.0a10.dist-info/METADATA,sha256=9u9UPm9jOeZp9pxLo9R5mDYvrACrOzbPET51mNyyXQU,12567
113
+ arize-8.0.0a10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
114
+ arize-8.0.0a10.dist-info/licenses/LICENSE.md,sha256=8vLN8Gms62NCBorxIv9MUvuK7myueb6_-dhXHPmm4H0,1479
115
+ arize-8.0.0a10.dist-info/RECORD,,