midlearn 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- midlearn-0.1.2/LICENSE.md +21 -0
- midlearn-0.1.2/PKG-INFO +324 -0
- midlearn-0.1.2/README.md +304 -0
- midlearn-0.1.2/pyproject.toml +31 -0
- midlearn-0.1.2/setup.cfg +4 -0
- midlearn-0.1.2/src/midlearn/__init__.py +31 -0
- midlearn-0.1.2/src/midlearn/_r_interface.py +327 -0
- midlearn-0.1.2/src/midlearn/api.py +727 -0
- midlearn-0.1.2/src/midlearn/exceptions.py +9 -0
- midlearn-0.1.2/src/midlearn/plotting.py +401 -0
- midlearn-0.1.2/src/midlearn/plotting_theme.py +218 -0
- midlearn-0.1.2/src/midlearn/utils.py +9 -0
- midlearn-0.1.2/src/midlearn.egg-info/PKG-INFO +324 -0
- midlearn-0.1.2/src/midlearn.egg-info/SOURCES.txt +16 -0
- midlearn-0.1.2/src/midlearn.egg-info/dependency_links.txt +1 -0
- midlearn-0.1.2/src/midlearn.egg-info/requires.txt +6 -0
- midlearn-0.1.2/src/midlearn.egg-info/top_level.txt +1 -0
- midlearn-0.1.2/tests/test_functions.py +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 midr authors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
midlearn-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: midlearn
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Python Wrapper of the 'midr' R package to interpret black-box models.
|
|
5
|
+
Author-email: Ryoichi Asashiba <ryoichi.asashiba@gmail.com>
|
|
6
|
+
Project-URL: Issues, https://github.com/ryo-asashi/midlearn/issues
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Requires-Python: >=3.8
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
License-File: LICENSE.md
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: pandas
|
|
15
|
+
Requires-Dist: scikit-learn
|
|
16
|
+
Requires-Dist: rpy2>=3.5.0
|
|
17
|
+
Requires-Dist: plotnine
|
|
18
|
+
Requires-Dist: mizani
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
|
|
21
|
+
<!-- README.md is generated from README.ipynb. Please edit that file -->
|
|
22
|
+
|
|
23
|
+
# midlearn <img src="docs/logo/logo_hex.png" align="right" height="138"/>
|
|
24
|
+
|
|
25
|
+
A [{rpy2}](https://rpy2.github.io/doc/latest/html/)-based Python wrapper for the [{midr}](https://ryo-asashi.github.io/midr/) R package to explain black-box models, with a [{scikit-learn}](https://scikit-learn.org/stable/) compatible API.
|
|
26
|
+
|
|
27
|
+
The goal of {midr} is to provide a model-agnostic method for interpreting and explaining black-box predictive models by creating a globally interpretable surrogate model.
|
|
28
|
+
The package implements 'Maximum Interpretation Decomposition' (MID), a functional decomposition technique that finds an optimal additive approximation of the original model.
|
|
29
|
+
This approximation is achieved by minimizing the squared error between the predictions of the black-box model and the surrogate model.
|
|
30
|
+
The theoretical foundations of MID are described in Iwasawa & Matsumori (2025) \[Forthcoming\], and the package itself is detailed in [Asashiba et al. (2025)](https://arxiv.org/abs/2506.08338).
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
You can install the package directly from GitHub:
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
pip install git+https://github.com/ryo-asashi/midlearn.git
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Features
|
|
41
|
+
|
|
42
|
+
- **Scikit-learn Compatible API**: Fits seamlessly into your existing 'scikit-learn' workflows with a familiar .fit() and .predict() interface.
|
|
43
|
+
|
|
44
|
+
- **Model-Agnostic IML**: Explains any black-box model, from complex neural networks to gradient boosting machines.
|
|
45
|
+
|
|
46
|
+
- **Global Interpretability**: Generates a simple, additive surrogate model (MID) that provides a global understanding of the black-box model's behavior.
|
|
47
|
+
|
|
48
|
+
- **Direct Visualizations**: Easily creates plots for feature importance, component functions (dependence), prediction breakdowns, and conditional expectations using a plotnine-based interface.
|
|
49
|
+
|
|
50
|
+
## Requirements
|
|
51
|
+
|
|
52
|
+
This package is a {rpy2}-based Python wrapper and requires a working R installation on your system, as well as the {midr} R package.
|
|
53
|
+
|
|
54
|
+
You can install the R package from CRAN by running the following command in your R console:
|
|
55
|
+
|
|
56
|
+
```r
|
|
57
|
+
install.packages('midr')
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## Quick Start
|
|
61
|
+
|
|
62
|
+
Here’s a basic example of how to use **midlearn** to explain a trained LightGBM model.
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
import pandas as pd
|
|
67
|
+
from sklearn.model_selection import train_test_split
|
|
68
|
+
from sklearn.metrics import root_mean_squared_error
|
|
69
|
+
from sklearn.datasets import fetch_openml
|
|
70
|
+
from sklearn import set_config
|
|
71
|
+
|
|
72
|
+
import lightgbm as lgb
|
|
73
|
+
import midlearn as mid
|
|
74
|
+
|
|
75
|
+
# Set up plotnine theme for clean visualizations
|
|
76
|
+
import plotnine as p9 # require plotnine >= 0.15.0
|
|
77
|
+
p9.theme_set(p9.theme_538(base_family='serif'))
|
|
78
|
+
|
|
79
|
+
# Configure scikit-learn display
|
|
80
|
+
set_config(display='text')
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## 1. Train a Black-Box Model
|
|
84
|
+
We use the California Housing dataset to train a LightGBM Regressor, which will serve as our black-box model.
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
# Load and prepare data
|
|
89
|
+
bikeshare = fetch_openml(data_id=42712)
|
|
90
|
+
X = pd.DataFrame(bikeshare.data, columns=bikeshare.feature_names)
|
|
91
|
+
y = bikeshare.target
|
|
92
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
|
93
|
+
|
|
94
|
+
# Fit a LightGBM regression model
|
|
95
|
+
estimator = lgb.LGBMRegressor(
|
|
96
|
+
force_col_wise=True,
|
|
97
|
+
n_estimators=500,
|
|
98
|
+
random_state=42
|
|
99
|
+
)
|
|
100
|
+
estimator.fit(X_train, y_train)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
[LightGBM] [Info] Total Bins 283
|
|
104
|
+
[LightGBM] [Info] Number of data points in the train set: 13034, number of used features: 12
|
|
105
|
+
[LightGBM] [Info] Start training from score 190.379623
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
LGBMRegressor(force_col_wise=True, n_estimators=500, random_state=42)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
```python
|
|
117
|
+
model_pred = estimator.predict(X_test)
|
|
118
|
+
rmse = root_mean_squared_error(model_pred, y_test)
|
|
119
|
+
print(f"RMSE: {round(rmse, 6)}")
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
RMSE: 37.615267
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
## 2. Create an Explaination Model
|
|
126
|
+
We fit the `MIDExplainer` to the training data to create a globally faithful, interpretable surrogate model (MID).
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
# Initialize and fit the MID model
|
|
131
|
+
explainer = mid.MIDExplainer(
|
|
132
|
+
estimator=estimator,
|
|
133
|
+
penalty=.05,
|
|
134
|
+
singular_ok=True,
|
|
135
|
+
interactions=True,
|
|
136
|
+
encoding_frames={'hour':list(range(24))}
|
|
137
|
+
)
|
|
138
|
+
explainer.fit(X_train)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
Generating predictions from the estimator...
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
R callback write-console: singular fit encountered
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
MIDExplainer(encoding_frames={'hour': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
|
152
|
+
13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
|
153
|
+
23]},
|
|
154
|
+
estimator=LGBMRegressor(force_col_wise=True, n_estimators=500,
|
|
155
|
+
random_state=42),
|
|
156
|
+
penalty=0.05, singular_ok=True)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
```python
|
|
162
|
+
# Check the fidelity of the surrogate model to the original model
|
|
163
|
+
p = p9.ggplot() \
|
|
164
|
+
+ p9.geom_abline(slope=1, color='gray') \
|
|
165
|
+
+ p9.geom_point(p9.aes(estimator.predict(X_test), explainer.predict(X_test)), alpha=0.5, shape=".") \
|
|
166
|
+
+ p9.labs(
|
|
167
|
+
x='Prediction (LightGBM Regressor)',
|
|
168
|
+
y='Prediction (Surrogate MID Regressor)',
|
|
169
|
+
title='Surrogate Model Fidelity Check',
|
|
170
|
+
subtitle=f'R-squared score: {round(explainer.fidelity_score(X_test), 6)}',
|
|
171
|
+
)
|
|
172
|
+
display(p + p9.theme(figure_size=(5,5)))
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
Generating predictions from the estimator...
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+

|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
## 3. Visualize the Explanation Model
|
|
185
|
+
The MID model allows for clear visualization of feature importance, individual effects, and local prediction breakdowns.
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
```python
|
|
189
|
+
# Calculate and plot overall feature importance (default bar plot and heatmap)
|
|
190
|
+
imp = explainer.importance()
|
|
191
|
+
p1 = (
|
|
192
|
+
imp.plot(max_nterms=20, theme="muted") +
|
|
193
|
+
p9.labs(title="Feature Imortance Plot", subtitle="colored by order")
|
|
194
|
+
)
|
|
195
|
+
p2 = (
|
|
196
|
+
imp.plot(style='heatmap', color='black', linetype='dotted') +
|
|
197
|
+
p9.labs(title="Feature Importance Map", subtitle="colored by importance")
|
|
198
|
+
)
|
|
199
|
+
display((p1 | p2) & p9.theme(figure_size=(8, 4), legend_position="bottom"))
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+

|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
```python
|
|
210
|
+
# Plot the top 3 important main effects (Component Functions)
|
|
211
|
+
plots = list()
|
|
212
|
+
for i, t in enumerate(imp.terms(interactions=False)):
|
|
213
|
+
p = (
|
|
214
|
+
explainer.plot(term=t) +
|
|
215
|
+
p9.lims(y=[-180, 250]) +
|
|
216
|
+
p9.labs(
|
|
217
|
+
subtitle=f"Main Effect of {t.capitalize()}",
|
|
218
|
+
x="",
|
|
219
|
+
y="effect size"
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
plots.append(p)
|
|
223
|
+
|
|
224
|
+
p1 = (
|
|
225
|
+
(plots[0] | plots[1] | plots[2]) /
|
|
226
|
+
(plots[3] | plots[4] | plots[5]) /
|
|
227
|
+
(plots[6] | plots[7] | plots[8]) /
|
|
228
|
+
(plots[9] | plots[10] | plots[11])
|
|
229
|
+
)
|
|
230
|
+
display(p1 + p9.theme(figure_size=(9, 12)))
|
|
231
|
+
```
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+

|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
```python
|
|
241
|
+
# Plot the interaction of pairs of variables (Component Functions)
|
|
242
|
+
p1 = (
|
|
243
|
+
explainer.plot(
|
|
244
|
+
"hour:workingday",
|
|
245
|
+
theme='mako',
|
|
246
|
+
main_effects=True
|
|
247
|
+
) +
|
|
248
|
+
p9.labs(subtitle="Total Effect of Hour and Workingday")
|
|
249
|
+
)
|
|
250
|
+
p2 = (
|
|
251
|
+
explainer.plot(
|
|
252
|
+
"hour:feel_temp",
|
|
253
|
+
style='data',
|
|
254
|
+
theme='mako',
|
|
255
|
+
data=X_train,
|
|
256
|
+
main_effects=True,
|
|
257
|
+
size=2
|
|
258
|
+
) +
|
|
259
|
+
p9.labs(subtitle="Total Effect of Hour and Feel_temp")
|
|
260
|
+
)
|
|
261
|
+
display((p1 | p2) & p9.theme(figure_size=(8, 4), legend_position="bottom"))
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+

|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
```python
|
|
272
|
+
# Plot prediction breakdowns for the first three test samples (Local Interpretability)
|
|
273
|
+
plots = list()
|
|
274
|
+
for i in range(4):
|
|
275
|
+
p = (
|
|
276
|
+
explainer.breakdown(row=i, data=X_test).plot() +
|
|
277
|
+
p9.labs(subtitle=f"Breakdown Plot for Row {i}")
|
|
278
|
+
)
|
|
279
|
+
plots.append(p)
|
|
280
|
+
|
|
281
|
+
p1 = (
|
|
282
|
+
(plots[0] | plots[1]) /
|
|
283
|
+
(plots[2] | plots[3])
|
|
284
|
+
)
|
|
285
|
+
display(p1 + p9.theme(figure_size=(8, 8)))
|
|
286
|
+
```
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+

|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
```python
|
|
296
|
+
# Plot individual conditional expectations (ICE) with color encoding
|
|
297
|
+
ice = explainer.conditional(
|
|
298
|
+
variable='hour',
|
|
299
|
+
data=X_train.head(500)
|
|
300
|
+
)
|
|
301
|
+
p1 = (
|
|
302
|
+
ice.plot(alpha=.1) +
|
|
303
|
+
p9.ggtitle("ICE Plot of Hour")
|
|
304
|
+
)
|
|
305
|
+
p2 = (
|
|
306
|
+
ice.plot(
|
|
307
|
+
style='centered',
|
|
308
|
+
var_color='workingday',
|
|
309
|
+
theme='muted'
|
|
310
|
+
) +
|
|
311
|
+
p9.labs(
|
|
312
|
+
title="Centered ICE Plot of Hour",
|
|
313
|
+
subtitle="Colored by Workingday"
|
|
314
|
+
) +
|
|
315
|
+
p9.theme(legend_position="bottom")
|
|
316
|
+
)
|
|
317
|
+
display((p1 | p2) & p9.theme(figure_size=(8, 4), legend_position="bottom"))
|
|
318
|
+
```
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+

|
|
323
|
+
|
|
324
|
+
|
midlearn-0.1.2/README.md
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
<!-- README.md is generated from README.ipynb. Please edit that file -->
|
|
2
|
+
|
|
3
|
+
# midlearn <img src="docs/logo/logo_hex.png" align="right" height="138"/>
|
|
4
|
+
|
|
5
|
+
A [{rpy2}](https://rpy2.github.io/doc/latest/html/)-based Python wrapper for the [{midr}](https://ryo-asashi.github.io/midr/) R package to explain black-box models, with a [{scikit-learn}](https://scikit-learn.org/stable/) compatible API.
|
|
6
|
+
|
|
7
|
+
The goal of {midr} is to provide a model-agnostic method for interpreting and explaining black-box predictive models by creating a globally interpretable surrogate model.
|
|
8
|
+
The package implements 'Maximum Interpretation Decomposition' (MID), a functional decomposition technique that finds an optimal additive approximation of the original model.
|
|
9
|
+
This approximation is achieved by minimizing the squared error between the predictions of the black-box model and the surrogate model.
|
|
10
|
+
The theoretical foundations of MID are described in Iwasawa & Matsumori (2025) \[Forthcoming\], and the package itself is detailed in [Asashiba et al. (2025)](https://arxiv.org/abs/2506.08338).
|
|
11
|
+
|
|
12
|
+
## Installation
|
|
13
|
+
|
|
14
|
+
You can install the package directly from GitHub:
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install git+https://github.com/ryo-asashi/midlearn.git
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## Features
|
|
21
|
+
|
|
22
|
+
- **Scikit-learn Compatible API**: Fits seamlessly into your existing 'scikit-learn' workflows with a familiar .fit() and .predict() interface.
|
|
23
|
+
|
|
24
|
+
- **Model-Agnostic IML**: Explains any black-box model, from complex neural networks to gradient boosting machines.
|
|
25
|
+
|
|
26
|
+
- **Global Interpretability**: Generates a simple, additive surrogate model (MID) that provides a global understanding of the black-box model's behavior.
|
|
27
|
+
|
|
28
|
+
- **Direct Visualizations**: Easily creates plots for feature importance, component functions (dependence), prediction breakdowns, and conditional expectations using a plotnine-based interface.
|
|
29
|
+
|
|
30
|
+
## Requirements
|
|
31
|
+
|
|
32
|
+
This package is a {rpy2}-based Python wrapper and requires a working R installation on your system, as well as the {midr} R package.
|
|
33
|
+
|
|
34
|
+
You can install the R package from CRAN by running the following command in your R console:
|
|
35
|
+
|
|
36
|
+
```r
|
|
37
|
+
install.packages('midr')
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Quick Start
|
|
41
|
+
|
|
42
|
+
Here’s a basic example of how to use **midlearn** to explain a trained LightGBM model.
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
import pandas as pd
|
|
47
|
+
from sklearn.model_selection import train_test_split
|
|
48
|
+
from sklearn.metrics import root_mean_squared_error
|
|
49
|
+
from sklearn.datasets import fetch_openml
|
|
50
|
+
from sklearn import set_config
|
|
51
|
+
|
|
52
|
+
import lightgbm as lgb
|
|
53
|
+
import midlearn as mid
|
|
54
|
+
|
|
55
|
+
# Set up plotnine theme for clean visualizations
|
|
56
|
+
import plotnine as p9 # require plotnine >= 0.15.0
|
|
57
|
+
p9.theme_set(p9.theme_538(base_family='serif'))
|
|
58
|
+
|
|
59
|
+
# Configure scikit-learn display
|
|
60
|
+
set_config(display='text')
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## 1. Train a Black-Box Model
|
|
64
|
+
We use the California Housing dataset to train a LightGBM Regressor, which will serve as our black-box model.
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
# Load and prepare data
|
|
69
|
+
bikeshare = fetch_openml(data_id=42712)
|
|
70
|
+
X = pd.DataFrame(bikeshare.data, columns=bikeshare.feature_names)
|
|
71
|
+
y = bikeshare.target
|
|
72
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
|
73
|
+
|
|
74
|
+
# Fit a LightGBM regression model
|
|
75
|
+
estimator = lgb.LGBMRegressor(
|
|
76
|
+
force_col_wise=True,
|
|
77
|
+
n_estimators=500,
|
|
78
|
+
random_state=42
|
|
79
|
+
)
|
|
80
|
+
estimator.fit(X_train, y_train)
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
[LightGBM] [Info] Total Bins 283
|
|
84
|
+
[LightGBM] [Info] Number of data points in the train set: 13034, number of used features: 12
|
|
85
|
+
[LightGBM] [Info] Start training from score 190.379623
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
LGBMRegressor(force_col_wise=True, n_estimators=500, random_state=42)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
model_pred = estimator.predict(X_test)
|
|
98
|
+
rmse = root_mean_squared_error(model_pred, y_test)
|
|
99
|
+
print(f"RMSE: {round(rmse, 6)}")
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
RMSE: 37.615267
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
## 2. Create an Explaination Model
|
|
106
|
+
We fit the `MIDExplainer` to the training data to create a globally faithful, interpretable surrogate model (MID).
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
```python
|
|
110
|
+
# Initialize and fit the MID model
|
|
111
|
+
explainer = mid.MIDExplainer(
|
|
112
|
+
estimator=estimator,
|
|
113
|
+
penalty=.05,
|
|
114
|
+
singular_ok=True,
|
|
115
|
+
interactions=True,
|
|
116
|
+
encoding_frames={'hour':list(range(24))}
|
|
117
|
+
)
|
|
118
|
+
explainer.fit(X_train)
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
Generating predictions from the estimator...
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
R callback write-console: singular fit encountered
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
MIDExplainer(encoding_frames={'hour': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
|
132
|
+
13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
|
|
133
|
+
23]},
|
|
134
|
+
estimator=LGBMRegressor(force_col_wise=True, n_estimators=500,
|
|
135
|
+
random_state=42),
|
|
136
|
+
penalty=0.05, singular_ok=True)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
# Check the fidelity of the surrogate model to the original model
|
|
143
|
+
p = p9.ggplot() \
|
|
144
|
+
+ p9.geom_abline(slope=1, color='gray') \
|
|
145
|
+
+ p9.geom_point(p9.aes(estimator.predict(X_test), explainer.predict(X_test)), alpha=0.5, shape=".") \
|
|
146
|
+
+ p9.labs(
|
|
147
|
+
x='Prediction (LightGBM Regressor)',
|
|
148
|
+
y='Prediction (Surrogate MID Regressor)',
|
|
149
|
+
title='Surrogate Model Fidelity Check',
|
|
150
|
+
subtitle=f'R-squared score: {round(explainer.fidelity_score(X_test), 6)}',
|
|
151
|
+
)
|
|
152
|
+
display(p + p9.theme(figure_size=(5,5)))
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
Generating predictions from the estimator...
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+

|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
## 3. Visualize the Explanation Model
|
|
165
|
+
The MID model allows for clear visualization of feature importance, individual effects, and local prediction breakdowns.
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
```python
|
|
169
|
+
# Calculate and plot overall feature importance (default bar plot and heatmap)
|
|
170
|
+
imp = explainer.importance()
|
|
171
|
+
p1 = (
|
|
172
|
+
imp.plot(max_nterms=20, theme="muted") +
|
|
173
|
+
p9.labs(title="Feature Imortance Plot", subtitle="colored by order")
|
|
174
|
+
)
|
|
175
|
+
p2 = (
|
|
176
|
+
imp.plot(style='heatmap', color='black', linetype='dotted') +
|
|
177
|
+
p9.labs(title="Feature Importance Map", subtitle="colored by importance")
|
|
178
|
+
)
|
|
179
|
+
display((p1 | p2) & p9.theme(figure_size=(8, 4), legend_position="bottom"))
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+

|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
```python
|
|
190
|
+
# Plot the top 3 important main effects (Component Functions)
|
|
191
|
+
plots = list()
|
|
192
|
+
for i, t in enumerate(imp.terms(interactions=False)):
|
|
193
|
+
p = (
|
|
194
|
+
explainer.plot(term=t) +
|
|
195
|
+
p9.lims(y=[-180, 250]) +
|
|
196
|
+
p9.labs(
|
|
197
|
+
subtitle=f"Main Effect of {t.capitalize()}",
|
|
198
|
+
x="",
|
|
199
|
+
y="effect size"
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
plots.append(p)
|
|
203
|
+
|
|
204
|
+
p1 = (
|
|
205
|
+
(plots[0] | plots[1] | plots[2]) /
|
|
206
|
+
(plots[3] | plots[4] | plots[5]) /
|
|
207
|
+
(plots[6] | plots[7] | plots[8]) /
|
|
208
|
+
(plots[9] | plots[10] | plots[11])
|
|
209
|
+
)
|
|
210
|
+
display(p1 + p9.theme(figure_size=(9, 12)))
|
|
211
|
+
```
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+

|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
```python
|
|
221
|
+
# Plot the interaction of pairs of variables (Component Functions)
|
|
222
|
+
p1 = (
|
|
223
|
+
explainer.plot(
|
|
224
|
+
"hour:workingday",
|
|
225
|
+
theme='mako',
|
|
226
|
+
main_effects=True
|
|
227
|
+
) +
|
|
228
|
+
p9.labs(subtitle="Total Effect of Hour and Workingday")
|
|
229
|
+
)
|
|
230
|
+
p2 = (
|
|
231
|
+
explainer.plot(
|
|
232
|
+
"hour:feel_temp",
|
|
233
|
+
style='data',
|
|
234
|
+
theme='mako',
|
|
235
|
+
data=X_train,
|
|
236
|
+
main_effects=True,
|
|
237
|
+
size=2
|
|
238
|
+
) +
|
|
239
|
+
p9.labs(subtitle="Total Effect of Hour and Feel_temp")
|
|
240
|
+
)
|
|
241
|
+
display((p1 | p2) & p9.theme(figure_size=(8, 4), legend_position="bottom"))
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+

|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
```python
|
|
252
|
+
# Plot prediction breakdowns for the first three test samples (Local Interpretability)
|
|
253
|
+
plots = list()
|
|
254
|
+
for i in range(4):
|
|
255
|
+
p = (
|
|
256
|
+
explainer.breakdown(row=i, data=X_test).plot() +
|
|
257
|
+
p9.labs(subtitle=f"Breakdown Plot for Row {i}")
|
|
258
|
+
)
|
|
259
|
+
plots.append(p)
|
|
260
|
+
|
|
261
|
+
p1 = (
|
|
262
|
+
(plots[0] | plots[1]) /
|
|
263
|
+
(plots[2] | plots[3])
|
|
264
|
+
)
|
|
265
|
+
display(p1 + p9.theme(figure_size=(8, 8)))
|
|
266
|
+
```
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+

|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
```python
|
|
276
|
+
# Plot individual conditional expectations (ICE) with color encoding
|
|
277
|
+
ice = explainer.conditional(
|
|
278
|
+
variable='hour',
|
|
279
|
+
data=X_train.head(500)
|
|
280
|
+
)
|
|
281
|
+
p1 = (
|
|
282
|
+
ice.plot(alpha=.1) +
|
|
283
|
+
p9.ggtitle("ICE Plot of Hour")
|
|
284
|
+
)
|
|
285
|
+
p2 = (
|
|
286
|
+
ice.plot(
|
|
287
|
+
style='centered',
|
|
288
|
+
var_color='workingday',
|
|
289
|
+
theme='muted'
|
|
290
|
+
) +
|
|
291
|
+
p9.labs(
|
|
292
|
+
title="Centered ICE Plot of Hour",
|
|
293
|
+
subtitle="Colored by Workingday"
|
|
294
|
+
) +
|
|
295
|
+
p9.theme(legend_position="bottom")
|
|
296
|
+
)
|
|
297
|
+
display((p1 | p2) & p9.theme(figure_size=(8, 4), legend_position="bottom"))
|
|
298
|
+
```
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+

|
|
303
|
+
|
|
304
|
+
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# pyproject.toml
|
|
2
|
+
|
|
3
|
+
[build-system]
|
|
4
|
+
requires = ["setuptools>=61.0"]
|
|
5
|
+
build-backend = "setuptools.build_meta"
|
|
6
|
+
|
|
7
|
+
[project]
|
|
8
|
+
name = "midlearn"
|
|
9
|
+
version = "0.1.2"
|
|
10
|
+
authors = [
|
|
11
|
+
{ name="Ryoichi Asashiba", email="ryoichi.asashiba@gmail.com" },
|
|
12
|
+
]
|
|
13
|
+
description = "Python Wrapper of the 'midr' R package to interpret black-box models."
|
|
14
|
+
readme = "README.md"
|
|
15
|
+
requires-python = ">=3.8"
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Operating System :: OS Independent",
|
|
20
|
+
]
|
|
21
|
+
dependencies = [
|
|
22
|
+
"numpy",
|
|
23
|
+
"pandas",
|
|
24
|
+
"scikit-learn",
|
|
25
|
+
"rpy2>=3.5.0",
|
|
26
|
+
"plotnine",
|
|
27
|
+
"mizani"
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.urls]
|
|
31
|
+
Issues = "https://github.com/ryo-asashi/midlearn/issues"
|
midlearn-0.1.2/setup.cfg
ADDED