upgini 1.2.71a3832.dev9__py3-none-any.whl → 1.2.71a3832.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- upgini/__about__.py +1 -1
- upgini/features_enricher.py +52 -6
- upgini/metrics.py +31 -20
- {upgini-1.2.71a3832.dev9.dist-info → upgini-1.2.71a3832.dev11.dist-info}/METADATA +1 -1
- {upgini-1.2.71a3832.dev9.dist-info → upgini-1.2.71a3832.dev11.dist-info}/RECORD +7 -7
- {upgini-1.2.71a3832.dev9.dist-info → upgini-1.2.71a3832.dev11.dist-info}/WHEEL +0 -0
- {upgini-1.2.71a3832.dev9.dist-info → upgini-1.2.71a3832.dev11.dist-info}/licenses/LICENSE +0 -0
upgini/__about__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "1.2.71a3832.
|
1
|
+
__version__ = "1.2.71a3832.dev11"
|
upgini/features_enricher.py
CHANGED
@@ -1512,8 +1512,7 @@ class FeaturesEnricher(TransformerMixin):
|
|
1512
1512
|
self.logger.info(f"Client features column on prepare data for metrics: {client_features}")
|
1513
1513
|
|
1514
1514
|
filtered_enriched_features = self.__filtered_enriched_features(
|
1515
|
-
importance_threshold,
|
1516
|
-
max_features,
|
1515
|
+
importance_threshold, max_features, trace_id, validated_X
|
1517
1516
|
)
|
1518
1517
|
filtered_enriched_features = [c for c in filtered_enriched_features if c not in client_features]
|
1519
1518
|
|
@@ -2541,7 +2540,9 @@ if response.status_code == 200:
|
|
2541
2540
|
for c in itertools.chain(validated_Xy.columns.tolist(), generated_features)
|
2542
2541
|
if c not in self.dropped_client_feature_names_
|
2543
2542
|
]
|
2544
|
-
filtered_columns = self.__filtered_enriched_features(
|
2543
|
+
filtered_columns = self.__filtered_enriched_features(
|
2544
|
+
importance_threshold, max_features, trace_id, validated_X
|
2545
|
+
)
|
2545
2546
|
selecting_columns.extend(
|
2546
2547
|
c for c in filtered_columns if c in result.columns and c not in validated_X.columns
|
2547
2548
|
)
|
@@ -3805,6 +3806,46 @@ if response.status_code == 200:
|
|
3805
3806
|
|
3806
3807
|
return result_features
|
3807
3808
|
|
3809
|
+
def __get_features_importance_from_server(self, trace_id: str, df: pd.DataFrame):
|
3810
|
+
if self._search_task is None:
|
3811
|
+
raise NotFittedError(self.bundle.get("transform_unfitted_enricher"))
|
3812
|
+
features_meta = self._search_task.get_all_features_metadata_v2()
|
3813
|
+
if features_meta is None:
|
3814
|
+
raise Exception(self.bundle.get("missing_features_meta"))
|
3815
|
+
|
3816
|
+
original_names_dict = {c.name: c.originalName for c in self._search_task.get_file_metadata(trace_id).columns}
|
3817
|
+
df = df.rename(columns=original_names_dict)
|
3818
|
+
|
3819
|
+
features_meta.sort(key=lambda m: (-m.shap_value, m.name))
|
3820
|
+
|
3821
|
+
importances = {}
|
3822
|
+
|
3823
|
+
for feature_meta in features_meta:
|
3824
|
+
if feature_meta.name in original_names_dict.keys():
|
3825
|
+
feature_meta.name = original_names_dict[feature_meta.name]
|
3826
|
+
|
3827
|
+
is_client_feature = feature_meta.name in df.columns
|
3828
|
+
|
3829
|
+
if feature_meta.shap_value == 0.0:
|
3830
|
+
continue
|
3831
|
+
|
3832
|
+
# Use only important features
|
3833
|
+
if (
|
3834
|
+
feature_meta.name == COUNTRY
|
3835
|
+
# In select_features mode we select also from etalon features and need to show them
|
3836
|
+
or (not self.fit_select_features and is_client_feature)
|
3837
|
+
):
|
3838
|
+
continue
|
3839
|
+
|
3840
|
+
# Temporary workaround for duplicate features metadata
|
3841
|
+
if feature_meta.name in importances:
|
3842
|
+
self.logger.warning(f"WARNING: Duplicate feature metadata: {feature_meta}")
|
3843
|
+
continue
|
3844
|
+
|
3845
|
+
importances[feature_meta.name] = feature_meta.shap_value
|
3846
|
+
|
3847
|
+
return importances
|
3848
|
+
|
3808
3849
|
def __prepare_feature_importances(
|
3809
3850
|
self, trace_id: str, df: pd.DataFrame, updated_shaps: Optional[Dict[str, float]] = None, silent=False
|
3810
3851
|
):
|
@@ -3990,9 +4031,12 @@ if response.status_code == 200:
|
|
3990
4031
|
)
|
3991
4032
|
|
3992
4033
|
def __filtered_importance_names(
|
3993
|
-
self, importance_threshold: Optional[float], max_features: Optional[int]
|
4034
|
+
self, importance_threshold: Optional[float], max_features: Optional[int], trace_id: str, df: pd.DataFrame
|
3994
4035
|
) -> List[str]:
|
3995
|
-
|
4036
|
+
# get features importance from server
|
4037
|
+
filtered_importances = self.__get_features_importance_from_server(trace_id, df)
|
4038
|
+
|
4039
|
+
if len(filtered_importances) == 0:
|
3996
4040
|
return []
|
3997
4041
|
|
3998
4042
|
filtered_importances = list(zip(self.feature_names_, self.feature_importances_))
|
@@ -4212,11 +4256,13 @@ if response.status_code == 200:
|
|
4212
4256
|
self,
|
4213
4257
|
importance_threshold: Optional[float],
|
4214
4258
|
max_features: Optional[int],
|
4259
|
+
trace_id: str,
|
4260
|
+
df: pd.DataFrame,
|
4215
4261
|
) -> List[str]:
|
4216
4262
|
importance_threshold = self.__validate_importance_threshold(importance_threshold)
|
4217
4263
|
max_features = self.__validate_max_features(max_features)
|
4218
4264
|
|
4219
|
-
return self.__filtered_importance_names(importance_threshold, max_features)
|
4265
|
+
return self.__filtered_importance_names(importance_threshold, max_features, trace_id, df)
|
4220
4266
|
|
4221
4267
|
def __detect_missing_search_keys(
|
4222
4268
|
self,
|
upgini/metrics.py
CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
import inspect
|
4
4
|
import logging
|
5
5
|
import re
|
6
|
-
import warnings
|
7
6
|
from collections import defaultdict
|
8
7
|
from copy import deepcopy
|
9
8
|
from dataclasses import dataclass
|
@@ -755,9 +754,12 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
755
754
|
logger=logger,
|
756
755
|
)
|
757
756
|
self.cat_features = None
|
757
|
+
self.n_classes = None
|
758
758
|
|
759
759
|
def _prepare_to_fit(self, x: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series, np.ndarray, dict]:
|
760
760
|
x, y_numpy, groups, params = super()._prepare_to_fit(x, y)
|
761
|
+
if self.target_type in [ModelTaskType.BINARY, ModelTaskType.MULTICLASS]:
|
762
|
+
self.n_classes = len(np.unique(y_numpy))
|
761
763
|
if LIGHTGBM_EARLY_STOPPING_ROUNDS is not None:
|
762
764
|
params["callbacks"] = [lgb.early_stopping(stopping_rounds=LIGHTGBM_EARLY_STOPPING_ROUNDS, verbose=False)]
|
763
765
|
self.cat_features = _get_cat_features(x)
|
@@ -783,31 +785,40 @@ class LightGBMWrapper(EstimatorWrapper):
|
|
783
785
|
|
784
786
|
def calculate_shap(self, x: pd.DataFrame, y: pd.Series, estimator) -> Optional[Dict[str, float]]:
|
785
787
|
try:
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
788
|
+
shap_matrix = estimator.predict(
|
789
|
+
x,
|
790
|
+
predict_disable_shape_check=True,
|
791
|
+
raw_score=True,
|
792
|
+
pred_leaf=False,
|
793
|
+
pred_early_stop=True,
|
794
|
+
pred_contrib=True,
|
792
795
|
)
|
793
|
-
from shap import TreeExplainer
|
794
|
-
|
795
|
-
if not isinstance(estimator, (LGBMRegressor, LGBMClassifier)):
|
796
|
-
return None
|
797
796
|
|
798
|
-
|
799
|
-
|
800
|
-
|
797
|
+
if self.target_type == ModelTaskType.MULTICLASS:
|
798
|
+
n_feat = x.shape[1]
|
799
|
+
shap_matrix.shape = (shap_matrix.shape[0], self.n_classes, n_feat + 1)
|
800
|
+
shap_matrix = np.mean(np.abs(shap_matrix), axis=1)
|
801
801
|
|
802
|
-
#
|
803
|
-
|
804
|
-
if isinstance(shap_values, list):
|
805
|
-
shap_values = shap_values[1]
|
802
|
+
# exclude base value
|
803
|
+
shap_matrix = shap_matrix[:, :-1]
|
806
804
|
|
807
|
-
# Calculate mean absolute SHAP value for each feature
|
808
805
|
feature_importance = {}
|
809
806
|
for i, col in enumerate(x.columns):
|
810
|
-
feature_importance[col] = np.mean(np.abs(
|
807
|
+
feature_importance[col] = np.mean(np.abs(shap_matrix[:, i]))
|
808
|
+
|
809
|
+
# # exclude last column (base value)
|
810
|
+
# shap_values_only = shap_values[:, :-1]
|
811
|
+
# mean_abs_shap = np.mean(np.abs(shap_values_only), axis=0)
|
812
|
+
|
813
|
+
# # For classification, shap_values is returned as a list for each class
|
814
|
+
# # Take values for the positive class
|
815
|
+
# if isinstance(shap_values, list):
|
816
|
+
# shap_values = shap_values[1]
|
817
|
+
|
818
|
+
# # Calculate mean absolute SHAP value for each feature
|
819
|
+
# feature_importance = {}
|
820
|
+
# for i, col in enumerate(x.columns):
|
821
|
+
# feature_importance[col] = np.mean(np.abs(shap_values[:, i]))
|
811
822
|
|
812
823
|
return feature_importance
|
813
824
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: upgini
|
3
|
-
Version: 1.2.71a3832.
|
3
|
+
Version: 1.2.71a3832.dev11
|
4
4
|
Summary: Intelligent data search & enrichment for Machine Learning
|
5
5
|
Project-URL: Bug Reports, https://github.com/upgini/upgini/issues
|
6
6
|
Project-URL: Homepage, https://upgini.com/
|
@@ -1,12 +1,12 @@
|
|
1
|
-
upgini/__about__.py,sha256=
|
1
|
+
upgini/__about__.py,sha256=MPYFg9v0SOhqTxe0IfYh4m6Nh3TlmyfHR9sua58WXBM,34
|
2
2
|
upgini/__init__.py,sha256=LXSfTNU0HnlOkE69VCxkgIKDhWP-JFo_eBQ71OxTr5Y,261
|
3
3
|
upgini/ads.py,sha256=nvuRxRx5MHDMgPr9SiU-fsqRdFaBv8p4_v1oqiysKpc,2714
|
4
4
|
upgini/dataset.py,sha256=aspri7ZAgwkNNUiIgQ1GRXvw8XQii3F4RfNXSrF4wrw,35365
|
5
5
|
upgini/errors.py,sha256=2b_Wbo0OYhLUbrZqdLIx5jBnAsiD1Mcenh-VjR4HCTw,950
|
6
|
-
upgini/features_enricher.py,sha256=
|
6
|
+
upgini/features_enricher.py,sha256=oYOBaHIyPjm-EEZvJT9pU35_DW8bArEQKymZyhW8LbE,206592
|
7
7
|
upgini/http.py,sha256=RvzcShpDXssLs6ycGN8xilkKi8ZV9XGUrrk8bwdUzbw,43607
|
8
8
|
upgini/metadata.py,sha256=Yd6iW2f7Wz6vUkg5uvR4xylN16ANnCKVKqAsAkap7p8,12354
|
9
|
-
upgini/metrics.py,sha256=
|
9
|
+
upgini/metrics.py,sha256=9AaQi7Yb22ZNnycUOAUpcP7TWF5Pfy_NGACcDj10aMs,38820
|
10
10
|
upgini/search_task.py,sha256=EuCGp0iCWz2fpuJgN6M47aP_CtIi3Oq9zw78w0mkKiU,17595
|
11
11
|
upgini/spinner.py,sha256=4iMd-eIe_BnkqFEMIliULTbj6rNI2HkN_VJ4qYe0cUc,1118
|
12
12
|
upgini/version_validator.py,sha256=DvbaAvuYFoJqYt0fitpsk6Xcv-H1BYDJYHUMxaKSH_Y,1509
|
@@ -70,7 +70,7 @@ upgini/utils/target_utils.py,sha256=b1GzO8_gMcwXSZ2v98CY50MJJBzKbWHId_BJGybXfkM,
|
|
70
70
|
upgini/utils/track_info.py,sha256=G5Lu1xxakg2_TQjKZk4b5SvrHsATTXNVV3NbvWtT8k8,5663
|
71
71
|
upgini/utils/ts_utils.py,sha256=26vhC0pN7vLXK6R09EEkMK3Lwb9IVPH7LRdqFIQ3kPs,1383
|
72
72
|
upgini/utils/warning_counter.py,sha256=-GRY8EUggEBKODPSuXAkHn9KnEQwAORC0mmz_tim-PM,254
|
73
|
-
upgini-1.2.71a3832.
|
74
|
-
upgini-1.2.71a3832.
|
75
|
-
upgini-1.2.71a3832.
|
76
|
-
upgini-1.2.71a3832.
|
73
|
+
upgini-1.2.71a3832.dev11.dist-info/METADATA,sha256=QuI4m49RjcWmDJ74fXMWfNqBKPXGKDsKGhhO_wR1Kfw,49102
|
74
|
+
upgini-1.2.71a3832.dev11.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
75
|
+
upgini-1.2.71a3832.dev11.dist-info/licenses/LICENSE,sha256=5RRzgvdJUu3BUDfv4bzVU6FqKgwHlIay63pPCSmSgzw,1514
|
76
|
+
upgini-1.2.71a3832.dev11.dist-info/RECORD,,
|
File without changes
|
File without changes
|