cherrypick-ml 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cherrypick_ml-0.1.0/LICENSE +21 -0
- cherrypick_ml-0.1.0/PKG-INFO +117 -0
- cherrypick_ml-0.1.0/README.md +82 -0
- cherrypick_ml-0.1.0/cherrypick/__init__.py +4 -0
- cherrypick_ml-0.1.0/cherrypick/anomaly.py +120 -0
- cherrypick_ml-0.1.0/cherrypick/explain.py +178 -0
- cherrypick_ml-0.1.0/cherrypick/orchestrator.py +797 -0
- cherrypick_ml-0.1.0/cherrypick/preprocessing.py +197 -0
- cherrypick_ml-0.1.0/cherrypick/splits.py +56 -0
- cherrypick_ml-0.1.0/cherrypick_ml.egg-info/PKG-INFO +117 -0
- cherrypick_ml-0.1.0/cherrypick_ml.egg-info/SOURCES.txt +14 -0
- cherrypick_ml-0.1.0/cherrypick_ml.egg-info/dependency_links.txt +1 -0
- cherrypick_ml-0.1.0/cherrypick_ml.egg-info/requires.txt +13 -0
- cherrypick_ml-0.1.0/cherrypick_ml.egg-info/top_level.txt +1 -0
- cherrypick_ml-0.1.0/setup.cfg +4 -0
- cherrypick_ml-0.1.0/setup.py +37 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sujal Giri Sanyasi
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cherrypick-ml
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A lightweight ML orchestration library with preprocessing, anomaly detection, and explainability tools
|
|
5
|
+
Author: Sujal G Sanyasi
|
|
6
|
+
Author-email: cherrypickml1@gmail.com
|
|
7
|
+
License: MIT
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Requires-Dist: pandas
|
|
14
|
+
Requires-Dist: numpy
|
|
15
|
+
Requires-Dist: matplotlib
|
|
16
|
+
Requires-Dist: shap
|
|
17
|
+
Requires-Dist: seaborn
|
|
18
|
+
Requires-Dist: joblib
|
|
19
|
+
Requires-Dist: plotly
|
|
20
|
+
Requires-Dist: xgboost
|
|
21
|
+
Requires-Dist: rich
|
|
22
|
+
Requires-Dist: lightgbm
|
|
23
|
+
Requires-Dist: catboost
|
|
24
|
+
Requires-Dist: pytest
|
|
25
|
+
Requires-Dist: imblearn
|
|
26
|
+
Dynamic: author
|
|
27
|
+
Dynamic: author-email
|
|
28
|
+
Dynamic: classifier
|
|
29
|
+
Dynamic: description
|
|
30
|
+
Dynamic: description-content-type
|
|
31
|
+
Dynamic: license
|
|
32
|
+
Dynamic: license-file
|
|
33
|
+
Dynamic: requires-dist
|
|
34
|
+
Dynamic: summary
|
|
35
|
+
|
|
36
|
+
<p align="center">
|
|
37
|
+
<img src="assets/cherrylogo.png" alt="cherrypick-ml logo" width="1100" height=450>
|
|
38
|
+
</p>
|
|
39
|
+
|
|
40
|
+
-----------------
|
|
41
|
+
|
|
42
|
+
# cherrypick-ml: A Machine Learning Orchestration and Pipeline Toolkit
|
|
43
|
+
|
|
44
|
+
| | |
|
|
45
|
+
| --- | --- |
|
|
46
|
+
| Testing | Structured validation of preprocessing, orchestration, and explainability components |
|
|
47
|
+
| Package | PyPI distribution for cherrypick-ml |
|
|
48
|
+
| Meta | MIT License, Python-based machine learning pipeline framework |
|
|
49
|
+
|
|
50
|
+
---
|
|
51
|
+
|
|
52
|
+
## What is it?
|
|
53
|
+
|
|
54
|
+
**cherrypick-ml** is a Python package that provides a unified interface for building, managing, and evaluating machine learning workflows. It integrates preprocessing, anomaly detection, model orchestration, and explainability into a single, modular framework.
|
|
55
|
+
|
|
56
|
+
The library is designed to simplify real-world machine learning development by reducing repetitive code while maintaining flexibility and transparency in model pipelines.
|
|
57
|
+
|
|
58
|
+
---
|
|
59
|
+
|
|
60
|
+
## Table of Contents
|
|
61
|
+
|
|
62
|
+
- [Main Features](#main-features)
|
|
63
|
+
- [Core Components](#core-components)
|
|
64
|
+
- [Where to get it](#where-to-get-it)
|
|
65
|
+
- [Dependencies](#dependencies)
|
|
66
|
+
- [Installation from sources](#installation-from-sources)
|
|
67
|
+
- [Basic Usage](#basic-usage)
|
|
68
|
+
- [License](#license)
|
|
69
|
+
- [Documentation](#documentation)
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## Main Features
|
|
74
|
+
|
|
75
|
+
cherrypick-ml provides the following core capabilities:
|
|
76
|
+
|
|
77
|
+
- Automated model orchestration for classification and regression tasks
|
|
78
|
+
- Integrated preprocessing utilities including encoding and missing value handling
|
|
79
|
+
- Outlier detection using statistical method such as Inter quartile range(IQR), Z-score, modified Z-score, Isolation Forest and Local Outlier Factor based outlier pruning
|
|
80
|
+
- SHAP-based explainability for feature importance and model interpretation
|
|
81
|
+
- Flexible train-test splitting utilities
|
|
82
|
+
- Modular design allowing independent usage of components
|
|
83
|
+
- Designed for practical, real-world machine learning workflows
|
|
84
|
+
|
|
85
|
+
---
|
|
86
|
+
|
|
87
|
+
## Core Components
|
|
88
|
+
|
|
89
|
+
The library is structured into the following modules:
|
|
90
|
+
|
|
91
|
+
- **Orchestrator**
|
|
92
|
+
High-level interface for training, evaluating, and selecting models with explainable visualisation
|
|
93
|
+
|
|
94
|
+
- **preprocessing**
|
|
95
|
+
Tools for encoding, imputation, and feature preparation
|
|
96
|
+
|
|
97
|
+
- **anomaly**
|
|
98
|
+
Outlier detection and data pruning utilities
|
|
99
|
+
|
|
100
|
+
- **explain**
|
|
101
|
+
Model explainability using SHAP-based analysis
|
|
102
|
+
|
|
103
|
+
- **splits**
|
|
104
|
+
Utilities for dataset partitioning
|
|
105
|
+
|
|
106
|
+
---
|
|
107
|
+
|
|
108
|
+
## Where to get it
|
|
109
|
+
|
|
110
|
+
The source code is currently hosted on GitHub at:
|
|
111
|
+
|
|
112
|
+
https://github.com/Sujal-G-Sanyasi/cherrypick-ml
|
|
113
|
+
|
|
114
|
+
Binary installers for the latest released version are available at the Python Package Index (PyPI):
|
|
115
|
+
|
|
116
|
+
```sh
|
|
117
|
+
pip install cherrypick-ml
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<img src="assets/cherrylogo.png" alt="cherrypick-ml logo" width="1100" height=450>
|
|
3
|
+
</p>
|
|
4
|
+
|
|
5
|
+
-----------------
|
|
6
|
+
|
|
7
|
+
# cherrypick-ml: A Machine Learning Orchestration and Pipeline Toolkit
|
|
8
|
+
|
|
9
|
+
| | |
|
|
10
|
+
| --- | --- |
|
|
11
|
+
| Testing | Structured validation of preprocessing, orchestration, and explainability components |
|
|
12
|
+
| Package | PyPI distribution for cherrypick-ml |
|
|
13
|
+
| Meta | MIT License, Python-based machine learning pipeline framework |
|
|
14
|
+
|
|
15
|
+
---
|
|
16
|
+
|
|
17
|
+
## What is it?
|
|
18
|
+
|
|
19
|
+
**cherrypick-ml** is a Python package that provides a unified interface for building, managing, and evaluating machine learning workflows. It integrates preprocessing, anomaly detection, model orchestration, and explainability into a single, modular framework.
|
|
20
|
+
|
|
21
|
+
The library is designed to simplify real-world machine learning development by reducing repetitive code while maintaining flexibility and transparency in model pipelines.
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## Table of Contents
|
|
26
|
+
|
|
27
|
+
- [Main Features](#main-features)
|
|
28
|
+
- [Core Components](#core-components)
|
|
29
|
+
- [Where to get it](#where-to-get-it)
|
|
30
|
+
- [Dependencies](#dependencies)
|
|
31
|
+
- [Installation from sources](#installation-from-sources)
|
|
32
|
+
- [Basic Usage](#basic-usage)
|
|
33
|
+
- [License](#license)
|
|
34
|
+
- [Documentation](#documentation)
|
|
35
|
+
|
|
36
|
+
---
|
|
37
|
+
|
|
38
|
+
## Main Features
|
|
39
|
+
|
|
40
|
+
cherrypick-ml provides the following core capabilities:
|
|
41
|
+
|
|
42
|
+
- Automated model orchestration for classification and regression tasks
|
|
43
|
+
- Integrated preprocessing utilities including encoding and missing value handling
|
|
44
|
+
- Outlier detection using statistical method such as Inter quartile range(IQR), Z-score, modified Z-score, Isolation Forest and Local Outlier Factor based outlier pruning
|
|
45
|
+
- SHAP-based explainability for feature importance and model interpretation
|
|
46
|
+
- Flexible train-test splitting utilities
|
|
47
|
+
- Modular design allowing independent usage of components
|
|
48
|
+
- Designed for practical, real-world machine learning workflows
|
|
49
|
+
|
|
50
|
+
---
|
|
51
|
+
|
|
52
|
+
## Core Components
|
|
53
|
+
|
|
54
|
+
The library is structured into the following modules:
|
|
55
|
+
|
|
56
|
+
- **Orchestrator**
|
|
57
|
+
High-level interface for training, evaluating, and selecting models with explainable visualisation
|
|
58
|
+
|
|
59
|
+
- **preprocessing**
|
|
60
|
+
Tools for encoding, imputation, and feature preparation
|
|
61
|
+
|
|
62
|
+
- **anomaly**
|
|
63
|
+
Outlier detection and data pruning utilities
|
|
64
|
+
|
|
65
|
+
- **explain**
|
|
66
|
+
Model explainability using SHAP-based analysis
|
|
67
|
+
|
|
68
|
+
- **splits**
|
|
69
|
+
Utilities for dataset partitioning
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## Where to get it
|
|
74
|
+
|
|
75
|
+
The source code is currently hosted on GitHub at:
|
|
76
|
+
|
|
77
|
+
https://github.com/Sujal-G-Sanyasi/cherrypick-ml
|
|
78
|
+
|
|
79
|
+
Binary installers for the latest released version are available at the Python Package Index (PyPI):
|
|
80
|
+
|
|
81
|
+
```sh
|
|
82
|
+
pip install cherrypick-ml
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from scipy.stats import zscore
|
|
4
|
+
from sklearn.ensemble import IsolationForest
|
|
5
|
+
from sklearn.neighbors import LocalOutlierFactor
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import warnings as war
|
|
9
|
+
war.filterwarnings('ignore')
|
|
10
|
+
|
|
11
|
+
class OutlierPruner:
|
|
12
|
+
"""
|
|
13
|
+
OutlierPruner provides statistical and ML-based methods
|
|
14
|
+
for detecting and removing outliers from a dataset.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
method : {'iqr', 'zscore', 'mod_zscore', 'isoforest', 'lof'}
|
|
19
|
+
Method used for outlier detection.
|
|
20
|
+
|
|
21
|
+
- `'iqr'` - Interquartile Range method
|
|
22
|
+
- `'zscore'` - Standard Z-score normalization
|
|
23
|
+
- `'mod_zscore'` - Modified Z-score:
|
|
24
|
+
|
|
25
|
+
`modified_Zscore = 0.6745 * (X - median) / MAD`
|
|
26
|
+
|
|
27
|
+
*Where,*
|
|
28
|
+
**median** = *median of the sample data*
|
|
29
|
+
**MAD** = *median absolute deviation*
|
|
30
|
+
**X** = *sample data points(Xi)*
|
|
31
|
+
- `'isoforest'` - Isolation Forest, an ensemble-based anomaly detection method
|
|
32
|
+
- `'lof'` - Local Outlier Factor, detects outliers using local density
|
|
33
|
+
|
|
34
|
+
df : pandas.DataFrame
|
|
35
|
+
Input dataset on which outlier pruning will be applied.
|
|
36
|
+
|
|
37
|
+
col : str
|
|
38
|
+
Column name used for outlier detection in statistical methods.
|
|
39
|
+
|
|
40
|
+
Notes
|
|
41
|
+
-----
|
|
42
|
+
- Statistical methods require a specific column (``col``).
|
|
43
|
+
- ML-based methods (Isolation Forest, Local Outlier Factor) operate on numerical features.
|
|
44
|
+
- Modified Z-score is robust to extreme values as it uses the median instead of mean.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, method: Literal['iqr', 'zscore', 'mod_zscore', 'isoforest', 'lof'], df:pd.DataFrame, col:str):
|
|
48
|
+
self.df = df
|
|
49
|
+
self.col = col
|
|
50
|
+
self.method = method
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def __iqr(self):
|
|
54
|
+
|
|
55
|
+
Q1 = self.df[self.col].quantile(0.25)
|
|
56
|
+
Q3 = self.df[self.col].quantile(0.75)
|
|
57
|
+
IQR = Q3 - Q1
|
|
58
|
+
|
|
59
|
+
lower_fence = Q1 - 1.5 * IQR
|
|
60
|
+
upper_fence = Q3 + 1.5 * IQR
|
|
61
|
+
|
|
62
|
+
return self.df[(self.df[self.col] >= lower_fence) & (self.df[self.col] <= upper_fence)]
|
|
63
|
+
|
|
64
|
+
def __zscore(self):
|
|
65
|
+
z = zscore(self.df[self.col])
|
|
66
|
+
return self.df[np.abs(z) < 3]
|
|
67
|
+
|
|
68
|
+
def __isoforest(self):
|
|
69
|
+
isolate = IsolationForest(contamination=0.3, n_jobs=-1, random_state=42)
|
|
70
|
+
|
|
71
|
+
X = self.df.select_dtypes(include=np.number)
|
|
72
|
+
labels_ = isolate.fit_predict(X)
|
|
73
|
+
|
|
74
|
+
# outliers = np.where(labels_ == -1)[0]
|
|
75
|
+
return self.df.iloc[labels_!= -1]
|
|
76
|
+
|
|
77
|
+
def __lof(self):
|
|
78
|
+
lof = LocalOutlierFactor(n_jobs=-1, n_neighbors=20, algorithm='kd_tree')
|
|
79
|
+
X = self.df.select_dtypes(include = np.number)
|
|
80
|
+
labels = lof.fit_predict(X)
|
|
81
|
+
|
|
82
|
+
return self.df.iloc[labels != -1]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def __modded_zscore(self):
|
|
86
|
+
|
|
87
|
+
df1=self.df
|
|
88
|
+
median = np.median(df1[self.col])
|
|
89
|
+
mad = np.median(np.abs(df1[self.col] - median))
|
|
90
|
+
## If MAD value == 0, then it will return original DataFrame instead of garbage value and prevent division by zero error
|
|
91
|
+
if mad == 0 :
|
|
92
|
+
return self.df
|
|
93
|
+
|
|
94
|
+
mod_zscore = 0.6745 * (df1[self.col] - median)/mad
|
|
95
|
+
|
|
96
|
+
normal_data = df1[mod_zscore.abs() < 3]
|
|
97
|
+
outliers = df1[mod_zscore.abs() > 3]
|
|
98
|
+
|
|
99
|
+
return normal_data
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def remove_outlier(self):
|
|
103
|
+
'''
|
|
104
|
+
Calling this function will transform dataset with configuration provided to **OutlierPruner**.
|
|
105
|
+
'''
|
|
106
|
+
try:
|
|
107
|
+
METHOD_CONFIG = {
|
|
108
|
+
"iqr" : self.__iqr,
|
|
109
|
+
"zscore" : self.__zscore,
|
|
110
|
+
"mod_zscore":self.__modded_zscore,
|
|
111
|
+
"isoforest" : self.__isoforest,
|
|
112
|
+
"lof" : self.__lof
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
return METHOD_CONFIG[self.method]()
|
|
116
|
+
|
|
117
|
+
except KeyError:
|
|
118
|
+
raise ValueError(f"Provide an appropriate method : {self.method}")
|
|
119
|
+
except Exception as err:
|
|
120
|
+
raise ValueError(err)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import shap
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
from sklearn import tree
|
|
6
|
+
from cherrypick.orchestrator import Orchestrator
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def explainer(model, data, impact_type :Literal['pos', 'neg', 'all'] = 'all'):
|
|
11
|
+
"""
|
|
12
|
+
Compute SHAP-based feature importance and return sorted impact values.
|
|
13
|
+
|
|
14
|
+
This function uses SHAP's TreeExplainer to calculate feature contributions
|
|
15
|
+
for a given model and dataset. It aggregates SHAP values across samples
|
|
16
|
+
and (if applicable) across multiple classes, returning feature importance
|
|
17
|
+
based on absolute SHAP magnitudes.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
model : object
|
|
22
|
+
A trained tree-based model compatible with shap.TreeExplainer
|
|
23
|
+
(e.g., XGBoost, LightGBM, RandomForest).
|
|
24
|
+
|
|
25
|
+
data : pandas.DataFrame
|
|
26
|
+
Input dataset for which SHAP values are to be computed.
|
|
27
|
+
Must contain only feature columns (no target column).
|
|
28
|
+
|
|
29
|
+
impact_type : {'pos', 'neg', 'all'}, default='all'
|
|
30
|
+
Type of feature impact to return:
|
|
31
|
+
- '**pos**' - Returns features with positive contribution.
|
|
32
|
+
- '**neg**' - Returns features with negative contribution.
|
|
33
|
+
- '**all**' - Returns all features with overall importance
|
|
34
|
+
(absolute SHAP values).
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
result : pandas.DataFrame
|
|
39
|
+
A sorted DataFrame containing feature importance:
|
|
40
|
+
- For 'all' → columns: ['Features', 'Overall_Impact']
|
|
41
|
+
- For 'pos' → columns: ['Features', 'Positive_Impact']
|
|
42
|
+
- For 'neg' → columns: ['Features', 'Negative_Impact']
|
|
43
|
+
|
|
44
|
+
shap_values : shap.Explanation
|
|
45
|
+
Raw SHAP explanation object containing per-sample contributions.
|
|
46
|
+
|
|
47
|
+
Notes
|
|
48
|
+
-----
|
|
49
|
+
- For multi-class models, SHAP values are averaged across classes.
|
|
50
|
+
- Feature importance is computed using mean absolute SHAP values.
|
|
51
|
+
- The function also stores SHAP values globally in `_shap_val`.
|
|
52
|
+
|
|
53
|
+
Raises
|
|
54
|
+
------
|
|
55
|
+
ValueError
|
|
56
|
+
If `impact_type` is not one of {'pos', 'neg', 'all'}.
|
|
57
|
+
|
|
58
|
+
Example
|
|
59
|
+
-------
|
|
60
|
+
>>> result, shap_vals = explainer(model, X_test, impact_type='all')
|
|
61
|
+
>>> print(result.head())
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
## All the Shap values with magnitude based as well!
|
|
66
|
+
features = [ ]
|
|
67
|
+
all_values = [ ]
|
|
68
|
+
|
|
69
|
+
neg_values = [ ]
|
|
70
|
+
neg_feature = [ ]
|
|
71
|
+
|
|
72
|
+
pos_values = [ ]
|
|
73
|
+
pos_feature = [ ]
|
|
74
|
+
|
|
75
|
+
explain = shap.TreeExplainer(model = model)
|
|
76
|
+
|
|
77
|
+
shap_values = explain(X = data)
|
|
78
|
+
|
|
79
|
+
global _shap_val
|
|
80
|
+
_shap_val = shap_values
|
|
81
|
+
|
|
82
|
+
vals = _shap_val.values
|
|
83
|
+
|
|
84
|
+
if vals.ndim >= 3 and impact_type == 'all':
|
|
85
|
+
vals = np.abs(vals).mean(axis = (0, 2))
|
|
86
|
+
elif vals.ndim == 2 and impact_type == 'all':
|
|
87
|
+
vals = np.abs(vals).mean(axis=0)
|
|
88
|
+
|
|
89
|
+
elif (impact_type == "pos" or impact_type == "neg") and vals.ndim >=3:
|
|
90
|
+
vals = vals.mean(axis=(0, 2))
|
|
91
|
+
elif (impact_type == "pos" or impact_type == "neg") and vals.ndim == 2:
|
|
92
|
+
vals = vals.mean(axis = 0)
|
|
93
|
+
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError("Invalid Impact type or dimentions of shap values")
|
|
96
|
+
|
|
97
|
+
for feature, value in zip(data.columns, vals):
|
|
98
|
+
|
|
99
|
+
features.append(feature)
|
|
100
|
+
all_values.append(value)
|
|
101
|
+
|
|
102
|
+
if value < 0:
|
|
103
|
+
neg_values.append(value)
|
|
104
|
+
neg_feature.append(feature)
|
|
105
|
+
|
|
106
|
+
else:
|
|
107
|
+
pos_values.append(value)
|
|
108
|
+
pos_feature.append(feature)
|
|
109
|
+
|
|
110
|
+
if impact_type == 'neg':
|
|
111
|
+
result = pd.DataFrame({
|
|
112
|
+
"Features" : neg_feature,
|
|
113
|
+
"Negative_Impact" : neg_values
|
|
114
|
+
}).sort_values(by="Negative_Impact", ascending=False)
|
|
115
|
+
|
|
116
|
+
elif impact_type == 'pos':
|
|
117
|
+
result = pd.DataFrame({
|
|
118
|
+
"Features" : pos_feature,
|
|
119
|
+
"Positive_Impact" : pos_values
|
|
120
|
+
|
|
121
|
+
}).sort_values(by="Positive_Impact", ascending=False)
|
|
122
|
+
|
|
123
|
+
elif impact_type == 'all':
|
|
124
|
+
result = pd.DataFrame({
|
|
125
|
+
"Features" : features,
|
|
126
|
+
"Overall_Impact" : all_values
|
|
127
|
+
}).sort_values(by="Overall_Impact", ascending=False)
|
|
128
|
+
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError("Invalid Impact type : must be neg, pos or all")
|
|
131
|
+
|
|
132
|
+
return result, shap_values
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def summary_plot(data):
|
|
136
|
+
'''
|
|
137
|
+
Summary plot for feature contribution for all the classes.
|
|
138
|
+
'''
|
|
139
|
+
|
|
140
|
+
shap.summary_plot(_shap_val, data)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def bar_plot(n_classes):
|
|
144
|
+
'''
|
|
145
|
+
Bar plot analysis of feature contribution for each class
|
|
146
|
+
'''
|
|
147
|
+
for class_id in range(n_classes):
|
|
148
|
+
plt.title(f"For class_id {class_id}")
|
|
149
|
+
shap.plots.bar(_shap_val[..., class_id])
|
|
150
|
+
plt.tight_layout()
|
|
151
|
+
plt.show()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# def force_plot(shap_values):
|
|
155
|
+
# pass
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def tree_plot(model , feature_names, size:tuple):
|
|
159
|
+
plt.figure(figsize=size)
|
|
160
|
+
tree.plot_tree(model, filled=True, feature_names=feature_names, class_names=True)
|
|
161
|
+
plt.tight_layout()
|
|
162
|
+
plt.show()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
|