antakia 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- antakia/__init__.py +5 -0
- antakia/antakia.py +160 -0
- antakia/assets/logo_ai-vidence.png +0 -0
- antakia/assets/logo_antakia.png +0 -0
- antakia/assets/logo_antakia_horizontal.png +0 -0
- antakia/config.py +17 -0
- antakia/explanation/__init__.py +0 -0
- antakia/explanation/explanation_method.py +66 -0
- antakia/explanation/explanations.py +108 -0
- antakia/gui/__init__.py +2 -0
- antakia/gui/colorTable.py +52 -0
- antakia/gui/data_store.py +8 -0
- antakia/gui/explanation_values.py +216 -0
- antakia/gui/gui.py +930 -0
- antakia/gui/high_dim_exp/__init__.py +0 -0
- antakia/gui/high_dim_exp/figure_display.py +565 -0
- antakia/gui/high_dim_exp/highdimexplorer.py +140 -0
- antakia/gui/high_dim_exp/projected_value_bank.py +12 -0
- antakia/gui/high_dim_exp/projected_values_selector.py +314 -0
- antakia/gui/progress_bar.py +133 -0
- antakia/gui/ruleswidget.py +661 -0
- antakia/gui/tabs/model_explorer.py +95 -0
- antakia/gui/tabs/tab1.py +390 -0
- antakia/gui/widget_utils.py +44 -0
- antakia/gui/widgets.py +1363 -0
- antakia/utils/__init__.py +0 -0
- antakia/utils/checks.py +5 -0
- antakia/utils/dummy_datasets.py +133 -0
- antakia/utils/logging.py +57 -0
- antakia-0.2.1.dist-info/LICENSE +13 -0
- antakia-0.2.1.dist-info/METADATA +114 -0
- antakia-0.2.1.dist-info/RECORD +33 -0
- antakia-0.2.1.dist-info/WHEEL +4 -0
antakia/__init__.py
ADDED
antakia/antakia.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Dict, Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
|
|
10
|
+
from antakia_core.utils.utils import ProblemCategory
|
|
11
|
+
|
|
12
|
+
load_dotenv()
|
|
13
|
+
|
|
14
|
+
from antakia.utils.checks import is_valid_model
|
|
15
|
+
from antakia_core.utils.variable import Variable, DataVariables
|
|
16
|
+
from antakia.gui.gui import GUI
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AntakIA:
|
|
20
|
+
"""
|
|
21
|
+
AntakIA class.
|
|
22
|
+
|
|
23
|
+
Antakia instances provide data and methods to explain a ML model.
|
|
24
|
+
|
|
25
|
+
Instance attributes
|
|
26
|
+
-------------------
|
|
27
|
+
X : pd.DataFrame the training dataset
|
|
28
|
+
y : pd.Series the target value
|
|
29
|
+
model : Model
|
|
30
|
+
the model to explain
|
|
31
|
+
variables : a list of Variables, describing X_list[0]
|
|
32
|
+
X_test : pd.DataFrame the test dataset
|
|
33
|
+
y_test : pd.Series the test target value
|
|
34
|
+
score : reference scoring function
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
X: pd.DataFrame,
|
|
40
|
+
y: pd.Series,
|
|
41
|
+
model,
|
|
42
|
+
variables: DataVariables | List[Dict[str, Any]] | pd.DataFrame | None = None,
|
|
43
|
+
X_test: pd.DataFrame = None,
|
|
44
|
+
y_test: pd.Series = None,
|
|
45
|
+
X_exp: pd.DataFrame | None = None,
|
|
46
|
+
score: callable | str = 'auto',
|
|
47
|
+
problem_category: str = 'auto'
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
AntakiIA constructor.
|
|
51
|
+
|
|
52
|
+
Parameters:
|
|
53
|
+
X : pd.DataFrame the training dataset
|
|
54
|
+
y : pd.Series the target value
|
|
55
|
+
model : Model
|
|
56
|
+
the model to explain
|
|
57
|
+
variables : a list of Variables, describing X_list[0]
|
|
58
|
+
X_test : pd.DataFrame the test dataset
|
|
59
|
+
y_test : pd.Series the test target value
|
|
60
|
+
score : reference scoring function
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
load_dotenv()
|
|
64
|
+
|
|
65
|
+
if not is_valid_model(model):
|
|
66
|
+
raise ValueError(model, " should implement predict and score methods")
|
|
67
|
+
X, y, X_exp = self._preprocess_data(X, y, X_exp)
|
|
68
|
+
|
|
69
|
+
self.X = X
|
|
70
|
+
self.X_test = X_test
|
|
71
|
+
if y.ndim > 1:
|
|
72
|
+
y = y.squeeze()
|
|
73
|
+
self.y = y.astype(float)
|
|
74
|
+
if y_test is not None and y_test.ndim > 1:
|
|
75
|
+
y_test = y_test.squeeze()
|
|
76
|
+
self.y_test = y_test
|
|
77
|
+
self.model = model
|
|
78
|
+
self.X_exp = X_exp
|
|
79
|
+
self.problem_category = self._preprocess_problem_category(problem_category, model, X)
|
|
80
|
+
self.score = self._preprocess_score(score, self.problem_category)
|
|
81
|
+
|
|
82
|
+
self.set_variables(X, variables)
|
|
83
|
+
|
|
84
|
+
self.gui = GUI(
|
|
85
|
+
self.X,
|
|
86
|
+
self.y,
|
|
87
|
+
self.model,
|
|
88
|
+
self.variables,
|
|
89
|
+
self.X_test,
|
|
90
|
+
self.y_test,
|
|
91
|
+
self.X_exp,
|
|
92
|
+
self.score,
|
|
93
|
+
self.problem_category
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def set_variables(self, X, variables):
|
|
97
|
+
if variables is not None:
|
|
98
|
+
if isinstance(variables, list):
|
|
99
|
+
self.variables: DataVariables = Variable.import_variable_list(variables)
|
|
100
|
+
if len(self.variables) != len(X.columns):
|
|
101
|
+
raise ValueError("Provided variable list must be the same length of the dataframe")
|
|
102
|
+
elif isinstance(variables, pd.DataFrame):
|
|
103
|
+
self.variables = Variable.import_variable_df(variables)
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError("Provided variable list must be a list or a pandas DataFrame")
|
|
106
|
+
else:
|
|
107
|
+
self.variables = Variable.guess_variables(X)
|
|
108
|
+
|
|
109
|
+
def start_gui(self) -> GUI:
|
|
110
|
+
return self.gui.show_splash_screen()
|
|
111
|
+
|
|
112
|
+
def export_regions(self):
|
|
113
|
+
return self.gui.region_set
|
|
114
|
+
|
|
115
|
+
def _preprocess_data(self, X: pd.DataFrame, y, X_exp: pd.DataFrame):
|
|
116
|
+
if isinstance(X, np.ndarray):
|
|
117
|
+
X = pd.DataFrame(X)
|
|
118
|
+
if isinstance(X_exp, np.ndarray):
|
|
119
|
+
X_exp = pd.DataFrame(X_exp)
|
|
120
|
+
if isinstance(y, np.ndarray):
|
|
121
|
+
y = pd.Series(y)
|
|
122
|
+
|
|
123
|
+
X.columns = [str(col) for col in X.columns]
|
|
124
|
+
if X_exp is not None:
|
|
125
|
+
X_exp.columns = X.columns
|
|
126
|
+
|
|
127
|
+
if X_exp is not None:
|
|
128
|
+
pd.testing.assert_index_equal(X.index, X_exp.index, check_names=False)
|
|
129
|
+
if X.reindex(X_exp.index).iloc[:, 0].isna().sum() != X.iloc[:, 0].isna().sum():
|
|
130
|
+
raise IndexError('X and X_exp must share the same index')
|
|
131
|
+
pd.testing.assert_index_equal(X.index, y.index, check_names=False)
|
|
132
|
+
return X, y, X_exp
|
|
133
|
+
|
|
134
|
+
def _preprocess_problem_category(self, problem_category: str, model, X: pd.DataFrame) -> ProblemCategory:
|
|
135
|
+
if problem_category not in [e.name for e in ProblemCategory]:
|
|
136
|
+
raise ValueError('Invalid problem category')
|
|
137
|
+
if problem_category == 'auto':
|
|
138
|
+
if hasattr(model, 'predict_proba'):
|
|
139
|
+
return ProblemCategory['classification_with_proba']
|
|
140
|
+
pred = self.model.predict(self.X.sample(min(100, len(self.X))))
|
|
141
|
+
if len(pred.shape) > 1 and pred.shape[1] > 1:
|
|
142
|
+
return ProblemCategory['classification_proba']
|
|
143
|
+
return ProblemCategory['regression']
|
|
144
|
+
if problem_category == 'classification':
|
|
145
|
+
if hasattr(model, 'prodict_proba'):
|
|
146
|
+
return ProblemCategory['classification_with_proba']
|
|
147
|
+
pred = model.predict(X.sample(min(100, len(X))))
|
|
148
|
+
if len(pred.shape) > 1 and pred.shape[1] > 1:
|
|
149
|
+
return ProblemCategory['classification_proba']
|
|
150
|
+
return ProblemCategory['classification_label_only']
|
|
151
|
+
return ProblemCategory[problem_category]
|
|
152
|
+
|
|
153
|
+
def _preprocess_score(self, score, problem_category):
|
|
154
|
+
if callable(score):
|
|
155
|
+
return score
|
|
156
|
+
if score != 'auto':
|
|
157
|
+
return score
|
|
158
|
+
if problem_category == ProblemCategory.regression:
|
|
159
|
+
return 'mse'
|
|
160
|
+
return 'accuracy'
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
antakia/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
DEFAULT_EXPLANATION_METHOD = int(os.environ.get('DEFAULT_EXPLANATION_METHOD', 1))
|
|
4
|
+
DEFAULT_DIMENSION = int(os.environ.get('DEFAULT_VS_DIMENSION', 2))
|
|
5
|
+
DEFAULT_PROJECTION = 'PaCMAP'
|
|
6
|
+
|
|
7
|
+
INIT_FIG_WIDTH = int(os.environ.get('INIT_FIG_WIDTH', 1800))
|
|
8
|
+
MAX_DOTS = int(os.environ.get('MAX_DOTS', 5000))
|
|
9
|
+
|
|
10
|
+
# Rule format
|
|
11
|
+
USE_INTERVALS_FOR_RULES = os.environ.get('USE_INTERVALS_FOR_RULES', 'True') == 'True'
|
|
12
|
+
MAX_RULES_DESCR_LENGTH = int(os.environ.get('MAX_RULES_DESCR_LENGTH', 200))
|
|
13
|
+
|
|
14
|
+
SHOW_LOG_MODULE_WIDGET = os.environ.get('SHOW_LOG_MODULE_WIDGET', 'False') == 'True'
|
|
15
|
+
|
|
16
|
+
#Auto cluster
|
|
17
|
+
MIN_POINTS_NUMBER = 100
|
|
File without changes
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
from antakia_core.utils.long_task import LongTask
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExplanationMethod(LongTask):
|
|
7
|
+
"""
|
|
8
|
+
Abstract class (see Long Task) to compute explaination values for the Explanation Space (ES)
|
|
9
|
+
|
|
10
|
+
Attributes
|
|
11
|
+
model : the model to explain
|
|
12
|
+
explanation_method : SHAP or LIME
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
# Class attributes
|
|
16
|
+
NONE = 0 # no explanation, ie: original values
|
|
17
|
+
SHAP = 1
|
|
18
|
+
LIME = 2
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
explanation_method: int,
|
|
23
|
+
X: pd.DataFrame,
|
|
24
|
+
model,
|
|
25
|
+
task_type,
|
|
26
|
+
progress_updated: callable = None,
|
|
27
|
+
):
|
|
28
|
+
if not ExplanationMethod.is_valid_explanation_method(explanation_method):
|
|
29
|
+
raise ValueError(explanation_method, " is a bad explanation method")
|
|
30
|
+
self.explanation_method = explanation_method
|
|
31
|
+
super().__init__(X, progress_updated)
|
|
32
|
+
self.task_type = task_type
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def is_valid_explanation_method(method: int) -> bool:
|
|
37
|
+
"""
|
|
38
|
+
Returns True if this is a valid explanation method.
|
|
39
|
+
"""
|
|
40
|
+
return (
|
|
41
|
+
method == ExplanationMethod.SHAP
|
|
42
|
+
or method == ExplanationMethod.LIME
|
|
43
|
+
or method == ExplanationMethod.NONE
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def explanation_methods_as_list() -> list:
|
|
48
|
+
return [ExplanationMethod.SHAP, ExplanationMethod.LIME]
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def explain_method_as_str(method: int) -> str:
|
|
52
|
+
if method == ExplanationMethod.SHAP:
|
|
53
|
+
return "SHAP"
|
|
54
|
+
elif method == ExplanationMethod.LIME:
|
|
55
|
+
return "LIME"
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(method, " is a bad explanation method")
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def explain_method_as_int(method: str) -> int:
|
|
61
|
+
if method.upper() == "SHAP":
|
|
62
|
+
return ExplanationMethod.SHAP
|
|
63
|
+
elif method.upper() == "LIME":
|
|
64
|
+
return ExplanationMethod.LIME
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(method, " is a bad explanation method")
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import lime
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import shap
|
|
5
|
+
from antakia_core.utils.utils import ProblemCategory
|
|
6
|
+
|
|
7
|
+
from antakia.explanation.explanation_method import ExplanationMethod
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# ===========================================================
|
|
11
|
+
# Explanations implementations
|
|
12
|
+
# ===========================================================
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SHAPExplanation(ExplanationMethod):
|
|
16
|
+
"""
|
|
17
|
+
SHAP computation class.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, X: pd.DataFrame, model, task_type, progress_updated: callable = None):
|
|
21
|
+
super().__init__(ExplanationMethod.SHAP, X, model, task_type, progress_updated)
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def link(self):
|
|
25
|
+
if self.task_type == ProblemCategory.regression:
|
|
26
|
+
return "identity"
|
|
27
|
+
return "logit"
|
|
28
|
+
|
|
29
|
+
def compute(self) -> pd.DataFrame:
|
|
30
|
+
self.publish_progress(0)
|
|
31
|
+
try:
|
|
32
|
+
explainer = shap.TreeExplainer(self.model)
|
|
33
|
+
except:
|
|
34
|
+
explainer = shap.KernelExplainer(self.model.predict, self.X.sample(min(200, len(self.X))), link=self.link)
|
|
35
|
+
chunck_size = 200
|
|
36
|
+
shap_val_list = []
|
|
37
|
+
for i in range(0, len(self.X), chunck_size):
|
|
38
|
+
explanations = explainer.shap_values(self.X.iloc[i:i + chunck_size])
|
|
39
|
+
shap_val_list.append(
|
|
40
|
+
pd.DataFrame(explanations, columns=self.X.columns, index=self.X.index[i:i + chunck_size]))
|
|
41
|
+
self.publish_progress(int(100 * (i * chunck_size) / len(self.X)))
|
|
42
|
+
shap_values = pd.concat(shap_val_list)
|
|
43
|
+
self.publish_progress(100)
|
|
44
|
+
return shap_values
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LIMExplanation(ExplanationMethod):
|
|
48
|
+
"""
|
|
49
|
+
LIME computation class.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, X: pd.DataFrame, model, task_type, progress_updated: callable = None):
|
|
53
|
+
super().__init__(ExplanationMethod.LIME, X, model, task_type, progress_updated)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def mode(self):
|
|
57
|
+
print(self.task_type)
|
|
58
|
+
if self.task_type == ProblemCategory.regression:
|
|
59
|
+
return 'regression'
|
|
60
|
+
else:
|
|
61
|
+
return 'classification'
|
|
62
|
+
|
|
63
|
+
def compute(self) -> pd.DataFrame:
|
|
64
|
+
self.publish_progress(0)
|
|
65
|
+
|
|
66
|
+
explainer = lime.lime_tabular.LimeTabularExplainer(
|
|
67
|
+
self.X.sample(min(len(self.X), 500)).values,
|
|
68
|
+
feature_names=self.X.columns,
|
|
69
|
+
verbose=False,
|
|
70
|
+
mode=self.mode,
|
|
71
|
+
discretize_continuous=False
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
values_lime = pd.DataFrame(
|
|
75
|
+
np.zeros(self.X.shape),
|
|
76
|
+
index=self.X.index,
|
|
77
|
+
columns=self.X.columns
|
|
78
|
+
)
|
|
79
|
+
progress = 0
|
|
80
|
+
if self.mode == 'regression':
|
|
81
|
+
predict_fct = self.model.predict
|
|
82
|
+
i = 0
|
|
83
|
+
else:
|
|
84
|
+
i = 1
|
|
85
|
+
if hasattr(self.model, 'predict_proba'):
|
|
86
|
+
predict_fct = self.model.predict_proba
|
|
87
|
+
else:
|
|
88
|
+
predict_fct = self.model.predict
|
|
89
|
+
for index, row in self.X.iterrows():
|
|
90
|
+
exp = explainer.explain_instance(row.values, predict_fct)
|
|
91
|
+
|
|
92
|
+
values_lime.loc[index] = pd.Series(exp.local_exp[i], index=explainer.feature_names).str[1]
|
|
93
|
+
progress += 100 / len(self.X)
|
|
94
|
+
self.publish_progress(int(progress))
|
|
95
|
+
self.publish_progress(100)
|
|
96
|
+
return values_lime
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def compute_explanations(X: pd.DataFrame, model, explanation_method: int, task_type,
|
|
100
|
+
callback: callable) -> pd.DataFrame:
|
|
101
|
+
""" Generic method to compute explanations, SHAP or LIME
|
|
102
|
+
"""
|
|
103
|
+
if explanation_method == ExplanationMethod.SHAP:
|
|
104
|
+
return SHAPExplanation(X, model, task_type, callback).compute()
|
|
105
|
+
elif explanation_method == ExplanationMethod.LIME:
|
|
106
|
+
return LIMExplanation(X, model, task_type, callback).compute()
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"This explanation method {explanation_method} is not valid!")
|
antakia/gui/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from traitlets import traitlets
|
|
2
|
+
import ipyvuetify as v
|
|
3
|
+
from antakia_core.utils.utils import colors
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ColorTable(v.VuetifyTemplate):
|
|
7
|
+
"""
|
|
8
|
+
table to display regions
|
|
9
|
+
"""
|
|
10
|
+
headers = traitlets.List([]).tag(sync=True, allow_null=True)
|
|
11
|
+
items = traitlets.List([]).tag(sync=True, allow_null=True)
|
|
12
|
+
selected = traitlets.List([]).tag(sync=True, allow_null=True)
|
|
13
|
+
colors = traitlets.List(colors).tag(sync=True)
|
|
14
|
+
template = traitlets.Unicode('''
|
|
15
|
+
<template>
|
|
16
|
+
<v-data-table
|
|
17
|
+
v-model="selected"
|
|
18
|
+
:headers="headers"
|
|
19
|
+
:items="items"
|
|
20
|
+
item-key="Region"
|
|
21
|
+
show-select
|
|
22
|
+
:hide-default-footer="false"
|
|
23
|
+
@item-selected="tableselect"
|
|
24
|
+
>
|
|
25
|
+
<template #header.data-table-select></template>
|
|
26
|
+
<template v-slot:item.Region="{ item }">
|
|
27
|
+
<v-chip :color="item.color" >
|
|
28
|
+
{{ item.Region }}
|
|
29
|
+
</v-chip>
|
|
30
|
+
</template>
|
|
31
|
+
</v-data-table>
|
|
32
|
+
</template>
|
|
33
|
+
''').tag(sync=True) # type: ignore
|
|
34
|
+
disable_sort = True
|
|
35
|
+
|
|
36
|
+
def __init__(self, **kwargs):
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.callback = None
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def get_color(item):
|
|
42
|
+
return item.color
|
|
43
|
+
|
|
44
|
+
# @click:row="tableclick"
|
|
45
|
+
# def vue_tableclick(self, data):
|
|
46
|
+
# raise ValueError(f"click event data = {data}")
|
|
47
|
+
|
|
48
|
+
def set_callback(self, callback: callable): # type: ignore
|
|
49
|
+
self.callback = callback
|
|
50
|
+
|
|
51
|
+
def vue_tableselect(self, data):
|
|
52
|
+
self.callback(data)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import ipyvuetify as v
|
|
3
|
+
|
|
4
|
+
from antakia import config
|
|
5
|
+
from antakia.explanation.explanations import compute_explanations, ExplanationMethod
|
|
6
|
+
from antakia.gui.progress_bar import ProgressBar
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ExplanationValues:
|
|
10
|
+
"""
|
|
11
|
+
Widget to manage explanation values
|
|
12
|
+
in charge on computing them when necessary
|
|
13
|
+
"""
|
|
14
|
+
available_exp = ['Imported', 'SHAP', 'LIME']
|
|
15
|
+
|
|
16
|
+
def __init__(self, X: pd.DataFrame, y: pd.Series, model, task_type, on_change_callback: callable,
|
|
17
|
+
disable_gui: callable, X_exp=None):
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
X: original train DataFrame
|
|
23
|
+
y: target variable
|
|
24
|
+
model: customer model
|
|
25
|
+
on_change_callback: callback to notify explanation change
|
|
26
|
+
X_exp: user provided explanations
|
|
27
|
+
"""
|
|
28
|
+
self.widget = None
|
|
29
|
+
self.X = X
|
|
30
|
+
self.y = y
|
|
31
|
+
self.model = model
|
|
32
|
+
self.task_type = task_type
|
|
33
|
+
self.on_change_callback = on_change_callback
|
|
34
|
+
self.disable_gui = disable_gui
|
|
35
|
+
self.initialized = False
|
|
36
|
+
|
|
37
|
+
# init dict of explanations
|
|
38
|
+
self.explanations: dict[str, pd.DataFrame | None] = {
|
|
39
|
+
exp: None for exp in self.available_exp
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if X_exp is not None:
|
|
43
|
+
self.explanations[self.available_exp[0]] = X_exp
|
|
44
|
+
|
|
45
|
+
# init selected explanation
|
|
46
|
+
if X_exp is not None:
|
|
47
|
+
self.current_exp = self.available_exp[0]
|
|
48
|
+
else:
|
|
49
|
+
self.current_exp = self.available_exp[1]
|
|
50
|
+
|
|
51
|
+
self.build_widget()
|
|
52
|
+
|
|
53
|
+
def build_widget(self):
|
|
54
|
+
self.widget = v.Row(children=[
|
|
55
|
+
v.Select( # Select of explanation method
|
|
56
|
+
label="Explanation method",
|
|
57
|
+
items=[
|
|
58
|
+
{"text": "Imported", "disabled": True},
|
|
59
|
+
{"text": "SHAP", "disabled": True},
|
|
60
|
+
{"text": "LIME", "disabled": True},
|
|
61
|
+
],
|
|
62
|
+
class_="ml-2 mr-2",
|
|
63
|
+
style_="width: 15%",
|
|
64
|
+
disabled=False,
|
|
65
|
+
),
|
|
66
|
+
v.ProgressCircular( # exp menu progress bar
|
|
67
|
+
class_="ml-2 mr-2 mt-2",
|
|
68
|
+
indeterminate=False,
|
|
69
|
+
color="grey",
|
|
70
|
+
width="6",
|
|
71
|
+
size="35",
|
|
72
|
+
)
|
|
73
|
+
])
|
|
74
|
+
# refresh select menu
|
|
75
|
+
self.update_explanation_select()
|
|
76
|
+
self.get_explanation_select().on_event("change", self.explanation_select_changed)
|
|
77
|
+
# set up callback
|
|
78
|
+
self.get_progress_bar().reset_progress_bar()
|
|
79
|
+
|
|
80
|
+
def initialize(self, progress_callback):
|
|
81
|
+
"""
|
|
82
|
+
initialize class (compute explanation if necessary)
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
progress_callback : callback to notify progress
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
|
|
90
|
+
"""
|
|
91
|
+
if not self.has_user_exp:
|
|
92
|
+
# compute explanation if not provided
|
|
93
|
+
self.compute_explanation(config.DEFAULT_EXPLANATION_METHOD, progress_callback)
|
|
94
|
+
# ensure progress is at 100%
|
|
95
|
+
progress_callback(100, 0)
|
|
96
|
+
self.initialized = True
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def current_exp_df(self) -> pd.DataFrame:
|
|
100
|
+
"""
|
|
101
|
+
currently selected explanation projected values instance
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
return self.explanations[self.current_exp]
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def has_user_exp(self) -> bool:
|
|
110
|
+
"""
|
|
111
|
+
has the user provided an explanation
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
|
|
115
|
+
"""
|
|
116
|
+
return self.explanations[self.available_exp[0]] is not None
|
|
117
|
+
|
|
118
|
+
def update_explanation_select(self):
|
|
119
|
+
"""
|
|
120
|
+
refresh explanation select menu
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
exp_values = []
|
|
126
|
+
for exp in self.available_exp:
|
|
127
|
+
if exp == 'Imported':
|
|
128
|
+
exp_values.append({
|
|
129
|
+
"text": exp,
|
|
130
|
+
'disabled': self.explanations[exp] is None
|
|
131
|
+
})
|
|
132
|
+
else:
|
|
133
|
+
exp_values.append({
|
|
134
|
+
"text": exp + (' (compute)' if self.explanations[exp] is None else ''),
|
|
135
|
+
'disabled': False
|
|
136
|
+
})
|
|
137
|
+
self.get_explanation_select().items = exp_values
|
|
138
|
+
self.get_explanation_select().v_model = self.current_exp
|
|
139
|
+
|
|
140
|
+
def get_progress_bar(self):
|
|
141
|
+
progress_widget = self.widget.children[1]
|
|
142
|
+
progress_bar = ProgressBar(progress_widget)
|
|
143
|
+
return progress_bar
|
|
144
|
+
|
|
145
|
+
def get_explanation_select(self):
|
|
146
|
+
"""
|
|
147
|
+
returns the explanation select menu
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
return self.widget.children[0]
|
|
153
|
+
|
|
154
|
+
def compute_explanation(self, explanation_method: int, progress_bar: callable):
|
|
155
|
+
"""
|
|
156
|
+
compute explanation and refresh widgets (select the new explanation method)
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
explanation_method: desired explanation
|
|
160
|
+
progress_bar : progress bar to notify progress to
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
|
|
165
|
+
"""
|
|
166
|
+
self.disable_gui(True)
|
|
167
|
+
self.current_exp = self.available_exp[explanation_method]
|
|
168
|
+
# We compute proj for this new PV :
|
|
169
|
+
x_exp = compute_explanations(self.X, self.model, explanation_method, self.task_type, progress_bar)
|
|
170
|
+
pd.testing.assert_index_equal(x_exp.columns, self.X.columns)
|
|
171
|
+
|
|
172
|
+
# update explanation
|
|
173
|
+
self.explanations[self.current_exp] = x_exp
|
|
174
|
+
# refresh front
|
|
175
|
+
self.update_explanation_select()
|
|
176
|
+
self.disable_gui(False)
|
|
177
|
+
|
|
178
|
+
def disable_selection(self, is_disabled: bool):
|
|
179
|
+
"""
|
|
180
|
+
disable widgets
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
is_disabled = should disable ?
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
|
|
188
|
+
"""
|
|
189
|
+
self.get_explanation_select().disabled = is_disabled
|
|
190
|
+
|
|
191
|
+
def explanation_select_changed(self, widget, event, data):
|
|
192
|
+
"""
|
|
193
|
+
triggered on selection of new explanation by user
|
|
194
|
+
explanation has already been computed (the option is enabled in select)
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
widget
|
|
198
|
+
event
|
|
199
|
+
data: explanation name
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
|
|
204
|
+
Called when the user chooses another dataframe
|
|
205
|
+
"""
|
|
206
|
+
if not isinstance(data, str):
|
|
207
|
+
raise KeyError('invalid explanation')
|
|
208
|
+
data = data.replace(' ', '').replace('(compute)', '')
|
|
209
|
+
self.current_exp = data
|
|
210
|
+
|
|
211
|
+
if self.explanations[self.current_exp] is None:
|
|
212
|
+
exp_method = ExplanationMethod.explain_method_as_int(self.current_exp)
|
|
213
|
+
progress_bar = self.get_progress_bar()
|
|
214
|
+
self.compute_explanation(exp_method, progress_bar)
|
|
215
|
+
|
|
216
|
+
self.on_change_callback(self.current_exp_df)
|