antakia 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
antakia/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+
2
+ __version__ = "0.2.1"
3
+ __author__ = "AI-vidence "
4
+
5
+ from antakia.antakia import AntakIA
antakia/antakia.py ADDED
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Dict, Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from dotenv import load_dotenv
9
+
10
+ from antakia_core.utils.utils import ProblemCategory
11
+
12
+ load_dotenv()
13
+
14
+ from antakia.utils.checks import is_valid_model
15
+ from antakia_core.utils.variable import Variable, DataVariables
16
+ from antakia.gui.gui import GUI
17
+
18
+
19
+ class AntakIA:
20
+ """
21
+ AntakIA class.
22
+
23
+ Antakia instances provide data and methods to explain a ML model.
24
+
25
+ Instance attributes
26
+ -------------------
27
+ X : pd.DataFrame the training dataset
28
+ y : pd.Series the target value
29
+ model : Model
30
+ the model to explain
31
+ variables : a list of Variables, describing X_list[0]
32
+ X_test : pd.DataFrame the test dataset
33
+ y_test : pd.Series the test target value
34
+ score : reference scoring function
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ X: pd.DataFrame,
40
+ y: pd.Series,
41
+ model,
42
+ variables: DataVariables | List[Dict[str, Any]] | pd.DataFrame | None = None,
43
+ X_test: pd.DataFrame = None,
44
+ y_test: pd.Series = None,
45
+ X_exp: pd.DataFrame | None = None,
46
+ score: callable | str = 'auto',
47
+ problem_category: str = 'auto'
48
+ ):
49
+ """
50
+ AntakiIA constructor.
51
+
52
+ Parameters:
53
+ X : pd.DataFrame the training dataset
54
+ y : pd.Series the target value
55
+ model : Model
56
+ the model to explain
57
+ variables : a list of Variables, describing X_list[0]
58
+ X_test : pd.DataFrame the test dataset
59
+ y_test : pd.Series the test target value
60
+ score : reference scoring function
61
+ """
62
+
63
+ load_dotenv()
64
+
65
+ if not is_valid_model(model):
66
+ raise ValueError(model, " should implement predict and score methods")
67
+ X, y, X_exp = self._preprocess_data(X, y, X_exp)
68
+
69
+ self.X = X
70
+ self.X_test = X_test
71
+ if y.ndim > 1:
72
+ y = y.squeeze()
73
+ self.y = y.astype(float)
74
+ if y_test is not None and y_test.ndim > 1:
75
+ y_test = y_test.squeeze()
76
+ self.y_test = y_test
77
+ self.model = model
78
+ self.X_exp = X_exp
79
+ self.problem_category = self._preprocess_problem_category(problem_category, model, X)
80
+ self.score = self._preprocess_score(score, self.problem_category)
81
+
82
+ self.set_variables(X, variables)
83
+
84
+ self.gui = GUI(
85
+ self.X,
86
+ self.y,
87
+ self.model,
88
+ self.variables,
89
+ self.X_test,
90
+ self.y_test,
91
+ self.X_exp,
92
+ self.score,
93
+ self.problem_category
94
+ )
95
+
96
+ def set_variables(self, X, variables):
97
+ if variables is not None:
98
+ if isinstance(variables, list):
99
+ self.variables: DataVariables = Variable.import_variable_list(variables)
100
+ if len(self.variables) != len(X.columns):
101
+ raise ValueError("Provided variable list must be the same length of the dataframe")
102
+ elif isinstance(variables, pd.DataFrame):
103
+ self.variables = Variable.import_variable_df(variables)
104
+ else:
105
+ raise ValueError("Provided variable list must be a list or a pandas DataFrame")
106
+ else:
107
+ self.variables = Variable.guess_variables(X)
108
+
109
+ def start_gui(self) -> GUI:
110
+ return self.gui.show_splash_screen()
111
+
112
+ def export_regions(self):
113
+ return self.gui.region_set
114
+
115
+ def _preprocess_data(self, X: pd.DataFrame, y, X_exp: pd.DataFrame):
116
+ if isinstance(X, np.ndarray):
117
+ X = pd.DataFrame(X)
118
+ if isinstance(X_exp, np.ndarray):
119
+ X_exp = pd.DataFrame(X_exp)
120
+ if isinstance(y, np.ndarray):
121
+ y = pd.Series(y)
122
+
123
+ X.columns = [str(col) for col in X.columns]
124
+ if X_exp is not None:
125
+ X_exp.columns = X.columns
126
+
127
+ if X_exp is not None:
128
+ pd.testing.assert_index_equal(X.index, X_exp.index, check_names=False)
129
+ if X.reindex(X_exp.index).iloc[:, 0].isna().sum() != X.iloc[:, 0].isna().sum():
130
+ raise IndexError('X and X_exp must share the same index')
131
+ pd.testing.assert_index_equal(X.index, y.index, check_names=False)
132
+ return X, y, X_exp
133
+
134
+ def _preprocess_problem_category(self, problem_category: str, model, X: pd.DataFrame) -> ProblemCategory:
135
+ if problem_category not in [e.name for e in ProblemCategory]:
136
+ raise ValueError('Invalid problem category')
137
+ if problem_category == 'auto':
138
+ if hasattr(model, 'predict_proba'):
139
+ return ProblemCategory['classification_with_proba']
140
+ pred = self.model.predict(self.X.sample(min(100, len(self.X))))
141
+ if len(pred.shape) > 1 and pred.shape[1] > 1:
142
+ return ProblemCategory['classification_proba']
143
+ return ProblemCategory['regression']
144
+ if problem_category == 'classification':
145
+ if hasattr(model, 'prodict_proba'):
146
+ return ProblemCategory['classification_with_proba']
147
+ pred = model.predict(X.sample(min(100, len(X))))
148
+ if len(pred.shape) > 1 and pred.shape[1] > 1:
149
+ return ProblemCategory['classification_proba']
150
+ return ProblemCategory['classification_label_only']
151
+ return ProblemCategory[problem_category]
152
+
153
+ def _preprocess_score(self, score, problem_category):
154
+ if callable(score):
155
+ return score
156
+ if score != 'auto':
157
+ return score
158
+ if problem_category == ProblemCategory.regression:
159
+ return 'mse'
160
+ return 'accuracy'
Binary file
Binary file
antakia/config.py ADDED
@@ -0,0 +1,17 @@
1
+ import os
2
+
3
+ DEFAULT_EXPLANATION_METHOD = int(os.environ.get('DEFAULT_EXPLANATION_METHOD', 1))
4
+ DEFAULT_DIMENSION = int(os.environ.get('DEFAULT_VS_DIMENSION', 2))
5
+ DEFAULT_PROJECTION = 'PaCMAP'
6
+
7
+ INIT_FIG_WIDTH = int(os.environ.get('INIT_FIG_WIDTH', 1800))
8
+ MAX_DOTS = int(os.environ.get('MAX_DOTS', 5000))
9
+
10
+ # Rule format
11
+ USE_INTERVALS_FOR_RULES = os.environ.get('USE_INTERVALS_FOR_RULES', 'True') == 'True'
12
+ MAX_RULES_DESCR_LENGTH = int(os.environ.get('MAX_RULES_DESCR_LENGTH', 200))
13
+
14
+ SHOW_LOG_MODULE_WIDGET = os.environ.get('SHOW_LOG_MODULE_WIDGET', 'False') == 'True'
15
+
16
+ #Auto cluster
17
+ MIN_POINTS_NUMBER = 100
File without changes
@@ -0,0 +1,66 @@
1
+ import pandas as pd
2
+
3
+ from antakia_core.utils.long_task import LongTask
4
+
5
+
6
+ class ExplanationMethod(LongTask):
7
+ """
8
+ Abstract class (see Long Task) to compute explaination values for the Explanation Space (ES)
9
+
10
+ Attributes
11
+ model : the model to explain
12
+ explanation_method : SHAP or LIME
13
+ """
14
+
15
+ # Class attributes
16
+ NONE = 0 # no explanation, ie: original values
17
+ SHAP = 1
18
+ LIME = 2
19
+
20
+ def __init__(
21
+ self,
22
+ explanation_method: int,
23
+ X: pd.DataFrame,
24
+ model,
25
+ task_type,
26
+ progress_updated: callable = None,
27
+ ):
28
+ if not ExplanationMethod.is_valid_explanation_method(explanation_method):
29
+ raise ValueError(explanation_method, " is a bad explanation method")
30
+ self.explanation_method = explanation_method
31
+ super().__init__(X, progress_updated)
32
+ self.task_type = task_type
33
+ self.model = model
34
+
35
+ @staticmethod
36
+ def is_valid_explanation_method(method: int) -> bool:
37
+ """
38
+ Returns True if this is a valid explanation method.
39
+ """
40
+ return (
41
+ method == ExplanationMethod.SHAP
42
+ or method == ExplanationMethod.LIME
43
+ or method == ExplanationMethod.NONE
44
+ )
45
+
46
+ @staticmethod
47
+ def explanation_methods_as_list() -> list:
48
+ return [ExplanationMethod.SHAP, ExplanationMethod.LIME]
49
+
50
+ @staticmethod
51
+ def explain_method_as_str(method: int) -> str:
52
+ if method == ExplanationMethod.SHAP:
53
+ return "SHAP"
54
+ elif method == ExplanationMethod.LIME:
55
+ return "LIME"
56
+ else:
57
+ raise ValueError(method, " is a bad explanation method")
58
+
59
+ @staticmethod
60
+ def explain_method_as_int(method: str) -> int:
61
+ if method.upper() == "SHAP":
62
+ return ExplanationMethod.SHAP
63
+ elif method.upper() == "LIME":
64
+ return ExplanationMethod.LIME
65
+ else:
66
+ raise ValueError(method, " is a bad explanation method")
@@ -0,0 +1,108 @@
1
+ import lime
2
+ import numpy as np
3
+ import pandas as pd
4
+ import shap
5
+ from antakia_core.utils.utils import ProblemCategory
6
+
7
+ from antakia.explanation.explanation_method import ExplanationMethod
8
+
9
+
10
+ # ===========================================================
11
+ # Explanations implementations
12
+ # ===========================================================
13
+
14
+
15
+ class SHAPExplanation(ExplanationMethod):
16
+ """
17
+ SHAP computation class.
18
+ """
19
+
20
+ def __init__(self, X: pd.DataFrame, model, task_type, progress_updated: callable = None):
21
+ super().__init__(ExplanationMethod.SHAP, X, model, task_type, progress_updated)
22
+
23
+ @property
24
+ def link(self):
25
+ if self.task_type == ProblemCategory.regression:
26
+ return "identity"
27
+ return "logit"
28
+
29
+ def compute(self) -> pd.DataFrame:
30
+ self.publish_progress(0)
31
+ try:
32
+ explainer = shap.TreeExplainer(self.model)
33
+ except:
34
+ explainer = shap.KernelExplainer(self.model.predict, self.X.sample(min(200, len(self.X))), link=self.link)
35
+ chunck_size = 200
36
+ shap_val_list = []
37
+ for i in range(0, len(self.X), chunck_size):
38
+ explanations = explainer.shap_values(self.X.iloc[i:i + chunck_size])
39
+ shap_val_list.append(
40
+ pd.DataFrame(explanations, columns=self.X.columns, index=self.X.index[i:i + chunck_size]))
41
+ self.publish_progress(int(100 * (i * chunck_size) / len(self.X)))
42
+ shap_values = pd.concat(shap_val_list)
43
+ self.publish_progress(100)
44
+ return shap_values
45
+
46
+
47
+ class LIMExplanation(ExplanationMethod):
48
+ """
49
+ LIME computation class.
50
+ """
51
+
52
+ def __init__(self, X: pd.DataFrame, model, task_type, progress_updated: callable = None):
53
+ super().__init__(ExplanationMethod.LIME, X, model, task_type, progress_updated)
54
+
55
+ @property
56
+ def mode(self):
57
+ print(self.task_type)
58
+ if self.task_type == ProblemCategory.regression:
59
+ return 'regression'
60
+ else:
61
+ return 'classification'
62
+
63
+ def compute(self) -> pd.DataFrame:
64
+ self.publish_progress(0)
65
+
66
+ explainer = lime.lime_tabular.LimeTabularExplainer(
67
+ self.X.sample(min(len(self.X), 500)).values,
68
+ feature_names=self.X.columns,
69
+ verbose=False,
70
+ mode=self.mode,
71
+ discretize_continuous=False
72
+ )
73
+
74
+ values_lime = pd.DataFrame(
75
+ np.zeros(self.X.shape),
76
+ index=self.X.index,
77
+ columns=self.X.columns
78
+ )
79
+ progress = 0
80
+ if self.mode == 'regression':
81
+ predict_fct = self.model.predict
82
+ i = 0
83
+ else:
84
+ i = 1
85
+ if hasattr(self.model, 'predict_proba'):
86
+ predict_fct = self.model.predict_proba
87
+ else:
88
+ predict_fct = self.model.predict
89
+ for index, row in self.X.iterrows():
90
+ exp = explainer.explain_instance(row.values, predict_fct)
91
+
92
+ values_lime.loc[index] = pd.Series(exp.local_exp[i], index=explainer.feature_names).str[1]
93
+ progress += 100 / len(self.X)
94
+ self.publish_progress(int(progress))
95
+ self.publish_progress(100)
96
+ return values_lime
97
+
98
+
99
+ def compute_explanations(X: pd.DataFrame, model, explanation_method: int, task_type,
100
+ callback: callable) -> pd.DataFrame:
101
+ """ Generic method to compute explanations, SHAP or LIME
102
+ """
103
+ if explanation_method == ExplanationMethod.SHAP:
104
+ return SHAPExplanation(X, model, task_type, callback).compute()
105
+ elif explanation_method == ExplanationMethod.LIME:
106
+ return LIMExplanation(X, model, task_type, callback).compute()
107
+ else:
108
+ raise ValueError(f"This explanation method {explanation_method} is not valid!")
@@ -0,0 +1,2 @@
1
+ __version__ = "0.1.1"
2
+ __author__ = "AI-vidence"
@@ -0,0 +1,52 @@
1
+ from traitlets import traitlets
2
+ import ipyvuetify as v
3
+ from antakia_core.utils.utils import colors
4
+
5
+
6
+ class ColorTable(v.VuetifyTemplate):
7
+ """
8
+ table to display regions
9
+ """
10
+ headers = traitlets.List([]).tag(sync=True, allow_null=True)
11
+ items = traitlets.List([]).tag(sync=True, allow_null=True)
12
+ selected = traitlets.List([]).tag(sync=True, allow_null=True)
13
+ colors = traitlets.List(colors).tag(sync=True)
14
+ template = traitlets.Unicode('''
15
+ <template>
16
+ <v-data-table
17
+ v-model="selected"
18
+ :headers="headers"
19
+ :items="items"
20
+ item-key="Region"
21
+ show-select
22
+ :hide-default-footer="false"
23
+ @item-selected="tableselect"
24
+ >
25
+ <template #header.data-table-select></template>
26
+ <template v-slot:item.Region="{ item }">
27
+ <v-chip :color="item.color" >
28
+ {{ item.Region }}
29
+ </v-chip>
30
+ </template>
31
+ </v-data-table>
32
+ </template>
33
+ ''').tag(sync=True) # type: ignore
34
+ disable_sort = True
35
+
36
+ def __init__(self, **kwargs):
37
+ super().__init__(**kwargs)
38
+ self.callback = None
39
+
40
+ @staticmethod
41
+ def get_color(item):
42
+ return item.color
43
+
44
+ # @click:row="tableclick"
45
+ # def vue_tableclick(self, data):
46
+ # raise ValueError(f"click event data = {data}")
47
+
48
+ def set_callback(self, callback: callable): # type: ignore
49
+ self.callback = callback
50
+
51
+ def vue_tableselect(self, data):
52
+ self.callback(data)
@@ -0,0 +1,8 @@
1
+ class DataStore:
2
+ def __init__(
3
+ self,
4
+ X,
5
+ y,
6
+ X_test,
7
+ y_test,
8
+ ):
@@ -0,0 +1,216 @@
1
+ import pandas as pd
2
+ import ipyvuetify as v
3
+
4
+ from antakia import config
5
+ from antakia.explanation.explanations import compute_explanations, ExplanationMethod
6
+ from antakia.gui.progress_bar import ProgressBar
7
+
8
+
9
+ class ExplanationValues:
10
+ """
11
+ Widget to manage explanation values
12
+ in charge on computing them when necessary
13
+ """
14
+ available_exp = ['Imported', 'SHAP', 'LIME']
15
+
16
+ def __init__(self, X: pd.DataFrame, y: pd.Series, model, task_type, on_change_callback: callable,
17
+ disable_gui: callable, X_exp=None):
18
+ """
19
+
20
+ Parameters
21
+ ----------
22
+ X: original train DataFrame
23
+ y: target variable
24
+ model: customer model
25
+ on_change_callback: callback to notify explanation change
26
+ X_exp: user provided explanations
27
+ """
28
+ self.widget = None
29
+ self.X = X
30
+ self.y = y
31
+ self.model = model
32
+ self.task_type = task_type
33
+ self.on_change_callback = on_change_callback
34
+ self.disable_gui = disable_gui
35
+ self.initialized = False
36
+
37
+ # init dict of explanations
38
+ self.explanations: dict[str, pd.DataFrame | None] = {
39
+ exp: None for exp in self.available_exp
40
+ }
41
+
42
+ if X_exp is not None:
43
+ self.explanations[self.available_exp[0]] = X_exp
44
+
45
+ # init selected explanation
46
+ if X_exp is not None:
47
+ self.current_exp = self.available_exp[0]
48
+ else:
49
+ self.current_exp = self.available_exp[1]
50
+
51
+ self.build_widget()
52
+
53
+ def build_widget(self):
54
+ self.widget = v.Row(children=[
55
+ v.Select( # Select of explanation method
56
+ label="Explanation method",
57
+ items=[
58
+ {"text": "Imported", "disabled": True},
59
+ {"text": "SHAP", "disabled": True},
60
+ {"text": "LIME", "disabled": True},
61
+ ],
62
+ class_="ml-2 mr-2",
63
+ style_="width: 15%",
64
+ disabled=False,
65
+ ),
66
+ v.ProgressCircular( # exp menu progress bar
67
+ class_="ml-2 mr-2 mt-2",
68
+ indeterminate=False,
69
+ color="grey",
70
+ width="6",
71
+ size="35",
72
+ )
73
+ ])
74
+ # refresh select menu
75
+ self.update_explanation_select()
76
+ self.get_explanation_select().on_event("change", self.explanation_select_changed)
77
+ # set up callback
78
+ self.get_progress_bar().reset_progress_bar()
79
+
80
+ def initialize(self, progress_callback):
81
+ """
82
+ initialize class (compute explanation if necessary)
83
+ Parameters
84
+ ----------
85
+ progress_callback : callback to notify progress
86
+
87
+ Returns
88
+ -------
89
+
90
+ """
91
+ if not self.has_user_exp:
92
+ # compute explanation if not provided
93
+ self.compute_explanation(config.DEFAULT_EXPLANATION_METHOD, progress_callback)
94
+ # ensure progress is at 100%
95
+ progress_callback(100, 0)
96
+ self.initialized = True
97
+
98
+ @property
99
+ def current_exp_df(self) -> pd.DataFrame:
100
+ """
101
+ currently selected explanation projected values instance
102
+ Returns
103
+ -------
104
+
105
+ """
106
+ return self.explanations[self.current_exp]
107
+
108
+ @property
109
+ def has_user_exp(self) -> bool:
110
+ """
111
+ has the user provided an explanation
112
+ Returns
113
+ -------
114
+
115
+ """
116
+ return self.explanations[self.available_exp[0]] is not None
117
+
118
+ def update_explanation_select(self):
119
+ """
120
+ refresh explanation select menu
121
+ Returns
122
+ -------
123
+
124
+ """
125
+ exp_values = []
126
+ for exp in self.available_exp:
127
+ if exp == 'Imported':
128
+ exp_values.append({
129
+ "text": exp,
130
+ 'disabled': self.explanations[exp] is None
131
+ })
132
+ else:
133
+ exp_values.append({
134
+ "text": exp + (' (compute)' if self.explanations[exp] is None else ''),
135
+ 'disabled': False
136
+ })
137
+ self.get_explanation_select().items = exp_values
138
+ self.get_explanation_select().v_model = self.current_exp
139
+
140
+ def get_progress_bar(self):
141
+ progress_widget = self.widget.children[1]
142
+ progress_bar = ProgressBar(progress_widget)
143
+ return progress_bar
144
+
145
+ def get_explanation_select(self):
146
+ """
147
+ returns the explanation select menu
148
+ Returns
149
+ -------
150
+
151
+ """
152
+ return self.widget.children[0]
153
+
154
+ def compute_explanation(self, explanation_method: int, progress_bar: callable):
155
+ """
156
+ compute explanation and refresh widgets (select the new explanation method)
157
+ Parameters
158
+ ----------
159
+ explanation_method: desired explanation
160
+ progress_bar : progress bar to notify progress to
161
+
162
+ Returns
163
+ -------
164
+
165
+ """
166
+ self.disable_gui(True)
167
+ self.current_exp = self.available_exp[explanation_method]
168
+ # We compute proj for this new PV :
169
+ x_exp = compute_explanations(self.X, self.model, explanation_method, self.task_type, progress_bar)
170
+ pd.testing.assert_index_equal(x_exp.columns, self.X.columns)
171
+
172
+ # update explanation
173
+ self.explanations[self.current_exp] = x_exp
174
+ # refresh front
175
+ self.update_explanation_select()
176
+ self.disable_gui(False)
177
+
178
+ def disable_selection(self, is_disabled: bool):
179
+ """
180
+ disable widgets
181
+ Parameters
182
+ ----------
183
+ is_disabled = should disable ?
184
+
185
+ Returns
186
+ -------
187
+
188
+ """
189
+ self.get_explanation_select().disabled = is_disabled
190
+
191
+ def explanation_select_changed(self, widget, event, data):
192
+ """
193
+ triggered on selection of new explanation by user
194
+ explanation has already been computed (the option is enabled in select)
195
+ Parameters
196
+ ----------
197
+ widget
198
+ event
199
+ data: explanation name
200
+
201
+ Returns
202
+ -------
203
+
204
+ Called when the user chooses another dataframe
205
+ """
206
+ if not isinstance(data, str):
207
+ raise KeyError('invalid explanation')
208
+ data = data.replace(' ', '').replace('(compute)', '')
209
+ self.current_exp = data
210
+
211
+ if self.explanations[self.current_exp] is None:
212
+ exp_method = ExplanationMethod.explain_method_as_int(self.current_exp)
213
+ progress_bar = self.get_progress_bar()
214
+ self.compute_explanation(exp_method, progress_bar)
215
+
216
+ self.on_change_callback(self.current_exp_df)